use anyhow::Result;
use rsactor::{message_handlers, spawn, Actor, ActorRef, Error};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
#[derive(Actor)]
struct TaskSpawnerActor {
task_counter: Arc<AtomicU32>,
}
struct ComputeTask {
value: u32,
delay_ms: u64,
}
struct PanicTask {
delay_ms: u64,
}
struct CancelledTask;
struct GetTaskCount;
#[message_handlers]
impl TaskSpawnerActor {
#[handler]
async fn handle_compute_task(
&mut self,
msg: ComputeTask,
_: &ActorRef<Self>,
) -> JoinHandle<u64> {
let counter = self.task_counter.fetch_add(1, Ordering::SeqCst);
let value = msg.value;
let delay = Duration::from_millis(msg.delay_ms);
tokio::spawn(async move {
tokio::time::sleep(delay).await;
(value as u64) * 2 + (counter as u64)
})
}
#[handler]
async fn handle_panic_task(
&mut self,
msg: PanicTask,
_: &ActorRef<Self>,
) -> JoinHandle<String> {
let delay = Duration::from_millis(msg.delay_ms);
tokio::spawn(async move {
tokio::time::sleep(delay).await;
panic!("Intentional panic for testing");
})
}
#[handler]
async fn handle_cancelled_task(
&mut self,
_: CancelledTask,
_: &ActorRef<Self>,
) -> JoinHandle<String> {
let handle = tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(10)).await;
"This should never complete".to_string()
});
handle.abort();
handle
}
#[handler]
async fn handle_get_task_count(&mut self, _: GetTaskCount, _: &ActorRef<Self>) -> u32 {
self.task_counter.load(Ordering::SeqCst)
}
}
#[tokio::test]
async fn test_ask_join_successful_computation() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor {
task_counter: task_counter.clone(),
});
let result: u64 = actor_ref
.ask_join(ComputeTask {
value: 10,
delay_ms: 100,
})
.await?;
assert_eq!(result, 20);
let task_count = actor_ref.ask(GetTaskCount).await?;
assert_eq!(task_count, 1);
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_multiple_concurrent_tasks() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor {
task_counter: task_counter.clone(),
});
let tasks = vec![
(5u32, 50u64), (10u32, 30u64), (15u32, 70u64), ];
let mut handles = Vec::new();
for (value, delay) in tasks {
let actor_ref_clone = actor_ref.clone();
let handle = tokio::spawn(async move {
actor_ref_clone
.ask_join(ComputeTask {
value,
delay_ms: delay,
})
.await
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
let result = handle.await??;
results.push(result);
}
assert_eq!(results.len(), 3);
results.sort();
assert!(results.windows(2).all(|w| w[0] != w[1]));
let task_count = actor_ref.ask(GetTaskCount).await?;
assert_eq!(task_count, 3);
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_panicked_task() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
let result = actor_ref.ask_join(PanicTask { delay_ms: 50 }).await;
match result {
Ok(_) => panic!("Expected join error for panicked task"),
Err(Error::Join { identity, source }) => {
assert!(identity.name().contains("TaskSpawnerActor"));
assert!(source.is_panic());
}
Err(e) => panic!("Expected Join error, got: {:?}", e),
}
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_cancelled_task() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
let result = actor_ref.ask_join(CancelledTask).await;
match result {
Ok(_) => panic!("Expected join error for cancelled task"),
Err(Error::Join { identity, source }) => {
assert!(identity.name().contains("TaskSpawnerActor"));
assert!(source.is_cancelled());
}
Err(e) => panic!("Expected Join error, got: {:?}", e),
}
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_vs_regular_ask() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
let join_handle: JoinHandle<u64> = actor_ref
.ask(ComputeTask {
value: 20,
delay_ms: 50,
})
.await?;
let manual_result = join_handle.await?;
let auto_result: u64 = actor_ref
.ask_join(ComputeTask {
value: 20,
delay_ms: 50,
})
.await?;
assert!(manual_result >= 40); assert!(auto_result >= 40);
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_with_actor_stopped() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
actor_ref.stop().await?;
actor_handle.await?;
let result = actor_ref
.ask_join(ComputeTask {
value: 5,
delay_ms: 10,
})
.await;
match result {
Ok(_) => panic!("Expected error when using ask_join on stopped actor"),
Err(Error::Send { identity, .. }) => {
assert!(identity.name().contains("TaskSpawnerActor"));
}
Err(e) => panic!("Expected Send error, got: {:?}", e),
}
Ok(())
}
#[tokio::test]
async fn test_ask_join_timeout_behavior() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
let start_time = std::time::Instant::now();
let result: u64 = actor_ref
.ask_join(ComputeTask {
value: 1,
delay_ms: 200, })
.await?;
let elapsed = start_time.elapsed();
assert!(elapsed >= Duration::from_millis(190)); assert_eq!(result, 2);
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_ask_join_error_source() -> Result<()> {
let task_counter = Arc::new(AtomicU32::new(0));
let (actor_ref, actor_handle) = spawn::<TaskSpawnerActor>(TaskSpawnerActor { task_counter });
let result = actor_ref.ask_join(PanicTask { delay_ms: 10 }).await;
match result {
Err(Error::Join { source, .. }) => {
assert!(source.is_panic());
let error_as_std_error: &dyn std::error::Error = &Error::Join {
identity: actor_ref.identity(),
source,
};
assert!(error_as_std_error.source().is_some());
}
_ => panic!("Expected Join error"),
}
actor_ref.stop().await?;
actor_handle.await?;
Ok(())
}