1use std::cell::{Ref, RefCell};
4use std::collections::hash_map::Entry;
5use std::collections::HashMap;
6
7use ic_cdk::api::call::CallResult;
8use ic_cdk::export::candid::utils::{ArgumentDecoder, ArgumentEncoder};
9use ic_cdk::export::candid::{decode_args, encode_args};
10
11use crate::candid::CandidType;
12use crate::{Context, MockContext, Principal};
13
14pub trait CallHandler {
16 fn accept(&self, canister_id: &Principal, method: &str) -> bool;
19
20 fn perform(
22 &self,
23 caller: &Principal,
24 cycles: u64,
25 canister_id: &Principal,
26 method: &str,
27 args_raw: &Vec<u8>,
28 ctx: Option<&mut MockContext>,
29 ) -> (CallResult<Vec<u8>>, u64);
30}
31
32pub struct Method {
34 name: Option<String>,
36 atoms: Vec<MethodAtom>,
38 expected_args: Option<Vec<u8>>,
40 expected_cycles: Option<u64>,
42 response: Option<Vec<u8>>,
44}
45
46enum MethodAtom {
47 ConsumeAllCycles,
48 ConsumeCycles(u64),
49 RefundCycles(u64),
50}
51
52pub struct RawHandler {
54 handler: Box<dyn Fn(&mut MockContext, &Vec<u8>, &Principal, &str) -> CallResult<Vec<u8>>>,
55}
56
57pub struct Canister {
59 id: Principal,
62 methods: HashMap<String, Box<dyn CallHandler>>,
64 default: Option<Box<dyn CallHandler>>,
66 context: RefCell<MockContext>,
68}
69
70impl Method {
71 #[inline]
73 pub const fn new() -> Self {
74 Method {
75 name: None,
76 atoms: Vec::new(),
77 expected_args: None,
78 expected_cycles: None,
79 response: None,
80 }
81 }
82
83 #[inline]
89 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
90 if self.name.is_some() {
91 panic!("Method already has a name.");
92 }
93
94 self.name = Some(name.into());
95 self
96 }
97
98 #[inline]
100 pub fn cycles_consume_all(mut self) -> Self {
101 self.atoms.push(MethodAtom::ConsumeAllCycles);
102 self
103 }
104
105 #[inline]
107 pub fn cycles_consume(mut self, cycles: u64) -> Self {
108 self.atoms.push(MethodAtom::ConsumeCycles(cycles));
109 self
110 }
111
112 #[inline]
114 pub fn cycles_refund(mut self, cycles: u64) -> Self {
115 self.atoms.push(MethodAtom::RefundCycles(cycles));
116 self
117 }
118
119 #[inline]
125 pub fn expect_arguments<T: ArgumentEncoder>(mut self, arguments: T) -> Self {
126 if self.expected_args.is_some() {
127 panic!("expect_arguments can only be called once on a method.");
128 }
129 self.expected_args = Some(encode_args(arguments).expect("Cannot encode arguments."));
130 self
131 }
132
133 pub fn expect_cycles(mut self, cycles: u64) -> Self {
138 if self.expected_cycles.is_some() {
139 panic!("expect_cycles can only be called once on a method.");
140 }
141 self.expected_cycles = Some(cycles);
142 self
143 }
144
145 #[inline]
150 pub fn response<T: CandidType>(mut self, value: T) -> Self {
151 if self.response.is_some() {
152 panic!("response can only be called once on a method.");
153 }
154 self.response = Some(encode_args((value,)).expect("Failed to encode response."));
155 self
156 }
157}
158
159impl Canister {
160 #[inline]
163 pub fn new(id: Principal) -> Self {
164 let context = MockContext::new().with_id(id);
165
166 Canister {
167 id,
168 methods: HashMap::new(),
169 default: None,
170 context: RefCell::new(context),
171 }
172 }
173
174 #[inline]
176 pub fn context(&self) -> Ref<'_, MockContext> {
177 self.context.borrow()
178 }
179
180 #[inline]
182 pub fn with_balance(self, cycles: u64) -> Self {
183 self.context.borrow_mut().update_balance(cycles);
184 self
185 }
186
187 #[inline]
192 pub fn method<S: Into<String> + Copy>(
193 mut self,
194 name: S,
195 handler: Box<dyn CallHandler>,
196 ) -> Self {
197 if let Entry::Vacant(o) = self.methods.entry(name.into()) {
198 o.insert(handler);
199 self
200 } else {
201 panic!(
202 "Method {} already exists on canister {}",
203 name.into(),
204 &self.id
205 );
206 }
207 }
208
209 #[inline]
214 pub fn or(mut self, handler: Box<dyn CallHandler>) -> Self {
215 if self.default.is_some() {
216 panic!("Default handler is already set for canister {}", self.id);
217 }
218 self.default = Some(handler);
219 self
220 }
221}
222
223impl RawHandler {
224 #[inline]
226 pub fn raw(
227 handler: Box<dyn Fn(&mut MockContext, &Vec<u8>, &Principal, &str) -> CallResult<Vec<u8>>>,
228 ) -> Self {
229 Self { handler }
230 }
231
232 #[inline]
234 pub fn new<
235 T: for<'de> ArgumentDecoder<'de>,
236 R: ArgumentEncoder,
237 F: 'static + Fn(&mut MockContext, T, &Principal, &str) -> CallResult<R>,
238 >(
239 handler: F,
240 ) -> Self {
241 Self {
242 handler: Box::new(move |ctx, bytes, canister_id, method_name| {
243 let args = decode_args(bytes).expect("Failed to decode arguments.");
244 handler(ctx, args, canister_id, method_name)
245 .map(|r| encode_args(r).expect("Failed to encode response."))
246 }),
247 }
248 }
249}
250
251impl CallHandler for Method {
252 #[inline]
253 fn accept(&self, _: &Principal, method: &str) -> bool {
254 if let Some(name) = &self.name {
255 name == method
256 } else {
257 true
258 }
259 }
260
261 #[inline]
262 fn perform(
263 &self,
264 _caller: &Principal,
265 cycles: u64,
266 _canister_id: &Principal,
267 _method: &str,
268 args_raw: &Vec<u8>,
269 ctx: Option<&mut MockContext>,
270 ) -> (CallResult<Vec<u8>>, u64) {
271 let mut default_ctx = MockContext::new().with_msg_cycles(cycles);
272 let ctx = ctx.unwrap_or(&mut default_ctx);
273
274 if let Some(expected_cycles) = &self.expected_cycles {
275 assert_eq!(*expected_cycles, ctx.msg_cycles_available());
276 }
277
278 if let Some(expected_args) = &self.expected_args {
279 assert_eq!(expected_args, args_raw);
280 }
281
282 for atom in &self.atoms {
283 match *atom {
284 MethodAtom::ConsumeAllCycles => {
285 ctx.msg_cycles_accept(u64::MAX);
286 }
287 MethodAtom::ConsumeCycles(cycles) => {
288 ctx.msg_cycles_accept(cycles);
289 }
290 MethodAtom::RefundCycles(amount) => {
291 let cycles = ctx.msg_cycles_available();
292 if amount > cycles {
293 panic!(
294 "Can not refund {} cycles when only {} cycles is available.",
295 amount, cycles
296 );
297 } else {
298 ctx.msg_cycles_accept(cycles - amount);
299 }
300 }
301 }
302 }
303
304 let refund = ctx.msg_cycles_available();
305
306 if let Some(v) = &self.response {
307 (Ok(v.clone()), refund)
308 } else {
309 (Ok(encode_args(()).unwrap()), refund)
310 }
311 }
312}
313
314impl CallHandler for RawHandler {
315 #[inline]
316 fn accept(&self, _: &Principal, _: &str) -> bool {
317 true
318 }
319
320 #[inline]
321 fn perform(
322 &self,
323 caller: &Principal,
324 cycles: u64,
325 canister_id: &Principal,
326 method: &str,
327 args_raw: &Vec<u8>,
328 ctx: Option<&mut MockContext>,
329 ) -> (CallResult<Vec<u8>>, u64) {
330 let mut default_ctx = MockContext::new()
331 .with_caller(*caller)
332 .with_msg_cycles(cycles)
333 .with_id(*canister_id);
334 let ctx = ctx.unwrap_or(&mut default_ctx);
335
336 let handler = &self.handler;
337 let res = handler(ctx, args_raw, canister_id, method);
338
339 (res, ctx.msg_cycles_available())
340 }
341}
342
343impl CallHandler for Canister {
344 #[inline]
345 fn accept(&self, canister_id: &Principal, method: &str) -> bool {
346 &self.id == canister_id
347 && (self.default.is_some() || {
348 let maybe_handler = self.methods.get(method);
349 if let Some(handler) = maybe_handler {
350 handler.accept(canister_id, method)
351 } else {
352 false
353 }
354 })
355 }
356
357 #[inline]
358 fn perform(
359 &self,
360 caller: &Principal,
361 cycles: u64,
362 canister_id: &Principal,
363 method: &str,
364 args_raw: &Vec<u8>,
365 ctx: Option<&mut MockContext>,
366 ) -> (CallResult<Vec<u8>>, u64) {
367 assert!(ctx.is_none());
368
369 let mut ctx = self.context.borrow_mut();
370 ctx.update_caller(*caller);
371 ctx.update_msg_cycles(cycles);
372
373 let res = if let Some(handler) = self.methods.get(method) {
374 handler.perform(
375 caller,
376 cycles,
377 canister_id,
378 method,
379 args_raw,
380 Some(&mut ctx),
381 )
382 } else {
383 let handler = self.default.as_ref().unwrap();
384 handler.perform(
385 caller,
386 cycles,
387 canister_id,
388 method,
389 args_raw,
390 Some(&mut ctx),
391 )
392 };
393
394 assert_eq!(res.1, ctx.msg_cycles_available());
395 ctx.update_msg_cycles(0);
396 res
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 #[should_panic]
406 fn method_repetitive_call_to_name() {
407 Method::new().name("A").name("B");
408 }
409
410 #[test]
411 fn method_name() {
412 let nameless = Method::new();
413 assert_eq!(
414 nameless.accept(&Principal::management_canister(), "XXX"),
415 true
416 );
417 let named = Method::new().name("deposit");
418 assert_eq!(
419 named.accept(&Principal::management_canister(), "XXX"),
420 false
421 );
422 assert_eq!(
423 named.accept(&Principal::management_canister(), "deposit"),
424 true
425 );
426 }
427
428 #[test]
429 fn cycles_consume_all() {
430 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
431
432 let method = Method::new();
433 let (_, refunded) = method.perform(
434 &alice,
435 2000,
436 &Principal::management_canister(),
437 "deposit",
438 &vec![],
439 None,
440 );
441 assert_eq!(refunded, 2000);
442
443 let method = Method::new().cycles_consume_all();
444 let (_, refunded) = method.perform(
445 &alice,
446 2000,
447 &Principal::management_canister(),
448 "deposit",
449 &vec![],
450 None,
451 );
452 assert_eq!(refunded, 0);
453 }
454
455 #[test]
456 fn cycles_consume() {
457 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
458 let method = Method::new().cycles_consume(100);
459 let (_, refunded) = method.perform(
460 &alice,
461 2000,
462 &Principal::management_canister(),
463 "deposit",
464 &vec![],
465 None,
466 );
467 assert_eq!(refunded, 1900);
468
469 let method = Method::new().cycles_consume(100).cycles_consume(150);
470 let (_, refunded) = method.perform(
471 &alice,
472 2000,
473 &Principal::management_canister(),
474 "deposit",
475 &vec![],
476 None,
477 );
478 assert_eq!(refunded, 1750);
479 }
480
481 #[test]
482 #[should_panic]
483 fn cycles_refund_panic() {
484 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
485 let method = Method::new().cycles_refund(3000);
486 method
487 .perform(
488 &alice,
489 2000,
490 &Principal::management_canister(),
491 "deposit",
492 &vec![],
493 None,
494 )
495 .0
496 .unwrap();
497 }
498
499 #[test]
500 fn cycles_refund() {
501 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
502 let method = Method::new().cycles_refund(100);
503 let (_, refunded) = method.perform(
504 &alice,
505 2000,
506 &Principal::management_canister(),
507 "deposit",
508 &vec![],
509 None,
510 );
511 assert_eq!(refunded, 100);
512
513 let method = Method::new().cycles_refund(170).cycles_consume(50);
514 let (_, refunded) = method.perform(
515 &alice,
516 2000,
517 &Principal::management_canister(),
518 "deposit",
519 &vec![],
520 None,
521 );
522 assert_eq!(refunded, 120);
523 }
524
525 #[test]
526 #[should_panic]
527 fn method_repetitive_call_to_expect_arguments() {
528 Method::new()
529 .expect_arguments((12,))
530 .expect_arguments((14,));
531 }
532
533 #[test]
534 #[should_panic]
535 fn expect_arguments_panic() {
536 let method = Method::new().expect_arguments((15u64,));
537 let bytes = encode_args((17u64,)).unwrap();
538 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
539 method
540 .perform(
541 &alice,
542 2000,
543 &Principal::management_canister(),
544 "deposit",
545 &bytes,
546 None,
547 )
548 .0
549 .unwrap();
550 }
551
552 #[test]
553 fn expect_arguments() {
554 let method = Method::new().expect_arguments((17u64,));
555 let bytes = encode_args((17u64,)).unwrap();
556 let alice = Principal::from_text("ai7t5-aibaq-aaaaa-aaaaa-c").unwrap();
557 method
558 .perform(
559 &alice,
560 2000,
561 &Principal::management_canister(),
562 "deposit",
563 &bytes,
564 None,
565 )
566 .0
567 .unwrap();
568 }
569}