1use async_trait::async_trait;
2use candid::{decode_args, decode_one, utils::ArgumentDecoder, CandidType};
3use ic_agent::{
4 agent::{CallResponse, UpdateBuilder},
5 export::Principal,
6 Agent, AgentError,
7};
8use serde::de::DeserializeOwned;
9use std::fmt;
10use std::future::{Future, IntoFuture};
11use std::marker::PhantomData;
12use std::pin::Pin;
13
14mod expiry;
15pub use expiry::Expiry;
16
17#[cfg_attr(target_family = "wasm", async_trait(?Send))]
19#[cfg_attr(not(target_family = "wasm"), async_trait)]
20pub trait SyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
21 type Value: for<'de> ArgumentDecoder<'de> + Send;
23 #[cfg(feature = "raw")]
25 async fn call_raw(self) -> Result<Vec<u8>, AgentError>;
26
27 async fn call(self) -> Result<Self::Value, AgentError>
30 where
31 Self: Sized + Send,
32 Self::Value: 'async_trait;
33}
34
35#[cfg_attr(target_family = "wasm", async_trait(?Send))]
42#[cfg_attr(not(target_family = "wasm"), async_trait)]
43pub trait AsyncCall: CallIntoFuture<Output = Result<Self::Value, AgentError>> {
44 type Value: for<'de> ArgumentDecoder<'de> + Send;
46 async fn call(self) -> Result<CallResponse<Self::Value>, AgentError>;
56
57 async fn call_and_wait(self) -> Result<Self::Value, AgentError>;
60
61 #[cfg_attr(unix, doc = " ```rust")] #[cfg_attr(not(unix), doc = " ```ignore")]
66 fn and_then<'a, Out2, R, AndThen>(
119 self,
120 and_then: AndThen,
121 ) -> AndThenAsyncCaller<'a, Self::Value, Out2, Self, R, AndThen>
122 where
123 Self: Sized + Send + 'a,
124 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
125 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
126 AndThen: Send + Fn(Self::Value) -> R + 'a,
127 {
128 AndThenAsyncCaller::new(self, and_then)
129 }
130
131 fn map<'a, Out, Map>(self, map: Map) -> MappedAsyncCaller<'a, Self::Value, Out, Self, Map>
133 where
134 Self: Sized + Send + 'a,
135 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
136 Map: Send + Fn(Self::Value) -> Out + 'a,
137 {
138 MappedAsyncCaller::new(self, map)
139 }
140}
141
142#[cfg(target_family = "wasm")]
143pub(crate) type CallFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, AgentError>> + 'a>>;
144#[cfg(not(target_family = "wasm"))]
145pub(crate) type CallFuture<'a, T> =
146 Pin<Box<dyn Future<Output = Result<T, AgentError>> + Send + 'a>>;
147#[cfg(not(target_family = "wasm"))]
148#[doc(hidden)]
149pub trait CallIntoFuture: IntoFuture<IntoFuture = <Self as CallIntoFuture>::IntoFuture> {
150 type IntoFuture: Future<Output = Self::Output> + Send;
151}
152#[cfg(not(target_family = "wasm"))]
153impl<T> CallIntoFuture for T
154where
155 T: IntoFuture + ?Sized,
156 T::IntoFuture: Send,
157{
158 type IntoFuture = T::IntoFuture;
159}
160#[cfg(target_family = "wasm")]
161use IntoFuture as CallIntoFuture;
162
163#[derive(Debug)]
165pub struct SyncCaller<'agent, Out>
166where
167 Out: for<'de> ArgumentDecoder<'de> + Send,
168{
169 pub(crate) agent: &'agent Agent,
170 pub(crate) effective_canister_id: Principal,
171 pub(crate) canister_id: Principal,
172 pub(crate) method_name: String,
173 pub(crate) arg: Result<Vec<u8>, AgentError>,
174 pub(crate) expiry: Expiry,
175 pub(crate) phantom_out: PhantomData<Out>,
176}
177
178impl<'agent, Out> SyncCaller<'agent, Out>
179where
180 Out: for<'de> ArgumentDecoder<'de> + Send,
181{
182 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
184 let mut builder = self.agent.query(&self.canister_id, &self.method_name);
185 builder = self.expiry.apply_to_query(builder);
186 builder
187 .with_arg(self.arg?)
188 .with_effective_canister_id(self.effective_canister_id)
189 .call()
190 .await
191 }
192}
193
194#[cfg_attr(target_family = "wasm", async_trait(?Send))]
195#[cfg_attr(not(target_family = "wasm"), async_trait)]
196impl<'agent, Out> SyncCall for SyncCaller<'agent, Out>
197where
198 Self: Sized,
199 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
200{
201 type Value = Out;
202 #[cfg(feature = "raw")]
203 async fn call_raw(self) -> Result<Vec<u8>, AgentError> {
204 Ok(self.call_raw().await?)
205 }
206
207 async fn call(self) -> Result<Out, AgentError> {
208 let result = self.call_raw().await?;
209
210 decode_args(&result).map_err(|e| AgentError::CandidError(Box::new(e)))
211 }
212}
213
214impl<'agent, Out> IntoFuture for SyncCaller<'agent, Out>
215where
216 Self: Sized,
217 Out: 'agent + for<'de> ArgumentDecoder<'de> + Send,
218{
219 type IntoFuture = CallFuture<'agent, Out>;
220 type Output = Result<Out, AgentError>;
221 fn into_future(self) -> Self::IntoFuture {
222 SyncCall::call(self)
223 }
224}
225
226#[derive(Debug)]
228pub struct AsyncCaller<'agent, Out>
229where
230 Out: for<'de> ArgumentDecoder<'de> + Send,
231{
232 pub(crate) agent: &'agent Agent,
233 pub(crate) effective_canister_id: Principal,
234 pub(crate) canister_id: Principal,
235 pub(crate) method_name: String,
236 pub(crate) arg: Result<Vec<u8>, AgentError>,
237 pub(crate) expiry: Expiry,
238 pub(crate) phantom_out: PhantomData<Out>,
239}
240
241impl<'agent, Out> AsyncCaller<'agent, Out>
242where
243 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
244{
245 pub fn build_call(self) -> Result<UpdateBuilder<'agent>, AgentError> {
248 let mut builder = self.agent.update(&self.canister_id, &self.method_name);
249 builder = self.expiry.apply_to_update(builder);
250 builder = builder
251 .with_arg(self.arg?)
252 .with_effective_canister_id(self.effective_canister_id);
253 Ok(builder)
254 }
255
256 pub async fn call(self) -> Result<CallResponse<Out>, AgentError> {
258 let response_bytes = match self.build_call()?.call().await? {
259 CallResponse::Response((response_bytes, _)) => response_bytes,
260 CallResponse::Poll(request_id) => return Ok(CallResponse::Poll(request_id)),
261 };
262
263 let decoded_response =
264 decode_args(&response_bytes).map_err(|e| AgentError::CandidError(Box::new(e)))?;
265
266 Ok(CallResponse::Response(decoded_response))
267 }
268
269 pub async fn call_and_wait(self) -> Result<Out, AgentError> {
271 self.build_call()?
272 .call_and_wait()
273 .await
274 .and_then(|r| decode_args(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
275 }
276
277 pub async fn call_and_wait_one<T>(self) -> Result<T, AgentError>
279 where
280 T: DeserializeOwned + CandidType,
281 {
282 self.build_call()?
283 .call_and_wait()
284 .await
285 .and_then(|r| decode_one(&r).map_err(|e| AgentError::CandidError(Box::new(e))))
286 }
287
288 pub fn map<Out2, Map>(self, map: Map) -> MappedAsyncCaller<'agent, Out, Out2, Self, Map>
290 where
291 Out2: for<'de> ArgumentDecoder<'de> + Send,
292 Map: Send + Fn(Out) -> Out2,
293 {
294 MappedAsyncCaller::new(self, map)
295 }
296}
297
298#[cfg_attr(target_family = "wasm", async_trait(?Send))]
299#[cfg_attr(not(target_family = "wasm"), async_trait)]
300impl<'agent, Out> AsyncCall for AsyncCaller<'agent, Out>
301where
302 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
303{
304 type Value = Out;
305 async fn call(self) -> Result<CallResponse<Out>, AgentError> {
306 self.call().await
307 }
308 async fn call_and_wait(self) -> Result<Out, AgentError> {
309 self.call_and_wait().await
310 }
311}
312
313impl<'agent, Out> IntoFuture for AsyncCaller<'agent, Out>
314where
315 Out: for<'de> ArgumentDecoder<'de> + Send + 'agent,
316{
317 type IntoFuture = CallFuture<'agent, Out>;
318 type Output = Result<Out, AgentError>;
319 fn into_future(self) -> Self::IntoFuture {
320 AsyncCall::call_and_wait(self)
321 }
322}
323
324pub struct AndThenAsyncCaller<
328 'a,
329 Out: for<'de> ArgumentDecoder<'de> + Send,
330 Out2: for<'de> ArgumentDecoder<'de> + Send,
331 Inner: AsyncCall<Value = Out> + Send + 'a,
332 R: Future<Output = Result<Out2, AgentError>> + Send,
333 AndThen: Send + Fn(Out) -> R,
334> {
335 inner: Inner,
336 and_then: AndThen,
337 _out: PhantomData<Out>,
338 _out2: PhantomData<Out2>,
339 _lifetime: PhantomData<&'a ()>,
340}
341
342impl<'a, Out, Out2, Inner, R, AndThen> fmt::Debug
343 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
344where
345 Out: for<'de> ArgumentDecoder<'de> + Send,
346 Out2: for<'de> ArgumentDecoder<'de> + Send,
347 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
348 R: Future<Output = Result<Out2, AgentError>> + Send,
349 AndThen: Send + Fn(Out) -> R + fmt::Debug,
350{
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("AndThenAsyncCaller")
353 .field("inner", &self.inner)
354 .field("and_then", &self.and_then)
355 .field("_out", &self._out)
356 .field("_out2", &self._out2)
357 .finish()
358 }
359}
360
361impl<'a, Out, Out2, Inner, R, AndThen> AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
362where
363 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
364 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
365 Inner: AsyncCall<Value = Out> + Send + 'a,
366 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
367 AndThen: Send + Fn(Out) -> R + 'a,
368{
369 pub fn new(inner: Inner, and_then: AndThen) -> Self {
371 Self {
372 inner,
373 and_then,
374 _out: PhantomData,
375 _out2: PhantomData,
376 _lifetime: PhantomData,
377 }
378 }
379
380 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
382 let raw_response = self.inner.call().await?;
383
384 let response = match raw_response {
385 CallResponse::Response(response_bytes) => {
386 let mapped_response = (self.and_then)(response_bytes);
387 CallResponse::Response(mapped_response.await?)
388 }
389 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
390 };
391
392 Ok(response)
393 }
394 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
396 let v = self.inner.call_and_wait().await?;
397
398 let f = (self.and_then)(v);
399
400 f.await
401 }
402
403 pub fn and_then<Out3, R2, AndThen2>(
405 self,
406 and_then: AndThen2,
407 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
408 where
409 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
410 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
411 AndThen2: Send + Fn(Out2) -> R2 + 'a,
412 {
413 AndThenAsyncCaller::new(self, and_then)
414 }
415
416 pub fn map<Out3, Map>(self, map: Map) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map>
418 where
419 Out3: for<'de> ArgumentDecoder<'de> + Send,
420 Map: Send + Fn(Out2) -> Out3,
421 {
422 MappedAsyncCaller::new(self, map)
423 }
424}
425
426#[cfg_attr(target_family = "wasm", async_trait(?Send))]
427#[cfg_attr(not(target_family = "wasm"), async_trait)]
428impl<'a, Out, Out2, Inner, R, AndThen> AsyncCall
429 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
430where
431 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
432 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
433 Inner: AsyncCall<Value = Out> + Send + 'a,
434 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
435 AndThen: Send + Fn(Out) -> R + 'a,
436{
437 type Value = Out2;
438
439 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
440 self.call().await
441 }
442
443 async fn call_and_wait(self) -> Result<Out2, AgentError> {
444 self.call_and_wait().await
445 }
446}
447
448impl<'a, Out, Out2, Inner, R, AndThen> IntoFuture
449 for AndThenAsyncCaller<'a, Out, Out2, Inner, R, AndThen>
450where
451 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
452 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
453 Inner: AsyncCall<Value = Out> + Send + 'a,
454 R: Future<Output = Result<Out2, AgentError>> + Send + 'a,
455 AndThen: Send + Fn(Out) -> R + 'a,
456{
457 type IntoFuture = CallFuture<'a, Out2>;
458 type Output = Result<Out2, AgentError>;
459 fn into_future(self) -> Self::IntoFuture {
460 AsyncCall::call_and_wait(self)
461 }
462}
463
464pub struct MappedAsyncCaller<
467 'a,
468 Out: for<'de> ArgumentDecoder<'de> + Send,
469 Out2: for<'de> ArgumentDecoder<'de> + Send,
470 Inner: AsyncCall<Value = Out> + Send + 'a,
471 Map: Send + Fn(Out) -> Out2,
472> {
473 inner: Inner,
474 map: Map,
475 _out: PhantomData<Out>,
476 _out2: PhantomData<Out2>,
477 _lifetime: PhantomData<&'a ()>,
478}
479
480impl<'a, Out, Out2, Inner, Map> fmt::Debug for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
481where
482 Out: for<'de> ArgumentDecoder<'de> + Send,
483 Out2: for<'de> ArgumentDecoder<'de> + Send,
484 Inner: AsyncCall<Value = Out> + Send + fmt::Debug + 'a,
485 Map: Send + Fn(Out) -> Out2 + fmt::Debug,
486{
487 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488 f.debug_struct("MappedAsyncCaller")
489 .field("inner", &self.inner)
490 .field("map", &self.map)
491 .field("_out", &self._out)
492 .field("_out2", &self._out2)
493 .finish()
494 }
495}
496
497impl<'a, Out, Out2, Inner, Map> MappedAsyncCaller<'a, Out, Out2, Inner, Map>
498where
499 Out: for<'de> ArgumentDecoder<'de> + Send,
500 Out2: for<'de> ArgumentDecoder<'de> + Send,
501 Inner: AsyncCall<Value = Out> + Send + 'a,
502 Map: Send + Fn(Out) -> Out2,
503{
504 pub fn new(inner: Inner, map: Map) -> Self {
506 Self {
507 inner,
508 map,
509 _out: PhantomData,
510 _out2: PhantomData,
511 _lifetime: PhantomData,
512 }
513 }
514
515 pub async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
517 self.inner.call().await.map(|response| match response {
518 CallResponse::Response(response_bytes) => {
519 let mapped_response = (self.map)(response_bytes);
520 CallResponse::Response(mapped_response)
521 }
522 CallResponse::Poll(request_id) => CallResponse::Poll(request_id),
523 })
524 }
525
526 pub async fn call_and_wait(self) -> Result<Out2, AgentError> {
528 let v = self.inner.call_and_wait().await?;
529 Ok((self.map)(v))
530 }
531
532 pub fn and_then<Out3, R2, AndThen2>(
534 self,
535 and_then: AndThen2,
536 ) -> AndThenAsyncCaller<'a, Out2, Out3, Self, R2, AndThen2>
537 where
538 Out3: for<'de> ArgumentDecoder<'de> + Send + 'a,
539 R2: Future<Output = Result<Out3, AgentError>> + Send + 'a,
540 AndThen2: Send + Fn(Out2) -> R2 + 'a,
541 {
542 AndThenAsyncCaller::new(self, and_then)
543 }
544
545 pub fn map<Out3, Map2>(self, map: Map2) -> MappedAsyncCaller<'a, Out2, Out3, Self, Map2>
547 where
548 Out3: for<'de> ArgumentDecoder<'de> + Send,
549 Map2: Send + Fn(Out2) -> Out3,
550 {
551 MappedAsyncCaller::new(self, map)
552 }
553}
554
555#[cfg_attr(target_family = "wasm", async_trait(?Send))]
556#[cfg_attr(not(target_family = "wasm"), async_trait)]
557impl<'a, Out, Out2, Inner, Map> AsyncCall for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
558where
559 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
560 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
561 Inner: AsyncCall<Value = Out> + Send + 'a,
562 Map: Send + Fn(Out) -> Out2 + 'a,
563{
564 type Value = Out2;
565
566 async fn call(self) -> Result<CallResponse<Out2>, AgentError> {
567 self.call().await
568 }
569
570 async fn call_and_wait(self) -> Result<Out2, AgentError> {
571 self.call_and_wait().await
572 }
573}
574
575impl<'a, Out, Out2, Inner, Map> IntoFuture for MappedAsyncCaller<'a, Out, Out2, Inner, Map>
576where
577 Out: for<'de> ArgumentDecoder<'de> + Send + 'a,
578 Out2: for<'de> ArgumentDecoder<'de> + Send + 'a,
579 Inner: AsyncCall<Value = Out> + Send + 'a,
580 Map: Send + Fn(Out) -> Out2 + 'a,
581{
582 type IntoFuture = CallFuture<'a, Out2>;
583 type Output = Result<Out2, AgentError>;
584
585 fn into_future(self) -> Self::IntoFuture {
586 AsyncCall::call_and_wait(self)
587 }
588}