use std::sync::{Arc, Mutex};
use serde_json::{Map, Value};
use tracing::field::{Field, Visit};
use tracing::Event;
use tracing_subscriber::layer::{Context, Layer};
#[derive(Clone, Debug)]
pub struct LogRecord {
pub level: String,
pub context: String,
pub message: String,
pub fields: Map<String, Value>,
}
#[derive(Clone, Default)]
pub struct LogRecorder {
records: Arc<Mutex<Vec<LogRecord>>>,
}
impl LogRecorder {
pub fn new() -> Self {
Self::default()
}
pub fn records(&self) -> Vec<LogRecord> {
self.records
.lock()
.expect("recorder mutex poisoned")
.clone()
}
pub fn clear(&self) {
self.records
.lock()
.expect("recorder mutex poisoned")
.clear();
}
pub fn contains_context(&self, context: &str) -> bool {
self.records()
.iter()
.any(|record| record.context == context)
}
pub fn contains_message(&self, text: &str) -> bool {
self.records()
.iter()
.any(|record| record.message.contains(text))
}
}
impl<S: tracing::Subscriber> Layer<S> for LogRecorder {
fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
let mut visitor = RecordVisitor::default();
event.record(&mut visitor);
let record = LogRecord {
level: event.metadata().level().to_string(),
context: visitor
.context
.unwrap_or_else(|| event.metadata().target().to_owned()),
message: visitor.message.unwrap_or_default(),
fields: visitor.fields,
};
self.records
.lock()
.expect("recorder mutex poisoned")
.push(record);
}
}
#[derive(Default)]
struct RecordVisitor {
message: Option<String>,
context: Option<String>,
fields: Map<String, Value>,
}
impl RecordVisitor {
fn set(&mut self, name: &str, value: String) {
match name {
"message" => self.message = Some(value),
"tork.context" => self.context = Some(value),
"tork.fields" => {
if let Ok(Value::Object(map)) = serde_json::from_str::<Value>(&value) {
self.fields = map;
}
}
_ => {}
}
}
}
impl Visit for RecordVisitor {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
self.set(field.name(), format!("{value:?}"));
}
fn record_str(&mut self, field: &Field, value: &str) {
self.set(field.name(), value.to_owned());
}
}
#[cfg(test)]
mod tests {
use super::*;
use tracing_subscriber::layer::SubscriberExt;
#[test]
fn recorder_helpers_find_context_and_message() {
let recorder = LogRecorder::new();
let subscriber = tracing_subscriber::registry().with(recorder.clone());
tracing::subscriber::with_default(subscriber, || {
tracing::info!(
tork.context = "Orders",
tork.fields = "{\"id\":1}",
"created order"
);
});
assert!(recorder.contains_context("Orders"));
assert!(recorder.contains_message("created order"));
assert_eq!(recorder.records()[0].fields["id"], Value::from(1));
}
#[test]
fn visitor_ignores_invalid_json_fields_payload() {
let mut visitor = RecordVisitor::default();
visitor.set("tork.fields", "not-json".to_owned());
visitor.set("message", "hello".to_owned());
assert_eq!(visitor.message.as_deref(), Some("hello"));
assert!(visitor.fields.is_empty());
}
}
#[macro_export]
macro_rules! assert_logs {
($recorder:expr, context = $context:expr, message = $message:expr $(,)?) => {{
let records = $recorder.records();
assert!(
records
.iter()
.any(|record| record.context == $context && record.message.contains($message)),
"no log with context {:?} and message containing {:?}; captured: {:?}",
$context,
$message,
records,
);
}};
}