use async_trait::async_trait;
use joerl::{
ActorSystem,
gen_server::{self, CallResponse, GenServer, GenServerContext, ReplyHandle},
};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
struct FlexibleCounter;
#[derive(Debug, Clone)]
#[allow(clippy::enum_variant_names)]
enum CounterCall {
GetImmediate,
GetDeferred,
GetDeferredSlow,
}
#[derive(Debug)]
enum CounterCast {
Increment,
}
#[async_trait]
impl GenServer for FlexibleCounter {
type State = i32;
type Call = CounterCall;
type Cast = CounterCast;
type CallReply = i32;
async fn init(&mut self, _ctx: &mut GenServerContext<'_, Self>) -> Self::State {
0
}
async fn handle_call(
&mut self,
call: Self::Call,
state: &mut Self::State,
ctx: &mut GenServerContext<'_, Self>,
) -> CallResponse<Self::CallReply> {
match call {
CounterCall::GetImmediate => {
CallResponse::Reply(*state)
}
CounterCall::GetDeferred => {
let handle = ctx.reply_handle();
let value = *state;
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
handle.reply(value).expect("reply failed");
});
CallResponse::NoReply
}
CounterCall::GetDeferredSlow => {
let handle = ctx.reply_handle();
let value = *state;
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
handle.reply(value).expect("reply failed");
});
CallResponse::NoReply
}
}
}
async fn handle_cast(
&mut self,
cast: Self::Cast,
state: &mut Self::State,
_ctx: &mut GenServerContext<'_, Self>,
) {
match cast {
CounterCast::Increment => *state += 1,
}
}
}
#[tokio::test]
async fn test_immediate_reply() {
let system = Arc::new(ActorSystem::new());
let counter = gen_server::spawn(&system, FlexibleCounter);
counter.cast(CounterCast::Increment).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let value = counter.call(CounterCall::GetImmediate).await.unwrap();
assert_eq!(value, 1);
}
#[tokio::test]
async fn test_deferred_reply() {
let system = Arc::new(ActorSystem::new());
let counter = gen_server::spawn(&system, FlexibleCounter);
counter.cast(CounterCast::Increment).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let value = counter.call(CounterCall::GetDeferred).await.unwrap();
assert_eq!(value, 1);
}
#[tokio::test]
async fn test_deferred_reply_allows_mailbox_processing() {
let system = Arc::new(ActorSystem::new());
let counter = gen_server::spawn(&system, FlexibleCounter);
let slow_call = tokio::spawn({
let counter = counter.clone();
async move { counter.call(CounterCall::GetDeferredSlow).await }
});
tokio::time::sleep(Duration::from_millis(10)).await;
counter.cast(CounterCast::Increment).await.unwrap();
counter.cast(CounterCast::Increment).await.unwrap();
counter.cast(CounterCast::Increment).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let value = counter.call(CounterCall::GetImmediate).await.unwrap();
assert_eq!(value, 3, "mailbox should have processed casts");
let slow_value = slow_call.await.unwrap().unwrap();
assert_eq!(slow_value, 0, "deferred reply should return original value");
}
#[tokio::test]
async fn test_multiple_concurrent_deferred_replies() {
let system = Arc::new(ActorSystem::new());
let counter = gen_server::spawn(&system, FlexibleCounter);
counter.cast(CounterCast::Increment).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let mut handles = vec![];
for _ in 0..5 {
let counter_clone = counter.clone();
let handle =
tokio::spawn(
async move { counter_clone.call(CounterCall::GetDeferred).await.unwrap() },
);
handles.push(handle);
}
for handle in handles {
let value = handle.await.unwrap();
assert_eq!(value, 1);
}
}
struct ControlledReplyServer {
reply_handles: Arc<Mutex<Vec<ReplyHandle<i32>>>>,
}
#[derive(Debug)]
enum ControlledCall {
DeferReply,
}
#[derive(Debug)]
enum ControlledCast {
ReplyToAll(i32),
}
#[async_trait]
impl GenServer for ControlledReplyServer {
type State = Arc<Mutex<Vec<ReplyHandle<i32>>>>;
type Call = ControlledCall;
type Cast = ControlledCast;
type CallReply = i32;
async fn init(&mut self, _ctx: &mut GenServerContext<'_, Self>) -> Self::State {
self.reply_handles.clone()
}
async fn handle_call(
&mut self,
call: Self::Call,
state: &mut Self::State,
ctx: &mut GenServerContext<'_, Self>,
) -> CallResponse<Self::CallReply> {
match call {
ControlledCall::DeferReply => {
let handle = ctx.reply_handle();
state.lock().await.push(handle);
CallResponse::NoReply
}
}
}
async fn handle_cast(
&mut self,
cast: Self::Cast,
state: &mut Self::State,
_ctx: &mut GenServerContext<'_, Self>,
) {
match cast {
ControlledCast::ReplyToAll(value) => {
let mut handles = state.lock().await;
for handle in handles.drain(..) {
handle.reply(value).ok();
}
}
}
}
}
#[tokio::test]
async fn test_manual_reply_control() {
let system = Arc::new(ActorSystem::new());
let reply_handles = Arc::new(Mutex::new(Vec::new()));
let server = gen_server::spawn(
&system,
ControlledReplyServer {
reply_handles: reply_handles.clone(),
},
);
let call1 = tokio::spawn({
let server = server.clone();
async move { server.call(ControlledCall::DeferReply).await }
});
let call2 = tokio::spawn({
let server = server.clone();
async move { server.call(ControlledCall::DeferReply).await }
});
let call3 = tokio::spawn({
let server = server.clone();
async move { server.call(ControlledCall::DeferReply).await }
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(reply_handles.lock().await.len(), 3);
server.cast(ControlledCast::ReplyToAll(42)).await.unwrap();
assert_eq!(call1.await.unwrap().unwrap(), 42);
assert_eq!(call2.await.unwrap().unwrap(), 42);
assert_eq!(call3.await.unwrap().unwrap(), 42);
assert_eq!(reply_handles.lock().await.len(), 0);
}
struct ErrorTestServer;
#[derive(Debug)]
enum ErrorCall {
TakeTwice,
}
#[async_trait]
impl GenServer for ErrorTestServer {
type State = ();
type Call = ErrorCall;
type Cast = ();
type CallReply = String;
async fn init(&mut self, _ctx: &mut GenServerContext<'_, Self>) -> Self::State {}
async fn handle_call(
&mut self,
call: Self::Call,
_state: &mut Self::State,
ctx: &mut GenServerContext<'_, Self>,
) -> CallResponse<Self::CallReply> {
match call {
ErrorCall::TakeTwice => {
let handle1 = ctx.reply_handle();
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = ctx.reply_handle(); }))
.expect_err("should panic on second take");
handle1.reply("success".to_string()).ok();
CallResponse::NoReply
}
}
}
async fn handle_cast(
&mut self,
_: Self::Cast,
_: &mut Self::State,
_: &mut GenServerContext<'_, Self>,
) {
}
}
#[tokio::test]
async fn test_reply_handle_can_only_be_taken_once() {
let system = Arc::new(ActorSystem::new());
let server = gen_server::spawn(&system, ErrorTestServer);
let result = server.call(ErrorCall::TakeTwice).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}