faucet_server/server/
logging.rs

1use hyper::{http::HeaderValue, Method, Request, Response, Uri, Version};
2
3use super::onion::{Layer, Service};
4use crate::{server::service::State, telemetry::TelemetrySender};
5use std::{net::IpAddr, time};
6
7pub mod logger {
8    use std::{io::BufWriter, io::Write, path::PathBuf};
9
10    use hyper::body::Bytes;
11    use tokio::task::JoinHandle;
12
13    use crate::shutdown::ShutdownSignal;
14
15    pub enum Target {
16        Stderr,
17        File(PathBuf),
18    }
19
20    struct LogFileWriter {
21        sender: tokio::sync::mpsc::Sender<Bytes>,
22    }
23
24    impl std::io::Write for LogFileWriter {
25        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26            let _ = self.sender.try_send(Bytes::copy_from_slice(buf));
27            Ok(buf.len())
28        }
29        fn flush(&mut self) -> std::io::Result<()> {
30            Ok(())
31        }
32    }
33
34    fn start_log_writer_thread(
35        path: PathBuf,
36        max_file_size: Option<u64>,
37        shutdown: ShutdownSignal,
38    ) -> (LogFileWriter, JoinHandle<()>) {
39        let max_file_size = max_file_size.unwrap_or(u64::MAX);
40        let mut current_file_size = match std::fs::metadata(&path) {
41            Ok(md) => md.len(),
42            Err(_) => 0,
43        };
44        let file = std::fs::File::options()
45            .create(true)
46            .append(true)
47            .truncate(false)
48            .open(&path)
49            .expect("Unable to open or create log file");
50
51        // Create a file path to a backup of the previous logs with MAX file size
52        let mut copy_path = path.clone();
53        copy_path.as_mut_os_string().push(".bak");
54
55        let mut writer = BufWriter::new(file);
56        let mut stderr = BufWriter::new(std::io::stderr());
57        let (sender, mut receiver) = tokio::sync::mpsc::channel::<Bytes>(1000);
58        let writer_thread = tokio::task::spawn(async move {
59            loop {
60                tokio::select! {
61                    bytes = receiver.recv() => {
62                        match bytes {
63                            Some(bytes) => {
64                                if let Err(e) = stderr.write_all(bytes.as_ref()) {
65                                    eprintln!("Unable to write to stderr: {e}");
66                                };
67
68                                if let Err(e) = writer.write_all(bytes.as_ref()) {
69                                    eprintln!("Unable to write to {path:?}: {e}");
70                                };
71
72                                current_file_size += bytes.len() as u64;
73                                if current_file_size > max_file_size {
74                                    // Flush the writer
75                                    let _ = writer.flush();
76                                    let file = writer.get_mut();
77
78                                    // Copy the current file to the backup
79                                    if let Err(e) = std::fs::copy(&path, &copy_path) {
80                                        log::error!("Unable to copy logs to backup file: {e}");
81                                    }
82
83                                    // Truncate the logs file
84                                    if let Err(e) = file.set_len(0) {
85                                        log::error!("Unable to truncate logs file: {e}");
86                                    }
87
88                                    current_file_size = 0;
89                                }
90                            },
91                            None => break
92                        }
93                    },
94                    _ = shutdown.wait() => break
95                }
96            }
97            let _ = writer.flush();
98            let _ = stderr.flush();
99        });
100        (LogFileWriter { sender }, writer_thread)
101    }
102
103    pub fn build_logger(
104        target: Target,
105        max_file_size: Option<u64>,
106        shutdown: ShutdownSignal,
107    ) -> Option<JoinHandle<()>> {
108        let (target, handle) = match target {
109            Target::File(path) => {
110                let (writer, handle) = start_log_writer_thread(path, max_file_size, shutdown);
111                (env_logger::Target::Pipe(Box::new(writer)), Some(handle))
112            }
113            Target::Stderr => (env_logger::Target::Stderr, None),
114        };
115
116        let mut env_builder = env_logger::Builder::new();
117        env_builder
118            .parse_env(env_logger::Env::new().filter_or("FAUCET_LOG", "info"))
119            .target(target)
120            .init();
121
122        handle
123    }
124}
125
126#[derive(Clone, Copy)]
127pub struct StateData {
128    pub uuid: uuid::Uuid,
129    pub ip: IpAddr,
130    pub worker_route: Option<&'static str>,
131    pub worker_id: usize,
132    pub target: &'static str,
133}
134
135trait StateLogData: Send + Sync + 'static {
136    fn get_state_data(&self) -> StateData;
137}
138
139impl StateLogData for State {
140    #[inline(always)]
141    fn get_state_data(&self) -> StateData {
142        let uuid = self.uuid;
143        let ip = self.remote_addr;
144        let worker_id = self.client.config.worker_id;
145        let worker_route = self.client.config.worker_route;
146        let target = self.client.config.target;
147        StateData {
148            uuid,
149            ip,
150            worker_id,
151            worker_route,
152            target,
153        }
154    }
155}
156
157#[derive(PartialEq, Eq)]
158pub enum LogOption<T> {
159    None,
160    Some(T),
161}
162
163impl<T> From<Option<T>> for LogOption<T> {
164    fn from(opt: Option<T>) -> Self {
165        match opt {
166            None => LogOption::None,
167            Some(v) => LogOption::Some(v),
168        }
169    }
170}
171
172impl<T> std::fmt::Display for LogOption<T>
173where
174    T: std::fmt::Display,
175{
176    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
177        match self {
178            LogOption::None => write!(f, "-"),
179            LogOption::Some(v) => write!(f, "{}", v),
180        }
181    }
182}
183
184impl<T> std::fmt::Debug for LogOption<T>
185where
186    T: std::fmt::Debug,
187{
188    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
189        match self {
190            LogOption::None => write!(f, r#""-""#),
191            LogOption::Some(v) => write!(f, "{:?}", v),
192        }
193    }
194}
195
196pub struct LogData {
197    pub state_data: StateData,
198    pub method: Method,
199    pub path: Uri,
200    pub version: Version,
201    pub status: i16,
202    pub user_agent: LogOption<HeaderValue>,
203    pub elapsed: i64,
204}
205
206impl LogData {
207    fn log(&self) {
208        log::info!(
209            target: self.state_data.target,
210            r#"{ip} "{method} {route}{path} {version:?}" {status} {user_agent:?} {elapsed}"#,
211            route = self.state_data.worker_route.map(|r| r.trim_end_matches('/')).unwrap_or_default(),
212            ip = self.state_data.ip,
213            method = self.method,
214            path = self.path,
215            version = self.version,
216            status = self.status,
217            user_agent = self.user_agent,
218            elapsed = self.elapsed,
219        );
220    }
221}
222
223#[inline(always)]
224async fn capture_log_data<Body, ResBody, Error, State: StateLogData>(
225    inner: &impl Service<Request<Body>, Response = Response<ResBody>, Error = Error>,
226    req: Request<Body>,
227) -> Result<(Response<ResBody>, LogData), Error> {
228    let start = time::Instant::now();
229
230    // Extract request info for logging
231    let state = req.extensions().get::<State>().expect("State not found");
232    let state_data = state.get_state_data();
233    let method = req.method().clone();
234    let path = req.uri().clone();
235    let version = req.version();
236    let user_agent: LogOption<_> = req.headers().get(hyper::header::USER_AGENT).cloned().into();
237
238    // Make the request
239    let res = inner.call(req, None).await?;
240
241    // Extract response info for logging
242    let status = res.status().as_u16() as i16;
243    let elapsed = start.elapsed().as_millis() as i64;
244
245    let log_data = LogData {
246        state_data,
247        method,
248        path,
249        version,
250        status,
251        user_agent,
252        elapsed,
253    };
254
255    Ok((res, log_data))
256}
257
258pub(super) struct LogService<S> {
259    inner: S,
260    telemetry: Option<TelemetrySender>,
261}
262
263impl<S, Body, ResBody> Service<Request<Body>> for LogService<S>
264where
265    S: Service<Request<Body>, Response = Response<ResBody>> + Send + Sync,
266{
267    type Error = S::Error;
268    type Response = Response<ResBody>;
269
270    async fn call(
271        &self,
272        req: Request<Body>,
273        _: Option<IpAddr>,
274    ) -> Result<Self::Response, Self::Error> {
275        let (res, log_data) = capture_log_data::<_, _, _, State>(&self.inner, req).await?;
276
277        log_data.log();
278        if let Some(telemetry) = &self.telemetry {
279            telemetry.send_http_event(log_data);
280        }
281
282        Ok(res)
283    }
284}
285
286pub(super) struct LogLayer {
287    pub telemetry: Option<TelemetrySender>,
288}
289
290impl<S> Layer<S> for LogLayer {
291    type Service = LogService<S>;
292    fn layer(&self, inner: S) -> Self::Service {
293        LogService {
294            inner,
295            telemetry: self.telemetry.clone(),
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use hyper::StatusCode;
303
304    use super::*;
305
306    #[tokio::test]
307    async fn log_capture() {
308        #[derive(Clone)]
309        struct MockState;
310
311        impl StateLogData for MockState {
312            fn get_state_data(&self) -> StateData {
313                StateData {
314                    uuid: uuid::Uuid::now_v7(),
315                    ip: IpAddr::V4([127, 0, 0, 1].into()),
316                    target: "test",
317                    worker_id: 1,
318                    worker_route: None,
319                }
320            }
321        }
322
323        struct Svc;
324
325        impl Service<Request<()>> for Svc {
326            type Response = Response<()>;
327            type Error = ();
328            async fn call(
329                &self,
330                _: Request<()>,
331                _: Option<IpAddr>,
332            ) -> Result<Self::Response, Self::Error> {
333                tokio::time::sleep(std::time::Duration::from_millis(5)).await;
334                Ok(Response::builder().status(StatusCode::OK).body(()).unwrap())
335            }
336        }
337
338        let req = Request::builder()
339            .method(Method::GET)
340            .uri("https://example.com/")
341            .extension(MockState)
342            .version(Version::HTTP_11)
343            .header(hyper::header::USER_AGENT, "test")
344            .body(())
345            .unwrap();
346
347        let (_, log_data) = capture_log_data::<_, _, _, MockState>(&Svc, req)
348            .await
349            .unwrap();
350
351        assert_eq!(log_data.state_data.ip, IpAddr::V4([127, 0, 0, 1].into()));
352        assert_eq!(log_data.method, Method::GET);
353        assert_eq!(log_data.path, "https://example.com/");
354        assert_eq!(log_data.version, Version::HTTP_11);
355        assert_eq!(log_data.status, 200);
356        assert_eq!(
357            log_data.user_agent,
358            LogOption::Some(HeaderValue::from_static("test"))
359        );
360        assert!(log_data.elapsed > 0);
361        assert_eq!(log_data.state_data.target, "test");
362    }
363
364    #[test]
365    fn log_option_display() {
366        assert_eq!(LogOption::<u8>::None.to_string(), "-");
367        assert_eq!(LogOption::Some(1).to_string(), "1");
368    }
369
370    #[test]
371    fn log_option_debug() {
372        assert_eq!(format!("{:?}", LogOption::<u8>::None), r#""-""#);
373        assert_eq!(format!("{:?}", LogOption::Some(1)), "1");
374    }
375
376    #[test]
377    fn log_option_from_option() {
378        assert_eq!(LogOption::<u8>::from(None), LogOption::None);
379        assert_eq!(LogOption::from(Some(1)), LogOption::Some(1));
380    }
381
382    #[test]
383    fn log_data_log() {
384        use std::io::Write;
385        use std::sync::{Arc, Mutex};
386
387        struct Buffer(Arc<Mutex<Vec<u8>>>);
388
389        impl Write for Buffer {
390            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
391                self.0.lock().unwrap().write(buf)
392            }
393            fn flush(&mut self) -> std::io::Result<()> {
394                self.0.lock().unwrap().flush()
395            }
396        }
397
398        impl Buffer {
399            fn clone_buf(&self) -> Vec<u8> {
400                self.0.lock().unwrap().clone()
401            }
402        }
403
404        impl Clone for Buffer {
405            fn clone(&self) -> Self {
406                Buffer(Arc::clone(&self.0))
407            }
408        }
409
410        let log_data = LogData {
411            state_data: StateData {
412                uuid: uuid::Uuid::now_v7(),
413                target: "test",
414                ip: IpAddr::V4([127, 0, 0, 1].into()),
415                worker_route: None,
416                worker_id: 1,
417            },
418            method: Method::GET,
419            path: "https://example.com/".parse().unwrap(),
420            version: Version::HTTP_11,
421            status: 200,
422            user_agent: LogOption::Some(HeaderValue::from_static("test")),
423            elapsed: 5,
424        };
425
426        let buf = Buffer(Arc::new(Mutex::new(Vec::new())));
427        let mut logger = env_logger::Builder::new();
428        // ALWAYS USE INFO LEVEL FOR LOGGING
429        logger.filter_level(log::LevelFilter::Info);
430        logger.format(|f, record| writeln!(f, "{}", record.args()));
431        logger.target(env_logger::Target::Pipe(Box::new(buf.clone())));
432        logger.init();
433
434        log_data.log();
435
436        let log = String::from_utf8(buf.clone_buf()).unwrap();
437
438        assert_eq!(
439            log.trim(),
440            r#"127.0.0.1 "GET https://example.com/ HTTP/1.1" 200 "test" 5"#
441        )
442    }
443}