1use hyper::{http::HeaderValue, Method, Request, Response, Uri, Version};
2
3use super::onion::{Layer, Service};
4use crate::{server::service::State, telemetry::TelemetrySender};
5use std::{net::IpAddr, time};
6
7pub mod logger {
8 use std::{io::BufWriter, io::Write, path::PathBuf};
9
10 use hyper::body::Bytes;
11 use tokio::task::JoinHandle;
12
13 use crate::shutdown::ShutdownSignal;
14
15 pub enum Target {
16 Stderr,
17 File(PathBuf),
18 }
19
20 struct LogFileWriter {
21 sender: tokio::sync::mpsc::Sender<Bytes>,
22 }
23
24 impl std::io::Write for LogFileWriter {
25 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
26 let _ = self.sender.try_send(Bytes::copy_from_slice(buf));
27 Ok(buf.len())
28 }
29 fn flush(&mut self) -> std::io::Result<()> {
30 Ok(())
31 }
32 }
33
34 fn start_log_writer_thread(
35 path: PathBuf,
36 max_file_size: Option<u64>,
37 shutdown: ShutdownSignal,
38 ) -> (LogFileWriter, JoinHandle<()>) {
39 let max_file_size = max_file_size.unwrap_or(u64::MAX);
40 let mut current_file_size = match std::fs::metadata(&path) {
41 Ok(md) => md.len(),
42 Err(_) => 0,
43 };
44 let file = std::fs::File::options()
45 .create(true)
46 .append(true)
47 .truncate(false)
48 .open(&path)
49 .expect("Unable to open or create log file");
50
51 let mut copy_path = path.clone();
53 copy_path.as_mut_os_string().push(".bak");
54
55 let mut writer = BufWriter::new(file);
56 let mut stderr = BufWriter::new(std::io::stderr());
57 let (sender, mut receiver) = tokio::sync::mpsc::channel::<Bytes>(1000);
58 let writer_thread = tokio::task::spawn(async move {
59 loop {
60 tokio::select! {
61 bytes = receiver.recv() => {
62 match bytes {
63 Some(bytes) => {
64 if let Err(e) = stderr.write_all(bytes.as_ref()) {
65 eprintln!("Unable to write to stderr: {e}");
66 };
67
68 if let Err(e) = writer.write_all(bytes.as_ref()) {
69 eprintln!("Unable to write to {path:?}: {e}");
70 };
71
72 current_file_size += bytes.len() as u64;
73 if current_file_size > max_file_size {
74 let _ = writer.flush();
76 let file = writer.get_mut();
77
78 if let Err(e) = std::fs::copy(&path, ©_path) {
80 log::error!("Unable to copy logs to backup file: {e}");
81 }
82
83 if let Err(e) = file.set_len(0) {
85 log::error!("Unable to truncate logs file: {e}");
86 }
87
88 current_file_size = 0;
89 }
90 },
91 None => break
92 }
93 },
94 _ = shutdown.wait() => break
95 }
96 }
97 let _ = writer.flush();
98 let _ = stderr.flush();
99 });
100 (LogFileWriter { sender }, writer_thread)
101 }
102
103 pub fn build_logger(
104 target: Target,
105 max_file_size: Option<u64>,
106 shutdown: ShutdownSignal,
107 ) -> Option<JoinHandle<()>> {
108 let (target, handle) = match target {
109 Target::File(path) => {
110 let (writer, handle) = start_log_writer_thread(path, max_file_size, shutdown);
111 (env_logger::Target::Pipe(Box::new(writer)), Some(handle))
112 }
113 Target::Stderr => (env_logger::Target::Stderr, None),
114 };
115
116 let mut env_builder = env_logger::Builder::new();
117 env_builder
118 .parse_env(env_logger::Env::new().filter_or("FAUCET_LOG", "info"))
119 .target(target)
120 .init();
121
122 handle
123 }
124}
125
126#[derive(Clone, Copy)]
127pub struct StateData {
128 pub uuid: uuid::Uuid,
129 pub ip: IpAddr,
130 pub worker_route: Option<&'static str>,
131 pub worker_id: usize,
132 pub target: &'static str,
133}
134
135trait StateLogData: Send + Sync + 'static {
136 fn get_state_data(&self) -> StateData;
137}
138
139impl StateLogData for State {
140 #[inline(always)]
141 fn get_state_data(&self) -> StateData {
142 let uuid = self.uuid;
143 let ip = self.remote_addr;
144 let worker_id = self.client.config.worker_id;
145 let worker_route = self.client.config.worker_route;
146 let target = self.client.config.target;
147 StateData {
148 uuid,
149 ip,
150 worker_id,
151 worker_route,
152 target,
153 }
154 }
155}
156
157#[derive(PartialEq, Eq)]
158pub enum LogOption<T> {
159 None,
160 Some(T),
161}
162
163impl<T> From<Option<T>> for LogOption<T> {
164 fn from(opt: Option<T>) -> Self {
165 match opt {
166 None => LogOption::None,
167 Some(v) => LogOption::Some(v),
168 }
169 }
170}
171
172impl<T> std::fmt::Display for LogOption<T>
173where
174 T: std::fmt::Display,
175{
176 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
177 match self {
178 LogOption::None => write!(f, "-"),
179 LogOption::Some(v) => write!(f, "{}", v),
180 }
181 }
182}
183
184impl<T> std::fmt::Debug for LogOption<T>
185where
186 T: std::fmt::Debug,
187{
188 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
189 match self {
190 LogOption::None => write!(f, r#""-""#),
191 LogOption::Some(v) => write!(f, "{:?}", v),
192 }
193 }
194}
195
196pub struct LogData {
197 pub state_data: StateData,
198 pub method: Method,
199 pub path: Uri,
200 pub version: Version,
201 pub status: i16,
202 pub user_agent: LogOption<HeaderValue>,
203 pub elapsed: i64,
204}
205
206impl LogData {
207 fn log(&self) {
208 log::info!(
209 target: self.state_data.target,
210 r#"{ip} "{method} {route}{path} {version:?}" {status} {user_agent:?} {elapsed}"#,
211 route = self.state_data.worker_route.map(|r| r.trim_end_matches('/')).unwrap_or_default(),
212 ip = self.state_data.ip,
213 method = self.method,
214 path = self.path,
215 version = self.version,
216 status = self.status,
217 user_agent = self.user_agent,
218 elapsed = self.elapsed,
219 );
220 }
221}
222
223#[inline(always)]
224async fn capture_log_data<Body, ResBody, Error, State: StateLogData>(
225 inner: &impl Service<Request<Body>, Response = Response<ResBody>, Error = Error>,
226 req: Request<Body>,
227) -> Result<(Response<ResBody>, LogData), Error> {
228 let start = time::Instant::now();
229
230 let state = req.extensions().get::<State>().expect("State not found");
232 let state_data = state.get_state_data();
233 let method = req.method().clone();
234 let path = req.uri().clone();
235 let version = req.version();
236 let user_agent: LogOption<_> = req.headers().get(hyper::header::USER_AGENT).cloned().into();
237
238 let res = inner.call(req, None).await?;
240
241 let status = res.status().as_u16() as i16;
243 let elapsed = start.elapsed().as_millis() as i64;
244
245 let log_data = LogData {
246 state_data,
247 method,
248 path,
249 version,
250 status,
251 user_agent,
252 elapsed,
253 };
254
255 Ok((res, log_data))
256}
257
258pub(super) struct LogService<S> {
259 inner: S,
260 telemetry: Option<TelemetrySender>,
261}
262
263impl<S, Body, ResBody> Service<Request<Body>> for LogService<S>
264where
265 S: Service<Request<Body>, Response = Response<ResBody>> + Send + Sync,
266{
267 type Error = S::Error;
268 type Response = Response<ResBody>;
269
270 async fn call(
271 &self,
272 req: Request<Body>,
273 _: Option<IpAddr>,
274 ) -> Result<Self::Response, Self::Error> {
275 let (res, log_data) = capture_log_data::<_, _, _, State>(&self.inner, req).await?;
276
277 log_data.log();
278 if let Some(telemetry) = &self.telemetry {
279 telemetry.send_http_event(log_data);
280 }
281
282 Ok(res)
283 }
284}
285
286pub(super) struct LogLayer {
287 pub telemetry: Option<TelemetrySender>,
288}
289
290impl<S> Layer<S> for LogLayer {
291 type Service = LogService<S>;
292 fn layer(&self, inner: S) -> Self::Service {
293 LogService {
294 inner,
295 telemetry: self.telemetry.clone(),
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use hyper::StatusCode;
303
304 use super::*;
305
306 #[tokio::test]
307 async fn log_capture() {
308 #[derive(Clone)]
309 struct MockState;
310
311 impl StateLogData for MockState {
312 fn get_state_data(&self) -> StateData {
313 StateData {
314 uuid: uuid::Uuid::now_v7(),
315 ip: IpAddr::V4([127, 0, 0, 1].into()),
316 target: "test",
317 worker_id: 1,
318 worker_route: None,
319 }
320 }
321 }
322
323 struct Svc;
324
325 impl Service<Request<()>> for Svc {
326 type Response = Response<()>;
327 type Error = ();
328 async fn call(
329 &self,
330 _: Request<()>,
331 _: Option<IpAddr>,
332 ) -> Result<Self::Response, Self::Error> {
333 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
334 Ok(Response::builder().status(StatusCode::OK).body(()).unwrap())
335 }
336 }
337
338 let req = Request::builder()
339 .method(Method::GET)
340 .uri("https://example.com/")
341 .extension(MockState)
342 .version(Version::HTTP_11)
343 .header(hyper::header::USER_AGENT, "test")
344 .body(())
345 .unwrap();
346
347 let (_, log_data) = capture_log_data::<_, _, _, MockState>(&Svc, req)
348 .await
349 .unwrap();
350
351 assert_eq!(log_data.state_data.ip, IpAddr::V4([127, 0, 0, 1].into()));
352 assert_eq!(log_data.method, Method::GET);
353 assert_eq!(log_data.path, "https://example.com/");
354 assert_eq!(log_data.version, Version::HTTP_11);
355 assert_eq!(log_data.status, 200);
356 assert_eq!(
357 log_data.user_agent,
358 LogOption::Some(HeaderValue::from_static("test"))
359 );
360 assert!(log_data.elapsed > 0);
361 assert_eq!(log_data.state_data.target, "test");
362 }
363
364 #[test]
365 fn log_option_display() {
366 assert_eq!(LogOption::<u8>::None.to_string(), "-");
367 assert_eq!(LogOption::Some(1).to_string(), "1");
368 }
369
370 #[test]
371 fn log_option_debug() {
372 assert_eq!(format!("{:?}", LogOption::<u8>::None), r#""-""#);
373 assert_eq!(format!("{:?}", LogOption::Some(1)), "1");
374 }
375
376 #[test]
377 fn log_option_from_option() {
378 assert_eq!(LogOption::<u8>::from(None), LogOption::None);
379 assert_eq!(LogOption::from(Some(1)), LogOption::Some(1));
380 }
381
382 #[test]
383 fn log_data_log() {
384 use std::io::Write;
385 use std::sync::{Arc, Mutex};
386
387 struct Buffer(Arc<Mutex<Vec<u8>>>);
388
389 impl Write for Buffer {
390 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
391 self.0.lock().unwrap().write(buf)
392 }
393 fn flush(&mut self) -> std::io::Result<()> {
394 self.0.lock().unwrap().flush()
395 }
396 }
397
398 impl Buffer {
399 fn clone_buf(&self) -> Vec<u8> {
400 self.0.lock().unwrap().clone()
401 }
402 }
403
404 impl Clone for Buffer {
405 fn clone(&self) -> Self {
406 Buffer(Arc::clone(&self.0))
407 }
408 }
409
410 let log_data = LogData {
411 state_data: StateData {
412 uuid: uuid::Uuid::now_v7(),
413 target: "test",
414 ip: IpAddr::V4([127, 0, 0, 1].into()),
415 worker_route: None,
416 worker_id: 1,
417 },
418 method: Method::GET,
419 path: "https://example.com/".parse().unwrap(),
420 version: Version::HTTP_11,
421 status: 200,
422 user_agent: LogOption::Some(HeaderValue::from_static("test")),
423 elapsed: 5,
424 };
425
426 let buf = Buffer(Arc::new(Mutex::new(Vec::new())));
427 let mut logger = env_logger::Builder::new();
428 logger.filter_level(log::LevelFilter::Info);
430 logger.format(|f, record| writeln!(f, "{}", record.args()));
431 logger.target(env_logger::Target::Pipe(Box::new(buf.clone())));
432 logger.init();
433
434 log_data.log();
435
436 let log = String::from_utf8(buf.clone_buf()).unwrap();
437
438 assert_eq!(
439 log.trim(),
440 r#"127.0.0.1 "GET https://example.com/ HTTP/1.1" 200 "test" 5"#
441 )
442 }
443}