crb_superagent/interplay/
fetcher.rs1use crate::supervisor::ForwardTo;
2use anyhow::{Error, Result, anyhow};
3use async_trait::async_trait;
4use crb_agent::{Address, Agent, AgentSession, Context, DoAsync, MessageFor, Next, RunAgent};
5use crb_core::{Msg, Slot, Tag};
6use crb_runtime::InterruptionLevel;
7use crb_send::{Recipient, Sender};
8use futures::channel::oneshot::{self, Canceled};
9use futures::{
10 Future,
11 task::{Context as FutContext, Poll},
12};
13use std::future::IntoFuture;
14use std::pin::Pin;
15use thiserror::Error;
16
17pub struct Interplay<IN, OUT> {
18 pub request: IN,
19 pub responder: Responder<OUT>,
20}
21
22pub struct Responder<OUT> {
23 tx: oneshot::Sender<Result<OUT>>,
24}
25
26impl<OUT> Responder<OUT> {
27 pub fn send(self, resp: OUT) -> Result<()> {
28 self.send_result(Ok(resp))
29 }
30
31 pub fn send_result(self, resp: Result<OUT>) -> Result<()> {
32 self.tx
33 .send(resp)
34 .map_err(|_| anyhow!("Can't send the response."))
35 }
36}
37
38impl<IN, OUT> Interplay<IN, OUT> {
39 pub fn new_pair(request: IN) -> (Self, Fetcher<OUT>) {
40 let (tx, rx) = oneshot::channel();
41 let responder = Responder { tx };
42 let interplay = Interplay { request, responder };
43 let fetcher = Fetcher { rx };
44 (interplay, fetcher)
45 }
46}
47
48#[must_use]
49pub struct Fetcher<OUT> {
50 rx: oneshot::Receiver<Result<OUT>>,
51}
52
53impl<OUT> Fetcher<OUT> {
54 pub fn grasp(self, result: Result<()>) -> Self {
55 match result {
56 Ok(_) => self,
57 Err(err) => Self::spoiled(err),
58 }
59 }
60
61 pub fn spoiled(err: Error) -> Fetcher<OUT> {
62 let (tx, rx) = oneshot::channel();
63 tx.send(Err(err)).ok();
64 Fetcher { rx }
65 }
66}
67
68#[derive(Error, Debug)]
69pub enum FetchError {
70 #[error("Request failed: {0}")]
71 Failed(#[from] anyhow::Error),
72 #[error("Request canceled: {0}")]
73 Canceled(#[from] Canceled),
74}
75
76pub type Output<R> = Result<R, FetchError>;
77
78impl<OUT> Future for Fetcher<OUT> {
79 type Output = Output<OUT>;
80 fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
81 Pin::new(&mut self.rx).poll(cx).map(|result| {
82 result
83 .map_err(FetchError::from)
84 .and_then(|res| res.map_err(FetchError::from))
85 })
86 }
87}
88
89impl<A, OUT, T> ForwardTo<A, T> for Fetcher<OUT>
90where
91 A: OnResponse<OUT, T>,
92 OUT: Msg,
93 T: Tag,
94{
95 type Runtime = RunAgent<FetcherTask<OUT, T>>;
96
97 fn into_trackable(self, address: Address<A>, tag: T) -> Self::Runtime {
98 let task = FetcherTask {
99 recipient: address.sender(),
100 fetcher: self,
101 tag: Slot::filled(tag),
102 };
103 let mut runtime = RunAgent::new(task);
104 runtime.level = InterruptionLevel::ABORT;
105 runtime
106 }
107}
108
109pub struct FetcherTask<OUT, T> {
110 recipient: Recipient<Response<OUT, T>>,
111 fetcher: Fetcher<OUT>,
112 tag: Slot<T>,
113}
114
115impl<OUT, T> Agent for FetcherTask<OUT, T>
116where
117 OUT: Msg,
118 T: Tag,
119{
120 type Context = AgentSession<Self>;
121 type Link = Address<Self>;
122
123 fn begin(&mut self) -> Next<Self> {
124 Next::do_async(())
125 }
126}
127
128#[async_trait]
129impl<OUT, T> DoAsync for FetcherTask<OUT, T>
130where
131 OUT: Msg,
132 T: Tag,
133{
134 async fn once(&mut self, _: &mut ()) -> Result<Next<Self>> {
135 let response = (&mut self.fetcher).await;
136 self.recipient.send(Response {
137 response,
138 tag: self.tag.take()?,
139 })?;
140 Ok(Next::done())
141 }
142}
143
144impl<OUT, T> IntoFuture for FetcherTask<OUT, T> {
145 type Output = Output<OUT>;
146 type IntoFuture = Fetcher<OUT>;
147
148 fn into_future(self) -> Self::IntoFuture {
149 self.fetcher
150 }
151}
152
153#[async_trait]
154pub trait OnResponse<OUT, T = ()>: Agent {
155 async fn on_response(
156 &mut self,
157 response: Output<OUT>,
158 tag: T,
159 ctx: &mut Context<Self>,
160 ) -> Result<()>;
161}
162
163struct Response<OUT, T> {
164 response: Output<OUT>,
165 tag: T,
166}
167
168#[async_trait]
169impl<A, OUT, T> MessageFor<A> for Response<OUT, T>
170where
171 A: OnResponse<OUT, T>,
172 OUT: Msg,
173 T: Tag,
174{
175 async fn handle(self: Box<Self>, agent: &mut A, ctx: &mut Context<A>) -> Result<()> {
176 agent.on_response(self.response, self.tag, ctx).await
177 }
178}