async_log/
logger.rs

1use crate::backtrace::async_log_capture_caller;
2use log::{set_boxed_logger, LevelFilter, Log, Metadata, Record};
3
4use std::thread;
5
6/// A Logger that wraps other loggers to extend it with async functionality.
7#[derive(Debug)]
8pub struct Logger<L: Log + 'static, F>
9where
10    F: Fn() -> u64 + Send + Sync + 'static,
11{
12    backtrace: bool,
13    logger: L,
14    with: F,
15}
16
17impl<L: Log + 'static, F> Logger<L, F>
18where
19    F: Fn() -> u64 + Send + Sync + 'static,
20{
21    /// Wrap an existing logger, extending it with async functionality.
22    pub fn wrap(logger: L, with: F) -> Self {
23        let backtrace = std::env::var_os("RUST_BACKTRACE")
24            .map(|x| &x != "0")
25            .unwrap_or(false);
26        Self {
27            logger,
28            backtrace,
29            with,
30        }
31    }
32
33    /// Start logging.
34    pub fn start(self, filter: LevelFilter) -> Result<(), log::SetLoggerError> {
35        let res = set_boxed_logger(Box::new(self));
36        if res.is_ok() {
37            log::set_max_level(filter);
38        }
39        res
40    }
41
42    /// Call the `self.with` closure, and return its results.
43    fn with(&self) -> u64 {
44        (self.with)()
45    }
46
47    /// Compute which stack frame to log based on an offset defined inside the log message.
48    /// This message is then stripped from the resulting record.
49    fn compute_stack_depth(&self, _record: &Record<'_>) -> u8 {
50        4
51    }
52}
53
54/// Get the thread id. Useful because ThreadId doesn't implement Display.
55fn thread_id() -> u64 {
56    let mut string = format!("{:?}", thread::current().id());
57    string.replace_range(0..9, "");
58    string.pop();
59    string.parse().unwrap()
60}
61
62impl<L: Log, F> log::Log for Logger<L, F>
63where
64    F: Fn() -> u64 + Send + Sync + 'static,
65{
66    fn enabled(&self, metadata: &Metadata<'_>) -> bool {
67        self.logger.enabled(metadata)
68    }
69
70    fn log(&self, record: &Record<'_>) {
71        if self.enabled(record.metadata()) {
72            let depth = self.compute_stack_depth(&record);
73            let symbol = async_log_capture_caller(depth);
74
75            let key_values = KeyValues {
76                thread_id: thread_id(),
77                task_id: self.with(),
78                kvs: record.key_values(),
79            };
80
81            let (line, filename, fn_name) = if self.backtrace {
82                match symbol {
83                    Some(symbol) => {
84                        let line = symbol
85                            .lineno
86                            .map(|l| format!(", line={}", l))
87                            .unwrap_or_else(|| String::from(""));
88
89                        let filename = symbol
90                            .filename
91                            .map(|f| format!(", filename={}", f.to_string_lossy()))
92                            .unwrap_or_else(|| String::from(""));
93
94                        let fn_name = symbol
95                            .name
96                            .map(|l| format!(", fn_name={}", l))
97                            .unwrap_or_else(|| String::from(""));
98
99                        (line, filename, fn_name)
100                    }
101                    None => (String::from(""), String::from(""), String::from("")),
102                }
103            } else {
104                (String::from(""), String::from(""), String::from(""))
105            };
106
107            // This is done this way b/c `Record` + `format_args` needs to be built inline. See:
108            // https://stackoverflow.com/q/56304313/1541707
109            self.logger.log(
110                &log::Record::builder()
111                    .args(format_args!(
112                        "{}{}{}{}",
113                        record.args(),
114                        filename,
115                        line,
116                        fn_name,
117                    ))
118                    .metadata(record.metadata().clone())
119                    .key_values(&key_values)
120                    .level(record.level())
121                    .target(record.target())
122                    .module_path(record.module_path())
123                    .file(record.file())
124                    .line(record.line())
125                    .build(),
126            )
127        }
128    }
129    fn flush(&self) {}
130}
131
132struct KeyValues<'a> {
133    thread_id: u64,
134    task_id: u64,
135    kvs: &'a dyn log::kv::Source,
136}
137impl<'a> log::kv::Source for KeyValues<'a> {
138    fn visit<'kvs>(
139        &'kvs self,
140        visitor: &mut dyn log::kv::Visitor<'kvs>,
141    ) -> Result<(), log::kv::Error> {
142        self.kvs.visit(visitor)?;
143        visitor.visit_pair("thread_id".into(), self.thread_id.into())?;
144        visitor.visit_pair("task_id".into(), self.task_id.into())?;
145        Ok(())
146    }
147}