1use async_trait::async_trait;
2use candid::{decode_args, decode_one, utils::ArgumentDecoder, CandidType};
3use ic_agent::{
4 agent::{CallResponse, UpdateBuilder},
5 export::Principal,
6 Agent, AgentError,
7};
8use serde::de::DeserializeOwned;
9use std::fmt;
10use std::future::{Future, IntoFuture};
11use std::marker::PhantomData;
12use std::pin::Pin;
13
14mod expiry;
15pub use expiry::Expiry;
16
17#[cfg_attr(target_family = "wasm", async_trait(?Send))]
19#[cfg_attr(not(target_family = "wasm"), async_trait)]
20pub trait SyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
21 type Value: for<'de> ArgumentDecoder<'de> + Send;
23 #[cfg(feature = "raw")]
25 async fn call_raw(self) -> Result<Vec<u8>, AgentError>;
26
27 async fn call(self) -> Result<Self::Value, AgentError>
30 where
31 Self: Sized + Send,
32 Self::Value: 'async_trait;
33}
34
35#[cfg_attr(target_family = "wasm", async_trait(?Send))]
42#[cfg_attr(not(target_family = "wasm"), async_trait)]
43pub trait AsyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
44 type Value: for<'de> ArgumentDecoder<'de> + Send;
46 async fn call(self) -> Result<CallResponse<Self::Value>, AgentError>;
56
57 async fn call_and_wait(self) -> Result<Self::Value, AgentError>;
60
61 fn and_then<'a, Out2, R, AndThen>(
116 self,
117 and_then: AndThen,
118 ) -> AndThenAsyncCaller<'a, Self::Value, Out2, Self, R, AndThen>
119 where
120 Self: Sized + Send + 'a,
121 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
122 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
123 AndThen: Send + Fn(Self::Value) -> R + 'a,
124 {
125 AndThenAsyncCaller::new(self, and_then)
126 }
127
128 fn map<'a, Out, Map>(self, map: Map) -> MappedAsyncCaller<'a, Self::Value, Out, Self, Map>
130 where
131 Self: Sized + Send + 'a,
132 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
133 Map: Send + Fn(Self::Value) -> Out + 'a,
134 {
135 MappedAsyncCaller::new(self, map)
136 }
137}
138
139#[cfg(target_family = "wasm")]
140pub(crate) type CallFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, AgentError>> + 'a>>;
141#[cfg(not(target_family = "wasm"))]
142pub(crate) type CallFuture<'a, T> =
143 Pin<Box<dyn Future<Output = Result<T, AgentError>> + Send + 'a>>;
144#[cfg(not(target_family = "wasm"))]
145#[doc(hidden)]
146pub trait CallIntoFuture: IntoFuture<IntoFuture = <Self as CallIntoFuture>::IntoFuture> {
147 type IntoFuture: Future<Output = Self::Output> + Send;
148}
149#[cfg(not(target_family = "wasm"))]
150impl<T> CallIntoFuture for T
151where
152 T: IntoFuture + ?Sized,
153 T::IntoFuture: Send,
154{
155 type IntoFuture = T::IntoFuture;
156}
157#[cfg(target_family = "wasm")]
158use IntoFuture as CallIntoFuture;
159
160#[derive(Debug)]
162pub struct SyncCaller<'agent, Out>
163where
164 Out: for<'de> ArgumentDecoder<'de> + Send,
165{
166 pub(crate) agent: &'agent Agent,
167 pub(crate) effective_canister_id: Principal,
168 pub(crate) canister_id: Principal,
169 pub(crate) method_name: String,
170 pub(crate) arg: Result<Vec<u8>, AgentError>,
171 pub(crate) expiry: Expiry,
172 pub(crate) phantom_out: PhantomData<Out>,
173}
174
175impl<'agent, Out> SyncCaller<'agent, Out>
176where
177 Out: for<'de> ArgumentDecoder<'de> + Send,
178{
179 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
181 let mut builder = self.agent.query(&self.canister_id, &self.method_name);
182 builder = self.expiry.apply_to_query(builder);
183 builder
184 .with_arg(self.arg?)
185 .with_effective_canister_id(self.effective_canister_id)
186 .call()
187 .await
188 }
189}
190
191#[cfg_attr(target_family = "wasm", async_trait(?Send))]
192#[cfg_attr(not(target_family = "wasm"), async_trait)]
193impl<'agent, Out> SyncCall for SyncCaller<'agent, Out>
194where
195 Self: Sized,
196 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
197{
198 type Value = Out;
199 #[cfg(feature = "raw")]
200 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
201 Ok(self.call_raw().await?)
202 }
203
204 async fn call(self) -> Result<Out, AgentError> {
205 let result = self.call_raw().await?;
206
207 decode_args(&result).map_err(|e| AgentError::CandidError(Box::new(e)))
208 }
209}
210
211impl<'agent, Out> IntoFuture for SyncCaller<'agent, Out>
212where
213 Self: Sized,
214 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
215{
216 type IntoFuture = CallFuture<'agent, Out>;
217 type Output = Result<Out, AgentError>;
218 fn into_future(self) -> Self::IntoFuture {
219 SyncCall::call(self)
220 }
221}
222
223#[derive(Debug)]
225pub struct AsyncCaller<'agent, Out>
226where
227 Out: for<'de> ArgumentDecoder<'de> + Send,
228{
229 pub(crate) agent: &'agent Agent,
230 pub(crate) effective_canister_id: Principal,
231 pub(crate) canister_id: Principal,
232 pub(crate) method_name: String,
233 pub(crate) arg: Result<Vec<u8>, AgentError>,
234 pub(crate) expiry: Expiry,
235 pub(crate) phantom_out: PhantomData<Out>,
236}
237
238impl<'agent, Out> AsyncCaller<'agent, Out>
239where
240 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
241{
242 pub fn build_call(self) -> Result<UpdateBuilder<'agent>, AgentError> {
245 let mut builder = self.agent.update(&self.canister_id, &self.method_name);
246 builder = self.expiry.apply_to_update(builder);
247 builder = builder
248 .with_arg(self.arg?)
249 .with_effective_canister_id(self.effective_canister_id);
250 Ok(builder)
251 }
252
253 pub async fn call(self) -> Result<CallResponse<Out>, AgentError> {
255 let response_bytes = match self.build_call()?.call().await? {
256 CallResponse::Response((response_bytes, _)) => response_bytes,
257 CallResponse::Poll(request_id) => return Ok(CallResponse::Poll(request_id)),
258 };
259
260 let decoded_response =
261 decode_args(&response_bytes).map_err(|e| AgentError::CandidError(Box::new(e)))?;
262
263 Ok(CallResponse::Response(decoded_response))
264 }
265
266 pub async fn call_and_wait(self) -> Result<Out, AgentError> {
268 self.build_call()?
269 .call_and_wait()
270 .await
271 .and_then(|r| decode_args(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
272 }
273
274 pub async fn call_and_wait_one<T>(self) -> Result<T, AgentError>
276 where
277 T: DeserializeOwned + CandidType,
278 {
279 self.build_call()?
280 .call_and_wait()
281 .await
282 .and_then(|r| decode_one(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
283 }
284
285 pub fn map<Out2, Map>(self, map: Map) -> MappedAsyncCaller<'agent, Out, Out2, Self, Map>
287 where
288 Out2: for<'de> ArgumentDecoder<'de> + Send,
289 Map: Send + Fn(Out) -> Out2,
290 {
291 MappedAsyncCaller::new(self, map)
292 }
293}
294
295#[cfg_attr(target_family = "wasm", async_trait(?Send))]
296#[cfg_attr(not(target_family = "wasm"), async_trait)]
297impl<'agent, Out> AsyncCall for AsyncCaller<'agent, Out>
298where
299 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
300{
301 type Value = Out;
302 async fn call(self) -> Result<CallResponse<Out>, AgentError> {
303 self.call().await
304 }
305 async fn call_and_wait(self) -> Result<Out, AgentError> {
306 self.call_and_wait().await
307 }
308}
309
310impl<'agent, Out> IntoFuture for AsyncCaller<'agent, Out>
311where
312 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
313{
314 type IntoFuture = CallFuture<'agent, Out>;
315 type Output = Result<Out, AgentError>;
316 fn into_future(self) -> Self::IntoFuture {
317 AsyncCall::call_and_wait(self)
318 }
319}
320
321pub struct AndThenAsyncCaller<
325 'a,
326 Out: for<'de> ArgumentDecoder<'de> + Send,
327 Out2: for<'de> ArgumentDecoder<'de> + Send,
328 Inner: AsyncCall<Value = Out> + Send + 'a,
329 R: Future<Output = Result<Out2, AgentError>> + Send,
330 AndThen: Send + Fn(Out) -> R,
331> {
332 inner: Inner,
333 and_then: AndThen,
334 _out: PhantomData<Out>,
335 _out2: PhantomData<Out2>,
336 _lifetime: PhantomData<&'a ()>,
337}
338
339impl<'a, Out, Out2, Inner, R, AndThen> fmt::Debug
340 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
341where
342 Out: for<'de> ArgumentDecoder<'de> + Send,
343 Out2: for<'de> ArgumentDecoder<'de> + Send,
344 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
345 R: Future<Output = Result<Out2, AgentError>> + Send,
346 AndThen: Send + Fn(Out) -> R + fmt::Debug,
347{
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 f.debug_struct("AndThenAsyncCaller")
350 .field("inner", &self.inner)
351 .field("and_then", &self.and_then)
352 .field("_out", &self._out)
353 .field("_out2", &self._out2)
354 .finish()
355 }
356}
357
358impl<'a, Out, Out2, Inner, R, AndThen> AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
359where
360 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
361 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
362 Inner: AsyncCall<Value = Out> + Send + 'a,
363 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
364 AndThen: Send + Fn(Out) -> R + 'a,
365{
366 pub fn new(inner: Inner, and_then: AndThen) -> Self {
368 Self {
369 inner,
370 and_then,
371 _out: PhantomData,
372 _out2: PhantomData,
373 _lifetime: PhantomData,
374 }
375 }
376
377 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
379 let raw_response = self.inner.call().await?;
380
381 let response = match raw_response {
382 CallResponse::Response(response_bytes) => {
383 let mapped_response = (self.and_then)(response_bytes);
384 CallResponse::Response(mapped_response.await?)
385 }
386 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
387 };
388
389 Ok(response)
390 }
391 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
393 let v = self.inner.call_and_wait().await?;
394
395 let f = (self.and_then)(v);
396
397 f.await
398 }
399
400 pub fn and_then<Out3, R2, AndThen2>(
402 self,
403 and_then: AndThen2,
404 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
405 where
406 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
407 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
408 AndThen2: Send + Fn(Out2) -> R2 + 'a,
409 {
410 AndThenAsyncCaller::new(self, and_then)
411 }
412
413 pub fn map<Out3, Map>(self, map: Map) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map>
415 where
416 Out3: for<'de> ArgumentDecoder<'de> + Send,
417 Map: Send + Fn(Out2) -> Out3,
418 {
419 MappedAsyncCaller::new(self, map)
420 }
421}
422
423#[cfg_attr(target_family = "wasm", async_trait(?Send))]
424#[cfg_attr(not(target_family = "wasm"), async_trait)]
425impl<'a, Out, Out2, Inner, R, AndThen> AsyncCall
426 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
427where
428 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
429 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
430 Inner: AsyncCall<Value = Out> + Send + 'a,
431 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
432 AndThen: Send + Fn(Out) -> R + 'a,
433{
434 type Value = Out2;
435
436 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
437 self.call().await
438 }
439
440 async fn call_and_wait(self) -> Result<Out2, AgentError> {
441 self.call_and_wait().await
442 }
443}
444
445impl<'a, Out, Out2, Inner, R, AndThen> IntoFuture
446 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
447where
448 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
449 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
450 Inner: AsyncCall<Value = Out> + Send + 'a,
451 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
452 AndThen: Send + Fn(Out) -> R + 'a,
453{
454 type IntoFuture = CallFuture<'a, Out2>;
455 type Output = Result<Out2, AgentError>;
456 fn into_future(self) -> Self::IntoFuture {
457 AsyncCall::call_and_wait(self)
458 }
459}
460
461pub struct MappedAsyncCaller<
464 'a,
465 Out: for<'de> ArgumentDecoder<'de> + Send,
466 Out2: for<'de> ArgumentDecoder<'de> + Send,
467 Inner: AsyncCall<Value = Out> + Send + 'a,
468 Map: Send + Fn(Out) -> Out2,
469> {
470 inner: Inner,
471 map: Map,
472 _out: PhantomData<Out>,
473 _out2: PhantomData<Out2>,
474 _lifetime: PhantomData<&'a ()>,
475}
476
477impl<'a, Out, Out2, Inner, Map> fmt::Debug for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
478where
479 Out: for<'de> ArgumentDecoder<'de> + Send,
480 Out2: for<'de> ArgumentDecoder<'de> + Send,
481 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
482 Map: Send + Fn(Out) -> Out2 + fmt::Debug,
483{
484 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485 f.debug_struct("MappedAsyncCaller")
486 .field("inner", &self.inner)
487 .field("map", &self.map)
488 .field("_out", &self._out)
489 .field("_out2", &self._out2)
490 .finish()
491 }
492}
493
494impl<'a, Out, Out2, Inner, Map> MappedAsyncCaller<'a, Out, Out2, Inner, Map>
495where
496 Out: for<'de> ArgumentDecoder<'de> + Send,
497 Out2: for<'de> ArgumentDecoder<'de> + Send,
498 Inner: AsyncCall<Value = Out> + Send + 'a,
499 Map: Send + Fn(Out) -> Out2,
500{
501 pub fn new(inner: Inner, map: Map) -> Self {
503 Self {
504 inner,
505 map,
506 _out: PhantomData,
507 _out2: PhantomData,
508 _lifetime: PhantomData,
509 }
510 }
511
512 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
514 self.inner.call().await.map(|response| match response {
515 CallResponse::Response(response_bytes) => {
516 let mapped_response = (self.map)(response_bytes);
517 CallResponse::Response(mapped_response)
518 }
519 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
520 })
521 }
522
523 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
525 let v = self.inner.call_and_wait().await?;
526 Ok((self.map)(v))
527 }
528
529 pub fn and_then<Out3, R2, AndThen2>(
531 self,
532 and_then: AndThen2,
533 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
534 where
535 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
536 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
537 AndThen2: Send + Fn(Out2) -> R2 + 'a,
538 {
539 AndThenAsyncCaller::new(self, and_then)
540 }
541
542 pub fn map<Out3, Map2>(self, map: Map2) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map2>
544 where
545 Out3: for<'de> ArgumentDecoder<'de> + Send,
546 Map2: Send + Fn(Out2) -> Out3,
547 {
548 MappedAsyncCaller::new(self, map)
549 }
550}
551
552#[cfg_attr(target_family = "wasm", async_trait(?Send))]
553#[cfg_attr(not(target_family = "wasm"), async_trait)]
554impl<'a, Out, Out2, Inner, Map> AsyncCall for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
555where
556 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
557 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
558 Inner: AsyncCall<Value = Out> + Send + 'a,
559 Map: Send + Fn(Out) -> Out2 + 'a,
560{
561 type Value = Out2;
562
563 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
564 self.call().await
565 }
566
567 async fn call_and_wait(self) -> Result<Out2, AgentError> {
568 self.call_and_wait().await
569 }
570}
571
572impl<'a, Out, Out2, Inner, Map> IntoFuture for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
573where
574 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
575 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
576 Inner: AsyncCall<Value = Out> + Send + 'a,
577 Map: Send + Fn(Out) -> Out2 + 'a,
578{
579 type IntoFuture = CallFuture<'a, Out2>;
580 type Output = Result<Out2, AgentError>;
581
582 fn into_future(self) -> Self::IntoFuture {
583 AsyncCall::call_and_wait(self)
584 }
585}