groundwork/
trace.rs

1use log::{Metadata, Record};
2use serde::Serialize;
3use std::io::Write;
4use std::{
5    sync::{Arc, Mutex},
6    time::{Duration, SystemTime, UNIX_EPOCH},
7};
8use thiserror::Error;
9
10pub struct Buffer<const SIZE: usize>(circular_buffer::CircularBuffer<SIZE, u8>);
11
12pub struct SpyLogger<const SIZE: usize, T: log::Log> {
13    buffer: Arc<Mutex<Buffer<SIZE>>>,
14    logger: T,
15}
16pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
17
18pub type SpyLoggerDefault<T> = SpyLogger<DEFAULT_BUFFER_SIZE, T>;
19
20impl<const SIZE: usize, T: log::Log> SpyLogger<SIZE, T> {
21    pub fn new(logger: T) -> Self {
22        Self {
23            buffer: Arc::new(Mutex::new(Buffer::new())),
24            logger,
25        }
26    }
27
28    pub fn buffer(&self) -> Arc<Mutex<Buffer<SIZE>>> {
29        self.buffer.clone()
30    }
31}
32
33#[derive(Serialize, Debug)]
34pub struct LogLine {
35    timestamp: Duration,
36    level: u8,
37    message: String,
38}
39
40#[derive(Error, Debug)]
41pub enum LogError {
42    #[error("Found unexpected character on timestamp place")]
43    UnexpectedTimestampValue,
44    #[error("Unexpected end")]
45    ValueExpected,
46    #[error("Cannot restore message from bytes")]
47    Utf8(#[from] std::string::FromUtf8Error),
48}
49
50impl<const SIZE: usize> Default for Buffer<SIZE> {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl<const SIZE: usize> Buffer<SIZE> {
57    pub fn new() -> Self {
58        Self(circular_buffer::CircularBuffer::new())
59    }
60
61    pub fn get_logs(&mut self) -> Result<Vec<LogLine>, LogError> {
62        let mut it = self.0.iter();
63        let mut result = vec![];
64        let mut add_result = |timestamp, level, message| -> Result<(), LogError> {
65            result.push(LogLine {
66                timestamp: Duration::from_secs(timestamp),
67                level,
68                message: String::from_utf8(message).map_err(LogError::Utf8)?,
69            });
70            Ok(())
71        };
72        if it.any(|&v| v == 0) {
73            'top: loop {
74                let mut timestamp = 0u64;
75                for _ in 0..16 {
76                    timestamp =
77                        (timestamp << 4) | read_hex(*it.next().ok_or(LogError::ValueExpected)?)?;
78                }
79                let level = read_hex(*it.next().ok_or(LogError::ValueExpected)?)? as u8;
80                let mut message = vec![];
81                for &c in it.by_ref() {
82                    if c == 0 {
83                        add_result(timestamp, level, message)?;
84                        continue 'top;
85                    }
86                    message.push(c);
87                }
88                add_result(timestamp, level, message)?;
89                break;
90            }
91        }
92        Ok(result)
93    }
94
95    pub fn get_traces(&mut self) -> Result<Vec<String>, LogError> {
96        let mut it = self.0.iter();
97        let mut result = vec![];
98        if it.any(|&v| v == 0) {
99            'top: loop {
100                let mut message = vec![];
101                for &v in it.by_ref() {
102                    if v != 0 {
103                        message.push(v);
104                    } else {
105                        result.push(String::from_utf8(message).map_err(LogError::Utf8)?);
106                        continue 'top;
107                    }
108                }
109                result.push(String::from_utf8(message).map_err(LogError::Utf8)?);
110                break;
111            }
112        }
113        Ok(result)
114    }
115
116    fn write_log(&mut self, level: log::Level, message: &str) -> std::fmt::Result {
117        let timestamp = SystemTime::now()
118            .duration_since(UNIX_EPOCH)
119            .map_err(|_| std::fmt::Error)?
120            .as_secs();
121        self.0
122            .write_all(format!("\0{timestamp:016X}{}", level as u8).as_bytes())
123            .map_err(|_| std::fmt::Error)?;
124        self.0
125            .write_all(message.as_bytes())
126            .map_err(|_| std::fmt::Error)
127    }
128
129    fn write_trace(&mut self, message: &str) -> std::io::Result<()> {
130        use std::io::Write;
131        self.0.write_all(&[0u8])?;
132        self.0.write_all(message.as_bytes())
133    }
134}
135
136fn read_hex(v: u8) -> Result<u64, LogError> {
137    Ok((match v {
138        b'0'..=b'9' => v - b'0',
139        b'A'..=b'F' => v - b'A' + 10,
140        _ => Err(LogError::UnexpectedTimestampValue)?,
141    }) as u64)
142}
143
144impl<const SIZE: usize, T: log::Log> log::Log for SpyLogger<SIZE, T> {
145    fn enabled(&self, metadata: &Metadata) -> bool {
146        self.logger.enabled(metadata)
147    }
148
149    fn log(&self, record: &Record) {
150        if self.enabled(record.metadata()) {
151            let mut m = self.buffer.lock().expect("can lock buffer mutex");
152            _ = m.write_log(record.level(), &format!("{}", record.args()));
153            self.logger.log(record);
154        }
155    }
156    fn flush(&self) {
157        self.logger.flush();
158    }
159}
160
161pub struct StdoutTraceWriterMaker<const SIZE: usize> {
162    buffer: Arc<Mutex<Buffer<SIZE>>>,
163}
164
165pub struct TraceWriter<const SIZE: usize> {
166    buffer: Arc<Mutex<Buffer<SIZE>>>,
167}
168
169impl<const SIZE: usize> StdoutTraceWriterMaker<SIZE> {
170    pub fn new(buffer: Arc<Mutex<Buffer<SIZE>>>) -> Self {
171        Self { buffer }
172    }
173}
174
175impl<'a, const SIZE: usize> tracing_subscriber::fmt::MakeWriter<'a>
176    for StdoutTraceWriterMaker<SIZE>
177{
178    type Writer = TraceWriter<SIZE>;
179
180    fn make_writer(&'a self) -> Self::Writer {
181        TraceWriter {
182            buffer: self.buffer.clone(),
183        }
184    }
185}
186
187impl<const SIZE: usize> std::io::Write for TraceWriter<SIZE> {
188    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
189        self.buffer
190            .lock()
191            .unwrap()
192            .write_trace(&String::from_utf8_lossy(buf))?;
193        std::io::stdout().write(buf)
194    }
195
196    fn flush(&mut self) -> std::io::Result<()> {
197        Ok(())
198    }
199}
200
201pub struct TraceWriterWrapperMaker<
202    const SIZE: usize,
203    T: for<'a> tracing_subscriber::fmt::MakeWriter<'a>,
204> {
205    buffer: Arc<Mutex<Buffer<SIZE>>>,
206    maker: T,
207}
208
209pub struct TraceWriterWrapper<const SIZE: usize, T: std::io::Write> {
210    buffer: Arc<Mutex<Buffer<SIZE>>>,
211    writer: T,
212}
213
214impl<const SIZE: usize, T: for<'a> tracing_subscriber::fmt::MakeWriter<'a>>
215    TraceWriterWrapperMaker<SIZE, T>
216{
217    pub fn new(buffer: Arc<Mutex<Buffer<SIZE>>>, maker: T) -> Self {
218        Self { buffer, maker }
219    }
220}
221
222impl<'a, const SIZE: usize, T: for<'b> tracing_subscriber::fmt::MakeWriter<'b>>
223    tracing_subscriber::fmt::MakeWriter<'a> for TraceWriterWrapperMaker<SIZE, T>
224{
225    type Writer = TraceWriterWrapper<SIZE, <T as tracing_subscriber::fmt::MakeWriter<'a>>::Writer>;
226
227    fn make_writer(&'a self) -> Self::Writer {
228        Self::Writer {
229            buffer: self.buffer.clone(),
230            writer: self.maker.make_writer(),
231        }
232    }
233}
234
235impl<const SIZE: usize, T: std::io::Write> std::io::Write for TraceWriterWrapper<SIZE, T> {
236    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237        self.buffer
238            .lock()
239            .unwrap()
240            .write_trace(&String::from_utf8_lossy(buf))?;
241        self.writer.write(buf)
242    }
243
244    fn flush(&mut self) -> std::io::Result<()> {
245        Ok(())
246    }
247}