ic_kit/
handler.rs

1//! Create mock handlers for simulating inter-canister calls.
2
3use std::cell::{Ref, RefCell};
4use std::collections::hash_map::Entry;
5use std::collections::HashMap;
6
7use ic_cdk::api::call::CallResult;
8use ic_cdk::export::candid::utils::{ArgumentDecoder, ArgumentEncoder};
9use ic_cdk::export::candid::{decode_args, encode_args};
10
11use crate::candid::CandidType;
12use crate::{Context, MockContext, Principal};
13
14/// Anything that could be used to simulate a inter-canister call.
15pub trait CallHandler {
16    /// Whatever the handler can handle the given call or not, if this method returns false, we
17    /// skip this handler and try to find the next handler that can handle the call.
18    fn accept(&self, canister_id: &Principal, method: &str) -> bool;
19
20    /// Perform the call using this handler. Only called if `accept()` first returned true.
21    fn perform(
22        &self,
23        caller: &Principal,
24        cycles: u64,
25        canister_id: &Principal,
26        method: &str,
27        args_raw: &Vec<u8>,
28        ctx: Option<&mut MockContext>,
29    ) -> (CallResult<Vec<u8>>, u64);
30}
31
32/// A method that is constructed using nested calls.
33pub struct Method {
34    /// An optional name for the method.
35    name: Option<String>,
36    /// The sub-commands that should be executed by the method.
37    atoms: Vec<MethodAtom>,
38    /// If set we assert that the arguments passed to the method are this value.
39    expected_args: Option<Vec<u8>>,
40    /// If set we assert the number of cycles sent to the canister.
41    expected_cycles: Option<u64>,
42    /// The response that we send back from the caller. By default `()` is returned.
43    response: Option<Vec<u8>>,
44}
45
46enum MethodAtom {
47    ConsumeAllCycles,
48    ConsumeCycles(u64),
49    RefundCycles(u64),
50}
51
52/// A method which uses Rust closures to handle the calls, it accepts every call.
53pub struct RawHandler {
54    handler: Box<dyn Fn(&mut MockContext, &Vec<u8>, &Principal, &str) -> CallResult<Vec<u8>>>,
55}
56
57/// Can be used to represent a canister and different method on the canister.
58pub struct Canister {
59    /// ID of the canister, makes the CallHandler skip the call to this canister if it's trying
60    /// to make a call to a canister with different id.
61    id: Principal,
62    /// Implementation of the methods on this canister.
63    methods: HashMap<String, Box<dyn CallHandler>>,
64    /// The default callback which can be called if the method was not found on this canister.
65    default: Option<Box<dyn CallHandler>>,
66    /// The context used in this canister.
67    context: RefCell<MockContext>,
68}
69
70impl Method {
71    /// Create a new method.
72    #[inline]
73    pub const fn new() -> Self {
74        Method {
75            name: None,
76            atoms: Vec::new(),
77            expected_args: None,
78            expected_cycles: None,
79            response: None,
80        }
81    }
82
83    /// Put a name for the method. Setting a name on the method makes the CallHandler for this
84    /// method skip this method if it's trying to make a call to a method with a different name.
85    ///
86    /// # Panics
87    /// If the method already has a name.
88    #[inline]
89    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
90        if self.name.is_some() {
91            panic!("Method already has a name.");
92        }
93
94        self.name = Some(name.into());
95        self
96    }
97
98    /// Make the method consume all of the cycles provided to it.
99    #[inline]
100    pub fn cycles_consume_all(mut self) -> Self {
101        self.atoms.push(MethodAtom::ConsumeAllCycles);
102        self
103    }
104
105    /// Make the method consume at most the given amount of cycles.
106    #[inline]
107    pub fn cycles_consume(mut self, cycles: u64) -> Self {
108        self.atoms.push(MethodAtom::ConsumeCycles(cycles));
109        self
110    }
111
112    /// Make the method refund the given amount of cycles.
113    #[inline]
114    pub fn cycles_refund(mut self, cycles: u64) -> Self {
115        self.atoms.push(MethodAtom::RefundCycles(cycles));
116        self
117    }
118
119    /// Make the method expect the given value as the argument, this method makes the method
120    /// panic if it's called with an argument other than what is provided.
121    ///
122    /// # Panics
123    /// If called more than once.
124    #[inline]
125    pub fn expect_arguments<T: ArgumentEncoder>(mut self, arguments: T) -> Self {
126        if self.expected_args.is_some() {
127            panic!("expect_arguments can only be called once on a method.");
128        }
129        self.expected_args = Some(encode_args(arguments).expect("Cannot encode arguments."));
130        self
131    }
132
133    /// Create a method that expects this amount of cycles to be sent to it.
134    ///
135    /// # Panics
136    /// If called more than once on a method.
137    pub fn expect_cycles(mut self, cycles: u64) -> Self {
138        if self.expected_cycles.is_some() {
139            panic!("expect_cycles can only be called once on a method.");
140        }
141        self.expected_cycles = Some(cycles);
142        self
143    }
144
145    /// Make the method return the given constant value every time.
146    ///
147    /// # Panics
148    /// If called more than once.
149    #[inline]
150    pub fn response<T: CandidType>(mut self, value: T) -> Self {
151        if self.response.is_some() {
152            panic!("response can only be called once on a method.");
153        }
154        self.response = Some(encode_args((value,)).expect("Failed to encode response."));
155        self
156    }
157}
158
159impl Canister {
160    /// Create a new canister with the given principal id, this handler rejects any call to a
161    /// different canister id.
162    #[inline]
163    pub fn new(id: Principal) -> Self {
164        let context = MockContext::new().with_id(id);
165
166        Canister {
167            id,
168            methods: HashMap::new(),
169            default: None,
170            context: RefCell::new(context),
171        }
172    }
173
174    /// Return a reference to the context associated with this canister.
175    #[inline]
176    pub fn context(&self) -> Ref<'_, MockContext> {
177        self.context.borrow()
178    }
179
180    /// Update the balance of this canister.
181    #[inline]
182    pub fn with_balance(self, cycles: u64) -> Self {
183        self.context.borrow_mut().update_balance(cycles);
184        self
185    }
186
187    /// Add the given method to the canister.
188    ///
189    /// # Panics
190    /// If a method with the same name is already defined on the canister.
191    #[inline]
192    pub fn method<S: Into<String> + Copy>(
193        mut self,
194        name: S,
195        handler: Box<dyn CallHandler>,
196    ) -> Self {
197        if let Entry::Vacant(o) = self.methods.entry(name.into()) {
198            o.insert(handler);
199            self
200        } else {
201            panic!(
202                "Method {} already exists on canister {}",
203                name.into(),
204                &self.id
205            );
206        }
207    }
208
209    /// Add a default handler to the canister.
210    ///
211    /// # Panics
212    /// If a default handler is already set.
213    #[inline]
214    pub fn or(mut self, handler: Box<dyn CallHandler>) -> Self {
215        if self.default.is_some() {
216            panic!("Default handler is already set for canister {}", self.id);
217        }
218        self.default = Some(handler);
219        self
220    }
221}
222
223impl RawHandler {
224    /// Create a raw handler.
225    #[inline]
226    pub fn raw(
227        handler: Box<dyn Fn(&mut MockContext, &Vec<u8>, &Principal, &str) -> CallResult<Vec<u8>>>,
228    ) -> Self {
229        Self { handler }
230    }
231
232    /// Create a new handler.
233    #[inline]
234    pub fn new<
235        T: for<'de> ArgumentDecoder<'de>,
236        R: ArgumentEncoder,
237        F: 'static + Fn(&mut MockContext, T, &Principal, &str) -> CallResult<R>,
238    >(
239        handler: F,
240    ) -> Self {
241        Self {
242            handler: Box::new(move |ctx, bytes, canister_id, method_name| {
243                let args = decode_args(bytes).expect("Failed to decode arguments.");
244                handler(ctx, args, canister_id, method_name)
245                    .map(|r| encode_args(r).expect("Failed to encode response."))
246            }),
247        }
248    }
249}
250
251impl CallHandler for Method {
252    #[inline]
253    fn accept(&self, _: &Principal, method: &str) -> bool {
254        if let Some(name) = &self.name {
255            name == method
256        } else {
257            true
258        }
259    }
260
261    #[inline]
262    fn perform(
263        &self,
264        _caller: &Principal,
265        cycles: u64,
266        _canister_id: &Principal,
267        _method: &str,
268        args_raw: &Vec<u8>,
269        ctx: Option<&mut MockContext>,
270    ) -> (CallResult<Vec<u8>>, u64) {
271        let mut default_ctx = MockContext::new().with_msg_cycles(cycles);
272        let ctx = ctx.unwrap_or(&mut default_ctx);
273
274        if let Some(expected_cycles) = &self.expected_cycles {
275            assert_eq!(*expected_cycles, ctx.msg_cycles_available());
276        }
277
278        if let Some(expected_args) = &self.expected_args {
279            assert_eq!(expected_args, args_raw);
280        }
281
282        for atom in &self.atoms {
283            match *atom {
284                MethodAtom::ConsumeAllCycles => {
285                    ctx.msg_cycles_accept(u64::MAX);
286                }
287                MethodAtom::ConsumeCycles(cycles) => {
288                    ctx.msg_cycles_accept(cycles);
289                }
290                MethodAtom::RefundCycles(amount) => {
291                    let cycles = ctx.msg_cycles_available();
292                    if amount > cycles {
293                        panic!(
294                            "Can not refund {} cycles when only {} cycles is available.",
295                            amount, cycles
296                        );
297                    } else {
298                        ctx.msg_cycles_accept(cycles - amount);
299                    }
300                }
301            }
302        }
303
304        let refund = ctx.msg_cycles_available();
305
306        if let Some(v) = &self.response {
307            (Ok(v.clone()), refund)
308        } else {
309            (Ok(encode_args(()).unwrap()), refund)
310        }
311    }
312}
313
314impl CallHandler for RawHandler {
315    #[inline]
316    fn accept(&self, _: &Principal, _: &str) -> bool {
317        true
318    }
319
320    #[inline]
321    fn perform(
322        &self,
323        caller: &Principal,
324        cycles: u64,
325        canister_id: &Principal,
326        method: &str,
327        args_raw: &Vec<u8>,
328        ctx: Option<&mut MockContext>,
329    ) -> (CallResult<Vec<u8>>, u64) {
330        let mut default_ctx = MockContext::new()
331            .with_caller(*caller)
332            .with_msg_cycles(cycles)
333            .with_id(*canister_id);
334        let ctx = ctx.unwrap_or(&mut default_ctx);
335
336        let handler = &self.handler;
337        let res = handler(ctx, args_raw, canister_id, method);
338
339        (res, ctx.msg_cycles_available())
340    }
341}
342
343impl CallHandler for Canister {
344    #[inline]
345    fn accept(&self, canister_id: &Principal, method: &str) -> bool {
346        &self.id == canister_id
347            && (self.default.is_some() || {
348                let maybe_handler = self.methods.get(method);
349                if let Some(handler) = maybe_handler {
350                    handler.accept(canister_id, method)
351                } else {
352                    false
353                }
354            })
355    }
356
357    #[inline]
358    fn perform(
359        &self,
360        caller: &Principal,
361        cycles: u64,
362        canister_id: &Principal,
363        method: &str,
364        args_raw: &Vec<u8>,
365        ctx: Option<&mut MockContext>,
366    ) -> (CallResult<Vec<u8>>, u64) {
367        assert!(ctx.is_none());
368
369        let mut ctx = self.context.borrow_mut();
370        ctx.update_caller(*caller);
371        ctx.update_msg_cycles(cycles);
372
373        let res = if let Some(handler) = self.methods.get(method) {
374            handler.perform(
375                caller,
376                cycles,
377                canister_id,
378                method,
379                args_raw,
380                Some(&mut ctx),
381            )
382        } else {
383            let handler = self.default.as_ref().unwrap();
384            handler.perform(
385                caller,
386                cycles,
387                canister_id,
388                method,
389                args_raw,
390                Some(&mut ctx),
391            )
392        };
393
394        assert_eq!(res.1, ctx.msg_cycles_available());
395        ctx.update_msg_cycles(0);
396        res
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    #[should_panic]
406    fn method_repetitive_call_to_name() {
407        Method::new().name("A").name("B");
408    }
409
410    #[test]
411    fn method_name() {
412        let nameless = Method::new();
413        assert_eq!(
414            nameless.accept(&Principal::management_canister(), "XXX"),
415            true
416        );
417        let named = Method::new().name("deposit");
418        assert_eq!(
419            named.accept(&Principal::management_canister(), "XXX"),
420            false
421        );
422        assert_eq!(
423            named.accept(&Principal::management_canister(), "deposit"),
424            true
425        );
426    }
427
428    #[test]
429    fn cycles_consume_all() {
430        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
431
432        let method = Method::new();
433        let (_, refunded) = method.perform(
434            &alice,
435            2000,
436            &Principal::management_canister(),
437            "deposit",
438            &vec![],
439            None,
440        );
441        assert_eq!(refunded, 2000);
442
443        let method = Method::new().cycles_consume_all();
444        let (_, refunded) = method.perform(
445            &alice,
446            2000,
447            &Principal::management_canister(),
448            "deposit",
449            &vec![],
450            None,
451        );
452        assert_eq!(refunded, 0);
453    }
454
455    #[test]
456    fn cycles_consume() {
457        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
458        let method = Method::new().cycles_consume(100);
459        let (_, refunded) = method.perform(
460            &alice,
461            2000,
462            &Principal::management_canister(),
463            "deposit",
464            &vec![],
465            None,
466        );
467        assert_eq!(refunded, 1900);
468
469        let method = Method::new().cycles_consume(100).cycles_consume(150);
470        let (_, refunded) = method.perform(
471            &alice,
472            2000,
473            &Principal::management_canister(),
474            "deposit",
475            &vec![],
476            None,
477        );
478        assert_eq!(refunded, 1750);
479    }
480
481    #[test]
482    #[should_panic]
483    fn cycles_refund_panic() {
484        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
485        let method = Method::new().cycles_refund(3000);
486        method
487            .perform(
488                &alice,
489                2000,
490                &Principal::management_canister(),
491                "deposit",
492                &vec![],
493                None,
494            )
495            .0
496            .unwrap();
497    }
498
499    #[test]
500    fn cycles_refund() {
501        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
502        let method = Method::new().cycles_refund(100);
503        let (_, refunded) = method.perform(
504            &alice,
505            2000,
506            &Principal::management_canister(),
507            "deposit",
508            &vec![],
509            None,
510        );
511        assert_eq!(refunded, 100);
512
513        let method = Method::new().cycles_refund(170).cycles_consume(50);
514        let (_, refunded) = method.perform(
515            &alice,
516            2000,
517            &Principal::management_canister(),
518            "deposit",
519            &vec![],
520            None,
521        );
522        assert_eq!(refunded, 120);
523    }
524
525    #[test]
526    #[should_panic]
527    fn method_repetitive_call_to_expect_arguments() {
528        Method::new()
529            .expect_arguments((12,))
530            .expect_arguments((14,));
531    }
532
533    #[test]
534    #[should_panic]
535    fn expect_arguments_panic() {
536        let method = Method::new().expect_arguments((15u64,));
537        let bytes = encode_args((17u64,)).unwrap();
538        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
539        method
540            .perform(
541                &alice,
542                2000,
543                &Principal::management_canister(),
544                "deposit",
545                &bytes,
546                None,
547            )
548            .0
549            .unwrap();
550    }
551
552    #[test]
553    fn expect_arguments() {
554        let method = Method::new().expect_arguments((17u64,));
555        let bytes = encode_args((17u64,)).unwrap();
556        let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
557        method
558            .perform(
559                &alice,
560                2000,
561                &Principal::management_canister(),
562                "deposit",
563                &bytes,
564                None,
565            )
566            .0
567            .unwrap();
568    }
569}