1use std::marker::PhantomData;
2
3use candid::{decode_one, CandidType, Principal};
4use serde::Deserialize;
5use thiserror::Error;
6
7use super::provider::{Provider, RejectResponse};
8
9#[derive(Debug, Error)]
10pub enum CallError {
11 #[error("failed to candid encode call arguments: {}", .0)]
12 ArgumentEncoding(candid::error::Error),
13 #[error("canister rejected: {}, error_code: {}", .0.reject_message, .0.error_code)]
14 Reject(RejectResponse),
15 #[error("failed to candid decode call result: {}", .0)]
16 ResultDecoding(candid::error::Error),
17}
18
19pub enum CallMode {
20 Query,
21 Update,
22}
23
24pub trait Caller {
25 type Provider: Provider;
26
27 fn call<ResultType>(
28 &self,
29 canister_id: Principal,
30 call_mode: CallMode,
31 method: &str,
32 args: Result<Vec<u8>, candid::error::Error>,
33 ) -> CallBuilder<ResultType, Self::Provider>
34 where
35 ResultType: for<'a> Deserialize<'a> + CandidType;
36}
37
38pub struct CallBuilder<R: for<'a> Deserialize<'a> + CandidType, P: Provider> {
39 pub provider: P,
40 pub canister_id: Principal,
41 pub call_mode: CallMode,
42 pub method: String,
43 pub args: Result<Vec<u8>, candid::error::Error>,
44 pub _result: PhantomData<R>,
45}
46
47impl<R: for<'a> Deserialize<'a> + CandidType, P: Provider> CallBuilder<R, P> {
48 pub fn with_caller<C: Caller>(self, caller: C) -> CallBuilder<R, C::Provider> {
49 caller.call::<R>(self.canister_id, self.call_mode, &self.method, self.args)
50 }
51
52 pub fn with_update(self) -> Self {
53 Self {
54 call_mode: CallMode::Update,
55 ..self
56 }
57 }
58
59 pub async fn maybe_call(self) -> Result<R, CallError> {
60 let args = self.args.map_err(CallError::ArgumentEncoding)?;
61
62 let result = match self.call_mode {
63 CallMode::Query => {
64 self.provider
65 .query_call(self.canister_id, &self.method, args)
66 .await
67 }
68 CallMode::Update => {
69 self.provider
70 .update_call(self.canister_id, &self.method, args)
71 .await
72 }
73 };
74
75 let reply = result.map_err(CallError::Reject)?;
76
77 decode_one(&reply).map_err(CallError::ResultDecoding)
78 }
79
80 pub async fn call(self) -> R {
81 self.maybe_call().await.unwrap()
82 }
83}