use derive_more::From;
use speare::{Actor, Ctx, Handle, Node, ReqErr, Request};
use std::time::Duration;
use tokio::task;
struct Echo;
#[derive(From)]
enum EchoMsg {
Echo(Request<String, String>),
Add(Request<(u32, u32), u32>),
}
impl Actor for Echo {
type Props = ();
type Msg = EchoMsg;
type Err = ();
async fn init(_: &mut Ctx<Self>) -> Result<Self, Self::Err> {
Ok(Echo)
}
async fn handle(&mut self, msg: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
match msg {
EchoMsg::Echo(req) => req.reply(req.data().clone()),
EchoMsg::Add(req) => {
let (a, b) = req.data();
req.reply(a + b);
}
}
Ok(())
}
}
struct BlackHole;
impl Actor for BlackHole {
type Props = ();
type Msg = Request<(), ()>;
type Err = ();
async fn init(_: &mut Ctx<Self>) -> Result<Self, Self::Err> {
Ok(BlackHole)
}
async fn handle(&mut self, _msg: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
Ok(())
}
}
struct SlowReplier;
#[derive(From)]
enum SlowMsg {
Slow(Request<Duration, String>),
}
impl Actor for SlowReplier {
type Props = ();
type Msg = SlowMsg;
type Err = ();
async fn init(_: &mut Ctx<Self>) -> Result<Self, Self::Err> {
Ok(SlowReplier)
}
async fn handle(&mut self, msg: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
match msg {
SlowMsg::Slow(req) => {
let delay = *req.data();
tokio::time::sleep(delay).await;
req.reply("done".into());
}
}
Ok(())
}
}
struct ManualWrap {
count: u32,
}
enum ManualMsg {
Inc,
GetCount(Request<(), u32>),
}
impl Actor for ManualWrap {
type Props = ();
type Msg = ManualMsg;
type Err = ();
async fn init(_: &mut Ctx<Self>) -> Result<Self, Self::Err> {
Ok(ManualWrap { count: 0 })
}
async fn handle(&mut self, msg: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
match msg {
ManualMsg::Inc => self.count += 1,
ManualMsg::GetCount(req) => req.reply(self.count),
}
Ok(())
}
}
struct Forwarder {
target: Handle<EchoMsg>,
}
#[derive(From)]
enum ForwarderMsg {
Forward(Request<String, String>),
}
impl Actor for Forwarder {
type Props = Handle<EchoMsg>;
type Msg = ForwarderMsg;
type Err = ();
async fn init(ctx: &mut Ctx<Self>) -> Result<Self, Self::Err> {
Ok(Forwarder {
target: ctx.props().clone(),
})
}
async fn handle(&mut self, msg: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
match msg {
ForwarderMsg::Forward(req) => {
let result: String = self.target.req(req.data().clone()).await.unwrap();
req.reply(result);
}
}
Ok(())
}
}
#[tokio::test]
async fn req_returns_response_from_actor() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
task::yield_now().await;
let result: String = echo.req("hello".to_string()).await.unwrap();
assert_eq!(result, "hello");
}
#[tokio::test]
async fn req_works_with_complex_data() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
task::yield_now().await;
let sum: u32 = echo.req((3u32, 7u32)).await.unwrap();
assert_eq!(sum, 10);
}
#[tokio::test]
async fn req_returns_dropped_when_actor_never_replies() {
let mut node = Node::default();
let hole = node.actor::<BlackHole>(()).spawn();
task::yield_now().await;
let result = hole.req(()).await;
assert!(matches!(result, Err(ReqErr::Dropped)));
}
#[tokio::test]
async fn req_returns_dropped_when_actor_is_stopped() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
task::yield_now().await;
let (req, res) = speare::req_res::<String, String>("hello".into());
echo.stop();
task::yield_now().await;
echo.send(req);
let result = res.recv().await;
assert!(matches!(result, Err(ReqErr::Dropped)));
}
#[tokio::test]
async fn req_timeout_returns_response_within_deadline() {
let mut node = Node::default();
let slow = node.actor::<SlowReplier>(()).spawn();
task::yield_now().await;
let result: String = slow
.req_timeout(Duration::from_millis(5), Duration::from_secs(1))
.await
.unwrap();
assert_eq!(result, "done");
}
#[tokio::test]
async fn req_timeout_returns_timeout_when_deadline_exceeded() {
let mut node = Node::default();
let slow = node.actor::<SlowReplier>(()).spawn();
task::yield_now().await;
let result = slow
.req_timeout(Duration::from_secs(10), Duration::from_millis(10))
.await;
assert!(matches!(result, Err(ReqErr::Timeout)));
}
#[tokio::test]
async fn reqw_sends_request_using_wrapper_function() {
let mut node = Node::default();
let actor = node.actor::<ManualWrap>(()).spawn();
task::yield_now().await;
actor.send(ManualMsg::Inc);
actor.send(ManualMsg::Inc);
actor.send(ManualMsg::Inc);
task::yield_now().await;
let count: u32 = actor.reqw(ManualMsg::GetCount, ()).await.unwrap();
assert_eq!(count, 3);
}
#[tokio::test]
async fn reqw_timeout_returns_response_within_deadline() {
let mut node = Node::default();
let actor = node.actor::<ManualWrap>(()).spawn();
task::yield_now().await;
actor.send(ManualMsg::Inc);
task::yield_now().await;
let count: u32 = actor
.reqw_timeout(ManualMsg::GetCount, (), Duration::from_secs(1))
.await
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn multiple_sequential_requests_to_same_actor() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
task::yield_now().await;
for i in 0..10 {
let msg = format!("msg-{i}");
let result: String = echo.req(msg.clone()).await.unwrap();
assert_eq!(result, msg);
}
}
#[tokio::test]
async fn concurrent_requests_from_multiple_callers() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
task::yield_now().await;
let mut handles = Vec::new();
for i in 0..10 {
let echo = echo.clone();
handles.push(tokio::spawn(async move {
let msg = format!("caller-{i}");
let result: String = echo.req(msg.clone()).await.unwrap();
(result, msg)
}));
}
for h in handles {
let (result, expected) = h.await.unwrap();
assert_eq!(result, expected);
}
}
#[tokio::test]
async fn actor_can_make_request_to_another_actor_inside_handle() {
let mut node = Node::default();
let echo = node.actor::<Echo>(()).spawn();
let forwarder = node.actor::<Forwarder>(echo).spawn();
task::yield_now().await;
let result: String = forwarder.req("forwarded".to_string()).await.unwrap();
assert_eq!(result, "forwarded");
}
#[tokio::test]
async fn response_recv_timeout_returns_dropped_when_request_dropped() {
let (req, res) = speare::req_res::<(), String>(());
drop(req);
let result = res.recv_timeout(Duration::from_secs(1)).await;
assert!(matches!(result, Err(ReqErr::Dropped)));
}
#[tokio::test]
async fn request_data_is_accessible() {
let (req, _res) = speare::req_res::<String, ()>("payload".to_string());
assert_eq!(req.data(), "payload");
}