odra_mock_vm/
mock_vm.rs

1use anyhow::Result;
2use odra_types::{
3    casper_types::{
4        account::AccountHash,
5        bytesrepr::{Error, FromBytes, ToBytes},
6        RuntimeArgs, SecretKey, U512
7    },
8    Address, BlockTime, EventData, ExecutionError, PublicKey
9};
10use odra_types::{event::EventError, OdraError, VmError};
11use std::collections::BTreeMap;
12use std::sync::{Arc, RwLock};
13
14use crate::balance::AccountBalance;
15use crate::callstack::{Callstack, CallstackElement, Entrypoint};
16use crate::contract_container::{ContractContainer, EntrypointArgs, EntrypointCall};
17use crate::contract_register::ContractRegister;
18use crate::storage::Storage;
19
20#[derive(Default)]
21pub struct MockVm {
22    state: Arc<RwLock<MockVmState>>,
23    contract_register: Arc<RwLock<ContractRegister>>
24}
25
26impl MockVm {
27    pub fn register_contract(
28        &self,
29        constructor: Option<(String, &RuntimeArgs, EntrypointCall)>,
30        constructors: BTreeMap<String, (EntrypointArgs, EntrypointCall)>,
31        entrypoints: BTreeMap<String, (EntrypointArgs, EntrypointCall)>
32    ) -> Address {
33        // Create a new address.
34        let address = { self.state.write().unwrap().next_contract_address() };
35        // Register new contract under the new address.
36        {
37            let contract_namespace = self.state.read().unwrap().get_contract_namespace();
38            let contract = ContractContainer::new(&contract_namespace, entrypoints, constructors);
39            self.contract_register
40                .write()
41                .unwrap()
42                .add(address, contract);
43            self.state
44                .write()
45                .unwrap()
46                .set_balance(address, U512::zero());
47        }
48
49        // Call constructor if needed.
50        if let Some(constructor) = constructor {
51            let (constructor_name, args, _) = constructor;
52            self.call_constructor(address, &constructor_name, args);
53        }
54        address
55    }
56
57    pub fn call_contract<T: ToBytes + FromBytes>(
58        &self,
59        address: Address,
60        entrypoint: &str,
61        args: &RuntimeArgs,
62        amount: Option<U512>
63    ) -> T {
64        self.prepare_call(address, entrypoint, amount);
65        // Call contract from register.
66        if let Some(amount) = amount {
67            let status = self.checked_transfer_tokens(&self.caller(), &address, &amount);
68            if let Err(err) = status {
69                self.revert(err.clone());
70                panic!("{:?}", err);
71            }
72        }
73
74        let result =
75            self.contract_register
76                .read()
77                .unwrap()
78                .call(&address, String::from(entrypoint), args);
79
80        let result = self.handle_call_result(result);
81        T::from_vec(result).unwrap().0
82    }
83
84    fn call_constructor(&self, address: Address, entrypoint: &str, args: &RuntimeArgs) -> Vec<u8> {
85        self.prepare_call(address, entrypoint, None);
86        // Call contract from register.
87        let register = self.contract_register.read().unwrap();
88        let result = register.call_constructor(&address, String::from(entrypoint), args);
89        self.handle_call_result(result)
90    }
91
92    fn prepare_call(&self, address: Address, entrypoint: &str, amount: Option<U512>) {
93        let mut state = self.state.write().unwrap();
94        // If only one address on the call_stack, record snapshot.
95        if state.is_in_caller_context() {
96            state.take_snapshot();
97            state.clear_error();
98        }
99        // Put the address on stack.
100
101        let element = CallstackElement::Entrypoint(Entrypoint::new(address, entrypoint, amount));
102        state.push_callstack_element(element);
103    }
104
105    fn handle_call_result(&self, result: Result<Vec<u8>, OdraError>) -> Vec<u8> {
106        let mut state = self.state.write().unwrap();
107        let result = match result {
108            Ok(data) => data,
109            Err(err) => {
110                state.set_error(err);
111                vec![]
112            }
113        };
114
115        // Drop the address from stack.
116        state.pop_callstack_element();
117
118        if state.error.is_none() {
119            // If only one address on the call_stack, drop the snapshot
120            if state.is_in_caller_context() {
121                state.drop_snapshot();
122            }
123            result
124        } else {
125            // If only one address on the call_stack an an error occurred, restore the snapshot
126            if state.is_in_caller_context() {
127                state.restore_snapshot();
128            };
129            vec![]
130        }
131    }
132
133    pub fn revert(&self, error: OdraError) {
134        let mut state = self.state.write().unwrap();
135        state.set_error(error);
136        state.clear_callstack();
137        if state.is_in_caller_context() {
138            state.restore_snapshot();
139        }
140    }
141
142    pub fn error(&self) -> Option<OdraError> {
143        self.state.read().unwrap().error()
144    }
145
146    pub fn get_backend_name(&self) -> String {
147        self.state.read().unwrap().get_backend_name()
148    }
149
150    /// Returns the callee, i.e. the currently executing contract.
151    pub fn callee(&self) -> Address {
152        self.state.read().unwrap().callee()
153    }
154
155    pub fn caller(&self) -> Address {
156        self.state.read().unwrap().caller()
157    }
158
159    pub fn callstack_tip(&self) -> CallstackElement {
160        self.state.read().unwrap().callstack_tip().clone()
161    }
162
163    pub fn set_caller(&self, caller: Address) {
164        self.state.write().unwrap().set_caller(caller);
165    }
166
167    pub fn set_var<T: ToBytes>(&self, key: &[u8], value: T) {
168        self.state.write().unwrap().set_var(key, value);
169    }
170
171    pub fn get_var<T: FromBytes>(&self, key: &[u8]) -> Option<T> {
172        let result = { self.state.read().unwrap().get_var(key) };
173        match result {
174            Ok(result) => result,
175            Err(error) => {
176                self.state
177                    .write()
178                    .unwrap()
179                    .set_error(Into::<ExecutionError>::into(error));
180                None
181            }
182        }
183    }
184
185    pub fn set_dict_value<T: ToBytes>(&self, dict: &[u8], key: &[u8], value: T) {
186        self.state.write().unwrap().set_dict_value(dict, key, value);
187    }
188
189    pub fn get_dict_value<T: FromBytes>(&self, dict: &[u8], key: &[u8]) -> Option<T> {
190        let result = { self.state.read().unwrap().get_dict_value(dict, key) };
191        match result {
192            Ok(result) => result,
193            Err(error) => {
194                self.state
195                    .write()
196                    .unwrap()
197                    .set_error(Into::<ExecutionError>::into(error));
198                None
199            }
200        }
201    }
202
203    pub fn emit_event(&self, event_data: &EventData) {
204        self.state.write().unwrap().emit_event(event_data);
205    }
206
207    pub fn get_event(&self, address: Address, index: i32) -> Result<EventData, EventError> {
208        self.state.read().unwrap().get_event(address, index)
209    }
210
211    pub fn get_block_time(&self) -> BlockTime {
212        self.state.read().unwrap().block_time()
213    }
214
215    pub fn advance_block_time_by(&self, milliseconds: BlockTime) {
216        self.state
217            .write()
218            .unwrap()
219            .advance_block_time_by(milliseconds)
220    }
221
222    pub fn attached_value(&self) -> U512 {
223        self.state.read().unwrap().attached_value()
224    }
225
226    pub fn get_account(&self, n: usize) -> Address {
227        self.state.read().unwrap().accounts.get(n).cloned().unwrap()
228    }
229
230    pub fn token_balance(&self, address: Address) -> U512 {
231        self.state.read().unwrap().get_balance(address)
232    }
233
234    pub fn transfer_tokens(&self, from: &Address, to: &Address, amount: &U512) {
235        if amount.is_zero() {
236            return;
237        }
238
239        let mut state = self.state.write().unwrap();
240        if state.reduce_balance(from, amount).is_err() {
241            self.revert(OdraError::VmError(VmError::BalanceExceeded))
242        }
243        if state.increase_balance(to, amount).is_err() {
244            self.revert(OdraError::VmError(VmError::BalanceExceeded))
245        }
246    }
247
248    pub fn checked_transfer_tokens(
249        &self,
250        from: &Address,
251        to: &Address,
252        amount: &U512
253    ) -> Result<(), OdraError> {
254        if amount.is_zero() {
255            return Ok(());
256        }
257
258        let mut state = self.state.write().unwrap();
259        if state.reduce_balance(from, amount).is_err() {
260            return Err(OdraError::VmError(VmError::BalanceExceeded));
261        }
262        if state.increase_balance(to, amount).is_err() {
263            return Err(OdraError::VmError(VmError::BalanceExceeded));
264        }
265        Ok(())
266    }
267
268    pub fn self_balance(&self) -> U512 {
269        let address = self.callee();
270        self.state.read().unwrap().get_balance(address)
271    }
272
273    pub fn public_key(&self, address: &Address) -> PublicKey {
274        self.state.read().unwrap().public_key(address)
275    }
276}
277
278pub struct MockVmState {
279    storage: Storage,
280    callstack: Callstack,
281    events: BTreeMap<Address, Vec<EventData>>,
282    contract_counter: u32,
283    error: Option<OdraError>,
284    block_time: u64,
285    accounts: Vec<Address>,
286    key_pairs: BTreeMap<Address, (SecretKey, PublicKey)>
287}
288
289impl MockVmState {
290    fn get_backend_name(&self) -> String {
291        "MockVM".to_string()
292    }
293
294    fn callee(&self) -> Address {
295        *self.callstack.current().address()
296    }
297
298    fn caller(&self) -> Address {
299        *self.callstack.previous().address()
300    }
301
302    fn callstack_tip(&self) -> &CallstackElement {
303        self.callstack.current()
304    }
305
306    fn set_caller(&mut self, address: Address) {
307        self.pop_callstack_element();
308        self.push_callstack_element(CallstackElement::Account(address));
309    }
310
311    fn set_var<T: ToBytes>(&mut self, key: &[u8], value: T) {
312        let ctx = self.callstack.current().address();
313        if let Err(error) = self.storage.set_value(ctx, key, value) {
314            self.set_error(Into::<ExecutionError>::into(error));
315        }
316    }
317
318    fn get_var<T: FromBytes>(&self, key: &[u8]) -> Result<Option<T>, Error> {
319        let ctx = self.callstack.current().address();
320        self.storage.get_value(ctx, key)
321    }
322
323    fn set_dict_value<T: ToBytes>(&mut self, dict: &[u8], key: &[u8], value: T) {
324        let ctx = self.callstack.current().address();
325        if let Err(error) = self.storage.insert_dict_value(ctx, dict, key, value) {
326            self.set_error(Into::<ExecutionError>::into(error));
327        }
328    }
329
330    fn get_dict_value<T: FromBytes>(&self, dict: &[u8], key: &[u8]) -> Result<Option<T>, Error> {
331        let ctx = &self.callstack.current().address();
332        self.storage.get_dict_value(ctx, dict, key)
333    }
334
335    fn emit_event(&mut self, event_data: &EventData) {
336        let contract_address = self.callstack.current().address();
337        let events = self.events.get_mut(contract_address).map(|events| {
338            events.push(event_data.clone());
339            events
340        });
341        if events.is_none() {
342            self.events
343                .insert(*contract_address, vec![event_data.clone()]);
344        }
345    }
346
347    fn get_event(&self, address: Address, index: i32) -> Result<EventData, EventError> {
348        let events = self.events.get(&address);
349        if events.is_none() {
350            return Err(EventError::IndexOutOfBounds);
351        }
352        let events: &Vec<EventData> = events.unwrap();
353        let event_position = odra_utils::event_absolute_position(events.len(), index)
354            .ok_or(EventError::IndexOutOfBounds)?;
355        Ok(events.get(event_position).unwrap().clone())
356    }
357
358    fn push_callstack_element(&mut self, element: CallstackElement) {
359        self.callstack.push(element);
360    }
361
362    fn pop_callstack_element(&mut self) {
363        self.callstack.pop();
364    }
365
366    fn clear_callstack(&mut self) {
367        let mut element = self.callstack.pop();
368        while element.is_some() {
369            let new_element = self.callstack.pop();
370            if new_element.is_none() {
371                self.callstack.push(element.unwrap());
372                return;
373            }
374            element = new_element;
375        }
376    }
377
378    fn next_contract_address(&mut self) -> Address {
379        self.contract_counter += 1;
380        Address::contract_from_u32(self.contract_counter)
381    }
382
383    fn get_contract_namespace(&self) -> String {
384        self.contract_counter.to_string()
385    }
386
387    fn set_error<E>(&mut self, error: E)
388    where
389        E: Into<OdraError>
390    {
391        if self.error.is_none() {
392            self.error = Some(error.into());
393        }
394    }
395
396    fn attached_value(&self) -> U512 {
397        self.callstack.current_amount()
398    }
399
400    fn clear_error(&mut self) {
401        self.error = None;
402    }
403
404    fn error(&self) -> Option<OdraError> {
405        self.error.clone()
406    }
407
408    fn is_in_caller_context(&self) -> bool {
409        self.callstack.len() == 1
410    }
411
412    fn take_snapshot(&mut self) {
413        self.storage.take_snapshot();
414    }
415
416    fn drop_snapshot(&mut self) {
417        self.storage.drop_snapshot();
418    }
419
420    fn restore_snapshot(&mut self) {
421        self.storage.restore_snapshot();
422    }
423
424    fn block_time(&self) -> u64 {
425        self.block_time
426    }
427
428    fn advance_block_time_by(&mut self, milliseconds: u64) {
429        self.block_time += milliseconds;
430    }
431
432    fn get_balance(&self, address: Address) -> U512 {
433        self.storage
434            .balance_of(&address)
435            .map(|b| b.value())
436            .unwrap_or_default()
437    }
438
439    fn set_balance(&mut self, address: Address, amount: U512) {
440        self.storage
441            .set_balance(address, AccountBalance::new(amount));
442    }
443
444    fn increase_balance(&mut self, address: &Address, amount: &U512) -> Result<()> {
445        self.storage.increase_balance(address, amount)
446    }
447
448    fn reduce_balance(&mut self, address: &Address, amount: &U512) -> Result<()> {
449        self.storage.reduce_balance(address, amount)
450    }
451
452    fn public_key(&self, address: &Address) -> PublicKey {
453        let (_, public_key) = self.key_pairs.get(address).unwrap();
454        public_key.clone()
455    }
456}
457
458impl Default for MockVmState {
459    fn default() -> Self {
460        let mut addresses: Vec<Address> = Vec::new();
461        let mut key_pairs = BTreeMap::<Address, (SecretKey, PublicKey)>::new();
462        for i in 0..20 {
463            // Create keypair.
464            let secret_key = SecretKey::ed25519_from_bytes([i; 32]).unwrap();
465            let public_key = PublicKey::from(&secret_key);
466
467            // Create an AccountHash from a public key.
468            let account_addr = AccountHash::from(&public_key);
469
470            addresses.push(account_addr.try_into().unwrap());
471            key_pairs.insert(account_addr.try_into().unwrap(), (secret_key, public_key));
472        }
473
474        let mut balances = BTreeMap::<Address, AccountBalance>::new();
475        for address in addresses.clone() {
476            balances.insert(address, 100_000_000_000_000u64.into());
477        }
478
479        let mut backend = MockVmState {
480            storage: Storage::new(balances),
481            callstack: Default::default(),
482            events: Default::default(),
483            contract_counter: 0,
484            error: None,
485            block_time: 0,
486            accounts: addresses.clone(),
487            key_pairs
488        };
489        backend.push_callstack_element(CallstackElement::Account(*addresses.first().unwrap()));
490        backend
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use std::collections::BTreeMap;
497
498    use odra_types::casper_types::bytesrepr::FromBytes;
499    use odra_types::casper_types::{RuntimeArgs, U512};
500    use odra_types::OdraAddress;
501    use odra_types::{Address, EventData};
502    use odra_types::{ExecutionError, OdraError, VmError};
503
504    use crate::{callstack::CallstackElement, EntrypointArgs, EntrypointCall};
505
506    use super::MockVm;
507
508    #[test]
509    fn contracts_have_different_addresses() {
510        // given a new instance
511        let instance = MockVm::default();
512        // when register two contracts with the same entrypoints
513        let entrypoint: Vec<(String, (EntrypointArgs, EntrypointCall))> =
514            vec![(String::from("abc"), (vec![], |_, _| vec![]))];
515        let entrypoints = entrypoint.into_iter().collect::<BTreeMap<_, _>>();
516        let constructors = BTreeMap::new();
517
518        let address1 = instance.register_contract(None, constructors.clone(), entrypoints.clone());
519        let address2 = instance.register_contract(None, constructors, entrypoints);
520
521        // then addresses are different
522        assert_ne!(address1, address2);
523    }
524
525    #[test]
526    fn addresses_have_different_type() {
527        // given an address of a contract and an address of an account
528        let instance = MockVm::default();
529        let (contract_address, _, _) = setup_contract(&instance);
530        let account_address = instance.get_account(0);
531
532        // Then the contract address is a contract
533        assert!(contract_address.is_contract());
534        // And the account address is not a contract
535        assert!(!account_address.is_contract());
536    }
537
538    #[test]
539    fn test_contract_call() {
540        // given an instance with a registered contract having one entrypoint
541        let instance = MockVm::default();
542
543        let (contract_address, entrypoint, call_result) = setup_contract(&instance);
544
545        // when call an existing entrypoint
546        let result =
547            instance.call_contract::<u32>(contract_address, &entrypoint, &RuntimeArgs::new(), None);
548
549        // then returns the expected value
550        assert_eq!(result, call_result);
551    }
552
553    #[test]
554    fn test_call_non_existing_contract() {
555        // given an empty vm
556        let instance = MockVm::default();
557
558        let address = Address::contract_from_u32(42);
559
560        // when call a contract
561        instance.call_contract::<()>(address, "abc", &RuntimeArgs::new(), None);
562
563        // then the vm is in error state
564        assert_eq!(
565            instance.error(),
566            Some(OdraError::VmError(VmError::InvalidContractAddress))
567        );
568    }
569
570    #[test]
571    fn test_call_non_existing_entrypoint() {
572        // given an instance with a registered contract having one entrypoint
573        let instance = MockVm::default();
574        let (contract_address, entrypoint, _) = setup_contract(&instance);
575
576        // when call non-existing entrypoint
577        let invalid_entrypoint = entrypoint.chars().take(1).collect::<String>();
578        instance.call_contract::<()>(
579            contract_address,
580            &invalid_entrypoint,
581            &RuntimeArgs::new(),
582            None
583        );
584
585        // then the vm is in error state
586        assert_eq!(
587            instance.error(),
588            Some(OdraError::VmError(VmError::NoSuchMethod(
589                invalid_entrypoint
590            )))
591        );
592    }
593
594    #[test]
595    fn test_caller_switching() {
596        // given an empty instance
597        let instance = MockVm::default();
598
599        // when set a new caller
600        let new_caller = Address::account_from_str("ff");
601        instance.set_caller(new_caller);
602        // put a fake contract on stack
603        push_address(&instance, &new_caller);
604
605        // then the caller is set
606        assert_eq!(instance.caller(), new_caller);
607    }
608
609    #[test]
610    fn test_revert() {
611        // given an empty instance
612        let instance = MockVm::default();
613
614        // when revert
615        instance.revert(ExecutionError::new(1, "err").into());
616
617        // then an error is set
618        assert_eq!(instance.error(), Some(ExecutionError::new(1, "err").into()));
619    }
620
621    #[test]
622    fn test_read_write_value() {
623        // given an empty instance
624        let instance = MockVm::default();
625
626        // when set a value
627        let key = b"key";
628        let value = 32u8;
629        instance.set_var(key, value);
630
631        // then the value can be read
632        assert_eq!(instance.get_var(key), Some(value));
633        // then the value under unknown key does not exist
634        assert_eq!(instance.get_var::<()>(b"other_key"), None);
635    }
636
637    #[test]
638    fn test_read_write_dict() {
639        // given an empty instance
640        let instance = MockVm::default();
641
642        // when set a value
643        let dict = b"dict";
644        let key: [u8; 2] = [1, 2];
645        let value = 32u8;
646        instance.set_dict_value(dict, &key, value);
647
648        // then the value can be read
649        assert_eq!(instance.get_dict_value(dict, &key), Some(value));
650        // then the value under the key in unknown dict does not exist
651        assert_eq!(instance.get_dict_value::<()>(b"other_dict", &key), None);
652        // then the value under unknown key does not exist
653        assert_eq!(instance.get_dict_value::<()>(dict, &[]), None);
654    }
655
656    #[test]
657    fn events() {
658        // given an empty instance
659        let instance = MockVm::default();
660
661        let first_contract_address = Address::account_from_str("abc");
662        // put a contract on stack
663        push_address(&instance, &first_contract_address);
664
665        let first_event: EventData = vec![1, 2, 3];
666        let second_event: EventData = vec![4, 5, 6];
667        instance.emit_event(&first_event);
668        instance.emit_event(&second_event);
669
670        let second_contract_address = Address::account_from_str("bca");
671        // put a next contract on stack
672        push_address(&instance, &second_contract_address);
673
674        let third_event: EventData = vec![7, 8, 9];
675        let fourth_event: EventData = vec![11, 22, 33];
676        instance.emit_event(&third_event);
677        instance.emit_event(&fourth_event);
678
679        assert_eq!(
680            instance.get_event(first_contract_address, 0),
681            Ok(first_event)
682        );
683        assert_eq!(
684            instance.get_event(first_contract_address, 1),
685            Ok(second_event)
686        );
687
688        assert_eq!(
689            instance.get_event(second_contract_address, 0),
690            Ok(third_event)
691        );
692        assert_eq!(
693            instance.get_event(second_contract_address, 1),
694            Ok(fourth_event)
695        );
696    }
697
698    #[test]
699    fn test_current_contract_address() {
700        // given an empty instance
701        let instance = MockVm::default();
702
703        // when push a contract into the stack
704        let contract_address = Address::contract_from_u32(100);
705        push_address(&instance, &contract_address);
706
707        // then the contract address in the callee
708        assert_eq!(instance.callee(), contract_address);
709    }
710
711    #[test]
712    fn test_call_contract_with_amount() {
713        // given an instance with a registered contract having one entrypoint
714        let instance = MockVm::default();
715        let (contract_address, entrypoint_name, _) = setup_contract(&instance);
716
717        // when call a contract with the whole balance of the caller
718        let caller = instance.get_account(0);
719        let caller_balance = instance.token_balance(caller);
720
721        instance.call_contract::<u32>(
722            contract_address,
723            &entrypoint_name,
724            &RuntimeArgs::new(),
725            Some(caller_balance)
726        );
727
728        // then the contract has the caller tokens and the caller balance is zero
729        assert_eq!(instance.token_balance(contract_address), caller_balance);
730        assert_eq!(instance.token_balance(caller), U512::zero());
731    }
732
733    #[test]
734    #[should_panic(expected = "VmError(BalanceExceeded)")]
735    fn test_call_contract_with_amount_exceeding_balance() {
736        // given an instance with a registered contract having one entrypoint
737        let instance = MockVm::default();
738        let (contract_address, entrypoint_name, _) = setup_contract(&instance);
739
740        let caller = instance.get_account(0);
741        let caller_balance = instance.token_balance(caller);
742
743        // when call a contract with the amount exceeding caller's balance
744        instance.call_contract::<()>(
745            contract_address,
746            &entrypoint_name,
747            &RuntimeArgs::new(),
748            Some(caller_balance + 1)
749        );
750    }
751
752    fn push_address(vm: &MockVm, address: &Address) {
753        let element = CallstackElement::Account(*address);
754        vm.state.write().unwrap().push_callstack_element(element);
755    }
756
757    fn setup_contract(instance: &MockVm) -> (Address, String, u32) {
758        let entrypoint_name = "abc";
759        let result = vec![1, 1, 0, 0];
760
761        let entrypoint: Vec<(String, (EntrypointArgs, EntrypointCall))> = vec![(
762            String::from(entrypoint_name),
763            (vec![], |_, _| vec![1, 1, 0, 0])
764        )];
765        let constructors = BTreeMap::new();
766        let contract_address = instance.register_contract(
767            None,
768            constructors,
769            entrypoint.into_iter().collect::<BTreeMap<_, _>>()
770        );
771
772        (
773            contract_address,
774            String::from(entrypoint_name),
775            <u32 as FromBytes>::from_vec(result).unwrap().0
776        )
777    }
778}