use super::PastaLogger;
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, OnceLock};
use tracing_subscriber::fmt::MakeWriter;
static REGISTRY: OnceLock<GlobalLoggerRegistry> = OnceLock::new();
#[derive(Clone)]
pub struct GlobalLoggerRegistry {
loggers: Arc<Mutex<HashMap<PathBuf, Arc<PastaLogger>>>>,
}
impl GlobalLoggerRegistry {
fn new() -> Self {
Self {
loggers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn instance() -> &'static Self {
REGISTRY.get_or_init(GlobalLoggerRegistry::new)
}
pub fn register(&self, load_dir: PathBuf, logger: Arc<PastaLogger>) {
let mut loggers = self
.loggers
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
loggers.insert(load_dir, logger);
}
pub fn unregister(&self, load_dir: &Path) {
let mut loggers = self
.loggers
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
loggers.remove(load_dir);
}
pub fn get(&self, load_dir: &Path) -> Option<Arc<PastaLogger>> {
let loggers = self
.loggers
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
loggers.get(load_dir).cloned()
}
}
pub struct RoutingWriter {
logger: Option<Arc<PastaLogger>>,
}
impl Write for RoutingWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if let Some(ref logger) = self.logger {
logger.write(buf)
} else {
Ok(buf.len())
}
}
fn flush(&mut self) -> io::Result<()> {
if let Some(ref logger) = self.logger {
logger.flush()
} else {
Ok(())
}
}
}
impl<'a> MakeWriter<'a> for GlobalLoggerRegistry {
type Writer = RoutingWriter;
fn make_writer(&'a self) -> Self::Writer {
let load_dir = CURRENT_LOAD_DIR.with(|cell| cell.borrow().clone());
let logger = load_dir.and_then(|path| self.get(&path));
RoutingWriter { logger }
}
}
thread_local! {
static CURRENT_LOAD_DIR: std::cell::RefCell<Option<PathBuf>> = const { std::cell::RefCell::new(None) };
}
pub fn set_current_load_dir(load_dir: Option<PathBuf>) {
CURRENT_LOAD_DIR.with(|cell| {
*cell.borrow_mut() = load_dir;
});
}
pub fn get_current_load_dir() -> Option<PathBuf> {
CURRENT_LOAD_DIR.with(|cell| cell.borrow().clone())
}
pub struct LoadDirGuard {
previous: Option<PathBuf>,
}
impl LoadDirGuard {
pub fn new(load_dir: PathBuf) -> Self {
let previous = get_current_load_dir();
set_current_load_dir(Some(load_dir));
Self { previous }
}
}
impl Drop for LoadDirGuard {
fn drop(&mut self) {
set_current_load_dir(self.previous.take());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_singleton() {
let r1 = GlobalLoggerRegistry::instance();
let r2 = GlobalLoggerRegistry::instance();
assert!(std::ptr::eq(r1, r2));
}
#[test]
fn test_load_dir_context() {
assert!(get_current_load_dir().is_none());
set_current_load_dir(Some(PathBuf::from("/test/path")));
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/test/path")));
set_current_load_dir(None);
assert!(get_current_load_dir().is_none());
}
#[test]
fn test_load_dir_guard() {
set_current_load_dir(Some(PathBuf::from("/original")));
{
let _guard = LoadDirGuard::new(PathBuf::from("/guarded"));
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/guarded")));
}
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/original")));
set_current_load_dir(None);
}
#[test]
fn test_load_dir_guard_nested() {
set_current_load_dir(None);
{
let _outer = LoadDirGuard::new(PathBuf::from("/outer"));
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/outer")));
{
let _inner = LoadDirGuard::new(PathBuf::from("/inner"));
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/inner")));
}
assert_eq!(get_current_load_dir(), Some(PathBuf::from("/outer")));
}
assert!(get_current_load_dir().is_none());
}
#[test]
fn test_registry_register_get_unregister() {
let temp_dir = tempfile::TempDir::new().unwrap();
let load_dir = temp_dir.path().to_path_buf();
let registry = GlobalLoggerRegistry::instance();
assert!(registry.get(&load_dir).is_none());
let logger = Arc::new(PastaLogger::new(&load_dir, None).unwrap());
registry.register(load_dir.clone(), logger.clone());
let fetched = registry.get(&load_dir).expect("logger should be registered");
assert_eq!(fetched.log_path(), logger.log_path());
registry.unregister(&load_dir);
assert!(registry.get(&load_dir).is_none());
}
#[test]
fn test_registry_register_replaces_existing() {
let temp_a = tempfile::TempDir::new().unwrap();
let temp_b = tempfile::TempDir::new().unwrap();
let key = temp_a.path().to_path_buf();
let registry = GlobalLoggerRegistry::instance();
let first = Arc::new(PastaLogger::new(temp_a.path(), None).unwrap());
let second = Arc::new(PastaLogger::new(temp_b.path(), None).unwrap());
registry.register(key.clone(), first.clone());
registry.register(key.clone(), second.clone());
let fetched = registry.get(&key).expect("logger should be registered");
assert_eq!(
fetched.log_path(),
second.log_path(),
"second registration should replace the first"
);
registry.unregister(&key);
}
#[test]
fn test_make_writer_without_context_is_noop() {
set_current_load_dir(None);
let registry = GlobalLoggerRegistry::instance();
let mut writer = registry.make_writer();
let n = writer.write(b"discarded").unwrap();
assert_eq!(n, b"discarded".len());
writer.flush().unwrap();
}
#[test]
fn test_make_writer_with_unregistered_dir_is_noop() {
let temp_dir = tempfile::TempDir::new().unwrap();
let _guard = LoadDirGuard::new(temp_dir.path().to_path_buf());
let registry = GlobalLoggerRegistry::instance();
let mut writer = registry.make_writer();
let n = writer.write(b"discarded").unwrap();
assert_eq!(n, b"discarded".len());
writer.flush().unwrap();
}
#[test]
fn test_make_writer_routes_to_registered_logger() {
let temp_dir = tempfile::TempDir::new().unwrap();
let load_dir = temp_dir.path().to_path_buf();
let registry = GlobalLoggerRegistry::instance();
let logger = Arc::new(PastaLogger::new(&load_dir, None).unwrap());
let log_path = logger.log_path().to_path_buf();
registry.register(load_dir.clone(), logger);
{
let _guard = LoadDirGuard::new(load_dir.clone());
let mut writer = registry.make_writer();
writer.write_all(b"routed via registry\n").unwrap();
writer.flush().unwrap();
}
registry.unregister(&load_dir);
let content = std::fs::read_to_string(&log_path).unwrap();
assert!(
content.contains("routed via registry"),
"log file should contain the routed line, got: {:?}",
content
);
}
}