ash-flare 2.3.2

Fault-tolerant supervision trees for Rust with distributed capabilities inspired by Erlang/OTP
Documentation
use ash_flare::{RestartPolicy, SupervisorHandle, SupervisorSpec, Worker};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::time::{Duration, sleep};

struct TestWorker {
    counter: Arc<AtomicU32>,
    fail_after: Option<u32>,
}

#[async_trait]
impl Worker for TestWorker {
    type Error = std::io::Error;

    async fn run(&mut self) -> Result<(), Self::Error> {
        loop {
            let count = self.counter.fetch_add(1, Ordering::SeqCst);
            if let Some(limit) = self.fail_after {
                if count >= limit {
                    return Err(std::io::Error::new(
                        std::io::ErrorKind::Other,
                        "intentional failure",
                    ));
                }
            }
            sleep(Duration::from_millis(10)).await;
        }
    }
}

#[tokio::test]
async fn test_worker_basic_lifecycle() {
    let counter = Arc::new(AtomicU32::new(0));
    let counter_clone = Arc::clone(&counter);

    let spec = SupervisorSpec::new("test").with_worker(
        "worker-1",
        move || TestWorker {
            counter: Arc::clone(&counter_clone),
            fail_after: Some(5),
        },
        RestartPolicy::Temporary,
    );

    let handle = SupervisorHandle::start(spec);
    sleep(Duration::from_millis(100)).await;

    let count = counter.load(Ordering::SeqCst);
    assert!(count >= 5, "Worker should have run at least 5 times");

    handle.shutdown().await.unwrap();
}

#[tokio::test]
async fn test_worker_permanent_restart() {
    let counter = Arc::new(AtomicU32::new(0));
    let counter_clone = Arc::clone(&counter);

    let spec = SupervisorSpec::new("test").with_worker(
        "worker-1",
        move || TestWorker {
            counter: Arc::clone(&counter_clone),
            fail_after: Some(3),
        },
        RestartPolicy::Permanent,
    );

    let handle = SupervisorHandle::start(spec);
    sleep(Duration::from_millis(200)).await;

    let count = counter.load(Ordering::SeqCst);
    assert!(
        count > 3,
        "Worker should have restarted and continued counting"
    );

    handle.shutdown().await.unwrap();
}

#[tokio::test]
async fn test_worker_transient_normal_exit() {
    let counter = Arc::new(AtomicU32::new(0));
    let counter_clone = Arc::clone(&counter);

    let spec = SupervisorSpec::new("test").with_worker(
        "worker-1",
        move || TestWorker {
            counter: Arc::clone(&counter_clone),
            fail_after: None,
        },
        RestartPolicy::Transient,
    );

    let handle = SupervisorHandle::start(spec);
    sleep(Duration::from_millis(50)).await;

    let children = handle.which_children().await.unwrap();
    assert_eq!(children.len(), 1);

    handle.shutdown().await.unwrap();
}

#[tokio::test]
async fn test_multiple_workers() {
    let counter1 = Arc::new(AtomicU32::new(0));
    let counter2 = Arc::new(AtomicU32::new(0));
    let c1 = Arc::clone(&counter1);
    let c2 = Arc::clone(&counter2);

    let spec = SupervisorSpec::new("test")
        .with_worker(
            "worker-1",
            move || TestWorker {
                counter: Arc::clone(&c1),
                fail_after: None,
            },
            RestartPolicy::Permanent,
        )
        .with_worker(
            "worker-2",
            move || TestWorker {
                counter: Arc::clone(&c2),
                fail_after: None,
            },
            RestartPolicy::Permanent,
        );

    let handle = SupervisorHandle::start(spec);
    sleep(Duration::from_millis(50)).await;

    let count1 = counter1.load(Ordering::SeqCst);
    let count2 = counter2.load(Ordering::SeqCst);

    assert!(count1 > 0, "Worker 1 should have executed");
    assert!(count2 > 0, "Worker 2 should have executed");

    handle.shutdown().await.unwrap();
}