Skip to main content

aster/tracing/
observation_layer.rs

1use chrono::Utc;
2use serde_json::{json, Value};
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::Arc;
6use tokio::sync::Mutex;
7use tracing::field::{Field, Visit};
8use tracing::{span, Event, Id, Level, Metadata, Subscriber};
9use tracing_subscriber::layer::Context;
10use tracing_subscriber::registry::LookupSpan;
11use tracing_subscriber::Layer;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
15pub struct SpanData {
16    pub observation_id: String, // Langfuse requires ids to be UUID v4 strings
17    pub name: String,
18    pub start_time: String,
19    pub level: String,
20    pub metadata: serde_json::Map<String, Value>,
21    pub parent_span_id: Option<u64>,
22}
23
24pub fn map_level(level: &Level) -> &'static str {
25    match *level {
26        Level::ERROR => "ERROR",
27        Level::WARN => "WARNING",
28        Level::INFO => "DEFAULT",
29        Level::DEBUG => "DEBUG",
30        Level::TRACE => "DEBUG",
31    }
32}
33
34pub fn flatten_metadata(
35    metadata: serde_json::Map<String, Value>,
36) -> serde_json::Map<String, Value> {
37    let mut flattened = serde_json::Map::new();
38    for (key, value) in metadata {
39        match value {
40            Value::String(s) => {
41                flattened.insert(key, json!(s));
42            }
43            Value::Object(mut obj) => {
44                if let Some(text) = obj.remove("text") {
45                    flattened.insert(key, text);
46                } else {
47                    flattened.insert(key, json!(obj));
48                }
49            }
50            _ => {
51                flattened.insert(key, value);
52            }
53        }
54    }
55    flattened
56}
57
58pub trait BatchManager: Send + Sync + 'static {
59    fn add_event(&mut self, event_type: &str, body: Value);
60    fn send(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
61    fn is_empty(&self) -> bool;
62}
63
64#[derive(Debug)]
65pub struct SpanTracker {
66    active_spans: HashMap<u64, String>, // span_id -> observation_id. span_id in Tracing is u64 whereas Langfuse requires UUID v4 strings
67    current_trace_id: Option<String>,
68}
69
70impl Default for SpanTracker {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl SpanTracker {
77    pub fn new() -> Self {
78        Self {
79            active_spans: HashMap::new(),
80            current_trace_id: None,
81        }
82    }
83
84    pub fn add_span(&mut self, span_id: u64, observation_id: String) {
85        self.active_spans.insert(span_id, observation_id);
86    }
87
88    pub fn get_span(&self, span_id: u64) -> Option<&String> {
89        self.active_spans.get(&span_id)
90    }
91
92    pub fn remove_span(&mut self, span_id: u64) -> Option<String> {
93        self.active_spans.remove(&span_id)
94    }
95}
96
97#[derive(Clone)]
98pub struct ObservationLayer {
99    pub batch_manager: Arc<Mutex<dyn BatchManager>>,
100    pub span_tracker: Arc<Mutex<SpanTracker>>,
101}
102
103impl ObservationLayer {
104    pub async fn handle_span(&self, span_id: u64, span_data: SpanData) {
105        let observation_id = span_data.observation_id.clone();
106
107        {
108            let mut spans = self.span_tracker.lock().await;
109            spans.add_span(span_id, observation_id.clone());
110        }
111
112        // Get parent ID if it exists
113        let parent_id = if let Some(parent_span_id) = span_data.parent_span_id {
114            let spans = self.span_tracker.lock().await;
115            spans.get_span(parent_span_id).cloned()
116        } else {
117            None
118        };
119
120        let trace_id = self.ensure_trace_id().await;
121
122        // Create the span observation
123        let mut batch = self.batch_manager.lock().await;
124        batch.add_event(
125            "observation-create",
126            json!({
127                "id": observation_id,
128                "traceId": trace_id,
129                "type": "SPAN",
130                "name": span_data.name,
131                "startTime": span_data.start_time,
132                "parentObservationId": parent_id,
133                "metadata": span_data.metadata,
134                "level": span_data.level
135            }),
136        );
137    }
138
139    pub async fn handle_span_close(&self, span_id: u64) {
140        let observation_id = {
141            let mut spans = self.span_tracker.lock().await;
142            spans.remove_span(span_id)
143        };
144
145        if let Some(observation_id) = observation_id {
146            let trace_id = self.ensure_trace_id().await;
147            let mut batch = self.batch_manager.lock().await;
148            batch.add_event(
149                "observation-update",
150                json!({
151                    "id": observation_id,
152                    "type": "SPAN",
153                    "traceId": trace_id,
154                    "endTime": Utc::now().to_rfc3339()
155                }),
156            );
157        }
158    }
159
160    pub async fn ensure_trace_id(&self) -> String {
161        let mut spans = self.span_tracker.lock().await;
162        if let Some(id) = spans.current_trace_id.clone() {
163            return id;
164        }
165
166        let trace_id = Uuid::new_v4().to_string();
167        spans.current_trace_id = Some(trace_id.clone());
168
169        let mut batch = self.batch_manager.lock().await;
170        batch.add_event(
171            "trace-create",
172            json!({
173                "id": trace_id,
174                "name": Utc::now().timestamp().to_string(),
175                "timestamp": Utc::now().to_rfc3339(),
176                "input": {},
177                "metadata": {},
178                "tags": [],
179                "public": false
180            }),
181        );
182
183        trace_id
184    }
185
186    pub async fn handle_record(&self, span_id: u64, metadata: serde_json::Map<String, Value>) {
187        let observation_id = {
188            let spans = self.span_tracker.lock().await;
189            spans.get_span(span_id).cloned()
190        };
191
192        if let Some(observation_id) = observation_id {
193            let trace_id = self.ensure_trace_id().await;
194
195            let mut update = json!({
196                "id": observation_id,
197                "traceId": trace_id,
198                "type": "SPAN"
199            });
200
201            // Handle special fields
202            if let Some(val) = metadata.get("input") {
203                update["input"] = val.clone();
204            }
205
206            if let Some(val) = metadata.get("output") {
207                update["output"] = val.clone();
208            }
209
210            if let Some(val) = metadata.get("model_config") {
211                update["metadata"] = json!({ "model_config": val });
212            }
213
214            // Handle any remaining metadata
215            let remaining_metadata: serde_json::Map<String, Value> = metadata
216                .iter()
217                .filter(|(k, _)| !["input", "output", "model_config"].contains(&k.as_str()))
218                .map(|(k, v)| (k.clone(), v.clone()))
219                .collect();
220
221            if !remaining_metadata.is_empty() {
222                let flattened = flatten_metadata(remaining_metadata);
223                if update.get("metadata").is_some() {
224                    // If metadata exists (from model_config), merge with it
225                    if let Some(obj) = update["metadata"].as_object_mut() {
226                        for (k, v) in flattened {
227                            obj.insert(k, v);
228                        }
229                    }
230                } else {
231                    // Otherwise set it directly
232                    update["metadata"] = json!(flattened);
233                }
234            }
235
236            let mut batch = self.batch_manager.lock().await;
237            batch.add_event("span-update", update);
238        }
239    }
240}
241
242impl<S> Layer<S> for ObservationLayer
243where
244    S: Subscriber + for<'a> LookupSpan<'a>,
245{
246    fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool {
247        metadata.target().starts_with("aster::")
248    }
249
250    fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
251        let span_id = id.into_u64();
252
253        let parent_span_id = ctx
254            .span_scope(id)
255            .and_then(|mut scope| scope.nth(1))
256            .map(|parent| parent.id().into_u64());
257
258        let mut visitor = JsonVisitor::new();
259        attrs.record(&mut visitor);
260
261        let span_data = SpanData {
262            observation_id: Uuid::new_v4().to_string(),
263            name: attrs.metadata().name().to_string(),
264            start_time: Utc::now().to_rfc3339(),
265            level: map_level(attrs.metadata().level()).to_owned(),
266            metadata: visitor.recorded_fields,
267            parent_span_id,
268        };
269
270        let layer = self.clone();
271        tokio::spawn(async move { layer.handle_span(span_id, span_data).await });
272    }
273
274    fn on_close(&self, id: Id, _ctx: Context<'_, S>) {
275        let span_id = id.into_u64();
276        let layer = self.clone();
277        tokio::spawn(async move { layer.handle_span_close(span_id).await });
278    }
279
280    fn on_record(&self, span: &Id, values: &span::Record<'_>, _ctx: Context<'_, S>) {
281        let span_id = span.into_u64();
282        let mut visitor = JsonVisitor::new();
283        values.record(&mut visitor);
284        let metadata = visitor.recorded_fields;
285
286        if !metadata.is_empty() {
287            let layer = self.clone();
288            tokio::spawn(async move { layer.handle_record(span_id, metadata).await });
289        }
290    }
291
292    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
293        let mut visitor = JsonVisitor::new();
294        event.record(&mut visitor);
295        let metadata = visitor.recorded_fields;
296
297        if let Some(span_id) = ctx.lookup_current().map(|span| span.id().into_u64()) {
298            let layer = self.clone();
299            tokio::spawn(async move { layer.handle_record(span_id, metadata).await });
300        }
301    }
302}
303
304#[derive(Debug)]
305struct JsonVisitor {
306    recorded_fields: serde_json::Map<String, Value>,
307}
308
309impl JsonVisitor {
310    fn new() -> Self {
311        Self {
312            recorded_fields: serde_json::Map::new(),
313        }
314    }
315
316    fn insert_value(&mut self, field: &Field, value: Value) {
317        self.recorded_fields.insert(field.name().to_string(), value);
318    }
319}
320
321macro_rules! record_field {
322    ($fn_name:ident, $type:ty) => {
323        fn $fn_name(&mut self, field: &Field, value: $type) {
324            self.insert_value(field, Value::from(value));
325        }
326    };
327}
328
329impl Visit for JsonVisitor {
330    record_field!(record_i64, i64);
331    record_field!(record_u64, u64);
332    record_field!(record_bool, bool);
333    record_field!(record_str, &str);
334
335    fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) {
336        self.insert_value(field, Value::String(format!("{:?}", value)));
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use std::time::Duration;
344    use tokio::sync::mpsc;
345    use tracing::dispatcher;
346
347    type Events = Arc<Mutex<Vec<(String, Value)>>>;
348    struct TestFixture {
349        original_subscriber: Option<dispatcher::Dispatch>,
350        events: Option<Events>,
351    }
352
353    impl TestFixture {
354        fn new() -> Self {
355            Self {
356                original_subscriber: Some(dispatcher::get_default(dispatcher::Dispatch::clone)),
357                events: None,
358            }
359        }
360
361        fn with_test_layer(mut self) -> (Self, ObservationLayer) {
362            let events = Arc::new(Mutex::new(Vec::new()));
363            let mock_manager = MockBatchManager::new(events.clone());
364
365            let layer = ObservationLayer {
366                batch_manager: Arc::new(Mutex::new(mock_manager)),
367                span_tracker: Arc::new(Mutex::new(SpanTracker::new())),
368            };
369
370            self.events = Some(events);
371            (self, layer)
372        }
373
374        async fn get_events(&self) -> Vec<(String, Value)> {
375            self.events
376                .as_ref()
377                .expect("Events not initialized")
378                .lock()
379                .await
380                .clone()
381        }
382    }
383
384    impl Drop for TestFixture {
385        fn drop(&mut self) {
386            if let Some(subscriber) = &self.original_subscriber {
387                let _ = dispatcher::set_global_default(subscriber.clone());
388            }
389        }
390    }
391
392    struct MockBatchManager {
393        events: Arc<Mutex<Vec<(String, Value)>>>,
394        sender: mpsc::UnboundedSender<(String, Value)>,
395    }
396
397    impl MockBatchManager {
398        fn new(events: Arc<Mutex<Vec<(String, Value)>>>) -> Self {
399            let (sender, mut receiver) = mpsc::unbounded_channel();
400            let events_clone = events.clone();
401
402            tokio::spawn(async move {
403                while let Some((event_type, body)) = receiver.recv().await {
404                    events_clone.lock().await.push((event_type, body));
405                }
406            });
407
408            Self { events, sender }
409        }
410    }
411
412    impl BatchManager for MockBatchManager {
413        fn add_event(&mut self, event_type: &str, body: Value) {
414            self.sender
415                .send((event_type.to_string(), body))
416                .expect("Failed to send event");
417        }
418
419        fn send(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
420            Ok(())
421        }
422
423        fn is_empty(&self) -> bool {
424            futures::executor::block_on(async { self.events.lock().await.is_empty() })
425        }
426    }
427
428    fn create_test_span_data() -> SpanData {
429        SpanData {
430            observation_id: Uuid::new_v4().to_string(),
431            name: "test_span".to_string(),
432            start_time: Utc::now().to_rfc3339(),
433            level: "DEFAULT".to_string(),
434            metadata: serde_json::Map::new(),
435            parent_span_id: None,
436        }
437    }
438
439    const TEST_WAIT_DURATION: Duration = Duration::from_secs(6);
440
441    #[tokio::test]
442    async fn test_span_creation() {
443        let (fixture, layer) = TestFixture::new().with_test_layer();
444        let span_id = 1u64;
445        let span_data = create_test_span_data();
446
447        layer.handle_span(span_id, span_data.clone()).await;
448        tokio::time::sleep(TEST_WAIT_DURATION).await;
449
450        let events = fixture.get_events().await;
451        assert_eq!(events.len(), 2); // trace-create and observation-create
452
453        let (event_type, body) = &events[1];
454        assert_eq!(event_type, "observation-create");
455        assert_eq!(body["id"], span_data.observation_id);
456        assert_eq!(body["name"], "test_span");
457        assert_eq!(body["type"], "SPAN");
458    }
459
460    #[tokio::test]
461    async fn test_span_close() {
462        let (fixture, layer) = TestFixture::new().with_test_layer();
463        let span_id = 1u64;
464        let span_data = create_test_span_data();
465
466        layer.handle_span(span_id, span_data.clone()).await;
467        layer.handle_span_close(span_id).await;
468        tokio::time::sleep(TEST_WAIT_DURATION).await;
469
470        let events = fixture.get_events().await;
471        assert_eq!(events.len(), 3); // trace-create, observation-create, observation-update
472
473        let (event_type, body) = &events[2];
474        assert_eq!(event_type, "observation-update");
475        assert_eq!(body["id"], span_data.observation_id);
476        assert!(body["endTime"].as_str().is_some());
477    }
478
479    #[tokio::test]
480    async fn test_record_handling() {
481        let (fixture, layer) = TestFixture::new().with_test_layer();
482        let span_id = 1u64;
483        let span_data = create_test_span_data();
484
485        layer.handle_span(span_id, span_data.clone()).await;
486
487        let mut metadata = serde_json::Map::new();
488        metadata.insert("input".to_string(), json!("test input"));
489        metadata.insert("output".to_string(), json!("test output"));
490        metadata.insert("custom_field".to_string(), json!("custom value"));
491
492        layer.handle_record(span_id, metadata).await;
493        tokio::time::sleep(TEST_WAIT_DURATION).await;
494
495        let events = fixture.get_events().await;
496        assert_eq!(events.len(), 3); // trace-create, observation-create, span-update
497
498        let (event_type, body) = &events[2];
499        assert_eq!(event_type, "span-update");
500        assert_eq!(body["input"], "test input");
501        assert_eq!(body["output"], "test output");
502        assert_eq!(body["metadata"]["custom_field"], "custom value");
503    }
504
505    #[test]
506    fn test_flatten_metadata() {
507        let _fixture = TestFixture::new();
508        let mut metadata = serde_json::Map::new();
509        metadata.insert("simple".to_string(), json!("value"));
510        metadata.insert(
511            "complex".to_string(),
512            json!({
513                "text": "inner value"
514            }),
515        );
516
517        let flattened = flatten_metadata(metadata);
518        assert_eq!(flattened["simple"], "value");
519        assert_eq!(flattened["complex"], "inner value");
520    }
521}