use crate::error::ActorError;
use spawned_rt::tasks::oneshot;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
enum ResponseState<T> {
Receiver(oneshot::Receiver<T>),
TimedReceiver(Pin<Box<dyn Future<Output = Result<T, ActorError>> + Send>>),
Ready(Result<T, ActorError>),
Done,
}
pub struct Response<T>(ResponseState<T>);
impl<T> Unpin for Response<T> {}
impl<T> Response<T> {
pub fn ready(result: Result<T, ActorError>) -> Self {
Self(ResponseState::Ready(result))
}
pub fn unwrap(self) -> T {
match self.0 {
ResponseState::Ready(result) => result.unwrap(),
ResponseState::Receiver(_) | ResponseState::TimedReceiver(_) => {
panic!("called unwrap() on a pending Response; use .await in async contexts")
}
ResponseState::Done => panic!("Response already consumed"),
}
}
pub fn expect(self, msg: &str) -> T {
match self.0 {
ResponseState::Ready(result) => result.expect(msg),
ResponseState::Receiver(_) | ResponseState::TimedReceiver(_) => {
panic!("{msg}: called expect() on a pending Response; use .await in async contexts")
}
ResponseState::Done => panic!("{msg}: Response already consumed"),
}
}
pub fn is_ok(&self) -> bool {
matches!(&self.0, ResponseState::Ready(Ok(_)))
}
pub fn is_err(&self) -> bool {
matches!(&self.0, ResponseState::Ready(Err(_)))
}
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> Response<U> {
match self.0 {
ResponseState::Ready(result) => Response(ResponseState::Ready(result.map(f))),
ResponseState::Receiver(_) | ResponseState::TimedReceiver(_) => {
panic!("called map() on a pending Response; use .await in async contexts")
}
ResponseState::Done => panic!("Response already consumed"),
}
}
}
impl<T: Send + 'static> Response<T> {
pub fn from_with_timeout(
result: Result<oneshot::Receiver<T>, ActorError>,
duration: Duration,
) -> Self {
match result {
Ok(rx) => {
let fut = Box::pin(async move {
match spawned_rt::tasks::timeout(duration, rx).await {
Ok(Ok(val)) => Ok(val),
Ok(Err(_)) => Err(ActorError::ActorStopped),
Err(_) => Err(ActorError::RequestTimeout),
}
});
Self(ResponseState::TimedReceiver(fut))
}
Err(e) => Self(ResponseState::Ready(Err(e))),
}
}
}
impl<T> From<Result<oneshot::Receiver<T>, ActorError>> for Response<T> {
fn from(result: Result<oneshot::Receiver<T>, ActorError>) -> Self {
match result {
Ok(rx) => Self(ResponseState::Receiver(rx)),
Err(e) => Self(ResponseState::Ready(Err(e))),
}
}
}
impl<T: Send + 'static> Future for Response<T> {
type Output = Result<T, ActorError>;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.get_mut();
match &mut this.0 {
ResponseState::Receiver(rx) => match Pin::new(rx).poll(cx) {
std::task::Poll::Ready(Ok(val)) => {
this.0 = ResponseState::Done;
std::task::Poll::Ready(Ok(val))
}
std::task::Poll::Ready(Err(_)) => {
this.0 = ResponseState::Done;
std::task::Poll::Ready(Err(ActorError::ActorStopped))
}
std::task::Poll::Pending => std::task::Poll::Pending,
},
ResponseState::TimedReceiver(fut) => match fut.as_mut().poll(cx) {
std::task::Poll::Ready(result) => {
this.0 = ResponseState::Done;
std::task::Poll::Ready(result)
}
std::task::Poll::Pending => std::task::Poll::Pending,
},
ResponseState::Ready(_) => match std::mem::replace(&mut this.0, ResponseState::Done) {
ResponseState::Ready(result) => std::task::Poll::Ready(result),
_ => unreachable!(),
},
ResponseState::Done => panic!("Response polled after completion"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use spawned_rt::tasks::oneshot;
#[test]
fn ready_ok_unwrap() {
let r: Response<i32> = Response::ready(Ok(42));
assert_eq!(r.unwrap(), 42);
}
#[test]
fn ready_err_is_err() {
let r: Response<i32> = Response::ready(Err(ActorError::ActorStopped));
assert!(r.is_err());
}
#[test]
#[should_panic(expected = "ActorStopped")]
fn ready_err_unwrap_panics() {
let r: Response<i32> = Response::ready(Err(ActorError::ActorStopped));
r.unwrap();
}
#[test]
fn future_resolves_from_receiver() {
let rt = spawned_rt::tasks::Runtime::new().unwrap();
rt.block_on(async {
let (tx, rx) = oneshot::channel::<i32>();
let resp: Response<i32> = Response::from(Ok(rx));
tx.send(99).unwrap();
let val = resp.await.unwrap();
assert_eq!(val, 99);
});
}
#[test]
fn future_err_on_dropped_sender() {
let rt = spawned_rt::tasks::Runtime::new().unwrap();
rt.block_on(async {
let (tx, rx) = oneshot::channel::<i32>();
let resp: Response<i32> = Response::from(Ok(rx));
drop(tx);
let result = resp.await;
assert!(matches!(result, Err(ActorError::ActorStopped)));
});
}
#[test]
fn map_transforms_value() {
let r: Response<i32> = Response::ready(Ok(2));
let mapped = r.map(|x| x * 3);
assert_eq!(mapped.unwrap(), 6);
}
#[test]
fn timed_receiver_resolves() {
let rt = spawned_rt::tasks::Runtime::new().unwrap();
rt.block_on(async {
let (tx, rx) = oneshot::channel::<i32>();
let resp = Response::from_with_timeout(Ok(rx), Duration::from_secs(5));
tx.send(42).unwrap();
assert_eq!(resp.await.unwrap(), 42);
});
}
#[test]
fn timed_receiver_times_out() {
let rt = spawned_rt::tasks::Runtime::new().unwrap();
rt.block_on(async {
let (_tx, rx) = oneshot::channel::<i32>();
let resp = Response::from_with_timeout(Ok(rx), Duration::from_millis(50));
let result = resp.await;
assert!(matches!(result, Err(ActorError::RequestTimeout)));
});
}
}