1use anyhow::Result;
2use odra_types::{
3 casper_types::{
4 account::AccountHash,
5 bytesrepr::{Error, FromBytes, ToBytes},
6 RuntimeArgs, SecretKey, U512
7 },
8 Address, BlockTime, EventData, ExecutionError, PublicKey
9};
10use odra_types::{event::EventError, OdraError, VmError};
11use std::collections::BTreeMap;
12use std::sync::{Arc, RwLock};
13
14use crate::balance::AccountBalance;
15use crate::callstack::{Callstack, CallstackElement, Entrypoint};
16use crate::contract_container::{ContractContainer, EntrypointArgs, EntrypointCall};
17use crate::contract_register::ContractRegister;
18use crate::storage::Storage;
19
20#[derive(Default)]
21pub struct MockVm {
22 state: Arc<RwLock<MockVmState>>,
23 contract_register: Arc<RwLock<ContractRegister>>
24}
25
26impl MockVm {
27 pub fn register_contract(
28 &self,
29 constructor: Option<(String, &RuntimeArgs, EntrypointCall)>,
30 constructors: BTreeMap<String, (EntrypointArgs, EntrypointCall)>,
31 entrypoints: BTreeMap<String, (EntrypointArgs, EntrypointCall)>
32 ) -> Address {
33 let address = { self.state.write().unwrap().next_contract_address() };
35 {
37 let contract_namespace = self.state.read().unwrap().get_contract_namespace();
38 let contract = ContractContainer::new(&contract_namespace, entrypoints, constructors);
39 self.contract_register
40 .write()
41 .unwrap()
42 .add(address, contract);
43 self.state
44 .write()
45 .unwrap()
46 .set_balance(address, U512::zero());
47 }
48
49 if let Some(constructor) = constructor {
51 let (constructor_name, args, _) = constructor;
52 self.call_constructor(address, &constructor_name, args);
53 }
54 address
55 }
56
57 pub fn call_contract<T: ToBytes + FromBytes>(
58 &self,
59 address: Address,
60 entrypoint: &str,
61 args: &RuntimeArgs,
62 amount: Option<U512>
63 ) -> T {
64 self.prepare_call(address, entrypoint, amount);
65 if let Some(amount) = amount {
67 let status = self.checked_transfer_tokens(&self.caller(), &address, &amount);
68 if let Err(err) = status {
69 self.revert(err.clone());
70 panic!("{:?}", err);
71 }
72 }
73
74 let result =
75 self.contract_register
76 .read()
77 .unwrap()
78 .call(&address, String::from(entrypoint), args);
79
80 let result = self.handle_call_result(result);
81 T::from_vec(result).unwrap().0
82 }
83
84 fn call_constructor(&self, address: Address, entrypoint: &str, args: &RuntimeArgs) -> Vec<u8> {
85 self.prepare_call(address, entrypoint, None);
86 let register = self.contract_register.read().unwrap();
88 let result = register.call_constructor(&address, String::from(entrypoint), args);
89 self.handle_call_result(result)
90 }
91
92 fn prepare_call(&self, address: Address, entrypoint: &str, amount: Option<U512>) {
93 let mut state = self.state.write().unwrap();
94 if state.is_in_caller_context() {
96 state.take_snapshot();
97 state.clear_error();
98 }
99 let element = CallstackElement::Entrypoint(Entrypoint::new(address, entrypoint, amount));
102 state.push_callstack_element(element);
103 }
104
105 fn handle_call_result(&self, result: Result<Vec<u8>, OdraError>) -> Vec<u8> {
106 let mut state = self.state.write().unwrap();
107 let result = match result {
108 Ok(data) => data,
109 Err(err) => {
110 state.set_error(err);
111 vec![]
112 }
113 };
114
115 state.pop_callstack_element();
117
118 if state.error.is_none() {
119 if state.is_in_caller_context() {
121 state.drop_snapshot();
122 }
123 result
124 } else {
125 if state.is_in_caller_context() {
127 state.restore_snapshot();
128 };
129 vec![]
130 }
131 }
132
133 pub fn revert(&self, error: OdraError) {
134 let mut state = self.state.write().unwrap();
135 state.set_error(error);
136 state.clear_callstack();
137 if state.is_in_caller_context() {
138 state.restore_snapshot();
139 }
140 }
141
142 pub fn error(&self) -> Option<OdraError> {
143 self.state.read().unwrap().error()
144 }
145
146 pub fn get_backend_name(&self) -> String {
147 self.state.read().unwrap().get_backend_name()
148 }
149
150 pub fn callee(&self) -> Address {
152 self.state.read().unwrap().callee()
153 }
154
155 pub fn caller(&self) -> Address {
156 self.state.read().unwrap().caller()
157 }
158
159 pub fn callstack_tip(&self) -> CallstackElement {
160 self.state.read().unwrap().callstack_tip().clone()
161 }
162
163 pub fn set_caller(&self, caller: Address) {
164 self.state.write().unwrap().set_caller(caller);
165 }
166
167 pub fn set_var<T: ToBytes>(&self, key: &[u8], value: T) {
168 self.state.write().unwrap().set_var(key, value);
169 }
170
171 pub fn get_var<T: FromBytes>(&self, key: &[u8]) -> Option<T> {
172 let result = { self.state.read().unwrap().get_var(key) };
173 match result {
174 Ok(result) => result,
175 Err(error) => {
176 self.state
177 .write()
178 .unwrap()
179 .set_error(Into::<ExecutionError>::into(error));
180 None
181 }
182 }
183 }
184
185 pub fn set_dict_value<T: ToBytes>(&self, dict: &[u8], key: &[u8], value: T) {
186 self.state.write().unwrap().set_dict_value(dict, key, value);
187 }
188
189 pub fn get_dict_value<T: FromBytes>(&self, dict: &[u8], key: &[u8]) -> Option<T> {
190 let result = { self.state.read().unwrap().get_dict_value(dict, key) };
191 match result {
192 Ok(result) => result,
193 Err(error) => {
194 self.state
195 .write()
196 .unwrap()
197 .set_error(Into::<ExecutionError>::into(error));
198 None
199 }
200 }
201 }
202
203 pub fn emit_event(&self, event_data: &EventData) {
204 self.state.write().unwrap().emit_event(event_data);
205 }
206
207 pub fn get_event(&self, address: Address, index: i32) -> Result<EventData, EventError> {
208 self.state.read().unwrap().get_event(address, index)
209 }
210
211 pub fn get_block_time(&self) -> BlockTime {
212 self.state.read().unwrap().block_time()
213 }
214
215 pub fn advance_block_time_by(&self, milliseconds: BlockTime) {
216 self.state
217 .write()
218 .unwrap()
219 .advance_block_time_by(milliseconds)
220 }
221
222 pub fn attached_value(&self) -> U512 {
223 self.state.read().unwrap().attached_value()
224 }
225
226 pub fn get_account(&self, n: usize) -> Address {
227 self.state.read().unwrap().accounts.get(n).cloned().unwrap()
228 }
229
230 pub fn token_balance(&self, address: Address) -> U512 {
231 self.state.read().unwrap().get_balance(address)
232 }
233
234 pub fn transfer_tokens(&self, from: &Address, to: &Address, amount: &U512) {
235 if amount.is_zero() {
236 return;
237 }
238
239 let mut state = self.state.write().unwrap();
240 if state.reduce_balance(from, amount).is_err() {
241 self.revert(OdraError::VmError(VmError::BalanceExceeded))
242 }
243 if state.increase_balance(to, amount).is_err() {
244 self.revert(OdraError::VmError(VmError::BalanceExceeded))
245 }
246 }
247
248 pub fn checked_transfer_tokens(
249 &self,
250 from: &Address,
251 to: &Address,
252 amount: &U512
253 ) -> Result<(), OdraError> {
254 if amount.is_zero() {
255 return Ok(());
256 }
257
258 let mut state = self.state.write().unwrap();
259 if state.reduce_balance(from, amount).is_err() {
260 return Err(OdraError::VmError(VmError::BalanceExceeded));
261 }
262 if state.increase_balance(to, amount).is_err() {
263 return Err(OdraError::VmError(VmError::BalanceExceeded));
264 }
265 Ok(())
266 }
267
268 pub fn self_balance(&self) -> U512 {
269 let address = self.callee();
270 self.state.read().unwrap().get_balance(address)
271 }
272
273 pub fn public_key(&self, address: &Address) -> PublicKey {
274 self.state.read().unwrap().public_key(address)
275 }
276}
277
278pub struct MockVmState {
279 storage: Storage,
280 callstack: Callstack,
281 events: BTreeMap<Address, Vec<EventData>>,
282 contract_counter: u32,
283 error: Option<OdraError>,
284 block_time: u64,
285 accounts: Vec<Address>,
286 key_pairs: BTreeMap<Address, (SecretKey, PublicKey)>
287}
288
289impl MockVmState {
290 fn get_backend_name(&self) -> String {
291 "MockVM".to_string()
292 }
293
294 fn callee(&self) -> Address {
295 *self.callstack.current().address()
296 }
297
298 fn caller(&self) -> Address {
299 *self.callstack.previous().address()
300 }
301
302 fn callstack_tip(&self) -> &CallstackElement {
303 self.callstack.current()
304 }
305
306 fn set_caller(&mut self, address: Address) {
307 self.pop_callstack_element();
308 self.push_callstack_element(CallstackElement::Account(address));
309 }
310
311 fn set_var<T: ToBytes>(&mut self, key: &[u8], value: T) {
312 let ctx = self.callstack.current().address();
313 if let Err(error) = self.storage.set_value(ctx, key, value) {
314 self.set_error(Into::<ExecutionError>::into(error));
315 }
316 }
317
318 fn get_var<T: FromBytes>(&self, key: &[u8]) -> Result<Option<T>, Error> {
319 let ctx = self.callstack.current().address();
320 self.storage.get_value(ctx, key)
321 }
322
323 fn set_dict_value<T: ToBytes>(&mut self, dict: &[u8], key: &[u8], value: T) {
324 let ctx = self.callstack.current().address();
325 if let Err(error) = self.storage.insert_dict_value(ctx, dict, key, value) {
326 self.set_error(Into::<ExecutionError>::into(error));
327 }
328 }
329
330 fn get_dict_value<T: FromBytes>(&self, dict: &[u8], key: &[u8]) -> Result<Option<T>, Error> {
331 let ctx = &self.callstack.current().address();
332 self.storage.get_dict_value(ctx, dict, key)
333 }
334
335 fn emit_event(&mut self, event_data: &EventData) {
336 let contract_address = self.callstack.current().address();
337 let events = self.events.get_mut(contract_address).map(|events| {
338 events.push(event_data.clone());
339 events
340 });
341 if events.is_none() {
342 self.events
343 .insert(*contract_address, vec![event_data.clone()]);
344 }
345 }
346
347 fn get_event(&self, address: Address, index: i32) -> Result<EventData, EventError> {
348 let events = self.events.get(&address);
349 if events.is_none() {
350 return Err(EventError::IndexOutOfBounds);
351 }
352 let events: &Vec<EventData> = events.unwrap();
353 let event_position = odra_utils::event_absolute_position(events.len(), index)
354 .ok_or(EventError::IndexOutOfBounds)?;
355 Ok(events.get(event_position).unwrap().clone())
356 }
357
358 fn push_callstack_element(&mut self, element: CallstackElement) {
359 self.callstack.push(element);
360 }
361
362 fn pop_callstack_element(&mut self) {
363 self.callstack.pop();
364 }
365
366 fn clear_callstack(&mut self) {
367 let mut element = self.callstack.pop();
368 while element.is_some() {
369 let new_element = self.callstack.pop();
370 if new_element.is_none() {
371 self.callstack.push(element.unwrap());
372 return;
373 }
374 element = new_element;
375 }
376 }
377
378 fn next_contract_address(&mut self) -> Address {
379 self.contract_counter += 1;
380 Address::contract_from_u32(self.contract_counter)
381 }
382
383 fn get_contract_namespace(&self) -> String {
384 self.contract_counter.to_string()
385 }
386
387 fn set_error<E>(&mut self, error: E)
388 where
389 E: Into<OdraError>
390 {
391 if self.error.is_none() {
392 self.error = Some(error.into());
393 }
394 }
395
396 fn attached_value(&self) -> U512 {
397 self.callstack.current_amount()
398 }
399
400 fn clear_error(&mut self) {
401 self.error = None;
402 }
403
404 fn error(&self) -> Option<OdraError> {
405 self.error.clone()
406 }
407
408 fn is_in_caller_context(&self) -> bool {
409 self.callstack.len() == 1
410 }
411
412 fn take_snapshot(&mut self) {
413 self.storage.take_snapshot();
414 }
415
416 fn drop_snapshot(&mut self) {
417 self.storage.drop_snapshot();
418 }
419
420 fn restore_snapshot(&mut self) {
421 self.storage.restore_snapshot();
422 }
423
424 fn block_time(&self) -> u64 {
425 self.block_time
426 }
427
428 fn advance_block_time_by(&mut self, milliseconds: u64) {
429 self.block_time += milliseconds;
430 }
431
432 fn get_balance(&self, address: Address) -> U512 {
433 self.storage
434 .balance_of(&address)
435 .map(|b| b.value())
436 .unwrap_or_default()
437 }
438
439 fn set_balance(&mut self, address: Address, amount: U512) {
440 self.storage
441 .set_balance(address, AccountBalance::new(amount));
442 }
443
444 fn increase_balance(&mut self, address: &Address, amount: &U512) -> Result<()> {
445 self.storage.increase_balance(address, amount)
446 }
447
448 fn reduce_balance(&mut self, address: &Address, amount: &U512) -> Result<()> {
449 self.storage.reduce_balance(address, amount)
450 }
451
452 fn public_key(&self, address: &Address) -> PublicKey {
453 let (_, public_key) = self.key_pairs.get(address).unwrap();
454 public_key.clone()
455 }
456}
457
458impl Default for MockVmState {
459 fn default() -> Self {
460 let mut addresses: Vec<Address> = Vec::new();
461 let mut key_pairs = BTreeMap::<Address, (SecretKey, PublicKey)>::new();
462 for i in 0..20 {
463 let secret_key = SecretKey::ed25519_from_bytes([i; 32]).unwrap();
465 let public_key = PublicKey::from(&secret_key);
466
467 let account_addr = AccountHash::from(&public_key);
469
470 addresses.push(account_addr.try_into().unwrap());
471 key_pairs.insert(account_addr.try_into().unwrap(), (secret_key, public_key));
472 }
473
474 let mut balances = BTreeMap::<Address, AccountBalance>::new();
475 for address in addresses.clone() {
476 balances.insert(address, 100_000_000_000_000u64.into());
477 }
478
479 let mut backend = MockVmState {
480 storage: Storage::new(balances),
481 callstack: Default::default(),
482 events: Default::default(),
483 contract_counter: 0,
484 error: None,
485 block_time: 0,
486 accounts: addresses.clone(),
487 key_pairs
488 };
489 backend.push_callstack_element(CallstackElement::Account(*addresses.first().unwrap()));
490 backend
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use std::collections::BTreeMap;
497
498 use odra_types::casper_types::bytesrepr::FromBytes;
499 use odra_types::casper_types::{RuntimeArgs, U512};
500 use odra_types::OdraAddress;
501 use odra_types::{Address, EventData};
502 use odra_types::{ExecutionError, OdraError, VmError};
503
504 use crate::{callstack::CallstackElement, EntrypointArgs, EntrypointCall};
505
506 use super::MockVm;
507
508 #[test]
509 fn contracts_have_different_addresses() {
510 let instance = MockVm::default();
512 let entrypoint: Vec<(String, (EntrypointArgs, EntrypointCall))> =
514 vec![(String::from("abc"), (vec![], |_, _| vec![]))];
515 let entrypoints = entrypoint.into_iter().collect::<BTreeMap<_, _>>();
516 let constructors = BTreeMap::new();
517
518 let address1 = instance.register_contract(None, constructors.clone(), entrypoints.clone());
519 let address2 = instance.register_contract(None, constructors, entrypoints);
520
521 assert_ne!(address1, address2);
523 }
524
525 #[test]
526 fn addresses_have_different_type() {
527 let instance = MockVm::default();
529 let (contract_address, _, _) = setup_contract(&instance);
530 let account_address = instance.get_account(0);
531
532 assert!(contract_address.is_contract());
534 assert!(!account_address.is_contract());
536 }
537
538 #[test]
539 fn test_contract_call() {
540 let instance = MockVm::default();
542
543 let (contract_address, entrypoint, call_result) = setup_contract(&instance);
544
545 let result =
547 instance.call_contract::<u32>(contract_address, &entrypoint, &RuntimeArgs::new(), None);
548
549 assert_eq!(result, call_result);
551 }
552
553 #[test]
554 fn test_call_non_existing_contract() {
555 let instance = MockVm::default();
557
558 let address = Address::contract_from_u32(42);
559
560 instance.call_contract::<()>(address, "abc", &RuntimeArgs::new(), None);
562
563 assert_eq!(
565 instance.error(),
566 Some(OdraError::VmError(VmError::InvalidContractAddress))
567 );
568 }
569
570 #[test]
571 fn test_call_non_existing_entrypoint() {
572 let instance = MockVm::default();
574 let (contract_address, entrypoint, _) = setup_contract(&instance);
575
576 let invalid_entrypoint = entrypoint.chars().take(1).collect::<String>();
578 instance.call_contract::<()>(
579 contract_address,
580 &invalid_entrypoint,
581 &RuntimeArgs::new(),
582 None
583 );
584
585 assert_eq!(
587 instance.error(),
588 Some(OdraError::VmError(VmError::NoSuchMethod(
589 invalid_entrypoint
590 )))
591 );
592 }
593
594 #[test]
595 fn test_caller_switching() {
596 let instance = MockVm::default();
598
599 let new_caller = Address::account_from_str("ff");
601 instance.set_caller(new_caller);
602 push_address(&instance, &new_caller);
604
605 assert_eq!(instance.caller(), new_caller);
607 }
608
609 #[test]
610 fn test_revert() {
611 let instance = MockVm::default();
613
614 instance.revert(ExecutionError::new(1, "err").into());
616
617 assert_eq!(instance.error(), Some(ExecutionError::new(1, "err").into()));
619 }
620
621 #[test]
622 fn test_read_write_value() {
623 let instance = MockVm::default();
625
626 let key = b"key";
628 let value = 32u8;
629 instance.set_var(key, value);
630
631 assert_eq!(instance.get_var(key), Some(value));
633 assert_eq!(instance.get_var::<()>(b"other_key"), None);
635 }
636
637 #[test]
638 fn test_read_write_dict() {
639 let instance = MockVm::default();
641
642 let dict = b"dict";
644 let key: [u8; 2] = [1, 2];
645 let value = 32u8;
646 instance.set_dict_value(dict, &key, value);
647
648 assert_eq!(instance.get_dict_value(dict, &key), Some(value));
650 assert_eq!(instance.get_dict_value::<()>(b"other_dict", &key), None);
652 assert_eq!(instance.get_dict_value::<()>(dict, &[]), None);
654 }
655
656 #[test]
657 fn events() {
658 let instance = MockVm::default();
660
661 let first_contract_address = Address::account_from_str("abc");
662 push_address(&instance, &first_contract_address);
664
665 let first_event: EventData = vec![1, 2, 3];
666 let second_event: EventData = vec![4, 5, 6];
667 instance.emit_event(&first_event);
668 instance.emit_event(&second_event);
669
670 let second_contract_address = Address::account_from_str("bca");
671 push_address(&instance, &second_contract_address);
673
674 let third_event: EventData = vec![7, 8, 9];
675 let fourth_event: EventData = vec![11, 22, 33];
676 instance.emit_event(&third_event);
677 instance.emit_event(&fourth_event);
678
679 assert_eq!(
680 instance.get_event(first_contract_address, 0),
681 Ok(first_event)
682 );
683 assert_eq!(
684 instance.get_event(first_contract_address, 1),
685 Ok(second_event)
686 );
687
688 assert_eq!(
689 instance.get_event(second_contract_address, 0),
690 Ok(third_event)
691 );
692 assert_eq!(
693 instance.get_event(second_contract_address, 1),
694 Ok(fourth_event)
695 );
696 }
697
698 #[test]
699 fn test_current_contract_address() {
700 let instance = MockVm::default();
702
703 let contract_address = Address::contract_from_u32(100);
705 push_address(&instance, &contract_address);
706
707 assert_eq!(instance.callee(), contract_address);
709 }
710
711 #[test]
712 fn test_call_contract_with_amount() {
713 let instance = MockVm::default();
715 let (contract_address, entrypoint_name, _) = setup_contract(&instance);
716
717 let caller = instance.get_account(0);
719 let caller_balance = instance.token_balance(caller);
720
721 instance.call_contract::<u32>(
722 contract_address,
723 &entrypoint_name,
724 &RuntimeArgs::new(),
725 Some(caller_balance)
726 );
727
728 assert_eq!(instance.token_balance(contract_address), caller_balance);
730 assert_eq!(instance.token_balance(caller), U512::zero());
731 }
732
733 #[test]
734 #[should_panic(expected = "VmError(BalanceExceeded)")]
735 fn test_call_contract_with_amount_exceeding_balance() {
736 let instance = MockVm::default();
738 let (contract_address, entrypoint_name, _) = setup_contract(&instance);
739
740 let caller = instance.get_account(0);
741 let caller_balance = instance.token_balance(caller);
742
743 instance.call_contract::<()>(
745 contract_address,
746 &entrypoint_name,
747 &RuntimeArgs::new(),
748 Some(caller_balance + 1)
749 );
750 }
751
752 fn push_address(vm: &MockVm, address: &Address) {
753 let element = CallstackElement::Account(*address);
754 vm.state.write().unwrap().push_callstack_element(element);
755 }
756
757 fn setup_contract(instance: &MockVm) -> (Address, String, u32) {
758 let entrypoint_name = "abc";
759 let result = vec![1, 1, 0, 0];
760
761 let entrypoint: Vec<(String, (EntrypointArgs, EntrypointCall))> = vec![(
762 String::from(entrypoint_name),
763 (vec![], |_, _| vec![1, 1, 0, 0])
764 )];
765 let constructors = BTreeMap::new();
766 let contract_address = instance.register_contract(
767 None,
768 constructors,
769 entrypoint.into_iter().collect::<BTreeMap<_, _>>()
770 );
771
772 (
773 contract_address,
774 String::from(entrypoint_name),
775 <u32 as FromBytes>::from_vec(result).unwrap().0
776 )
777 }
778}