1use async_trait::async_trait;
2use candid::{decode_args, decode_one, utils::ArgumentDecoder, CandidType};
3use ic_agent::{
4 agent::{CallResponse, UpdateBuilder},
5 export::Principal,
6 Agent, AgentError,
7};
8use serde::de::DeserializeOwned;
9use std::fmt;
10use std::future::{Future, IntoFuture};
11use std::marker::PhantomData;
12use std::pin::Pin;
13
14mod expiry;
15pub use expiry::Expiry;
16
17#[cfg_attr(target_family = "wasm", async_trait(?Send))]
19#[cfg_attr(not(target_family = "wasm"), async_trait)]
20pub trait SyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
21 type Value: for<'de> ArgumentDecoder<'de> + Send;
23 #[cfg(feature = "raw")]
25 async fn call_raw(self) -> Result<Vec<u8>, AgentError>;
26
27 async fn call(self) -> Result<Self::Value, AgentError>
30 where
31 Self: Sized + Send,
32 Self::Value: 'async_trait;
33}
34
35#[cfg_attr(target_family = "wasm", async_trait(?Send))]
42#[cfg_attr(not(target_family = "wasm"), async_trait)]
43pub trait AsyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
44 type Value: for<'de> ArgumentDecoder<'de> + Send;
46 async fn call(self) -> Result<CallResponse<Self::Value>, AgentError>;
56
57 async fn call_and_wait(self) -> Result<Self::Value, AgentError>;
60
61 fn and_then<'a, Out2, R, AndThen>(
118 self,
119 and_then: AndThen,
120 ) -> AndThenAsyncCaller<'a, Self::Value, Out2, Self, R, AndThen>
121 where
122 Self: Sized + Send + 'a,
123 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
124 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
125 AndThen: Send + Fn(Self::Value) -> R + 'a,
126 {
127 AndThenAsyncCaller::new(self, and_then)
128 }
129
130 fn map<'a, Out, Map>(self, map: Map) -> MappedAsyncCaller<'a, Self::Value, Out, Self, Map>
132 where
133 Self: Sized + Send + 'a,
134 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
135 Map: Send + Fn(Self::Value) -> Out + 'a,
136 {
137 MappedAsyncCaller::new(self, map)
138 }
139}
140
141#[cfg(target_family = "wasm")]
142pub(crate) type CallFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, AgentError>> + 'a>>;
143#[cfg(not(target_family = "wasm"))]
144pub(crate) type CallFuture<'a, T> =
145 Pin<Box<dyn Future<Output = Result<T, AgentError>> + Send + 'a>>;
146#[cfg(not(target_family = "wasm"))]
147#[doc(hidden)]
148pub trait CallIntoFuture: IntoFuture<IntoFuture = <Self as CallIntoFuture>::IntoFuture> {
149 type IntoFuture: Future<Output = Self::Output> + Send;
150}
151#[cfg(not(target_family = "wasm"))]
152impl<T> CallIntoFuture for T
153where
154 T: IntoFuture + ?Sized,
155 T::IntoFuture: Send,
156{
157 type IntoFuture = T::IntoFuture;
158}
159#[cfg(target_family = "wasm")]
160use IntoFuture as CallIntoFuture;
161
162#[derive(Debug)]
164pub struct SyncCaller<'agent, Out>
165where
166 Out: for<'de> ArgumentDecoder<'de> + Send,
167{
168 pub(crate) agent: &'agent Agent,
169 pub(crate) effective_canister_id: Principal,
170 pub(crate) canister_id: Principal,
171 pub(crate) method_name: String,
172 pub(crate) arg: Result<Vec<u8>, AgentError>,
173 pub(crate) expiry: Expiry,
174 pub(crate) phantom_out: PhantomData<Out>,
175}
176
177impl<'agent, Out> SyncCaller<'agent, Out>
178where
179 Out: for<'de> ArgumentDecoder<'de> + Send,
180{
181 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
183 let mut builder = self.agent.query(&self.canister_id, &self.method_name);
184 builder = self.expiry.apply_to_query(builder);
185 builder
186 .with_arg(self.arg?)
187 .with_effective_canister_id(self.effective_canister_id)
188 .call()
189 .await
190 }
191}
192
193#[cfg_attr(target_family = "wasm", async_trait(?Send))]
194#[cfg_attr(not(target_family = "wasm"), async_trait)]
195impl<'agent, Out> SyncCall for SyncCaller<'agent, Out>
196where
197 Self: Sized,
198 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
199{
200 type Value = Out;
201 #[cfg(feature = "raw")]
202 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
203 Ok(self.call_raw().await?)
204 }
205
206 async fn call(self) -> Result<Out, AgentError> {
207 let result = self.call_raw().await?;
208
209 decode_args(&result).map_err(|e| AgentError::CandidError(Box::new(e)))
210 }
211}
212
213impl<'agent, Out> IntoFuture for SyncCaller<'agent, Out>
214where
215 Self: Sized,
216 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
217{
218 type IntoFuture = CallFuture<'agent, Out>;
219 type Output = Result<Out, AgentError>;
220 fn into_future(self) -> Self::IntoFuture {
221 SyncCall::call(self)
222 }
223}
224
225#[derive(Debug)]
227pub struct AsyncCaller<'agent, Out>
228where
229 Out: for<'de> ArgumentDecoder<'de> + Send,
230{
231 pub(crate) agent: &'agent Agent,
232 pub(crate) effective_canister_id: Principal,
233 pub(crate) canister_id: Principal,
234 pub(crate) method_name: String,
235 pub(crate) arg: Result<Vec<u8>, AgentError>,
236 pub(crate) expiry: Expiry,
237 pub(crate) phantom_out: PhantomData<Out>,
238}
239
240impl<'agent, Out> AsyncCaller<'agent, Out>
241where
242 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
243{
244 pub fn build_call(self) -> Result<UpdateBuilder<'agent>, AgentError> {
247 let mut builder = self.agent.update(&self.canister_id, &self.method_name);
248 builder = self.expiry.apply_to_update(builder);
249 builder = builder
250 .with_arg(self.arg?)
251 .with_effective_canister_id(self.effective_canister_id);
252 Ok(builder)
253 }
254
255 pub async fn call(self) -> Result<CallResponse<Out>, AgentError> {
257 let response_bytes = match self.build_call()?.call().await? {
258 CallResponse::Response((response_bytes, _)) => response_bytes,
259 CallResponse::Poll(request_id) => return Ok(CallResponse::Poll(request_id)),
260 };
261
262 let decoded_response =
263 decode_args(&response_bytes).map_err(|e| AgentError::CandidError(Box::new(e)))?;
264
265 Ok(CallResponse::Response(decoded_response))
266 }
267
268 pub async fn call_and_wait(self) -> Result<Out, AgentError> {
270 self.build_call()?
271 .call_and_wait()
272 .await
273 .and_then(|r| decode_args(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
274 }
275
276 pub async fn call_and_wait_one<T>(self) -> Result<T, AgentError>
278 where
279 T: DeserializeOwned + CandidType,
280 {
281 self.build_call()?
282 .call_and_wait()
283 .await
284 .and_then(|r| decode_one(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
285 }
286
287 pub fn map<Out2, Map>(self, map: Map) -> MappedAsyncCaller<'agent, Out, Out2, Self, Map>
289 where
290 Out2: for<'de> ArgumentDecoder<'de> + Send,
291 Map: Send + Fn(Out) -> Out2,
292 {
293 MappedAsyncCaller::new(self, map)
294 }
295}
296
297#[cfg_attr(target_family = "wasm", async_trait(?Send))]
298#[cfg_attr(not(target_family = "wasm"), async_trait)]
299impl<'agent, Out> AsyncCall for AsyncCaller<'agent, Out>
300where
301 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
302{
303 type Value = Out;
304 async fn call(self) -> Result<CallResponse<Out>, AgentError> {
305 self.call().await
306 }
307 async fn call_and_wait(self) -> Result<Out, AgentError> {
308 self.call_and_wait().await
309 }
310}
311
312impl<'agent, Out> IntoFuture for AsyncCaller<'agent, Out>
313where
314 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
315{
316 type IntoFuture = CallFuture<'agent, Out>;
317 type Output = Result<Out, AgentError>;
318 fn into_future(self) -> Self::IntoFuture {
319 AsyncCall::call_and_wait(self)
320 }
321}
322
323pub struct AndThenAsyncCaller<
327 'a,
328 Out: for<'de> ArgumentDecoder<'de> + Send,
329 Out2: for<'de> ArgumentDecoder<'de> + Send,
330 Inner: AsyncCall<Value = Out> + Send + 'a,
331 R: Future<Output = Result<Out2, AgentError>> + Send,
332 AndThen: Send + Fn(Out) -> R,
333> {
334 inner: Inner,
335 and_then: AndThen,
336 _out: PhantomData<Out>,
337 _out2: PhantomData<Out2>,
338 _lifetime: PhantomData<&'a ()>,
339}
340
341impl<'a, Out, Out2, Inner, R, AndThen> fmt::Debug
342 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
343where
344 Out: for<'de> ArgumentDecoder<'de> + Send,
345 Out2: for<'de> ArgumentDecoder<'de> + Send,
346 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
347 R: Future<Output = Result<Out2, AgentError>> + Send,
348 AndThen: Send + Fn(Out) -> R + fmt::Debug,
349{
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 f.debug_struct("AndThenAsyncCaller")
352 .field("inner", &self.inner)
353 .field("and_then", &self.and_then)
354 .field("_out", &self._out)
355 .field("_out2", &self._out2)
356 .finish()
357 }
358}
359
360impl<'a, Out, Out2, Inner, R, AndThen> AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
361where
362 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
363 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
364 Inner: AsyncCall<Value = Out> + Send + 'a,
365 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
366 AndThen: Send + Fn(Out) -> R + 'a,
367{
368 pub fn new(inner: Inner, and_then: AndThen) -> Self {
370 Self {
371 inner,
372 and_then,
373 _out: PhantomData,
374 _out2: PhantomData,
375 _lifetime: PhantomData,
376 }
377 }
378
379 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
381 let raw_response = self.inner.call().await?;
382
383 let response = match raw_response {
384 CallResponse::Response(response_bytes) => {
385 let mapped_response = (self.and_then)(response_bytes);
386 CallResponse::Response(mapped_response.await?)
387 }
388 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
389 };
390
391 Ok(response)
392 }
393 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
395 let v = self.inner.call_and_wait().await?;
396
397 let f = (self.and_then)(v);
398
399 f.await
400 }
401
402 pub fn and_then<Out3, R2, AndThen2>(
404 self,
405 and_then: AndThen2,
406 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
407 where
408 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
409 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
410 AndThen2: Send + Fn(Out2) -> R2 + 'a,
411 {
412 AndThenAsyncCaller::new(self, and_then)
413 }
414
415 pub fn map<Out3, Map>(self, map: Map) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map>
417 where
418 Out3: for<'de> ArgumentDecoder<'de> + Send,
419 Map: Send + Fn(Out2) -> Out3,
420 {
421 MappedAsyncCaller::new(self, map)
422 }
423}
424
425#[cfg_attr(target_family = "wasm", async_trait(?Send))]
426#[cfg_attr(not(target_family = "wasm"), async_trait)]
427impl<'a, Out, Out2, Inner, R, AndThen> AsyncCall
428 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
429where
430 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
431 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
432 Inner: AsyncCall<Value = Out> + Send + 'a,
433 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
434 AndThen: Send + Fn(Out) -> R + 'a,
435{
436 type Value = Out2;
437
438 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
439 self.call().await
440 }
441
442 async fn call_and_wait(self) -> Result<Out2, AgentError> {
443 self.call_and_wait().await
444 }
445}
446
447impl<'a, Out, Out2, Inner, R, AndThen> IntoFuture
448 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
449where
450 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
451 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
452 Inner: AsyncCall<Value = Out> + Send + 'a,
453 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
454 AndThen: Send + Fn(Out) -> R + 'a,
455{
456 type IntoFuture = CallFuture<'a, Out2>;
457 type Output = Result<Out2, AgentError>;
458 fn into_future(self) -> Self::IntoFuture {
459 AsyncCall::call_and_wait(self)
460 }
461}
462
463pub struct MappedAsyncCaller<
466 'a,
467 Out: for<'de> ArgumentDecoder<'de> + Send,
468 Out2: for<'de> ArgumentDecoder<'de> + Send,
469 Inner: AsyncCall<Value = Out> + Send + 'a,
470 Map: Send + Fn(Out) -> Out2,
471> {
472 inner: Inner,
473 map: Map,
474 _out: PhantomData<Out>,
475 _out2: PhantomData<Out2>,
476 _lifetime: PhantomData<&'a ()>,
477}
478
479impl<'a, Out, Out2, Inner, Map> fmt::Debug for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
480where
481 Out: for<'de> ArgumentDecoder<'de> + Send,
482 Out2: for<'de> ArgumentDecoder<'de> + Send,
483 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
484 Map: Send + Fn(Out) -> Out2 + fmt::Debug,
485{
486 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
487 f.debug_struct("MappedAsyncCaller")
488 .field("inner", &self.inner)
489 .field("map", &self.map)
490 .field("_out", &self._out)
491 .field("_out2", &self._out2)
492 .finish()
493 }
494}
495
496impl<'a, Out, Out2, Inner, Map> MappedAsyncCaller<'a, Out, Out2, Inner, Map>
497where
498 Out: for<'de> ArgumentDecoder<'de> + Send,
499 Out2: for<'de> ArgumentDecoder<'de> + Send,
500 Inner: AsyncCall<Value = Out> + Send + 'a,
501 Map: Send + Fn(Out) -> Out2,
502{
503 pub fn new(inner: Inner, map: Map) -> Self {
505 Self {
506 inner,
507 map,
508 _out: PhantomData,
509 _out2: PhantomData,
510 _lifetime: PhantomData,
511 }
512 }
513
514 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
516 self.inner.call().await.map(|response| match response {
517 CallResponse::Response(response_bytes) => {
518 let mapped_response = (self.map)(response_bytes);
519 CallResponse::Response(mapped_response)
520 }
521 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
522 })
523 }
524
525 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
527 let v = self.inner.call_and_wait().await?;
528 Ok((self.map)(v))
529 }
530
531 pub fn and_then<Out3, R2, AndThen2>(
533 self,
534 and_then: AndThen2,
535 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
536 where
537 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
538 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
539 AndThen2: Send + Fn(Out2) -> R2 + 'a,
540 {
541 AndThenAsyncCaller::new(self, and_then)
542 }
543
544 pub fn map<Out3, Map2>(self, map: Map2) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map2>
546 where
547 Out3: for<'de> ArgumentDecoder<'de> + Send,
548 Map2: Send + Fn(Out2) -> Out3,
549 {
550 MappedAsyncCaller::new(self, map)
551 }
552}
553
554#[cfg_attr(target_family = "wasm", async_trait(?Send))]
555#[cfg_attr(not(target_family = "wasm"), async_trait)]
556impl<'a, Out, Out2, Inner, Map> AsyncCall for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
557where
558 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
559 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
560 Inner: AsyncCall<Value = Out> + Send + 'a,
561 Map: Send + Fn(Out) -> Out2 + 'a,
562{
563 type Value = Out2;
564
565 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
566 self.call().await
567 }
568
569 async fn call_and_wait(self) -> Result<Out2, AgentError> {
570 self.call_and_wait().await
571 }
572}
573
574impl<'a, Out, Out2, Inner, Map> IntoFuture for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
575where
576 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
577 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
578 Inner: AsyncCall<Value = Out> + Send + 'a,
579 Map: Send + Fn(Out) -> Out2 + 'a,
580{
581 type IntoFuture = CallFuture<'a, Out2>;
582 type Output = Result<Out2, AgentError>;
583
584 fn into_future(self) -> Self::IntoFuture {
585 AsyncCall::call_and_wait(self)
586 }
587}