#![allow(clippy::cast_possible_truncation)]
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use hannibal::{TaskHandle, prelude::*, runtime::sleep};
#[derive(Default)]
struct TaskSpawningActor {
task_completed_flags: Vec<Arc<AtomicBool>>,
}
impl Actor for TaskSpawningActor {}
#[message(response = Arc<AtomicBool>)]
struct SpawnQuickTask {
duration_ms: u64,
}
impl Handler<SpawnQuickTask> for TaskSpawningActor {
async fn handle(&mut self, ctx: &mut Context<Self>, msg: SpawnQuickTask) -> Arc<AtomicBool> {
let completed = Arc::new(AtomicBool::new(false));
let completed_clone = Arc::clone(&completed);
ctx.spawn_task(async move {
sleep(Duration::from_millis(msg.duration_ms)).await;
completed_clone.store(true, Ordering::SeqCst);
});
self.task_completed_flags.push(Arc::clone(&completed));
completed
}
}
#[message]
struct TriggerGC;
impl Handler<TriggerGC> for TaskSpawningActor {
async fn handle(&mut self, ctx: &mut Context<Self>, _: TriggerGC) {
ctx.gc();
}
}
#[test_log::test(tokio::test)]
async fn gc_runs_without_errors() {
let actor = TaskSpawningActor::default();
let addr = actor.spawn();
for _ in 0..5 {
addr.call(SpawnQuickTask { duration_ms: 50 }).await.unwrap();
}
sleep(Duration::from_millis(100)).await;
addr.send(TriggerGC).await.unwrap();
addr.send(TriggerGC).await.unwrap();
let flag = addr.call(SpawnQuickTask { duration_ms: 10 }).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(
flag.load(Ordering::SeqCst),
"Actor should still work after GC"
);
}
#[test_log::test(tokio::test)]
async fn gc_with_mixed_task_states() {
let actor = TaskSpawningActor::default();
let addr = actor.spawn();
let quick_flags: Vec<_> =
futures::future::join_all((0..3).map(|_| addr.call(SpawnQuickTask { duration_ms: 50 })))
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
let slow_flags: Vec<_> =
futures::future::join_all((0..2).map(|_| addr.call(SpawnQuickTask { duration_ms: 500 })))
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
sleep(Duration::from_millis(100)).await;
for flag in &quick_flags {
assert!(flag.load(Ordering::SeqCst), "Quick task should be done");
}
for flag in &slow_flags {
assert!(
!flag.load(Ordering::SeqCst),
"Slow task should still be running"
);
}
addr.send(TriggerGC).await.unwrap();
sleep(Duration::from_millis(500)).await;
for flag in &slow_flags {
assert!(flag.load(Ordering::SeqCst), "Slow task should now be done");
}
addr.send(TriggerGC).await.unwrap();
let new_flag = addr.call(SpawnQuickTask { duration_ms: 10 }).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(new_flag.load(Ordering::SeqCst));
}
#[derive(Default)]
struct TaskTrackingActor;
impl Actor for TaskTrackingActor {}
#[message(response = TaskHandle)]
struct SpawnTrackedTask {
duration_ms: u64,
}
impl Handler<SpawnTrackedTask> for TaskTrackingActor {
async fn handle(&mut self, ctx: &mut Context<Self>, msg: SpawnTrackedTask) -> TaskHandle {
ctx.spawn_task(async move {
sleep(Duration::from_millis(msg.duration_ms)).await;
})
}
}
#[message(response = Option<bool>)]
struct CheckTaskFinished {
handle: TaskHandle,
}
impl Handler<CheckTaskFinished> for TaskTrackingActor {
async fn handle(&mut self, ctx: &mut Context<Self>, msg: CheckTaskFinished) -> Option<bool> {
ctx.is_task_finished(&msg.handle)
}
}
#[message]
struct StopTrackedTask {
handle: TaskHandle,
}
impl Handler<StopTrackedTask> for TaskTrackingActor {
async fn handle(&mut self, ctx: &mut Context<Self>, msg: StopTrackedTask) {
ctx.stop_task(msg.handle);
}
}
#[test_log::test(tokio::test)]
async fn is_task_finished_tracks_task_state() {
let addr = TaskTrackingActor.spawn();
let handle: TaskHandle = addr
.call(SpawnTrackedTask { duration_ms: 200 })
.await
.unwrap();
let is_finished = addr.call(CheckTaskFinished { handle }).await.unwrap();
assert_eq!(is_finished, Some(false), "Task should not be finished yet");
sleep(Duration::from_millis(250)).await;
let is_finished = addr.call(CheckTaskFinished { handle }).await.unwrap();
assert_eq!(is_finished, Some(true), "Task should be finished now");
}
#[test_log::test(tokio::test)]
async fn stop_task_aborts_and_removes_task() {
let addr = TaskTrackingActor.spawn();
let handle: TaskHandle = addr
.call(SpawnTrackedTask { duration_ms: 5000 })
.await
.unwrap();
sleep(Duration::from_millis(50)).await;
let is_finished = addr.call(CheckTaskFinished { handle }).await.unwrap();
assert_eq!(is_finished, Some(false), "Task should be running");
addr.send(StopTrackedTask { handle }).await.unwrap();
let is_finished = addr.call(CheckTaskFinished { handle }).await.unwrap();
assert_eq!(is_finished, None, "Task should be removed after stop");
}
#[test_log::test(tokio::test)]
async fn multiple_gc_cycles() {
let actor = TaskSpawningActor::default();
let addr = actor.spawn();
for cycle in 0..5 {
let flags: Vec<_> = futures::future::join_all(
(0..3).map(|_| addr.call(SpawnQuickTask { duration_ms: 50 })),
)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
sleep(Duration::from_millis(100)).await;
for (i, flag) in flags.iter().enumerate() {
assert!(
flag.load(Ordering::SeqCst),
"Cycle {}: Task {} should be done",
cycle,
i
);
}
addr.send(TriggerGC).await.unwrap();
}
let final_flag = addr.call(SpawnQuickTask { duration_ms: 10 }).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(final_flag.load(Ordering::SeqCst));
}
#[test_log::test(tokio::test)]
async fn tasks_aborted_on_actor_stop() {
let actor = TaskSpawningActor::default();
let mut addr = actor.spawn();
let flags: Vec<_> = futures::future::join_all((0..3).map(|_| {
addr.call(SpawnQuickTask {
duration_ms: 10_000,
})
}))
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
sleep(Duration::from_millis(50)).await;
for flag in &flags {
assert!(!flag.load(Ordering::SeqCst), "Task should still be running");
}
addr.stop().unwrap();
addr.await.unwrap();
sleep(Duration::from_millis(50)).await;
for flag in &flags {
assert!(
!flag.load(Ordering::SeqCst),
"Task should have been aborted, not completed"
);
}
}
#[test_log::test(tokio::test)]
async fn check_invalid_handle_returns_none() {
let addr = TaskTrackingActor.spawn();
let handle: TaskHandle = addr
.call(SpawnTrackedTask { duration_ms: 10 })
.await
.unwrap();
addr.send(StopTrackedTask { handle }).await.unwrap();
let result = addr.call(CheckTaskFinished { handle }).await.unwrap();
assert_eq!(result, None, "Checking removed task should return None");
}
#[test_log::test(tokio::test)]
async fn stress_test_many_tasks() {
let actor = TaskSpawningActor::default();
let addr = actor.spawn();
let flags: Vec<_> =
futures::future::join_all((0..50).map(|_| addr.call(SpawnQuickTask { duration_ms: 50 })))
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
sleep(Duration::from_millis(150)).await;
for flag in &flags {
assert!(flag.load(Ordering::SeqCst), "Task should be done");
}
addr.send(TriggerGC).await.unwrap();
let new_flag = addr.call(SpawnQuickTask { duration_ms: 10 }).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(new_flag.load(Ordering::SeqCst));
}
#[derive(Default, Actor)]
struct ParentActor {
generic_child_count: usize,
typed_child_count: usize,
}
#[derive(Default, Actor)]
#[allow(dead_code)]
struct ChildActor {
id: usize,
}
impl Handler<()> for ChildActor {
async fn handle(&mut self, _ctx: &mut Context<Self>, _msg: ()) {}
}
#[derive(Clone)]
#[message]
#[allow(dead_code)]
struct UpdateChild(String);
impl Handler<UpdateChild> for ChildActor {
async fn handle(&mut self, _ctx: &mut Context<Self>, _msg: UpdateChild) {}
}
#[message(response = (Addr<ChildActor>, usize))]
struct AddGenericChild;
impl Handler<AddGenericChild> for ParentActor {
async fn handle(
&mut self,
ctx: &mut Context<Self>,
_: AddGenericChild,
) -> (Addr<ChildActor>, usize) {
let id = self.generic_child_count;
self.generic_child_count += 1;
let child = ChildActor { id }.spawn();
ctx.add_child(child.clone());
(child, id)
}
}
#[message(response = (Addr<ChildActor>, usize))]
struct AddTypedChild;
impl Handler<AddTypedChild> for ParentActor {
async fn handle(
&mut self,
ctx: &mut Context<Self>,
_: AddTypedChild,
) -> (Addr<ChildActor>, usize) {
let id = self.typed_child_count;
self.typed_child_count += 1;
let child = ChildActor { id }.spawn();
ctx.register_child::<UpdateChild>(child.clone());
(child, id)
}
}
#[message]
struct TriggerChildGC;
impl Handler<TriggerChildGC> for ParentActor {
async fn handle(&mut self, ctx: &mut Context<Self>, _: TriggerChildGC) {
ctx.gc();
}
}
#[message]
struct SendToTypedChildren;
impl Handler<SendToTypedChildren> for ParentActor {
async fn handle(&mut self, ctx: &mut Context<Self>, _: SendToTypedChildren) {
ctx.send_to_children(UpdateChild("test".to_string()));
}
}
#[test_log::test(tokio::test)]
async fn gc_removes_stopped_generic_children() {
let parent = ParentActor::default().spawn();
let (child_addr, _id) = parent.call(AddGenericChild).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(child_addr.ping().await.is_ok(), "Child should be running");
parent.send(TriggerChildGC).await.unwrap();
sleep(Duration::from_millis(50)).await;
drop(child_addr);
sleep(Duration::from_millis(50)).await;
parent.send(TriggerChildGC).await.unwrap();
parent.send(TriggerChildGC).await.unwrap();
}
#[test_log::test(tokio::test)]
async fn gc_removes_stopped_typed_children() {
let parent = ParentActor::default().spawn();
for _ in 0..10 {
let (child_addr, _id) = parent.call(AddTypedChild).await.unwrap();
drop(child_addr); sleep(Duration::from_millis(10)).await;
}
parent.send(TriggerChildGC).await.unwrap();
parent.send(TriggerChildGC).await.unwrap();
parent.send(SendToTypedChildren).await.unwrap();
parent.send(TriggerChildGC).await.unwrap();
}
#[test_log::test(tokio::test)]
async fn gc_with_mixed_child_types() {
let parent = ParentActor::default().spawn();
let (generic_child, _) = parent.call(AddGenericChild).await.unwrap();
let (typed_child, _) = parent.call(AddTypedChild).await.unwrap();
sleep(Duration::from_millis(50)).await;
assert!(generic_child.ping().await.is_ok());
assert!(typed_child.ping().await.is_ok());
parent.send(TriggerChildGC).await.unwrap();
sleep(Duration::from_millis(50)).await;
drop(generic_child);
sleep(Duration::from_millis(50)).await;
parent.send(TriggerChildGC).await.unwrap();
assert!(typed_child.ping().await.is_ok());
drop(typed_child);
parent.send(TriggerChildGC).await.unwrap();
}
#[test_log::test(tokio::test)]
async fn typed_children_memory_leak_demonstration() {
let parent = ParentActor::default().spawn();
for i in 0..100 {
let (child, id) = parent.call(AddTypedChild).await.unwrap();
assert_eq!(id, i, "Child ID should match iteration");
drop(child); }
sleep(Duration::from_millis(100)).await;
parent.send(TriggerChildGC).await.unwrap();
parent.send(SendToTypedChildren).await.unwrap();
}