use async_trait::async_trait;
use std::time::Duration;
use tokio::{task, time};
use crate::{McReceiver, MpSender, Run};
#[async_trait]
pub trait SendReceive<JId, ReqTx, RespRx>: Run
where
JId: Copy + Eq + Send + Sync + 'static,
<Self as Run>::Error: Clone + Send + 'static,
<Self as Run>::Request: Send + 'static,
<Self as Run>::Response: Send + 'static,
ReqTx: MpSender<Message = (JId, <Self as Run>::Request)> + Send + Sync + 'static,
RespRx: McReceiver<Message = (JId, Result<<Self as Run>::Response, <Self as Run>::Error>)>
+ Send
+ Sync
+ 'static,
{
async fn send_receive(
request_sender: ReqTx,
response_receiver: RespRx,
jid: JId,
req: <Self as Run>::Request,
timeout: Duration,
) -> Result<<Self as Run>::Response, SendReceiveError<<Self as Run>::Error>> {
let poller = task::spawn(async move {
loop {
match response_receiver.receive().await {
Ok((resp_job_id, resp_rslt)) if jid == resp_job_id => return resp_rslt,
_ => {}
}
}
});
request_sender
.send((jid, req))
.await
.map_err(|_| SendReceiveError::RequestSender)?;
let resp = time::timeout(timeout, poller)
.await
.map_err(|_| SendReceiveError::Elapsed)?
.map_err(|_| SendReceiveError::Join)?
.map_err(SendReceiveError::Job)?;
Ok(resp)
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum SendReceiveError<JErr>
where
JErr: Clone,
{
#[error("error occured sending a request to the job processor")]
RequestSender,
#[error("error occured receiving a response from the job processor")]
ResponseReceiver,
#[error(transparent)]
Job(JErr),
#[error("error occured joining on threads")]
Join,
#[error("response was not received by within the expected timeout period")]
Elapsed,
}