Skip to main content

ralph_core/diagnostics/
trace_layer.rs

1use serde::{Deserialize, Serialize};
2use std::fs::{File, OpenOptions};
3use std::io::{BufWriter, Write};
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use tracing::Subscriber;
7use tracing_subscriber::Layer;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TraceEntry {
11    pub timestamp: String,
12    pub iteration: Option<u32>,
13    pub hat: Option<String>,
14    pub level: String,
15    pub target: String,
16    pub message: String,
17    pub fields: serde_json::Value,
18}
19
20pub struct DiagnosticTraceLayer {
21    writer: Arc<Mutex<BufWriter<File>>>,
22    context: Arc<Mutex<TraceContext>>,
23}
24
25#[derive(Debug, Default)]
26struct TraceContext {
27    iteration: Option<u32>,
28    hat: Option<String>,
29}
30
31impl DiagnosticTraceLayer {
32    pub fn new(session_dir: &Path) -> std::io::Result<Self> {
33        let trace_file = session_dir.join("trace.jsonl");
34        let file = OpenOptions::new()
35            .create(true)
36            .append(true)
37            .open(trace_file)?;
38
39        Ok(Self {
40            writer: Arc::new(Mutex::new(BufWriter::new(file))),
41            context: Arc::new(Mutex::new(TraceContext::default())),
42        })
43    }
44
45    pub fn set_context(&self, iteration: u32, hat: &str) {
46        let mut ctx = self.context.lock().unwrap();
47        ctx.iteration = Some(iteration);
48        ctx.hat = Some(hat.to_string());
49    }
50}
51
52impl<S: Subscriber> Layer<S> for DiagnosticTraceLayer {
53    fn on_event(
54        &self,
55        event: &tracing::Event<'_>,
56        _ctx: tracing_subscriber::layer::Context<'_, S>,
57    ) {
58        let metadata = event.metadata();
59
60        // Extract message and fields
61        let mut visitor = FieldVisitor::default();
62        event.record(&mut visitor);
63
64        // Get context
65        let ctx = self.context.lock().unwrap();
66
67        let entry = TraceEntry {
68            timestamp: chrono::Local::now().to_rfc3339(),
69            iteration: ctx.iteration,
70            hat: ctx.hat.clone(),
71            level: metadata.level().to_string(),
72            target: metadata.target().to_string(),
73            message: visitor.message,
74            fields: serde_json::to_value(&visitor.fields).unwrap_or(serde_json::Value::Null),
75        };
76
77        // Write to file
78        let mut writer = self.writer.lock().unwrap();
79        if let Ok(json) = serde_json::to_string(&entry) {
80            let _ = writeln!(writer, "{}", json);
81            let _ = writer.flush();
82        }
83    }
84}
85
86#[derive(Default)]
87struct FieldVisitor {
88    message: String,
89    fields: std::collections::HashMap<String, serde_json::Value>,
90}
91
92impl tracing::field::Visit for FieldVisitor {
93    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
94        if field.name() == "message" {
95            self.message = format!("{:?}", value).trim_matches('"').to_string();
96        } else {
97            self.fields.insert(
98                field.name().to_string(),
99                serde_json::Value::String(format!("{:?}", value)),
100            );
101        }
102    }
103
104    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
105        self.fields.insert(
106            field.name().to_string(),
107            serde_json::Value::Number(value.into()),
108        );
109    }
110
111    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
112        self.fields.insert(
113            field.name().to_string(),
114            serde_json::Value::Number(value.into()),
115        );
116    }
117
118    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
119        if field.name() == "message" {
120            self.message = value.to_string();
121        } else {
122            self.fields.insert(
123                field.name().to_string(),
124                serde_json::Value::String(value.to_string()),
125            );
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::io::BufRead;
134    use tempfile::TempDir;
135    use tracing::{debug, error, info, warn};
136    use tracing_subscriber::prelude::*;
137
138    #[test]
139    fn test_layer_captures_all_levels() {
140        let temp_dir = TempDir::new().unwrap();
141        let layer = DiagnosticTraceLayer::new(temp_dir.path()).unwrap();
142
143        let subscriber = tracing_subscriber::registry().with(layer);
144        tracing::subscriber::with_default(subscriber, || {
145            info!("info message");
146            debug!("debug message");
147            warn!("warn message");
148            error!("error message");
149        });
150
151        // Read trace.jsonl
152        let trace_file = temp_dir.path().join("trace.jsonl");
153        let file = File::open(trace_file).unwrap();
154        let reader = std::io::BufReader::new(file);
155        let entries: Vec<TraceEntry> = reader
156            .lines()
157            .map(|line| serde_json::from_str(&line.unwrap()).unwrap())
158            .collect();
159
160        assert_eq!(entries.len(), 4);
161        assert_eq!(entries[0].level, "INFO");
162        assert_eq!(entries[0].message, "info message");
163        assert_eq!(entries[1].level, "DEBUG");
164        assert_eq!(entries[1].message, "debug message");
165        assert_eq!(entries[2].level, "WARN");
166        assert_eq!(entries[2].message, "warn message");
167        assert_eq!(entries[3].level, "ERROR");
168        assert_eq!(entries[3].message, "error message");
169    }
170
171    #[test]
172    fn test_fields_serialized_correctly() {
173        let temp_dir = TempDir::new().unwrap();
174        let layer = DiagnosticTraceLayer::new(temp_dir.path()).unwrap();
175
176        let subscriber = tracing_subscriber::registry().with(layer);
177        tracing::subscriber::with_default(subscriber, || {
178            info!(bytes = 1024, status = "ok", "message with fields");
179        });
180
181        let trace_file = temp_dir.path().join("trace.jsonl");
182        let file = File::open(trace_file).unwrap();
183        let reader = std::io::BufReader::new(file);
184        let entries: Vec<TraceEntry> = reader
185            .lines()
186            .map(|line| serde_json::from_str(&line.unwrap()).unwrap())
187            .collect();
188
189        assert_eq!(entries.len(), 1);
190        assert_eq!(entries[0].message, "message with fields");
191        assert_eq!(entries[0].fields["bytes"], 1024);
192        assert_eq!(entries[0].fields["status"], "ok");
193    }
194
195    #[test]
196    fn test_context_included() {
197        let temp_dir = TempDir::new().unwrap();
198        let layer = DiagnosticTraceLayer::new(temp_dir.path()).unwrap();
199
200        layer.set_context(5, "builder");
201
202        let subscriber = tracing_subscriber::registry().with(layer);
203        tracing::subscriber::with_default(subscriber, || {
204            info!("test message");
205        });
206
207        let trace_file = temp_dir.path().join("trace.jsonl");
208        let file = File::open(trace_file).unwrap();
209        let reader = std::io::BufReader::new(file);
210        let entries: Vec<TraceEntry> = reader
211            .lines()
212            .map(|line| serde_json::from_str(&line.unwrap()).unwrap())
213            .collect();
214
215        assert_eq!(entries.len(), 1);
216        assert_eq!(entries[0].iteration, Some(5));
217        assert_eq!(entries[0].hat, Some("builder".to_string()));
218    }
219}