1use http::{Request, Response};
46use pin_project_lite::pin_project;
47use std::{
48 fmt,
49 future::Future,
50 pin::Pin,
51 sync::Arc,
52 task::{self, Poll},
53 time::{Instant, SystemTime, UNIX_EPOCH},
54};
55use tower_layer::Layer;
56use tower_service::Service;
57
58#[derive(Debug, Clone)]
82pub struct AuditEvent {
83 pub timestamp_secs: u64,
85 pub method: String,
87 pub path: String,
89 pub query: Option<String>,
91 pub client_ip: Option<String>,
93 pub user_agent: Option<String>,
95 pub status: u16,
97 pub latency_ms: u64,
99}
100
101impl fmt::Display for AuditEvent {
102 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
103 let timestamp = format_timestamp(self.timestamp_secs);
104 let path_and_query = match &self.query {
105 Some(query) => format!("{}?{}", self.path, query),
106 None => self.path.clone(),
107 };
108 write!(
109 formatter,
110 "[{timestamp}] {method} {path_and_query} status={status} latency={latency_ms}ms",
111 method = self.method,
112 status = self.status,
113 latency_ms = self.latency_ms,
114 )?;
115 if let Some(ip) = &self.client_ip {
116 write!(formatter, " ip={ip}")?;
117 }
118 if let Some(ua) = &self.user_agent {
119 write!(formatter, " ua=\"{ua}\"")?;
120 }
121 Ok(())
122 }
123}
124
125fn format_timestamp(secs: u64) -> String {
126 let remaining = secs;
127 let seconds_per_minute = 60u64;
128 let seconds_per_hour = 3600u64;
129 let seconds_per_day = 86400u64;
130
131 let time_of_day = remaining % seconds_per_day;
132 let hour = time_of_day / seconds_per_hour;
133 let minute = (time_of_day % seconds_per_hour) / seconds_per_minute;
134 let second = time_of_day % seconds_per_minute;
135
136 let days_since_epoch = remaining / seconds_per_day;
137 let (year, month, day) = days_to_ymd(days_since_epoch);
138
139 format!("{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}Z")
140}
141
142fn days_to_ymd(days: u64) -> (u64, u64, u64) {
143 let days_per_400_years = 146097u64;
144 let days_per_100_years = 36524u64;
145 let days_per_4_years = 1461u64;
146 let days_per_year = 365u64;
147
148 let mut remaining = days;
149
150 let quadricentennials = remaining / days_per_400_years;
151 remaining %= days_per_400_years;
152
153 let centurials = (remaining / days_per_100_years).min(3);
154 remaining -= centurials * days_per_100_years;
155
156 let quadrennials = remaining / days_per_4_years;
157 remaining %= days_per_4_years;
158
159 let annuals = (remaining / days_per_year).min(3);
160 remaining -= annuals * days_per_year;
161
162 let year = quadricentennials * 400 + centurials * 100 + quadrennials * 4 + annuals + 1970;
163
164 let is_leap = (year % 4 == 0 && year % 100 != 0) || year % 400 == 0;
165 let month_lengths = [
166 31u64,
167 if is_leap { 29 } else { 28 },
168 31,
169 30,
170 31,
171 30,
172 31,
173 31,
174 30,
175 31,
176 30,
177 31,
178 ];
179
180 let mut month = 1u64;
181 let mut day_of_month = remaining;
182 for length in month_lengths {
183 if day_of_month < length {
184 break;
185 }
186 day_of_month -= length;
187 month += 1;
188 }
189
190 (year, month, day_of_month + 1)
191}
192
193pub trait AuditSink: Send + Sync + 'static {
215 fn record(&self, event: AuditEvent);
217}
218
219pub struct StdoutSink;
232
233impl AuditSink for StdoutSink {
234 fn record(&self, event: AuditEvent) {
235 println!("{event}");
236 }
237}
238
239pub struct StderrSink;
252
253impl AuditSink for StderrSink {
254 fn record(&self, event: AuditEvent) {
255 eprintln!("{event}");
256 }
257}
258
259pub struct NullSink;
274
275impl AuditSink for NullSink {
276 fn record(&self, _event: AuditEvent) {}
277}
278
279pub struct CallbackSink<F>
297where
298 F: Fn(AuditEvent) + Send + Sync + 'static,
299{
300 callback: F,
301}
302
303impl<F> CallbackSink<F>
304where
305 F: Fn(AuditEvent) + Send + Sync + 'static,
306{
307 pub fn new(callback: F) -> Self {
317 Self { callback }
318 }
319}
320
321impl<F> AuditSink for CallbackSink<F>
322where
323 F: Fn(AuditEvent) + Send + Sync + 'static,
324{
325 fn record(&self, event: AuditEvent) {
326 (self.callback)(event);
327 }
328}
329
330#[derive(Clone)]
346pub struct AuditLayer<Sink = StdoutSink> {
347 sink: Arc<Sink>,
348}
349
350impl AuditLayer<StdoutSink> {
351 pub fn stdout() -> Self {
364 Self {
365 sink: Arc::new(StdoutSink),
366 }
367 }
368}
369
370impl AuditLayer<StderrSink> {
371 pub fn stderr() -> Self {
384 Self {
385 sink: Arc::new(StderrSink),
386 }
387 }
388}
389
390impl<Sink: AuditSink> AuditLayer<Sink> {
391 pub fn new(sink: Sink) -> Self {
404 Self {
405 sink: Arc::new(sink),
406 }
407 }
408
409 pub fn with_arc(sink: Arc<Sink>) -> Self {
427 Self { sink }
428 }
429}
430
431impl<S, Sink: AuditSink> Layer<S> for AuditLayer<Sink> {
432 type Service = AuditService<S, Sink>;
433
434 fn layer(&self, inner: S) -> Self::Service {
435 AuditService {
436 inner,
437 sink: Arc::clone(&self.sink),
438 }
439 }
440}
441
442#[derive(Clone)]
448pub struct AuditService<S, Sink> {
449 inner: S,
450 sink: Arc<Sink>,
451}
452
453impl<S, Sink, ReqB> Service<Request<ReqB>> for AuditService<S, Sink>
454where
455 S: Service<Request<ReqB>, Response = Response<axum::body::Body>>,
456 Sink: AuditSink,
457 ReqB: Send + 'static,
458{
459 type Response = Response<axum::body::Body>;
460 type Error = S::Error;
461 type Future = AuditFuture<S::Future, Sink>;
462
463 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
464 self.inner.poll_ready(cx)
465 }
466
467 fn call(&mut self, req: Request<ReqB>) -> Self::Future {
468 let start = Instant::now();
469 let timestamp_secs = SystemTime::now()
470 .duration_since(UNIX_EPOCH)
471 .unwrap_or_default()
472 .as_secs();
473
474 let method = req.method().to_string();
475 let path = req.uri().path().to_owned();
476 let query = req.uri().query().map(str::to_owned);
477
478 let client_ip = req
479 .headers()
480 .get("x-forwarded-for")
481 .or_else(|| req.headers().get("x-real-ip"))
482 .and_then(|value| value.to_str().ok())
483 .and_then(|value| value.split(',').next())
484 .map(str::trim)
485 .map(str::to_owned);
486
487 let user_agent = req
488 .headers()
489 .get(http::header::USER_AGENT)
490 .and_then(|value| value.to_str().ok())
491 .map(str::to_owned);
492
493 AuditFuture {
494 inner: self.inner.call(req),
495 sink: Arc::clone(&self.sink),
496 meta: Some(EventMeta {
497 timestamp_secs,
498 method,
499 path,
500 query,
501 client_ip,
502 user_agent,
503 start,
504 }),
505 }
506 }
507}
508
509struct EventMeta {
510 timestamp_secs: u64,
511 method: String,
512 path: String,
513 query: Option<String>,
514 client_ip: Option<String>,
515 user_agent: Option<String>,
516 start: Instant,
517}
518
519pin_project! {
520 pub struct AuditFuture<F, Sink> {
522 #[pin]
523 inner: F,
524 sink: Arc<Sink>,
525 meta: Option<EventMeta>,
526 }
527}
528
529impl<F, Sink, E> Future for AuditFuture<F, Sink>
530where
531 F: Future<Output = Result<Response<axum::body::Body>, E>>,
532 Sink: AuditSink,
533{
534 type Output = Result<Response<axum::body::Body>, E>;
535
536 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
537 let this = self.project();
538 match this.inner.poll(cx) {
539 Poll::Pending => Poll::Pending,
540 Poll::Ready(result) => {
541 if let Some(meta) = this.meta.take() {
542 let latency_ms = meta.start.elapsed().as_millis() as u64;
543 let status = result
544 .as_ref()
545 .map(|resp| resp.status().as_u16())
546 .unwrap_or(0);
547
548 this.sink.record(AuditEvent {
549 timestamp_secs: meta.timestamp_secs,
550 method: meta.method,
551 path: meta.path,
552 query: meta.query,
553 client_ip: meta.client_ip,
554 user_agent: meta.user_agent,
555 status,
556 latency_ms,
557 });
558 }
559 Poll::Ready(result)
560 }
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use axum::{Router, body::Body, routing::get};
569 use http::{Request, StatusCode};
570 use std::sync::Mutex;
571 use tower::ServiceExt;
572
573 struct TestSink {
574 events: Arc<Mutex<Vec<AuditEvent>>>,
575 }
576
577 impl AuditSink for TestSink {
578 fn record(&self, event: AuditEvent) {
579 self.events.lock().unwrap().push(event);
580 }
581 }
582
583 fn build_app(events: Arc<Mutex<Vec<AuditEvent>>>, status: StatusCode) -> Router {
584 Router::new()
585 .route(
586 "/api/items",
587 get(move || {
588 let status = status;
589 async move { status }
590 }),
591 )
592 .layer(AuditLayer::new(TestSink { events }))
593 }
594
595 async fn send(app: Router, req: Request<Body>) -> http::Response<Body> {
596 app.oneshot(req).await.unwrap()
597 }
598
599 #[tokio::test]
600 async fn records_method_path_and_status() {
601 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
602 let app = build_app(Arc::clone(&events), StatusCode::OK);
603
604 send(
605 app,
606 Request::builder()
607 .uri("/api/items")
608 .body(Body::empty())
609 .unwrap(),
610 )
611 .await;
612
613 let captured = events.lock().unwrap();
614 assert_eq!(captured.len(), 1);
615 let event = &captured[0];
616 assert_eq!(event.method, "GET");
617 assert_eq!(event.path, "/api/items");
618 assert_eq!(event.status, 200);
619 }
620
621 #[tokio::test]
622 async fn records_nonzero_latency() {
623 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
624 let app = build_app(Arc::clone(&events), StatusCode::OK);
625
626 send(
627 app,
628 Request::builder()
629 .uri("/api/items")
630 .body(Body::empty())
631 .unwrap(),
632 )
633 .await;
634
635 let captured = events.lock().unwrap();
636 assert_eq!(captured.len(), 1);
637 }
638
639 #[tokio::test]
640 async fn null_sink_does_not_panic() {
641 let app: Router = Router::new()
642 .route("/", get(|| async { "ok" }))
643 .layer(AuditLayer::new(NullSink));
644
645 let resp = send(
646 app,
647 Request::builder().uri("/").body(Body::empty()).unwrap(),
648 )
649 .await;
650 assert_eq!(resp.status(), StatusCode::OK);
651 }
652
653 #[tokio::test]
654 async fn callback_sink_receives_event() {
655 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
656 let captured = Arc::clone(&events);
657
658 let app: Router = Router::new()
659 .route("/ping", get(|| async { "pong" }))
660 .layer(AuditLayer::new(CallbackSink::new(move |event| {
661 captured.lock().unwrap().push(event);
662 })));
663
664 send(
665 app,
666 Request::builder().uri("/ping").body(Body::empty()).unwrap(),
667 )
668 .await;
669
670 let captured = events.lock().unwrap();
671 assert_eq!(captured.len(), 1);
672 assert_eq!(captured[0].path, "/ping");
673 assert_eq!(captured[0].status, 200);
674 }
675
676 #[tokio::test]
677 async fn captures_client_ip_from_x_forwarded_for() {
678 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
679 let app = build_app(Arc::clone(&events), StatusCode::OK);
680
681 send(
682 app,
683 Request::builder()
684 .uri("/api/items")
685 .header("x-forwarded-for", "1.2.3.4, 5.6.7.8")
686 .body(Body::empty())
687 .unwrap(),
688 )
689 .await;
690
691 let captured = events.lock().unwrap();
692 assert_eq!(captured[0].client_ip.as_deref(), Some("1.2.3.4"));
693 }
694
695 #[tokio::test]
696 async fn captures_user_agent() {
697 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
698 let app = build_app(Arc::clone(&events), StatusCode::OK);
699
700 send(
701 app,
702 Request::builder()
703 .uri("/api/items")
704 .header("user-agent", "curl/7.68.0")
705 .body(Body::empty())
706 .unwrap(),
707 )
708 .await;
709
710 let captured = events.lock().unwrap();
711 assert_eq!(captured[0].user_agent.as_deref(), Some("curl/7.68.0"));
712 }
713
714 #[tokio::test]
715 async fn captures_query_string() {
716 let events: Arc<Mutex<Vec<AuditEvent>>> = Arc::new(Mutex::new(Vec::new()));
717 let app = build_app(Arc::clone(&events), StatusCode::OK);
718
719 send(
720 app,
721 Request::builder()
722 .uri("/api/items?page=2&limit=10")
723 .body(Body::empty())
724 .unwrap(),
725 )
726 .await;
727
728 let captured = events.lock().unwrap();
729 assert_eq!(captured[0].query.as_deref(), Some("page=2&limit=10"));
730 }
731
732 #[tokio::test]
733 async fn display_formats_correctly() {
734 let event = AuditEvent {
735 timestamp_secs: 0,
736 method: "GET".to_owned(),
737 path: "/api".to_owned(),
738 query: None,
739 client_ip: Some("1.2.3.4".to_owned()),
740 user_agent: Some("curl/7.68.0".to_owned()),
741 status: 200,
742 latency_ms: 5,
743 };
744 let output = event.to_string();
745 assert!(output.contains("GET /api"));
746 assert!(output.contains("status=200"));
747 assert!(output.contains("latency=5ms"));
748 assert!(output.contains("ip=1.2.3.4"));
749 assert!(output.contains("ua=\"curl/7.68.0\""));
750 }
751}