nidus-core 1.0.0

Core Nidus dependency injection, modules, provider lifetimes, and application lifecycle.
Documentation
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use nidus_core::{LifecycleHook, LifecycleRunner, NidusError};
use tracing::Level;
use tracing_subscriber::{Layer, fmt::MakeWriter, layer::SubscriberExt};

#[derive(Clone, Default)]
struct SharedLogWriter {
    output: Arc<Mutex<Vec<u8>>>,
}

impl SharedLogWriter {
    fn contents(&self) -> String {
        String::from_utf8(self.output.lock().unwrap().clone()).unwrap()
    }
}

impl<'writer> MakeWriter<'writer> for SharedLogWriter {
    type Writer = SharedLogGuard;

    fn make_writer(&'writer self) -> Self::Writer {
        SharedLogGuard {
            output: Arc::clone(&self.output),
        }
    }
}

struct SharedLogGuard {
    output: Arc<Mutex<Vec<u8>>>,
}

impl std::io::Write for SharedLogGuard {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.output.lock().unwrap().extend_from_slice(buf);
        Ok(buf.len())
    }

    fn flush(&mut self) -> std::io::Result<()> {
        Ok(())
    }
}

#[derive(Clone)]
struct RecordingHook {
    name: &'static str,
    events: Arc<Mutex<Vec<String>>>,
}

#[derive(Clone)]
struct FailingStartupHook {
    name: &'static str,
    events: Arc<Mutex<Vec<String>>>,
}

#[derive(Clone)]
struct FailingShutdownHook {
    name: &'static str,
    events: Arc<Mutex<Vec<String>>>,
}

#[async_trait]
impl LifecycleHook for RecordingHook {
    async fn on_startup(&self) -> nidus_core::Result<()> {
        self.events
            .lock()
            .unwrap()
            .push(format!("{}:startup", self.name));
        Ok(())
    }

    async fn on_shutdown(&self) -> nidus_core::Result<()> {
        self.events
            .lock()
            .unwrap()
            .push(format!("{}:shutdown", self.name));
        Ok(())
    }
}

#[async_trait]
impl LifecycleHook for FailingStartupHook {
    async fn on_startup(&self) -> nidus_core::Result<()> {
        self.events
            .lock()
            .unwrap()
            .push(format!("{}:startup", self.name));
        Err(NidusError::MissingProvider {
            type_name: self.name,
        })
    }
}

#[async_trait]
impl LifecycleHook for FailingShutdownHook {
    async fn on_startup(&self) -> nidus_core::Result<()> {
        self.events
            .lock()
            .unwrap()
            .push(format!("{}:startup", self.name));
        Ok(())
    }

    async fn on_shutdown(&self) -> nidus_core::Result<()> {
        self.events
            .lock()
            .unwrap()
            .push(format!("{}:shutdown", self.name));
        Err(NidusError::DuplicateProvider {
            type_name: self.name,
        })
    }
}

#[tokio::test]
async fn lifecycle_runner_starts_in_order_and_shuts_down_in_reverse_order() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let runner = LifecycleRunner::new()
        .hook(RecordingHook {
            name: "database",
            events: Arc::clone(&events),
        })
        .hook(RecordingHook {
            name: "server",
            events: Arc::clone(&events),
        });

    runner.startup().await.unwrap();
    runner.shutdown().await.unwrap();

    assert_eq!(
        *events.lock().unwrap(),
        [
            "database:startup",
            "server:startup",
            "server:shutdown",
            "database:shutdown"
        ]
    );
}

#[tokio::test(flavor = "current_thread")]
async fn lifecycle_runner_emits_startup_and_shutdown_debug_logs() {
    let writer = SharedLogWriter::default();
    let subscriber = tracing_subscriber::registry().with(
        tracing_subscriber::fmt::layer()
            .with_writer(writer.clone())
            .with_ansi(false)
            .with_target(false)
            .with_filter(tracing_subscriber::filter::LevelFilter::from_level(
                Level::DEBUG,
            )),
    );
    let _guard = tracing::subscriber::set_default(subscriber);
    let events = Arc::new(Mutex::new(Vec::new()));
    let runner = LifecycleRunner::new().hook(RecordingHook {
        name: "database",
        events,
    });

    runner.startup().await.unwrap();
    runner.shutdown().await.unwrap();

    let logs = writer.contents();
    assert!(logs.contains("lifecycle startup begin"));
    assert!(logs.contains("lifecycle startup hook begin"));
    assert!(logs.contains("lifecycle startup hook complete"));
    assert!(logs.contains("lifecycle shutdown hook begin"));
    assert!(logs.contains("lifecycle shutdown hook complete"));
    assert!(logs.contains("lifecycle shutdown complete"));
}

#[tokio::test]
async fn lifecycle_runner_rolls_back_started_hooks_when_startup_fails() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let runner = LifecycleRunner::new()
        .hook(RecordingHook {
            name: "database",
            events: Arc::clone(&events),
        })
        .hook(FailingStartupHook {
            name: "server",
            events: Arc::clone(&events),
        });

    let error = runner.startup().await.unwrap_err();

    let NidusError::LifecycleStartup {
        source,
        rollback_errors,
    } = error
    else {
        panic!("expected lifecycle startup error");
    };
    assert!(matches!(*source, NidusError::MissingProvider { .. }));
    assert!(rollback_errors.is_empty());
    assert_eq!(
        *events.lock().unwrap(),
        ["database:startup", "server:startup", "database:shutdown"]
    );
}

#[tokio::test(flavor = "current_thread")]
async fn lifecycle_runner_emits_failure_and_rollback_logs() {
    let writer = SharedLogWriter::default();
    let subscriber = tracing_subscriber::registry().with(
        tracing_subscriber::fmt::layer()
            .with_writer(writer.clone())
            .with_ansi(false)
            .with_target(false)
            .with_filter(tracing_subscriber::filter::LevelFilter::from_level(
                Level::DEBUG,
            )),
    );
    let _guard = tracing::subscriber::set_default(subscriber);
    let events = Arc::new(Mutex::new(Vec::new()));
    let runner = LifecycleRunner::new()
        .hook(RecordingHook {
            name: "database",
            events: Arc::clone(&events),
        })
        .hook(FailingStartupHook {
            name: "server",
            events,
        });

    let error = runner.startup().await.unwrap_err();

    assert!(matches!(error, NidusError::LifecycleStartup { .. }));
    let logs = writer.contents();
    assert!(logs.contains("lifecycle startup hook failed"));
    assert!(logs.contains("lifecycle startup rollback hook begin"));
    assert!(logs.contains("lifecycle startup rollback hook complete"));
}

#[tokio::test]
async fn lifecycle_runner_reports_rollback_errors() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let runner = LifecycleRunner::new()
        .hook(FailingShutdownHook {
            name: "cache",
            events: Arc::clone(&events),
        })
        .hook(FailingStartupHook {
            name: "server",
            events: Arc::clone(&events),
        });

    let error = runner.startup().await.unwrap_err();

    let NidusError::LifecycleStartup {
        source,
        rollback_errors,
    } = error
    else {
        panic!("expected lifecycle startup error");
    };
    assert!(matches!(*source, NidusError::MissingProvider { .. }));
    assert_eq!(rollback_errors.len(), 1);
    assert!(matches!(
        rollback_errors[0],
        NidusError::DuplicateProvider { .. }
    ));
    assert_eq!(
        *events.lock().unwrap(),
        ["cache:startup", "server:startup", "cache:shutdown"]
    );
}