use std::sync::Arc;
use serde::Serialize;
use serde_json::{Map, Value};
use tracing::Level;
use super::event::LogEvent;
use crate::error::Result;
use crate::extract::{FromRequest, RequestContext};
const DEFAULT_CONTEXT: &str = "app";
const REQUEST_ID_HEADER: &str = "x-request-id";
#[derive(Clone)]
pub struct Logger {
context: Arc<str>,
base: Arc<LogFields>,
}
enum LogFields {
Empty,
Field {
parent: Arc<LogFields>,
key: &'static str,
value: Value,
},
}
impl Logger {
pub fn new(context: impl AsRef<str>) -> Self {
Self {
context: Arc::from(context.as_ref()),
base: Arc::new(LogFields::Empty),
}
}
pub(crate) fn framework(context: &'static str) -> Self {
Self::new(context)
}
pub fn context(&self) -> &str {
&self.context
}
pub fn for_context(&self, context: impl AsRef<str>) -> Logger {
Logger {
context: Arc::from(context.as_ref()),
base: self.base.clone(),
}
}
pub fn with_field<T: Serialize>(&self, key: &'static str, value: T) -> Logger {
if let Ok(value) = serde_json::to_value(value) {
return Logger {
context: self.context.clone(),
base: Arc::new(LogFields::Field {
parent: self.base.clone(),
key,
value,
}),
};
}
self.clone()
}
fn event(&self, level: Level, message: impl Into<String>) -> LogEvent {
let mut fields = Map::new();
populate_fields(&self.base, &mut fields);
LogEvent {
level,
context: self.context.clone(),
message: message.into(),
fields,
error: None,
}
}
pub fn trace(&self, message: impl Into<String>) -> LogEvent {
self.event(Level::TRACE, message)
}
pub fn debug(&self, message: impl Into<String>) -> LogEvent {
self.event(Level::DEBUG, message)
}
pub fn info(&self, message: impl Into<String>) -> LogEvent {
self.event(Level::INFO, message)
}
pub fn warn(&self, message: impl Into<String>) -> LogEvent {
self.event(Level::WARN, message)
}
pub fn error(&self, message: impl Into<String>) -> LogEvent {
self.event(Level::ERROR, message)
}
pub fn span(&self, name: impl Into<String>) -> super::LogSpan {
let mut fields = Map::new();
populate_fields(&self.base, &mut fields);
super::LogSpan::new(self.context.clone(), name, fields)
}
pub fn instrument(&self, name: impl Into<String>) -> super::LogSpan {
let mut fields = Map::new();
populate_fields(&self.base, &mut fields);
super::LogSpan::new(self.context.clone(), name, fields)
}
}
fn populate_fields(fields: &Arc<LogFields>, out: &mut Map<String, Value>) {
match fields.as_ref() {
LogFields::Empty => {}
LogFields::Field { parent, key, value } => {
populate_fields(parent, out);
out.insert((*key).to_owned(), value.clone());
}
}
}
impl FromRequest for Logger {
fn from_request(
ctx: &RequestContext,
) -> impl std::future::Future<Output = Result<Self>> + Send {
let mut base: Vec<(&'static str, Value)> = Vec::new();
if let Some(request_id) = ctx
.headers()
.get(REQUEST_ID_HEADER)
.and_then(|value| value.to_str().ok())
{
base.push(("request_id", Value::String(request_id.to_owned())));
}
base.push(("method", Value::String(ctx.method().to_string())));
base.push(("path", Value::String(ctx.uri().path().to_owned())));
let logger = Logger {
context: Arc::from(DEFAULT_CONTEXT),
base: base
.into_iter()
.fold(Arc::new(LogFields::Empty), |parent, (key, value)| {
Arc::new(LogFields::Field { parent, key, value })
}),
};
async move { Ok(logger) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::sync::{Arc, Mutex};
use super::super::format::{JsonFormat, TorkFormat};
use crate::extract::FromRequest;
use crate::{box_body, PathParams, RequestContext, StateMap};
use bytes::Bytes;
use http_body_util::Full;
use serde::ser::Error as _;
use serde::Serializer;
use std::sync::Arc as StdArc;
use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::prelude::*;
#[derive(Clone)]
struct BufWriter(Arc<Mutex<Vec<u8>>>);
struct BadSerialize;
impl serde::Serialize for BadSerialize {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(S::Error::custom("nope"))
}
}
impl Write for BufWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<'a> MakeWriter<'a> for BufWriter {
type Writer = BufWriter;
fn make_writer(&'a self) -> Self::Writer {
self.clone()
}
}
#[test]
fn emits_context_message_and_fields() {
let buffer = Arc::new(Mutex::new(Vec::new()));
let layer = tracing_subscriber::fmt::layer()
.event_format(TorkFormat::Json(JsonFormat {
service_name: "svc".to_owned(),
}))
.with_writer(BufWriter(buffer.clone()));
let subscriber = tracing_subscriber::registry().with(layer);
tracing::subscriber::with_default(subscriber, || {
Logger::new("PaymentService")
.with_field("tenant", "acme")
.info("Charging user")
.field("user_id", 42)
.emit();
});
let bytes = buffer.lock().unwrap().clone();
let output = String::from_utf8(bytes).unwrap();
assert!(
output.contains("\"context\":\"PaymentService\""),
"{output}"
);
assert!(output.contains("\"message\":\"Charging user\""), "{output}");
assert!(output.contains("\"user_id\":42"), "{output}");
assert!(output.contains("\"tenant\":\"acme\""), "{output}");
}
#[test]
fn for_context_and_framework_preserve_base_fields() {
let logger = Logger::framework("startup").with_field("tenant", "acme");
let relabeled = logger.for_context("payments");
assert_eq!(logger.context(), "startup");
assert_eq!(relabeled.context(), "payments");
let output = {
let buffer = Arc::new(Mutex::new(Vec::new()));
let layer = tracing_subscriber::fmt::layer()
.event_format(TorkFormat::Json(JsonFormat {
service_name: "svc".to_owned(),
}))
.with_writer(BufWriter(buffer.clone()));
let subscriber = tracing_subscriber::registry().with(layer);
tracing::subscriber::with_default(subscriber, || {
relabeled.info("Boot").emit();
});
let bytes = buffer.lock().unwrap().clone();
String::from_utf8(bytes).unwrap()
};
assert!(output.contains("\"context\":\"payments\""), "{output}");
assert!(output.contains("\"tenant\":\"acme\""), "{output}");
}
#[test]
fn with_field_ignores_unserializable_values() {
let logger = Logger::new("logger").with_field("tenant", BadSerialize);
let output = {
let buffer = Arc::new(Mutex::new(Vec::new()));
let layer = tracing_subscriber::fmt::layer()
.event_format(TorkFormat::Json(JsonFormat {
service_name: "svc".to_owned(),
}))
.with_writer(BufWriter(buffer.clone()));
let subscriber = tracing_subscriber::registry().with(layer);
tracing::subscriber::with_default(subscriber, || {
logger.info("Hello").emit();
});
let bytes = buffer.lock().unwrap().clone();
String::from_utf8(bytes).unwrap()
};
assert!(!output.contains("tenant"), "{output}");
}
#[test]
fn trace_debug_warn_error_span_and_instrument_cover_helper_methods() {
let buffer = Arc::new(Mutex::new(Vec::new()));
let layer = tracing_subscriber::fmt::layer()
.event_format(TorkFormat::Json(JsonFormat {
service_name: "svc".to_owned(),
}))
.with_writer(BufWriter(buffer.clone()));
let subscriber = tracing_subscriber::registry().with(layer);
tracing::subscriber::with_default(subscriber, || {
Logger::framework("boot").trace("trace").emit();
Logger::new("worker").debug("debug").emit();
Logger::new("worker").warn("warn").emit();
Logger::new("worker").error("error").emit();
let _ = Logger::new("worker").span("span").enter();
let _ = Logger::new("worker").instrument("task");
});
let bytes = buffer.lock().unwrap().clone();
let output = String::from_utf8(bytes).unwrap();
assert!(output.contains("\"message\":\"trace\""), "{output}");
assert!(output.contains("\"message\":\"debug\""), "{output}");
assert!(output.contains("\"message\":\"warn\""), "{output}");
assert!(output.contains("\"message\":\"error\""), "{output}");
}
#[tokio::test]
async fn from_request_uses_request_metadata_and_default_context() {
let head = http::Request::builder()
.method("GET")
.uri("/logs")
.header("x-request-id", "req-123")
.body(())
.unwrap()
.into_parts()
.0;
let ctx = RequestContext::new(
head,
PathParams::new(),
StdArc::new(StateMap::new()),
box_body(Full::new(Bytes::new())),
);
let logger = Logger::from_request(&ctx).await.unwrap();
assert_eq!(logger.context(), "app");
let output = {
let buffer = Arc::new(Mutex::new(Vec::new()));
let layer = tracing_subscriber::fmt::layer()
.event_format(TorkFormat::Json(JsonFormat {
service_name: "svc".to_owned(),
}))
.with_writer(BufWriter(buffer.clone()));
let subscriber = tracing_subscriber::registry().with(layer);
tracing::subscriber::with_default(subscriber, || {
logger.info("Hello").emit();
});
let bytes = buffer.lock().unwrap().clone();
String::from_utf8(bytes).unwrap()
};
assert!(output.contains("\"request_id\":\"req-123\""), "{output}");
assert!(output.contains("\"method\":\"GET\""), "{output}");
assert!(output.contains("\"path\":\"/logs\""), "{output}");
}
}