mt_mock_logger/
lib.rs

1use std::{ collections::HashMap, sync::{ LazyLock, RwLock }, thread::ThreadId };
2
3static MOCK_LOGGER: MockLogger = MockLogger::new();
4
5pub struct MockLoggerGuard;
6
7impl Drop for MockLoggerGuard {
8    fn drop(&mut self) {
9        MockLogger::remove_logger();
10    }
11}
12
13pub struct MockLogger {
14    #[allow(clippy::type_complexity)]
15    mutex: LazyLock<RwLock<HashMap<ThreadId, (Box<dyn log::Log>, log::LevelFilter)>>>,
16}
17
18impl MockLogger {
19    const fn new() -> Self {
20        MockLogger {
21            mutex: LazyLock::new(|| {
22                let _ = log::set_logger(&MOCK_LOGGER);
23                log::set_max_level(log::LevelFilter::Trace);
24                RwLock::new(HashMap::new())
25            }),
26        }
27    }
28
29    pub fn set_logger(
30        logger: impl log::Log + 'static,
31        max_level: log::LevelFilter
32    ) -> MockLoggerGuard {
33        MOCK_LOGGER.mutex.write()
34            .expect("mutex is poisoned")
35            .insert(std::thread::current().id(), (Box::new(logger), max_level));
36
37        MockLoggerGuard
38    }
39
40    fn remove_logger() {
41        MOCK_LOGGER.mutex.write().expect("mutex is poisoned").remove(&std::thread::current().id());
42    }
43}
44
45impl log::Log for MockLogger {
46    fn enabled(&self, metadata: &log::Metadata) -> bool {
47        if
48            let Some((logger, _)) = self.mutex
49                .read()
50                .expect("mutex is poisoned")
51                .get(&std::thread::current().id())
52        {
53            return logger.enabled(metadata);
54        }
55
56        false
57    }
58
59    fn log(&self, record: &log::Record) {
60        if
61            let Some((logger, max_level)) = self.mutex
62                .read()
63                .expect("mutex is poisoned")
64                .get(&std::thread::current().id())
65        {
66            if record.level() <= *max_level {
67                logger.log(record);
68            }
69        }
70    }
71
72    fn flush(&self) {
73        if
74            let Some((logger, _)) = self.mutex
75                .read()
76                .expect("mutex is poisoned")
77                .get(&std::thread::current().id())
78        {
79            logger.flush();
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use mockall::mock;
87
88    use super::*;
89
90    mock! {
91        pub MyLogger {}
92        impl log::Log for MyLogger {
93            fn enabled<'a>(&self, metadata: &log::Metadata<'a>) -> bool;
94            fn log<'a>(&self, record: &log::Record<'a>);
95            fn flush(&self);
96        }
97    }
98
99    #[test]
100    fn test_logging() {
101        let mut my_logger = MockMyLogger::new();
102        my_logger
103            .expect_log()
104            .withf(|r| r.level() == log::LevelFilter::Info && r.args().as_str() == Some("ok"))
105            .once()
106            .return_const(());
107
108        let _guard = MockLogger::set_logger(my_logger, log::LevelFilter::Info);
109
110        log::info!("ok");
111        log::trace!("ok");
112    }
113
114    #[test]
115    fn test_logging_below_max_level() {
116        let mut my_logger = MockMyLogger::new();
117        my_logger.expect_log().never().return_const(());
118
119        let _guard = MockLogger::set_logger(my_logger, log::LevelFilter::Info);
120
121        log::trace!("ok");
122    }
123
124    #[test]
125    fn test_no_logger() {
126        log::trace!("ok");
127    }
128}