Skip to main content

forge_runtime/observability/
tracing_layer.rs

1//! Custom tracing layer that forwards logs to the LogCollector.
2
3use std::sync::Arc;
4
5use forge_core::LogLevel;
6use forge_core::observability::LogEntry;
7use tracing::field::{Field, Visit};
8use tracing::{Event, Level, Subscriber};
9use tracing_subscriber::Layer;
10use tracing_subscriber::layer::Context;
11
12use super::LogCollector;
13
14/// A tracing layer that forwards log events to the LogCollector.
15pub struct ForgeTracingLayer {
16    collector: Arc<LogCollector>,
17}
18
19impl ForgeTracingLayer {
20    /// Create a new tracing layer.
21    pub fn new(collector: Arc<LogCollector>) -> Self {
22        Self { collector }
23    }
24}
25
26impl<S> Layer<S> for ForgeTracingLayer
27where
28    S: Subscriber,
29{
30    fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
31        let metadata = event.metadata();
32        let level = convert_level(*metadata.level());
33
34        // Skip if below minimum level
35        if level < self.collector.min_level() {
36            return;
37        }
38
39        // Extract message and fields
40        let mut visitor = FieldVisitor::default();
41        event.record(&mut visitor);
42
43        let message = visitor.message.unwrap_or_default();
44
45        let mut entry = LogEntry::new(level, message);
46        entry.target = Some(metadata.target().to_string());
47        entry.fields = visitor.fields;
48
49        // Record asynchronously using a spawned task
50        let collector = self.collector.clone();
51        tokio::spawn(async move {
52            collector.record(entry).await;
53        });
54    }
55}
56
57/// Convert tracing Level to LogLevel.
58fn convert_level(level: Level) -> LogLevel {
59    match level {
60        Level::TRACE => LogLevel::Trace,
61        Level::DEBUG => LogLevel::Debug,
62        Level::INFO => LogLevel::Info,
63        Level::WARN => LogLevel::Warn,
64        Level::ERROR => LogLevel::Error,
65    }
66}
67
68/// Visitor to extract fields from a tracing event.
69#[derive(Default)]
70struct FieldVisitor {
71    message: Option<String>,
72    fields: std::collections::HashMap<String, serde_json::Value>,
73}
74
75impl Visit for FieldVisitor {
76    fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
77        let name = field.name();
78        if name == "message" {
79            self.message = Some(format!("{:?}", value));
80        } else {
81            self.fields.insert(
82                name.to_string(),
83                serde_json::Value::String(format!("{:?}", value)),
84            );
85        }
86    }
87
88    fn record_str(&mut self, field: &Field, value: &str) {
89        let name = field.name();
90        if name == "message" {
91            self.message = Some(value.to_string());
92        } else {
93            self.fields.insert(
94                name.to_string(),
95                serde_json::Value::String(value.to_string()),
96            );
97        }
98    }
99
100    fn record_i64(&mut self, field: &Field, value: i64) {
101        self.fields.insert(
102            field.name().to_string(),
103            serde_json::Value::Number(value.into()),
104        );
105    }
106
107    fn record_u64(&mut self, field: &Field, value: u64) {
108        self.fields.insert(
109            field.name().to_string(),
110            serde_json::Value::Number(value.into()),
111        );
112    }
113
114    fn record_bool(&mut self, field: &Field, value: bool) {
115        self.fields
116            .insert(field.name().to_string(), serde_json::Value::Bool(value));
117    }
118
119    fn record_f64(&mut self, field: &Field, value: f64) {
120        if let Some(n) = serde_json::Number::from_f64(value) {
121            self.fields
122                .insert(field.name().to_string(), serde_json::Value::Number(n));
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_convert_level() {
133        assert_eq!(convert_level(Level::TRACE), LogLevel::Trace);
134        assert_eq!(convert_level(Level::DEBUG), LogLevel::Debug);
135        assert_eq!(convert_level(Level::INFO), LogLevel::Info);
136        assert_eq!(convert_level(Level::WARN), LogLevel::Warn);
137        assert_eq!(convert_level(Level::ERROR), LogLevel::Error);
138    }
139}