1use futures::executor::LocalPool;
2use std::any::TypeId;
3use std::collections::BTreeSet;
4use std::hash::Hasher;
5use std::io::Write;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use ic_cdk::export::candid::decode_args;
9use ic_cdk::export::candid::utils::{ArgumentDecoder, ArgumentEncoder};
10use ic_cdk::export::{candid, Principal};
11use serde::Serialize;
12
13use crate::candid::CandidType;
14use crate::inject::{get_context, inject};
15use crate::interface::{CallResponse, Context};
16use crate::stable::StableWriter;
17use crate::storage::Storage;
18use crate::{CallHandler, Method, StableMemoryError};
19
20pub struct MockContext {
22 watcher: Watcher,
24 id: Principal,
26 balance: u64,
28 caller: Principal,
30 is_reply_callback_mode: bool,
32 trapped: bool,
34 cycles: u64,
36 cycles_refunded: u64,
38 storage: Storage,
40 stable: Vec<u8>,
42 certified_data: Option<Vec<u8>>,
44 certificate: Option<Vec<u8>>,
46 handlers: Vec<Box<dyn CallHandler>>,
48 pool: LocalPool,
50}
51
52pub struct Watcher {
54 pub called_id: bool,
56 pub called_time: bool,
58 pub called_balance: bool,
60 pub called_caller: bool,
62 pub called_msg_cycles_available: bool,
64 pub called_msg_cycles_accept: bool,
66 pub called_msg_cycles_refunded: bool,
68 pub called_stable_store: bool,
70 pub called_stable_restore: bool,
72 pub called_set_certified_data: bool,
74 pub called_data_certificate: bool,
76 storage_modified: BTreeSet<TypeId>,
78 calls: Vec<WatcherCall>,
80}
81
82pub struct WatcherCall {
83 canister_id: Principal,
84 method_name: String,
85 args_raw: Vec<u8>,
86 cycles_sent: u64,
87 cycles_refunded: u64,
88}
89
90impl MockContext {
91 #[inline]
93 pub fn new() -> Self {
94 Self {
95 watcher: Watcher::default(),
96 id: Principal::from_text("sgymv-uiaaa-aaaaa-aaaia-cai").unwrap(),
97 balance: 100_000_000_000_000,
98 caller: Principal::anonymous(),
99 is_reply_callback_mode: false,
100 trapped: false,
101 cycles: 0,
102 cycles_refunded: 0,
103 storage: Storage::default(),
104 stable: Vec::new(),
105 certified_data: None,
106 certificate: None,
107 handlers: vec![],
108 pool: LocalPool::new(),
109 }
110 }
111
112 #[inline]
114 pub fn watch(&self) -> &Watcher {
115 self.as_mut().watcher = Watcher::default();
116 &self.watcher
117 }
118
119 #[inline]
135 pub fn with_id(mut self, id: Principal) -> Self {
136 self.id = id;
137 self
138 }
139
140 #[inline]
154 pub fn with_balance(mut self, cycles: u64) -> Self {
155 self.balance = cycles;
156 self
157 }
158
159 #[inline]
175 pub fn with_caller(mut self, caller: Principal) -> Self {
176 self.caller = caller;
177 self
178 }
179
180 #[inline]
198 pub fn with_msg_cycles(mut self, cycles: u64) -> Self {
199 self.cycles = cycles;
200 self
201 }
202
203 #[inline]
217 pub fn with_data<T: 'static>(mut self, data: T) -> Self {
218 self.storage.swap(data);
219 self
220 }
221
222 #[inline]
236 pub fn with_stable<T: Serialize>(mut self, data: T) -> Self
237 where
238 T: ArgumentEncoder,
239 {
240 self.stable.truncate(0);
241 candid::write_args(&mut self.stable, data).expect("Encoding stable data failed.");
242 self.stable_grow(0)
243 .expect("Can not write this amount of data to stable storage.");
244 self
245 }
246
247 #[inline]
249 pub fn with_certified_data(mut self, data: Vec<u8>) -> Self {
250 assert!(data.len() < 32);
251 self.certificate = Some(MockContext::sign(data.as_slice()));
252 self.certified_data = Some(data);
253 self
254 }
255
256 #[inline]
259 pub fn with_consume_cycles_handler(self, cycles: u64) -> Self {
260 self.with_handler(Method::new().cycles_consume(cycles))
261 }
262
263 #[inline]
266 pub fn with_expect_cycles_handler(self, cycles: u64) -> Self {
267 self.with_handler(Method::new().expect_cycles(cycles))
268 }
269
270 #[inline]
273 pub fn with_refund_cycles_handler(self, cycles: u64) -> Self {
274 self.with_handler(Method::new().cycles_refund(cycles))
275 }
276
277 #[inline]
279 pub fn with_constant_return_handler<T: CandidType>(self, value: T) -> Self {
280 self.with_handler(Method::new().response(value))
281 }
282
283 #[inline]
285 pub fn with_handler<T: 'static + CallHandler>(mut self, handler: T) -> Self {
286 self.use_handler(handler);
287 self
288 }
289
290 #[inline]
292 pub fn inject(self) -> &'static mut Self {
293 inject(self);
294 get_context()
295 }
296
297 pub fn sign(data: &[u8]) -> Vec<u8> {
300 let data = {
301 let mut tmp: Vec<u8> = vec![0; 32];
302 for (i, b) in data.iter().enumerate() {
303 tmp[i] = *b;
304 }
305 tmp
306 };
307
308 let mut certificate = Vec::with_capacity(32 * 8);
309
310 for i in 0..32 {
311 let mut hasher = std::collections::hash_map::DefaultHasher::new();
312 for b in &certificate {
313 hasher.write_u8(*b);
314 }
315 hasher.write_u8(data[i]);
316 let hash = hasher.finish().to_be_bytes();
317 certificate.extend_from_slice(&hash);
318 }
319
320 certificate
321 }
322
323 #[inline]
326 fn as_mut(&self) -> &mut Self {
327 unsafe {
328 let const_ptr = self as *const Self;
329 let mut_ptr = const_ptr as *mut Self;
330 &mut *mut_ptr
331 }
332 }
333}
334
335impl MockContext {
336 #[inline]
338 pub fn call_state_reset(&self) {
339 let mut_ref = self.as_mut();
340 mut_ref.is_reply_callback_mode = false;
341 mut_ref.trapped = false;
342 }
343
344 #[inline]
346 pub fn clear_storage(&self) {
347 self.as_mut().storage = Storage::default()
348 }
349
350 #[inline]
352 pub fn update_balance(&self, cycles: u64) {
353 self.as_mut().balance = cycles;
354 }
355
356 #[inline]
358 pub fn update_msg_cycles(&self, cycles: u64) {
359 self.as_mut().cycles = cycles;
360 }
361
362 #[inline]
364 pub fn update_caller(&self, caller: Principal) {
365 self.as_mut().caller = caller;
366 }
367
368 #[inline]
370 pub fn update_id(&self, canister_id: Principal) {
371 self.as_mut().id = canister_id;
372 }
373
374 #[inline]
376 pub fn get_certified_data(&self) -> Option<Vec<u8>> {
377 self.certified_data.as_ref().cloned()
378 }
379
380 #[inline]
382 pub fn use_handler<T: 'static + CallHandler>(&mut self, handler: T) {
383 self.handlers.push(Box::new(handler));
384 }
385
386 #[inline]
388 pub fn clear_handlers(&mut self) {
389 self.handlers.clear();
390 }
391
392 #[inline]
394 pub fn join(&mut self) {
395 self.pool.run();
396 }
397}
398
399impl Context for MockContext {
400 #[inline]
401 fn trap(&self, message: &str) -> ! {
402 self.as_mut().trapped = true;
403 panic!("Canister {} trapped with message: {}", self.id, message);
404 }
405
406 #[inline]
407 fn print<S: AsRef<str>>(&self, s: S) {
408 println!("{} : {}", self.id, s.as_ref())
409 }
410
411 #[inline]
412 fn id(&self) -> Principal {
413 self.as_mut().watcher.called_id = true;
414 self.id
415 }
416
417 #[inline]
418 fn time(&self) -> u64 {
419 self.as_mut().watcher.called_time = true;
420 SystemTime::now()
421 .duration_since(UNIX_EPOCH)
422 .expect("Time went backwards")
423 .as_nanos() as u64
424 }
425
426 #[inline]
427 fn balance(&self) -> u64 {
428 self.as_mut().watcher.called_balance = true;
429 self.balance
430 }
431
432 #[inline]
433 fn caller(&self) -> Principal {
434 self.as_mut().watcher.called_caller = true;
435
436 if self.is_reply_callback_mode {
437 panic!(
438 "Canister {} violated contract: \"{}\" cannot be executed in reply callback mode",
439 self.id(),
440 "ic0_msg_caller_size"
441 )
442 }
443
444 self.caller
445 }
446
447 #[inline]
448 fn msg_cycles_available(&self) -> u64 {
449 self.as_mut().watcher.called_msg_cycles_available = true;
450 self.cycles
451 }
452
453 #[inline]
454 fn msg_cycles_accept(&self, cycles: u64) -> u64 {
455 self.as_mut().watcher.called_msg_cycles_accept = true;
456 let mut_ref = self.as_mut();
457 if cycles > mut_ref.cycles {
458 let r = mut_ref.cycles;
459 mut_ref.cycles = 0;
460 mut_ref.balance += r;
461 r
462 } else {
463 mut_ref.cycles -= cycles;
464 mut_ref.balance += cycles;
465 cycles
466 }
467 }
468
469 #[inline]
470 fn msg_cycles_refunded(&self) -> u64 {
471 self.as_mut().watcher.called_msg_cycles_refunded = true;
472 self.cycles_refunded
473 }
474
475 #[inline]
476 fn stable_store<T>(&self, data: T) -> Result<(), candid::Error>
477 where
478 T: ArgumentEncoder,
479 {
480 self.as_mut().watcher.called_stable_store = true;
481 candid::write_args(&mut StableWriter::default(), data)
482 }
483
484 #[inline]
485 fn stable_restore<T>(&self) -> Result<T, String>
486 where
487 T: for<'de> ArgumentDecoder<'de>,
488 {
489 self.as_mut().watcher.called_stable_restore = true;
490 let bytes = &self.stable;
491
492 let mut de =
493 candid::de::IDLDeserialize::new(bytes.as_slice()).map_err(|e| format!("{:?}", e))?;
494 let res =
495 candid::utils::ArgumentDecoder::decode(&mut de).map_err(|e| format!("{:?}", e))?;
496 let _ = de.done();
499 Ok(res)
500 }
501
502 fn call_raw<S: Into<String>>(
503 &'static self,
504 id: Principal,
505 method: S,
506 args_raw: Vec<u8>,
507 cycles: u64,
508 ) -> CallResponse<Vec<u8>> {
509 if cycles > self.balance {
510 panic!(
511 "Calling canister {} with {} cycles when there is only {} cycles available.",
512 id, cycles, self.balance
513 );
514 }
515
516 let method = method.into();
517 let mut_ref = self.as_mut();
518 mut_ref.balance -= cycles;
519 mut_ref.is_reply_callback_mode = true;
520
521 let mut i = 0;
522 let (res, refunded) = loop {
523 if i == self.handlers.len() {
524 panic!("No handler found to handle the data.")
525 }
526
527 let handler = &self.handlers[i];
528 i += 1;
529
530 if handler.accept(&id, &method) {
531 break handler.perform(&self.id, cycles, &id, &method, &args_raw, None);
532 }
533 };
534
535 mut_ref.cycles_refunded = refunded;
536 mut_ref.balance += refunded;
537
538 mut_ref.watcher.record_call(WatcherCall {
539 canister_id: id,
540 method_name: method,
541 args_raw,
542 cycles_sent: cycles,
543 cycles_refunded: refunded,
544 });
545
546 Box::pin(async move { res })
547 }
548
549 #[inline]
550 fn set_certified_data(&self, data: &[u8]) {
551 if data.len() > 32 {
552 panic!("Data certificate has more than 32 bytes.");
553 }
554
555 let mut_ref = self.as_mut();
556 mut_ref.watcher.called_set_certified_data = true;
557 mut_ref.certificate = Some(MockContext::sign(data));
558 mut_ref.certified_data = Some(data.to_vec());
559 }
560
561 #[inline]
562 fn data_certificate(&self) -> Option<Vec<u8>> {
563 self.as_mut().watcher.called_data_certificate = true;
564 self.certificate.as_ref().cloned()
565 }
566
567 #[inline]
568 fn spawn<F: 'static + std::future::Future<Output = ()>>(&mut self, future: F) {
569 self.pool.run_until(future)
571 }
572
573 #[inline]
574 fn stable_size(&self) -> u32 {
575 (self.stable.len() >> 16) as u32
576 }
577
578 #[inline]
579 fn stable_grow(&self, new_pages: u32) -> Result<u32, StableMemoryError> {
580 let old_pages = (self.stable.len() >> 16) as u32 + 1;
581
582 if old_pages > 65536 {
583 panic!("stable storage");
584 }
585
586 let pages = old_pages + new_pages;
587 if pages > 65536 {
588 return Err(StableMemoryError());
589 }
590
591 let new_size = (pages as usize) << 16;
592 let additional = new_size - self.stable.len();
593 let stable = &mut self.as_mut().stable;
594 stable.reserve(additional);
595 for _ in 0..additional {
596 stable.push(0);
597 }
598
599 Ok(old_pages)
600 }
601
602 #[inline]
603 fn stable_write(&self, offset: u32, buf: &[u8]) {
604 let stable = &mut self.as_mut().stable;
605 for (i, &b) in buf.iter().enumerate() {
607 let index = (offset as usize) + i;
608 stable[index] = b;
609 }
610 }
611
612 #[inline]
613 fn stable_read(&self, offset: u32, mut buf: &mut [u8]) {
614 let stable = &mut self.as_mut().stable;
615 let slice = &stable[offset as usize..];
616 buf.write(slice).expect("Failed to write to buffer.");
617 }
618
619 #[inline]
620 fn with<T: 'static + Default, U, F: FnOnce(&T) -> U>(&self, callback: F) -> U {
621 self.as_mut().storage.with(callback)
622 }
623
624 #[inline]
625 fn maybe_with<T: 'static, U, F: FnOnce(&T) -> U>(&self, callback: F) -> Option<U> {
626 self.as_mut().storage.maybe_with(callback)
627 }
628
629 #[inline]
630 fn with_mut<T: 'static + Default, U, F: FnOnce(&mut T) -> U>(&self, callback: F) -> U {
631 self.as_mut()
632 .watcher
633 .storage_modified
634 .insert(TypeId::of::<T>());
635 self.as_mut().storage.with_mut(callback)
636 }
637
638 #[inline]
639 fn maybe_with_mut<T: 'static, U, F: FnOnce(&mut T) -> U>(&self, callback: F) -> Option<U> {
640 self.as_mut()
641 .watcher
642 .storage_modified
643 .insert(TypeId::of::<T>());
644 self.as_mut().storage.maybe_with_mut(callback)
645 }
646
647 #[inline]
648 fn take<T: 'static>(&self) -> Option<T> {
649 self.as_mut()
650 .watcher
651 .storage_modified
652 .insert(TypeId::of::<T>());
653 self.as_mut().storage.take()
654 }
655
656 #[inline]
657 fn swap<T: 'static>(&self, value: T) -> Option<T> {
658 self.as_mut()
659 .watcher
660 .storage_modified
661 .insert(TypeId::of::<T>());
662 self.as_mut().storage.swap(value)
663 }
664
665 #[inline]
666 fn store<T: 'static>(&self, data: T) {
667 self.as_mut()
668 .watcher
669 .storage_modified
670 .insert(TypeId::of::<T>());
671 self.as_mut().storage.swap(data);
672 }
673
674 #[inline]
675 fn get_maybe<T: 'static>(&self) -> Option<&T> {
676 self.as_mut().storage.get_maybe()
677 }
678
679 #[inline]
680 fn get<T: 'static + Default>(&self) -> &T {
681 self.as_mut().storage.get()
682 }
683
684 #[inline]
685 fn get_mut<T: 'static + Default>(&self) -> &mut T {
686 self.as_mut()
687 .watcher
688 .storage_modified
689 .insert(TypeId::of::<T>());
690 self.as_mut().storage.get_mut()
691 }
692
693 #[inline]
694 fn delete<T: 'static + Default>(&self) -> bool {
695 self.as_mut()
696 .watcher
697 .storage_modified
698 .insert(TypeId::of::<T>());
699 self.as_mut().storage.take::<T>().is_some()
700 }
701}
702
703impl Default for Watcher {
704 #[inline]
705 fn default() -> Self {
706 Watcher {
707 called_id: false,
708 called_time: false,
709 called_balance: false,
710 called_caller: false,
711 called_msg_cycles_available: false,
712 called_msg_cycles_accept: false,
713 called_msg_cycles_refunded: false,
714 called_stable_store: false,
715 called_stable_restore: false,
716 called_set_certified_data: false,
717 called_data_certificate: false,
718 storage_modified: Default::default(),
719 calls: Vec::with_capacity(3),
720 }
721 }
722}
723
724impl Watcher {
725 #[inline]
727 pub fn record_call(&mut self, call: WatcherCall) {
728 self.calls.push(call);
729 }
730
731 #[inline]
733 pub fn call_count(&self) -> usize {
734 self.calls.len()
735 }
736
737 #[inline]
739 pub fn cycles_consumed(&self) -> u64 {
740 let mut result = 0;
741 for call in &self.calls {
742 result += call.cycles_consumed();
743 }
744 result
745 }
746
747 #[inline]
749 pub fn cycles_refunded(&self) -> u64 {
750 let mut result = 0;
751 for call in &self.calls {
752 result += call.cycles_refunded();
753 }
754 result
755 }
756
757 #[inline]
760 pub fn cycles_sent(&self) -> u64 {
761 let mut result = 0;
762 for call in &self.calls {
763 result += call.cycles_sent();
764 }
765 result
766 }
767
768 #[inline]
770 pub fn get_call(&self, n: usize) -> &WatcherCall {
771 &self.calls[n]
772 }
773
774 #[inline]
776 pub fn is_method_called(&self, method_name: &str) -> bool {
777 for call in &self.calls {
778 if call.method_name() == method_name {
779 return true;
780 }
781 }
782 false
783 }
784
785 #[inline]
787 pub fn is_canister_called(&self, canister_id: &Principal) -> bool {
788 for call in &self.calls {
789 if &call.canister_id() == canister_id {
790 return true;
791 }
792 }
793 false
794 }
795
796 #[inline]
798 pub fn is_called(&self, canister_id: &Principal, method_name: &str) -> bool {
799 for call in &self.calls {
800 if &call.canister_id() == canister_id && call.method_name() == method_name {
801 return true;
802 }
803 }
804 false
805 }
806
807 #[inline]
813 pub fn is_modified<T: 'static>(&self) -> bool {
814 let type_id = std::any::TypeId::of::<T>();
815 self.storage_modified.contains(&type_id)
816 }
817}
818
819impl WatcherCall {
820 #[inline]
822 pub fn cycles_consumed(&self) -> u64 {
823 self.cycles_sent - self.cycles_refunded
824 }
825
826 #[inline]
828 pub fn cycles_sent(&self) -> u64 {
829 self.cycles_sent
830 }
831
832 #[inline]
834 pub fn cycles_refunded(&self) -> u64 {
835 self.cycles_refunded
836 }
837
838 #[inline]
840 pub fn args<T: for<'de> ArgumentDecoder<'de>>(&self) -> T {
841 decode_args(&self.args_raw).expect("Failed to decode arguments.")
842 }
843
844 #[inline]
846 pub fn method_name(&self) -> &str {
847 &self.method_name
848 }
849
850 #[inline]
852 pub fn canister_id(&self) -> Principal {
853 self.canister_id
854 }
855}
856
857#[cfg(test)]
858mod tests {
859 use crate::Principal;
860 use crate::{Context, MockContext};
861
862 mod canister {
864 use std::collections::BTreeMap;
865
866 use crate::ic;
867 use crate::interfaces::management::WithCanisterId;
868 use crate::interfaces::*;
869 use crate::Principal;
870
871 pub fn whoami() -> Principal {
873 ic::caller()
874 }
875
876 pub fn canister_id() -> Principal {
878 ic::id()
879 }
880
881 pub fn balance() -> u64 {
883 ic::balance()
884 }
885
886 pub fn msg_cycles_available() -> u64 {
888 ic::msg_cycles_available()
889 }
890
891 pub fn msg_cycles_accept(cycles: u64) -> u64 {
894 ic::msg_cycles_accept(cycles)
895 }
896
897 pub type Counter = BTreeMap<u64, i64>;
898
899 pub fn increment(key: u64) -> i64 {
901 let count = ic::get_mut::<Counter>().entry(key).or_insert(0);
902 *count += 1;
903 *count
904 }
905
906 pub fn decrement(key: u64) -> i64 {
908 let count = ic::get_mut::<Counter>().entry(key).or_insert(0);
909 *count -= 1;
910 *count
911 }
912
913 pub async fn withdraw(canister_id: Principal, amount: u64) -> Result<(), String> {
914 let user_balance = ic::get_mut::<u64>();
915
916 if amount > *user_balance {
917 return Err("Insufficient balance.".to_string());
918 }
919
920 *user_balance -= amount;
921
922 match management::DepositCycles::perform_with_payment(
923 Principal::management_canister(),
924 (WithCanisterId { canister_id },),
925 amount,
926 )
927 .await
928 {
929 Ok(()) => {
930 *user_balance += ic::msg_cycles_refunded();
931 Ok(())
932 }
933 Err((code, msg)) => {
934 assert_eq!(amount, ic::msg_cycles_refunded());
935 *user_balance += amount;
936 Err(format!(
937 "An error happened during the call: {}: {}",
938 code as u8, msg
939 ))
940 }
941 }
942 }
943
944 pub fn user_balance() -> u64 {
945 *ic::get::<u64>()
946 }
947
948 pub fn pre_upgrade() {
949 let map = ic::get::<Counter>();
950 ic::stable_store((map,)).expect("Failed to write to stable storage");
951 }
952
953 pub fn post_upgrade() {
954 if let Ok((map,)) = ic::stable_restore() {
955 ic::store::<Counter>(map);
956 }
957 }
958
959 pub fn set_certified_data(data: &[u8]) {
960 ic::set_certified_data(data);
961 }
962
963 pub fn data_certificate() -> Option<Vec<u8>> {
964 ic::data_certificate()
965 }
966 }
967
968 mod users {
970 use crate::Principal;
971
972 pub fn bob() -> Principal {
973 Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap()
974 }
975
976 pub fn john() -> Principal {
977 Principal::from_text("hozae-racaq-aaaaa-aaaaa-c").unwrap()
978 }
979 }
980
981 #[test]
982 fn test_with_id() {
983 let ctx = MockContext::new()
984 .with_id(Principal::management_canister())
985 .inject();
986 let watcher = ctx.watch();
987
988 assert_eq!(canister::canister_id(), Principal::management_canister());
989 assert!(watcher.called_id);
990 }
991
992 #[test]
993 fn test_balance() {
994 let ctx = MockContext::new().with_balance(1000).inject();
995 let watcher = ctx.watch();
996
997 assert_eq!(canister::balance(), 1000);
998 assert!(watcher.called_balance);
999
1000 ctx.update_balance(2000);
1001 assert_eq!(canister::balance(), 2000);
1002 }
1003
1004 #[test]
1005 fn test_caller() {
1006 let ctx = MockContext::new().with_caller(users::john()).inject();
1007 let watcher = ctx.watch();
1008
1009 assert_eq!(canister::whoami(), users::john());
1010 assert!(watcher.called_caller);
1011
1012 ctx.update_caller(users::bob());
1013 assert_eq!(canister::whoami(), users::bob());
1014 }
1015
1016 #[test]
1017 fn test_msg_cycles() {
1018 let ctx = MockContext::new().with_msg_cycles(1000).inject();
1019 let watcher = ctx.watch();
1020
1021 assert_eq!(canister::msg_cycles_available(), 1000);
1022 assert!(watcher.called_msg_cycles_available);
1023
1024 ctx.update_msg_cycles(50);
1025 assert_eq!(canister::msg_cycles_available(), 50);
1026 }
1027
1028 #[test]
1029 fn test_msg_cycles_accept() {
1030 let ctx = MockContext::new()
1031 .with_msg_cycles(1000)
1032 .with_balance(240)
1033 .inject();
1034 let watcher = ctx.watch();
1035
1036 assert_eq!(canister::msg_cycles_accept(100), 100);
1037 assert!(watcher.called_msg_cycles_accept);
1038 assert_eq!(ctx.msg_cycles_available(), 900);
1039 assert_eq!(ctx.balance(), 340);
1040
1041 ctx.update_msg_cycles(50);
1042 assert_eq!(canister::msg_cycles_accept(100), 50);
1043 assert_eq!(ctx.msg_cycles_available(), 0);
1044 assert_eq!(ctx.balance(), 390);
1045 }
1046
1047 #[test]
1048 fn test_storage_simple() {
1049 let ctx = MockContext::new().inject();
1050 let watcher = ctx.watch();
1051 assert_eq!(watcher.is_modified::<canister::Counter>(), false);
1052 assert_eq!(canister::increment(0), 1);
1053 assert_eq!(watcher.is_modified::<canister::Counter>(), true);
1054 assert_eq!(canister::increment(0), 2);
1055 assert_eq!(canister::increment(0), 3);
1056 assert_eq!(canister::increment(1), 1);
1057 assert_eq!(canister::decrement(0), 2);
1058 assert_eq!(canister::decrement(2), -1);
1059 }
1060
1061 #[test]
1062 fn test_storage() {
1063 let ctx = MockContext::new()
1064 .with_data({
1065 let mut map = canister::Counter::default();
1066 map.insert(0, 12);
1067 map.insert(1, 17);
1068 map
1069 })
1070 .inject();
1071 assert_eq!(canister::increment(0), 13);
1072 assert_eq!(canister::decrement(1), 16);
1073
1074 let watcher = ctx.watch();
1075 assert_eq!(watcher.is_modified::<canister::Counter>(), false);
1076 ctx.store({
1077 let mut map = canister::Counter::default();
1078 map.insert(0, 12);
1079 map.insert(1, 17);
1080 map
1081 });
1082 assert_eq!(watcher.is_modified::<canister::Counter>(), true);
1083
1084 assert_eq!(canister::increment(0), 13);
1085 assert_eq!(canister::decrement(1), 16);
1086
1087 ctx.clear_storage();
1088
1089 assert_eq!(canister::increment(0), 1);
1090 assert_eq!(canister::decrement(1), -1);
1091 }
1092
1093 #[test]
1094 fn stable_storage() {
1095 let ctx = MockContext::new()
1096 .with_data({
1097 let mut map = canister::Counter::default();
1098 map.insert(0, 2);
1099 map.insert(1, 27);
1100 map.insert(2, 5);
1101 map.insert(3, 17);
1102 map
1103 })
1104 .inject();
1105
1106 let watcher = ctx.watch();
1107
1108 canister::pre_upgrade();
1109 assert!(watcher.called_stable_store);
1110 ctx.clear_storage();
1111 canister::post_upgrade();
1112 assert!(watcher.called_stable_restore);
1113
1114 let counter = ctx.get::<canister::Counter>();
1115 let data: Vec<(u64, i64)> = counter.iter().map(|(k, v)| (*k, *v)).collect();
1116 assert_eq!(data, vec![(0, 2), (1, 27), (2, 5), (3, 17)]);
1117
1118 assert_eq!(canister::increment(0), 3);
1119 assert_eq!(canister::decrement(1), 26);
1120 }
1121
1122 #[test]
1123 fn certified_data() {
1124 let ctx = MockContext::new()
1125 .with_certified_data(vec![0, 1, 2, 3, 4, 5])
1126 .inject();
1127 let watcher = ctx.watch();
1128
1129 assert_eq!(ctx.get_certified_data(), Some(vec![0, 1, 2, 3, 4, 5]));
1130 assert_eq!(
1131 ctx.data_certificate(),
1132 Some(MockContext::sign(&[0, 1, 2, 3, 4, 5]))
1133 );
1134
1135 canister::set_certified_data(&[1, 2, 3]);
1136 assert_eq!(watcher.called_set_certified_data, true);
1137 assert_eq!(ctx.get_certified_data(), Some(vec![1, 2, 3]));
1138 assert_eq!(ctx.data_certificate(), Some(MockContext::sign(&[1, 2, 3])));
1139
1140 canister::data_certificate();
1141 assert_eq!(watcher.called_data_certificate, true);
1142 }
1143
1144 #[async_std::test]
1145 async fn withdraw_accept() {
1146 let ctx = MockContext::new()
1147 .with_consume_cycles_handler(200)
1148 .with_data(1000u64)
1149 .with_balance(2000)
1150 .inject();
1151 let watcher = ctx.watch();
1152
1153 assert_eq!(canister::user_balance(), 1000);
1154
1155 assert_eq!(
1156 watcher.is_canister_called(&Principal::management_canister()),
1157 false
1158 );
1159 assert_eq!(watcher.is_method_called("deposit_cycles"), false);
1160 assert_eq!(
1161 watcher.is_called(&Principal::management_canister(), "deposit_cycles"),
1162 false
1163 );
1164 assert_eq!(watcher.cycles_consumed(), 0);
1165
1166 canister::withdraw(users::bob(), 100).await.unwrap();
1167
1168 assert_eq!(watcher.call_count(), 1);
1169 assert_eq!(
1170 watcher.is_canister_called(&Principal::management_canister()),
1171 true
1172 );
1173 assert_eq!(watcher.is_method_called("deposit_cycles"), true);
1174 assert_eq!(
1175 watcher.is_called(&Principal::management_canister(), "deposit_cycles"),
1176 true
1177 );
1178 assert_eq!(watcher.cycles_consumed(), 100);
1179
1180 assert_eq!(canister::user_balance(), 900);
1182 assert_eq!(canister::balance(), 1900);
1184 }
1185
1186 #[async_std::test]
1187 async fn withdraw_accept_portion() {
1188 let ctx = MockContext::new()
1189 .with_consume_cycles_handler(50)
1190 .with_data(1000u64)
1191 .with_balance(2000)
1192 .inject();
1193 let watcher = ctx.watch();
1194
1195 assert_eq!(canister::user_balance(), 1000);
1196
1197 canister::withdraw(users::bob(), 100).await.unwrap();
1198 assert_eq!(watcher.cycles_sent(), 100);
1199 assert_eq!(watcher.cycles_consumed(), 50);
1200 assert_eq!(watcher.cycles_refunded(), 50);
1201
1202 assert_eq!(canister::user_balance(), 950);
1204 assert_eq!(canister::balance(), 1950);
1206 }
1207
1208 #[async_std::test]
1209 async fn withdraw_accept_zero() {
1210 let ctx = MockContext::new()
1211 .with_consume_cycles_handler(0)
1212 .with_data(1000u64)
1213 .with_balance(2000)
1214 .inject();
1215 let watcher = ctx.watch();
1216
1217 assert_eq!(canister::user_balance(), 1000);
1218
1219 canister::withdraw(users::bob(), 100).await.unwrap();
1220 assert_eq!(watcher.cycles_sent(), 100);
1221 assert_eq!(watcher.cycles_consumed(), 0);
1222 assert_eq!(watcher.cycles_refunded(), 100);
1223
1224 assert_eq!(canister::user_balance(), 1000);
1226 assert_eq!(canister::balance(), 2000);
1227 }
1228
1229 #[async_std::test]
1230 async fn with_refund() {
1231 let ctx = MockContext::new()
1232 .with_refund_cycles_handler(30)
1233 .with_data(1000u64)
1234 .with_balance(2000)
1235 .inject();
1236 let watcher = ctx.watch();
1237
1238 canister::withdraw(users::bob(), 100).await.unwrap();
1239 assert_eq!(watcher.cycles_sent(), 100);
1240 assert_eq!(watcher.cycles_consumed(), 70);
1241 assert_eq!(watcher.cycles_refunded(), 30);
1242 assert_eq!(canister::user_balance(), 930);
1243 assert_eq!(canister::balance(), 1930);
1244 }
1245
1246 #[test]
1247 #[should_panic]
1248 fn trap_should_panic() {
1249 let ctx = MockContext::new().inject();
1250 ctx.trap("Unreachable");
1251 }
1252
1253 #[test]
1254 #[should_panic]
1255 fn large_certificate_should_panic() {
1256 let ctx = MockContext::new().inject();
1257 let bytes = vec![0; 33];
1258 ctx.set_certified_data(bytes.as_slice());
1259 }
1260}