ic_kit/
mock.rs

1use futures::executor::LocalPool;
2use std::any::TypeId;
3use std::collections::BTreeSet;
4use std::hash::Hasher;
5use std::io::Write;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use ic_cdk::export::candid::decode_args;
9use ic_cdk::export::candid::utils::{ArgumentDecoder, ArgumentEncoder};
10use ic_cdk::export::{candid, Principal};
11use serde::Serialize;
12
13use crate::candid::CandidType;
14use crate::inject::{get_context, inject};
15use crate::interface::{CallResponse, Context};
16use crate::stable::StableWriter;
17use crate::storage::Storage;
18use crate::{CallHandler, Method, StableMemoryError};
19
20/// A context that could be used to fake/control the behaviour of the IC when testing the canister.
21pub struct MockContext {
22    /// The watcher on the context.
23    watcher: Watcher,
24    /// ID of the current canister.
25    id: Principal,
26    /// The balance of the canister. By default set to 100TC.
27    balance: u64,
28    /// The caller principal passed to the calls, by default `anonymous` is used.
29    caller: Principal,
30    /// Determines if a call was made or not.
31    is_reply_callback_mode: bool,
32    /// Whatever the canister called trap or not.
33    trapped: bool,
34    /// Available cycles sent by the caller.
35    cycles: u64,
36    /// Cycles refunded by the previous call.
37    cycles_refunded: u64,
38    /// The storage tree for the current context.
39    storage: Storage,
40    /// The stable storage data.
41    stable: Vec<u8>,
42    /// The certified data.
43    certified_data: Option<Vec<u8>>,
44    /// The certificate certifying the certified_data.
45    certificate: Option<Vec<u8>>,
46    /// The handlers used to handle inter-canister calls.
47    handlers: Vec<Box<dyn CallHandler>>,
48    /// All of the spawned futures.
49    pool: LocalPool,
50}
51
52/// A watcher can be used to inspect the calls made in a call.
53pub struct Watcher {
54    /// True if the `context.id()` was called during execution.
55    pub called_id: bool,
56    /// True if the `context.time()` was called during execution.
57    pub called_time: bool,
58    /// True if the `context.balance()` was called during execution.
59    pub called_balance: bool,
60    /// True if the `context.caller()` was called during execution.
61    pub called_caller: bool,
62    /// True if the `context.msg_cycles_available()` was called during execution.
63    pub called_msg_cycles_available: bool,
64    /// True if the `context.msg_cycles_accept()` was called during execution.
65    pub called_msg_cycles_accept: bool,
66    /// True if the `context.msg_cycles_refunded()` was called during execution.
67    pub called_msg_cycles_refunded: bool,
68    /// True if the `context.stable_store()` was called during execution.
69    pub called_stable_store: bool,
70    /// True if the `context.stable_restore()` was called during execution.
71    pub called_stable_restore: bool,
72    /// True if the `context.set_certified_data()` was called during execution.
73    pub called_set_certified_data: bool,
74    /// True if the `context.data_certificate()` was called during execution.
75    pub called_data_certificate: bool,
76    /// Storage items that were mutated.
77    storage_modified: BTreeSet<TypeId>,
78    /// List of all the inter-canister calls that took place.
79    calls: Vec<WatcherCall>,
80}
81
82pub struct WatcherCall {
83    canister_id: Principal,
84    method_name: String,
85    args_raw: Vec<u8>,
86    cycles_sent: u64,
87    cycles_refunded: u64,
88}
89
90impl MockContext {
91    /// Create a new mock context which could be injected for testing.
92    #[inline]
93    pub fn new() -> Self {
94        Self {
95            watcher: Watcher::default(),
96            id: Principal::from_text("sgymv-uiaaa-aaaaa-aaaia-cai").unwrap(),
97            balance: 100_000_000_000_000,
98            caller: Principal::anonymous(),
99            is_reply_callback_mode: false,
100            trapped: false,
101            cycles: 0,
102            cycles_refunded: 0,
103            storage: Storage::default(),
104            stable: Vec::new(),
105            certified_data: None,
106            certificate: None,
107            handlers: vec![],
108            pool: LocalPool::new(),
109        }
110    }
111
112    /// Reset the current watcher on the MockContext and return a reference to it.
113    #[inline]
114    pub fn watch(&self) -> &Watcher {
115        self.as_mut().watcher = Watcher::default();
116        &self.watcher
117    }
118
119    /// Set the ID of the canister.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use ic_kit::*;
125    ///
126    /// let id = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
127    ///
128    /// MockContext::new()
129    ///     .with_id(id.clone())
130    ///     .inject();
131    ///
132    /// assert_eq!(ic::id(), id);
133    /// ```
134    #[inline]
135    pub fn with_id(mut self, id: Principal) -> Self {
136        self.id = id;
137        self
138    }
139
140    /// Set the balance of the canister.
141    ///
142    /// # Example
143    ///
144    /// ```
145    /// use ic_kit::*;
146    ///
147    /// MockContext::new()
148    ///     .with_balance(1000)
149    ///     .inject();
150    ///
151    /// assert_eq!(ic::balance(), 1000);
152    /// ```
153    #[inline]
154    pub fn with_balance(mut self, cycles: u64) -> Self {
155        self.balance = cycles;
156        self
157    }
158
159    /// Set the caller for the current call.
160    ///
161    /// # Example
162    ///
163    /// ```
164    /// use ic_kit::*;
165    ///
166    /// let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
167    ///
168    /// MockContext::new()
169    ///     .with_caller(alice.clone())
170    ///     .inject();
171    ///
172    /// assert_eq!(ic::caller(), alice);
173    /// ```
174    #[inline]
175    pub fn with_caller(mut self, caller: Principal) -> Self {
176        self.caller = caller;
177        self
178    }
179
180    /// Make the given amount of cycles available for the call. This amount of cycles will
181    /// be deduced if the call accepts them or will be refunded. If the canister accepts any
182    /// cycles the balance of the canister will be increased.
183    ///
184    /// # Example
185    ///
186    /// ```
187    /// use ic_kit::*;
188    ///
189    /// MockContext::new()
190    ///     .with_msg_cycles(1000)
191    ///     .inject();
192    ///
193    /// assert_eq!(ic::msg_cycles_available(), 1000);
194    /// ic::msg_cycles_accept(300);
195    /// assert_eq!(ic::msg_cycles_available(), 700);
196    /// ```
197    #[inline]
198    pub fn with_msg_cycles(mut self, cycles: u64) -> Self {
199        self.cycles = cycles;
200        self
201    }
202
203    /// Initialize the context with the given value inserted in the storage.
204    ///
205    /// # Example
206    ///
207    /// ```
208    /// use ic_kit::*;
209    ///
210    /// MockContext::new()
211    ///     .with_data(String::from("Hello"))
212    ///     .inject();
213    ///
214    /// assert_eq!(ic::get::<String>(), &"Hello".to_string());
215    /// ```
216    #[inline]
217    pub fn with_data<T: 'static>(mut self, data: T) -> Self {
218        self.storage.swap(data);
219        self
220    }
221
222    /// Initialize the context with the given value inserted into the stable storage.
223    ///
224    /// # Example
225    ///
226    /// ```
227    /// use ic_kit::*;
228    ///
229    /// MockContext::new()
230    ///     .with_stable(("Bella".to_string(), ))
231    ///     .inject();
232    ///
233    /// assert_eq!(ic::stable_restore::<(String, )>(), Ok(("Bella".to_string(), )));
234    /// ```
235    #[inline]
236    pub fn with_stable<T: Serialize>(mut self, data: T) -> Self
237    where
238        T: ArgumentEncoder,
239    {
240        self.stable.truncate(0);
241        candid::write_args(&mut self.stable, data).expect("Encoding stable data failed.");
242        self.stable_grow(0)
243            .expect("Can not write this amount of data to stable storage.");
244        self
245    }
246
247    /// Set the certified data of the canister.
248    #[inline]
249    pub fn with_certified_data(mut self, data: Vec<u8>) -> Self {
250        assert!(data.len() < 32);
251        self.certificate = Some(MockContext::sign(data.as_slice()));
252        self.certified_data = Some(data);
253        self
254    }
255
256    /// Creates a mock context with a default handler that accepts the given amount of cycles
257    /// on every request.
258    #[inline]
259    pub fn with_consume_cycles_handler(self, cycles: u64) -> Self {
260        self.with_handler(Method::new().cycles_consume(cycles))
261    }
262
263    /// Create a mock context with a default handler that expects this amount of cycles to
264    /// be passed to it.
265    #[inline]
266    pub fn with_expect_cycles_handler(self, cycles: u64) -> Self {
267        self.with_handler(Method::new().expect_cycles(cycles))
268    }
269
270    /// Creates a mock context with a default handler that refunds the given amount of cycles
271    /// on every request.
272    #[inline]
273    pub fn with_refund_cycles_handler(self, cycles: u64) -> Self {
274        self.with_handler(Method::new().cycles_refund(cycles))
275    }
276
277    /// Create a mock context with a default handler that returns the given value.
278    #[inline]
279    pub fn with_constant_return_handler<T: CandidType>(self, value: T) -> Self {
280        self.with_handler(Method::new().response(value))
281    }
282
283    /// Add the given handler to the handlers pipeline.
284    #[inline]
285    pub fn with_handler<T: 'static + CallHandler>(mut self, handler: T) -> Self {
286        self.use_handler(handler);
287        self
288    }
289
290    /// Use this context as the default context for this thread.
291    #[inline]
292    pub fn inject(self) -> &'static mut Self {
293        inject(self);
294        get_context()
295    }
296
297    /// Sign a data and return the certificate, this is the method used in set_certified_data
298    /// to set the data certificate for the given certified data.
299    pub fn sign(data: &[u8]) -> Vec<u8> {
300        let data = {
301            let mut tmp: Vec<u8> = vec![0; 32];
302            for (i, b) in data.iter().enumerate() {
303                tmp[i] = *b;
304            }
305            tmp
306        };
307
308        let mut certificate = Vec::with_capacity(32 * 8);
309
310        for i in 0..32 {
311            let mut hasher = std::collections::hash_map::DefaultHasher::new();
312            for b in &certificate {
313                hasher.write_u8(*b);
314            }
315            hasher.write_u8(data[i]);
316            let hash = hasher.finish().to_be_bytes();
317            certificate.extend_from_slice(&hash);
318        }
319
320        certificate
321    }
322
323    /// This is how we do interior mutability for MockContext. Since the context is only accessible
324    /// by only one thread, it is safe to do it here.
325    #[inline]
326    fn as_mut(&self) -> &mut Self {
327        unsafe {
328            let const_ptr = self as *const Self;
329            let mut_ptr = const_ptr as *mut Self;
330            &mut *mut_ptr
331        }
332    }
333}
334
335impl MockContext {
336    /// Reset the state after a call.
337    #[inline]
338    pub fn call_state_reset(&self) {
339        let mut_ref = self.as_mut();
340        mut_ref.is_reply_callback_mode = false;
341        mut_ref.trapped = false;
342    }
343
344    /// Clear the storage.
345    #[inline]
346    pub fn clear_storage(&self) {
347        self.as_mut().storage = Storage::default()
348    }
349
350    /// Update the balance of the canister.
351    #[inline]
352    pub fn update_balance(&self, cycles: u64) {
353        self.as_mut().balance = cycles;
354    }
355
356    /// Update the cycles of the next message.
357    #[inline]
358    pub fn update_msg_cycles(&self, cycles: u64) {
359        self.as_mut().cycles = cycles;
360    }
361
362    /// Update the caller for the next message.
363    #[inline]
364    pub fn update_caller(&self, caller: Principal) {
365        self.as_mut().caller = caller;
366    }
367
368    /// Update the canister id the call happens from for the next message.
369    #[inline]
370    pub fn update_id(&self, canister_id: Principal) {
371        self.as_mut().id = canister_id;
372    }
373
374    /// Return the certified data set on the canister.
375    #[inline]
376    pub fn get_certified_data(&self) -> Option<Vec<u8>> {
377        self.certified_data.as_ref().cloned()
378    }
379
380    /// Add the given handler to the call handlers pipeline.
381    #[inline]
382    pub fn use_handler<T: 'static + CallHandler>(&mut self, handler: T) {
383        self.handlers.push(Box::new(handler));
384    }
385
386    /// Remove all of the call handlers that are already registered to this context.
387    #[inline]
388    pub fn clear_handlers(&mut self) {
389        self.handlers.clear();
390    }
391
392    /// Block the current thread until all the spawned futures are complete.
393    #[inline]
394    pub fn join(&mut self) {
395        self.pool.run();
396    }
397}
398
399impl Context for MockContext {
400    #[inline]
401    fn trap(&self, message: &str) -> ! {
402        self.as_mut().trapped = true;
403        panic!("Canister {} trapped with message: {}", self.id, message);
404    }
405
406    #[inline]
407    fn print<S: AsRef<str>>(&self, s: S) {
408        println!("{} : {}", self.id, s.as_ref())
409    }
410
411    #[inline]
412    fn id(&self) -> Principal {
413        self.as_mut().watcher.called_id = true;
414        self.id
415    }
416
417    #[inline]
418    fn time(&self) -> u64 {
419        self.as_mut().watcher.called_time = true;
420        SystemTime::now()
421            .duration_since(UNIX_EPOCH)
422            .expect("Time went backwards")
423            .as_nanos() as u64
424    }
425
426    #[inline]
427    fn balance(&self) -> u64 {
428        self.as_mut().watcher.called_balance = true;
429        self.balance
430    }
431
432    #[inline]
433    fn caller(&self) -> Principal {
434        self.as_mut().watcher.called_caller = true;
435
436        if self.is_reply_callback_mode {
437            panic!(
438                "Canister {} violated contract: \"{}\" cannot be executed in reply callback mode",
439                self.id(),
440                "ic0_msg_caller_size"
441            )
442        }
443
444        self.caller
445    }
446
447    #[inline]
448    fn msg_cycles_available(&self) -> u64 {
449        self.as_mut().watcher.called_msg_cycles_available = true;
450        self.cycles
451    }
452
453    #[inline]
454    fn msg_cycles_accept(&self, cycles: u64) -> u64 {
455        self.as_mut().watcher.called_msg_cycles_accept = true;
456        let mut_ref = self.as_mut();
457        if cycles > mut_ref.cycles {
458            let r = mut_ref.cycles;
459            mut_ref.cycles = 0;
460            mut_ref.balance += r;
461            r
462        } else {
463            mut_ref.cycles -= cycles;
464            mut_ref.balance += cycles;
465            cycles
466        }
467    }
468
469    #[inline]
470    fn msg_cycles_refunded(&self) -> u64 {
471        self.as_mut().watcher.called_msg_cycles_refunded = true;
472        self.cycles_refunded
473    }
474
475    #[inline]
476    fn stable_store<T>(&self, data: T) -> Result<(), candid::Error>
477    where
478        T: ArgumentEncoder,
479    {
480        self.as_mut().watcher.called_stable_store = true;
481        candid::write_args(&mut StableWriter::default(), data)
482    }
483
484    #[inline]
485    fn stable_restore<T>(&self) -> Result<T, String>
486    where
487        T: for<'de> ArgumentDecoder<'de>,
488    {
489        self.as_mut().watcher.called_stable_restore = true;
490        let bytes = &self.stable;
491
492        let mut de =
493            candid::de::IDLDeserialize::new(bytes.as_slice()).map_err(|e| format!("{:?}", e))?;
494        let res =
495            candid::utils::ArgumentDecoder::decode(&mut de).map_err(|e| format!("{:?}", e))?;
496        // The idea here is to ignore an error that comes from Candid, because we have trailing
497        // bytes.
498        let _ = de.done();
499        Ok(res)
500    }
501
502    fn call_raw<S: Into<String>>(
503        &'static self,
504        id: Principal,
505        method: S,
506        args_raw: Vec<u8>,
507        cycles: u64,
508    ) -> CallResponse<Vec<u8>> {
509        if cycles > self.balance {
510            panic!(
511                "Calling canister {} with {} cycles when there is only {} cycles available.",
512                id, cycles, self.balance
513            );
514        }
515
516        let method = method.into();
517        let mut_ref = self.as_mut();
518        mut_ref.balance -= cycles;
519        mut_ref.is_reply_callback_mode = true;
520
521        let mut i = 0;
522        let (res, refunded) = loop {
523            if i == self.handlers.len() {
524                panic!("No handler found to handle the data.")
525            }
526
527            let handler = &self.handlers[i];
528            i += 1;
529
530            if handler.accept(&id, &method) {
531                break handler.perform(&self.id, cycles, &id, &method, &args_raw, None);
532            }
533        };
534
535        mut_ref.cycles_refunded = refunded;
536        mut_ref.balance += refunded;
537
538        mut_ref.watcher.record_call(WatcherCall {
539            canister_id: id,
540            method_name: method,
541            args_raw,
542            cycles_sent: cycles,
543            cycles_refunded: refunded,
544        });
545
546        Box::pin(async move { res })
547    }
548
549    #[inline]
550    fn set_certified_data(&self, data: &[u8]) {
551        if data.len() > 32 {
552            panic!("Data certificate has more than 32 bytes.");
553        }
554
555        let mut_ref = self.as_mut();
556        mut_ref.watcher.called_set_certified_data = true;
557        mut_ref.certificate = Some(MockContext::sign(data));
558        mut_ref.certified_data = Some(data.to_vec());
559    }
560
561    #[inline]
562    fn data_certificate(&self) -> Option<Vec<u8>> {
563        self.as_mut().watcher.called_data_certificate = true;
564        self.certificate.as_ref().cloned()
565    }
566
567    #[inline]
568    fn spawn<F: 'static + std::future::Future<Output = ()>>(&mut self, future: F) {
569        // TODO(qti3e) Setup the context in the pool thread.
570        self.pool.run_until(future)
571    }
572
573    #[inline]
574    fn stable_size(&self) -> u32 {
575        (self.stable.len() >> 16) as u32
576    }
577
578    #[inline]
579    fn stable_grow(&self, new_pages: u32) -> Result<u32, StableMemoryError> {
580        let old_pages = (self.stable.len() >> 16) as u32 + 1;
581
582        if old_pages > 65536 {
583            panic!("stable storage");
584        }
585
586        let pages = old_pages + new_pages;
587        if pages > 65536 {
588            return Err(StableMemoryError());
589        }
590
591        let new_size = (pages as usize) << 16;
592        let additional = new_size - self.stable.len();
593        let stable = &mut self.as_mut().stable;
594        stable.reserve(additional);
595        for _ in 0..additional {
596            stable.push(0);
597        }
598
599        Ok(old_pages)
600    }
601
602    #[inline]
603    fn stable_write(&self, offset: u32, buf: &[u8]) {
604        let stable = &mut self.as_mut().stable;
605        // todo(qti3e) improve this implementation using copy.
606        for (i, &b) in buf.iter().enumerate() {
607            let index = (offset as usize) + i;
608            stable[index] = b;
609        }
610    }
611
612    #[inline]
613    fn stable_read(&self, offset: u32, mut buf: &mut [u8]) {
614        let stable = &mut self.as_mut().stable;
615        let slice = &stable[offset as usize..];
616        buf.write(slice).expect("Failed to write to buffer.");
617    }
618
619    #[inline]
620    fn with<T: 'static + Default, U, F: FnOnce(&T) -> U>(&self, callback: F) -> U {
621        self.as_mut().storage.with(callback)
622    }
623
624    #[inline]
625    fn maybe_with<T: 'static, U, F: FnOnce(&T) -> U>(&self, callback: F) -> Option<U> {
626        self.as_mut().storage.maybe_with(callback)
627    }
628
629    #[inline]
630    fn with_mut<T: 'static + Default, U, F: FnOnce(&mut T) -> U>(&self, callback: F) -> U {
631        self.as_mut()
632            .watcher
633            .storage_modified
634            .insert(TypeId::of::<T>());
635        self.as_mut().storage.with_mut(callback)
636    }
637
638    #[inline]
639    fn maybe_with_mut<T: 'static, U, F: FnOnce(&mut T) -> U>(&self, callback: F) -> Option<U> {
640        self.as_mut()
641            .watcher
642            .storage_modified
643            .insert(TypeId::of::<T>());
644        self.as_mut().storage.maybe_with_mut(callback)
645    }
646
647    #[inline]
648    fn take<T: 'static>(&self) -> Option<T> {
649        self.as_mut()
650            .watcher
651            .storage_modified
652            .insert(TypeId::of::<T>());
653        self.as_mut().storage.take()
654    }
655
656    #[inline]
657    fn swap<T: 'static>(&self, value: T) -> Option<T> {
658        self.as_mut()
659            .watcher
660            .storage_modified
661            .insert(TypeId::of::<T>());
662        self.as_mut().storage.swap(value)
663    }
664
665    #[inline]
666    fn store<T: 'static>(&self, data: T) {
667        self.as_mut()
668            .watcher
669            .storage_modified
670            .insert(TypeId::of::<T>());
671        self.as_mut().storage.swap(data);
672    }
673
674    #[inline]
675    fn get_maybe<T: 'static>(&self) -> Option<&T> {
676        self.as_mut().storage.get_maybe()
677    }
678
679    #[inline]
680    fn get<T: 'static + Default>(&self) -> &T {
681        self.as_mut().storage.get()
682    }
683
684    #[inline]
685    fn get_mut<T: 'static + Default>(&self) -> &mut T {
686        self.as_mut()
687            .watcher
688            .storage_modified
689            .insert(TypeId::of::<T>());
690        self.as_mut().storage.get_mut()
691    }
692
693    #[inline]
694    fn delete<T: 'static + Default>(&self) -> bool {
695        self.as_mut()
696            .watcher
697            .storage_modified
698            .insert(TypeId::of::<T>());
699        self.as_mut().storage.take::<T>().is_some()
700    }
701}
702
703impl Default for Watcher {
704    #[inline]
705    fn default() -> Self {
706        Watcher {
707            called_id: false,
708            called_time: false,
709            called_balance: false,
710            called_caller: false,
711            called_msg_cycles_available: false,
712            called_msg_cycles_accept: false,
713            called_msg_cycles_refunded: false,
714            called_stable_store: false,
715            called_stable_restore: false,
716            called_set_certified_data: false,
717            called_data_certificate: false,
718            storage_modified: Default::default(),
719            calls: Vec::with_capacity(3),
720        }
721    }
722}
723
724impl Watcher {
725    /// Push a call to the call history of the watcher.
726    #[inline]
727    pub fn record_call(&mut self, call: WatcherCall) {
728        self.calls.push(call);
729    }
730
731    /// Return the number of calls made during the last execution.
732    #[inline]
733    pub fn call_count(&self) -> usize {
734        self.calls.len()
735    }
736
737    /// Returns the total amount of cycles consumed in inter-canister calls.
738    #[inline]
739    pub fn cycles_consumed(&self) -> u64 {
740        let mut result = 0;
741        for call in &self.calls {
742            result += call.cycles_consumed();
743        }
744        result
745    }
746
747    /// Returns the total amount of cycles refunded in inter-canister calls.
748    #[inline]
749    pub fn cycles_refunded(&self) -> u64 {
750        let mut result = 0;
751        for call in &self.calls {
752            result += call.cycles_refunded();
753        }
754        result
755    }
756
757    /// Returns the total amount of cycles sent in inter-canister calls, not deducing the refunded
758    /// amounts.
759    #[inline]
760    pub fn cycles_sent(&self) -> u64 {
761        let mut result = 0;
762        for call in &self.calls {
763            result += call.cycles_sent();
764        }
765        result
766    }
767
768    /// Return the n-th call that took place during the execution.
769    #[inline]
770    pub fn get_call(&self, n: usize) -> &WatcherCall {
771        &self.calls[n]
772    }
773
774    /// Returns true if the given method was called during the execution.
775    #[inline]
776    pub fn is_method_called(&self, method_name: &str) -> bool {
777        for call in &self.calls {
778            if call.method_name() == method_name {
779                return true;
780            }
781        }
782        false
783    }
784
785    /// Returns true if the given canister was called during the execution.
786    #[inline]
787    pub fn is_canister_called(&self, canister_id: &Principal) -> bool {
788        for call in &self.calls {
789            if &call.canister_id() == canister_id {
790                return true;
791            }
792        }
793        false
794    }
795
796    /// Returns true if the given method was called.
797    #[inline]
798    pub fn is_called(&self, canister_id: &Principal, method_name: &str) -> bool {
799        for call in &self.calls {
800            if &call.canister_id() == canister_id && call.method_name() == method_name {
801                return true;
802            }
803        }
804        false
805    }
806
807    /// Returns true if the given storage item was accessed in a mutable way during the execution.
808    /// This method tracks calls to:
809    /// - context.store()
810    /// - context.get_mut()
811    /// - context.delete()
812    #[inline]
813    pub fn is_modified<T: 'static>(&self) -> bool {
814        let type_id = std::any::TypeId::of::<T>();
815        self.storage_modified.contains(&type_id)
816    }
817}
818
819impl WatcherCall {
820    /// The amount of cycles consumed by this call.
821    #[inline]
822    pub fn cycles_consumed(&self) -> u64 {
823        self.cycles_sent - self.cycles_refunded
824    }
825
826    /// The amount of cycles sent to the call.
827    #[inline]
828    pub fn cycles_sent(&self) -> u64 {
829        self.cycles_sent
830    }
831
832    /// The amount of cycles refunded from the call.
833    #[inline]
834    pub fn cycles_refunded(&self) -> u64 {
835        self.cycles_refunded
836    }
837
838    /// Return the arguments passed to the call.
839    #[inline]
840    pub fn args<T: for<'de> ArgumentDecoder<'de>>(&self) -> T {
841        decode_args(&self.args_raw).expect("Failed to decode arguments.")
842    }
843
844    /// Name of the method that was called.
845    #[inline]
846    pub fn method_name(&self) -> &str {
847        &self.method_name
848    }
849
850    /// Canister ID that was target of the call.
851    #[inline]
852    pub fn canister_id(&self) -> Principal {
853        self.canister_id
854    }
855}
856
857#[cfg(test)]
858mod tests {
859    use crate::Principal;
860    use crate::{Context, MockContext};
861
862    /// A simple canister implementation which helps the testing.
863    mod canister {
864        use std::collections::BTreeMap;
865
866        use crate::ic;
867        use crate::interfaces::management::WithCanisterId;
868        use crate::interfaces::*;
869        use crate::Principal;
870
871        /// An update method that returns the principal id of the caller.
872        pub fn whoami() -> Principal {
873            ic::caller()
874        }
875
876        /// An update method that returns the principal id of the canister.
877        pub fn canister_id() -> Principal {
878            ic::id()
879        }
880
881        /// An update method that returns the balance of the canister.
882        pub fn balance() -> u64 {
883            ic::balance()
884        }
885
886        /// An update method that returns the number of cycles provided by the user in the call.
887        pub fn msg_cycles_available() -> u64 {
888            ic::msg_cycles_available()
889        }
890
891        /// An update method that accepts the given number of cycles from the caller, the number of
892        /// accepted cycles is returned.
893        pub fn msg_cycles_accept(cycles: u64) -> u64 {
894            ic::msg_cycles_accept(cycles)
895        }
896
897        pub type Counter = BTreeMap<u64, i64>;
898
899        /// An update method that increments one to the given key, the new value is returned.
900        pub fn increment(key: u64) -> i64 {
901            let count = ic::get_mut::<Counter>().entry(key).or_insert(0);
902            *count += 1;
903            *count
904        }
905
906        /// An update method that decrement one from the given key. The new value is returned.
907        pub fn decrement(key: u64) -> i64 {
908            let count = ic::get_mut::<Counter>().entry(key).or_insert(0);
909            *count -= 1;
910            *count
911        }
912
913        pub async fn withdraw(canister_id: Principal, amount: u64) -> Result<(), String> {
914            let user_balance = ic::get_mut::<u64>();
915
916            if amount > *user_balance {
917                return Err("Insufficient balance.".to_string());
918            }
919
920            *user_balance -= amount;
921
922            match management::DepositCycles::perform_with_payment(
923                Principal::management_canister(),
924                (WithCanisterId { canister_id },),
925                amount,
926            )
927            .await
928            {
929                Ok(()) => {
930                    *user_balance += ic::msg_cycles_refunded();
931                    Ok(())
932                }
933                Err((code, msg)) => {
934                    assert_eq!(amount, ic::msg_cycles_refunded());
935                    *user_balance += amount;
936                    Err(format!(
937                        "An error happened during the call: {}: {}",
938                        code as u8, msg
939                    ))
940                }
941            }
942        }
943
944        pub fn user_balance() -> u64 {
945            *ic::get::<u64>()
946        }
947
948        pub fn pre_upgrade() {
949            let map = ic::get::<Counter>();
950            ic::stable_store((map,)).expect("Failed to write to stable storage");
951        }
952
953        pub fn post_upgrade() {
954            if let Ok((map,)) = ic::stable_restore() {
955                ic::store::<Counter>(map);
956            }
957        }
958
959        pub fn set_certified_data(data: &[u8]) {
960            ic::set_certified_data(data);
961        }
962
963        pub fn data_certificate() -> Option<Vec<u8>> {
964            ic::data_certificate()
965        }
966    }
967
968    /// Some mock principal ids.
969    mod users {
970        use crate::Principal;
971
972        pub fn bob() -> Principal {
973            Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap()
974        }
975
976        pub fn john() -> Principal {
977            Principal::from_text("hozae-racaq-aaaaa-aaaaa-c").unwrap()
978        }
979    }
980
981    #[test]
982    fn test_with_id() {
983        let ctx = MockContext::new()
984            .with_id(Principal::management_canister())
985            .inject();
986        let watcher = ctx.watch();
987
988        assert_eq!(canister::canister_id(), Principal::management_canister());
989        assert!(watcher.called_id);
990    }
991
992    #[test]
993    fn test_balance() {
994        let ctx = MockContext::new().with_balance(1000).inject();
995        let watcher = ctx.watch();
996
997        assert_eq!(canister::balance(), 1000);
998        assert!(watcher.called_balance);
999
1000        ctx.update_balance(2000);
1001        assert_eq!(canister::balance(), 2000);
1002    }
1003
1004    #[test]
1005    fn test_caller() {
1006        let ctx = MockContext::new().with_caller(users::john()).inject();
1007        let watcher = ctx.watch();
1008
1009        assert_eq!(canister::whoami(), users::john());
1010        assert!(watcher.called_caller);
1011
1012        ctx.update_caller(users::bob());
1013        assert_eq!(canister::whoami(), users::bob());
1014    }
1015
1016    #[test]
1017    fn test_msg_cycles() {
1018        let ctx = MockContext::new().with_msg_cycles(1000).inject();
1019        let watcher = ctx.watch();
1020
1021        assert_eq!(canister::msg_cycles_available(), 1000);
1022        assert!(watcher.called_msg_cycles_available);
1023
1024        ctx.update_msg_cycles(50);
1025        assert_eq!(canister::msg_cycles_available(), 50);
1026    }
1027
1028    #[test]
1029    fn test_msg_cycles_accept() {
1030        let ctx = MockContext::new()
1031            .with_msg_cycles(1000)
1032            .with_balance(240)
1033            .inject();
1034        let watcher = ctx.watch();
1035
1036        assert_eq!(canister::msg_cycles_accept(100), 100);
1037        assert!(watcher.called_msg_cycles_accept);
1038        assert_eq!(ctx.msg_cycles_available(), 900);
1039        assert_eq!(ctx.balance(), 340);
1040
1041        ctx.update_msg_cycles(50);
1042        assert_eq!(canister::msg_cycles_accept(100), 50);
1043        assert_eq!(ctx.msg_cycles_available(), 0);
1044        assert_eq!(ctx.balance(), 390);
1045    }
1046
1047    #[test]
1048    fn test_storage_simple() {
1049        let ctx = MockContext::new().inject();
1050        let watcher = ctx.watch();
1051        assert_eq!(watcher.is_modified::<canister::Counter>(), false);
1052        assert_eq!(canister::increment(0), 1);
1053        assert_eq!(watcher.is_modified::<canister::Counter>(), true);
1054        assert_eq!(canister::increment(0), 2);
1055        assert_eq!(canister::increment(0), 3);
1056        assert_eq!(canister::increment(1), 1);
1057        assert_eq!(canister::decrement(0), 2);
1058        assert_eq!(canister::decrement(2), -1);
1059    }
1060
1061    #[test]
1062    fn test_storage() {
1063        let ctx = MockContext::new()
1064            .with_data({
1065                let mut map = canister::Counter::default();
1066                map.insert(0, 12);
1067                map.insert(1, 17);
1068                map
1069            })
1070            .inject();
1071        assert_eq!(canister::increment(0), 13);
1072        assert_eq!(canister::decrement(1), 16);
1073
1074        let watcher = ctx.watch();
1075        assert_eq!(watcher.is_modified::<canister::Counter>(), false);
1076        ctx.store({
1077            let mut map = canister::Counter::default();
1078            map.insert(0, 12);
1079            map.insert(1, 17);
1080            map
1081        });
1082        assert_eq!(watcher.is_modified::<canister::Counter>(), true);
1083
1084        assert_eq!(canister::increment(0), 13);
1085        assert_eq!(canister::decrement(1), 16);
1086
1087        ctx.clear_storage();
1088
1089        assert_eq!(canister::increment(0), 1);
1090        assert_eq!(canister::decrement(1), -1);
1091    }
1092
1093    #[test]
1094    fn stable_storage() {
1095        let ctx = MockContext::new()
1096            .with_data({
1097                let mut map = canister::Counter::default();
1098                map.insert(0, 2);
1099                map.insert(1, 27);
1100                map.insert(2, 5);
1101                map.insert(3, 17);
1102                map
1103            })
1104            .inject();
1105
1106        let watcher = ctx.watch();
1107
1108        canister::pre_upgrade();
1109        assert!(watcher.called_stable_store);
1110        ctx.clear_storage();
1111        canister::post_upgrade();
1112        assert!(watcher.called_stable_restore);
1113
1114        let counter = ctx.get::<canister::Counter>();
1115        let data: Vec<(u64, i64)> = counter.iter().map(|(k, v)| (*k, *v)).collect();
1116        assert_eq!(data, vec![(0, 2), (1, 27), (2, 5), (3, 17)]);
1117
1118        assert_eq!(canister::increment(0), 3);
1119        assert_eq!(canister::decrement(1), 26);
1120    }
1121
1122    #[test]
1123    fn certified_data() {
1124        let ctx = MockContext::new()
1125            .with_certified_data(vec![0, 1, 2, 3, 4, 5])
1126            .inject();
1127        let watcher = ctx.watch();
1128
1129        assert_eq!(ctx.get_certified_data(), Some(vec![0, 1, 2, 3, 4, 5]));
1130        assert_eq!(
1131            ctx.data_certificate(),
1132            Some(MockContext::sign(&[0, 1, 2, 3, 4, 5]))
1133        );
1134
1135        canister::set_certified_data(&[1, 2, 3]);
1136        assert_eq!(watcher.called_set_certified_data, true);
1137        assert_eq!(ctx.get_certified_data(), Some(vec![1, 2, 3]));
1138        assert_eq!(ctx.data_certificate(), Some(MockContext::sign(&[1, 2, 3])));
1139
1140        canister::data_certificate();
1141        assert_eq!(watcher.called_data_certificate, true);
1142    }
1143
1144    #[async_std::test]
1145    async fn withdraw_accept() {
1146        let ctx = MockContext::new()
1147            .with_consume_cycles_handler(200)
1148            .with_data(1000u64)
1149            .with_balance(2000)
1150            .inject();
1151        let watcher = ctx.watch();
1152
1153        assert_eq!(canister::user_balance(), 1000);
1154
1155        assert_eq!(
1156            watcher.is_canister_called(&Principal::management_canister()),
1157            false
1158        );
1159        assert_eq!(watcher.is_method_called("deposit_cycles"), false);
1160        assert_eq!(
1161            watcher.is_called(&Principal::management_canister(), "deposit_cycles"),
1162            false
1163        );
1164        assert_eq!(watcher.cycles_consumed(), 0);
1165
1166        canister::withdraw(users::bob(), 100).await.unwrap();
1167
1168        assert_eq!(watcher.call_count(), 1);
1169        assert_eq!(
1170            watcher.is_canister_called(&Principal::management_canister()),
1171            true
1172        );
1173        assert_eq!(watcher.is_method_called("deposit_cycles"), true);
1174        assert_eq!(
1175            watcher.is_called(&Principal::management_canister(), "deposit_cycles"),
1176            true
1177        );
1178        assert_eq!(watcher.cycles_consumed(), 100);
1179
1180        // The user balance needs to be decremented.
1181        assert_eq!(canister::user_balance(), 900);
1182        // The canister balance needs to be decremented.
1183        assert_eq!(canister::balance(), 1900);
1184    }
1185
1186    #[async_std::test]
1187    async fn withdraw_accept_portion() {
1188        let ctx = MockContext::new()
1189            .with_consume_cycles_handler(50)
1190            .with_data(1000u64)
1191            .with_balance(2000)
1192            .inject();
1193        let watcher = ctx.watch();
1194
1195        assert_eq!(canister::user_balance(), 1000);
1196
1197        canister::withdraw(users::bob(), 100).await.unwrap();
1198        assert_eq!(watcher.cycles_sent(), 100);
1199        assert_eq!(watcher.cycles_consumed(), 50);
1200        assert_eq!(watcher.cycles_refunded(), 50);
1201
1202        // The user balance needs to be decremented.
1203        assert_eq!(canister::user_balance(), 950);
1204        // The canister balance needs to be decremented.
1205        assert_eq!(canister::balance(), 1950);
1206    }
1207
1208    #[async_std::test]
1209    async fn withdraw_accept_zero() {
1210        let ctx = MockContext::new()
1211            .with_consume_cycles_handler(0)
1212            .with_data(1000u64)
1213            .with_balance(2000)
1214            .inject();
1215        let watcher = ctx.watch();
1216
1217        assert_eq!(canister::user_balance(), 1000);
1218
1219        canister::withdraw(users::bob(), 100).await.unwrap();
1220        assert_eq!(watcher.cycles_sent(), 100);
1221        assert_eq!(watcher.cycles_consumed(), 0);
1222        assert_eq!(watcher.cycles_refunded(), 100);
1223
1224        // The balance should not be decremented.
1225        assert_eq!(canister::user_balance(), 1000);
1226        assert_eq!(canister::balance(), 2000);
1227    }
1228
1229    #[async_std::test]
1230    async fn with_refund() {
1231        let ctx = MockContext::new()
1232            .with_refund_cycles_handler(30)
1233            .with_data(1000u64)
1234            .with_balance(2000)
1235            .inject();
1236        let watcher = ctx.watch();
1237
1238        canister::withdraw(users::bob(), 100).await.unwrap();
1239        assert_eq!(watcher.cycles_sent(), 100);
1240        assert_eq!(watcher.cycles_consumed(), 70);
1241        assert_eq!(watcher.cycles_refunded(), 30);
1242        assert_eq!(canister::user_balance(), 930);
1243        assert_eq!(canister::balance(), 1930);
1244    }
1245
1246    #[test]
1247    #[should_panic]
1248    fn trap_should_panic() {
1249        let ctx = MockContext::new().inject();
1250        ctx.trap("Unreachable");
1251    }
1252
1253    #[test]
1254    #[should_panic]
1255    fn large_certificate_should_panic() {
1256        let ctx = MockContext::new().inject();
1257        let bytes = vec![0; 33];
1258        ctx.set_certified_data(bytes.as_slice());
1259    }
1260}