Skip to main content

cognis_trace/
handler.rs

1//! `TracingHandler` — `CallbackHandler` impl that translates lifecycle
2//! events into `Span`s and fans them out to per-exporter batchers.
3
4use std::sync::Arc;
5
6use cognis_core::callbacks::CallbackHandler;
7use dashmap::DashMap;
8use uuid::Uuid;
9
10use crate::batch::{Batcher, BatcherConfig};
11use crate::cost::PriceTable;
12use crate::exporter::TraceExporter;
13use crate::span::{ScoreRecord, Span, SpanBuilder, SpanKind};
14
15/// Bridges `CallbackHandler` events to one or more `TraceExporter`s.
16pub struct TracingHandler {
17    exporters: Vec<Arc<dyn TraceExporter>>,
18    inflight: DashMap<Uuid, SpanBuilder>,
19    span_batchers: Vec<Batcher<Span>>,
20    score_batchers: Vec<Batcher<ScoreRecord>>,
21    pricing: Arc<PriceTable>,
22}
23
24impl TracingHandler {
25    /// Start a new builder.
26    pub fn builder() -> TracingHandlerBuilder {
27        TracingHandlerBuilder::default()
28    }
29
30    /// Submit an out-of-band evaluation score for an existing run_id.
31    pub fn record_score(&self, score: ScoreRecord) {
32        for b in &self.score_batchers {
33            b.send(score.clone());
34        }
35    }
36
37    /// Stats per exporter (sent, dropped, failed).
38    pub fn stats(&self, exporter_name: &str) -> Option<(usize, usize, usize)> {
39        for (i, e) in self.exporters.iter().enumerate() {
40            if e.name() == exporter_name {
41                return self.span_batchers.get(i).map(|b| b.stats().snapshot());
42            }
43        }
44        None
45    }
46
47    /// Graceful shutdown: drain batchers, then call each exporter's
48    /// `shutdown()`. Must be awaited.
49    pub async fn shutdown(self) {
50        let Self {
51            exporters,
52            span_batchers,
53            score_batchers,
54            ..
55        } = self;
56        for b in span_batchers {
57            b.shutdown().await;
58        }
59        for b in score_batchers {
60            b.shutdown().await;
61        }
62        for e in exporters {
63            if let Err(err) = e.shutdown().await {
64                tracing::warn!(exporter = e.name(), error = %err, "exporter shutdown failed");
65            }
66        }
67    }
68}
69
70/// Builder for `TracingHandler`.
71#[derive(Default)]
72pub struct TracingHandlerBuilder {
73    exporters: Vec<Arc<dyn TraceExporter>>,
74    pricing: Option<PriceTable>,
75    batcher_cfg: BatcherConfig,
76}
77
78impl TracingHandlerBuilder {
79    /// Append an exporter.
80    pub fn with_exporter<E: TraceExporter + 'static>(mut self, e: E) -> Self {
81        self.exporters.push(Arc::new(e));
82        self
83    }
84
85    /// Use the dated default pricing snapshot.
86    pub fn with_default_pricing(mut self) -> Self {
87        self.pricing = Some(PriceTable::with_defaults());
88        self
89    }
90
91    /// Provide a fully custom price table.
92    pub fn with_pricing(mut self, p: PriceTable) -> Self {
93        self.pricing = Some(p);
94        self
95    }
96
97    /// Override or insert one model's price.
98    pub fn override_price(mut self, model: impl Into<String>, p: crate::cost::ModelPrice) -> Self {
99        let mut t = self.pricing.unwrap_or_default();
100        t.insert(model, p);
101        self.pricing = Some(t);
102        self
103    }
104
105    /// Override the per-exporter batcher config.
106    pub fn with_batcher_config(mut self, cfg: BatcherConfig) -> Self {
107        self.batcher_cfg = cfg;
108        self
109    }
110
111    /// Finalize. Spawns one `Batcher<Span>` and one `Batcher<ScoreRecord>`
112    /// per exporter.
113    pub fn build(self) -> TracingHandler {
114        let cfg = self.batcher_cfg;
115        let pricing = Arc::new(self.pricing.unwrap_or_default());
116
117        let mut span_batchers = Vec::with_capacity(self.exporters.len());
118        let mut score_batchers = Vec::with_capacity(self.exporters.len());
119        for e in &self.exporters {
120            let e_for_spans = e.clone();
121            span_batchers.push(Batcher::spawn(cfg, move |batch: Vec<Span>| {
122                let e = e_for_spans.clone();
123                async move { e.export_spans(batch).await }
124            }));
125            let e_for_scores = e.clone();
126            score_batchers.push(Batcher::spawn(cfg, move |batch: Vec<ScoreRecord>| {
127                let e = e_for_scores.clone();
128                async move { e.export_scores(batch).await }
129            }));
130        }
131
132        TracingHandler {
133            exporters: self.exporters,
134            inflight: DashMap::new(),
135            span_batchers,
136            score_batchers,
137            pricing,
138        }
139    }
140}
141
142impl TracingHandler {
143    fn start_span(
144        &self,
145        kind: SpanKind,
146        name: &str,
147        input: Option<serde_json::Value>,
148        run_id: Uuid,
149    ) {
150        let parent = crate::parent::peek();
151        let trace_id = parent.unwrap_or(run_id);
152        let b = SpanBuilder::open(
153            run_id,
154            parent,
155            trace_id,
156            kind,
157            name.to_string(),
158            input,
159            std::time::SystemTime::now(),
160        );
161        self.inflight.insert(run_id, b);
162        crate::parent::push(run_id);
163    }
164
165    fn finish_ok(&self, run_id: Uuid, output: Option<serde_json::Value>) {
166        if let Some((_, b)) = self.inflight.remove(&run_id) {
167            let span = b.finish_ok(output, std::time::SystemTime::now());
168            self.dispatch(span);
169        }
170        crate::parent::pop(run_id);
171    }
172
173    fn finish_error(&self, run_id: Uuid, message: &str) {
174        if let Some((_, b)) = self.inflight.remove(&run_id) {
175            let span = b.finish_error(message, std::time::SystemTime::now());
176            self.dispatch(span);
177        }
178        crate::parent::pop(run_id);
179    }
180
181    fn dispatch(&self, span: Span) {
182        for b in &self.span_batchers {
183            b.send(span.clone());
184        }
185    }
186
187    /// Parse a provider's `on_llm_end` payload into a `Generation` per the
188    /// schema documented in spec §4.3. Missing fields default cleanly.
189    fn parse_generation(
190        &self,
191        model_hint: &str,
192        payload: &serde_json::Value,
193    ) -> crate::span::Generation {
194        use crate::span::{Generation, TokenUsage};
195
196        let obj = payload.as_object();
197        let model = obj
198            .and_then(|o| o.get("model"))
199            .and_then(|v| v.as_str())
200            .unwrap_or(model_hint)
201            .to_string();
202        let provider = obj
203            .and_then(|o| o.get("provider"))
204            .and_then(|v| v.as_str())
205            .unwrap_or("")
206            .to_string();
207        let finish_reason = obj
208            .and_then(|o| o.get("finish_reason"))
209            .and_then(|v| v.as_str())
210            .map(String::from);
211        let model_parameters = obj
212            .and_then(|o| o.get("model_parameters"))
213            .and_then(|v| v.as_object())
214            .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
215            .unwrap_or_default();
216        let usage = obj
217            .and_then(|o| o.get("usage"))
218            .and_then(|v| v.as_object())
219            .map(|u| TokenUsage {
220                input: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
221                output: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
222                cache_read: u
223                    .get("cache_read_tokens")
224                    .and_then(|v| v.as_u64())
225                    .unwrap_or(0) as u32,
226                cache_write: u
227                    .get("cache_creation_tokens")
228                    .and_then(|v| v.as_u64())
229                    .unwrap_or(0) as u32,
230            })
231            .unwrap_or_default();
232        let prompt_name = obj
233            .and_then(|o| o.get("prompt_name"))
234            .and_then(|v| v.as_str())
235            .map(String::from);
236        let prompt_version = obj
237            .and_then(|o| o.get("prompt_version"))
238            .and_then(|v| v.as_u64())
239            .map(|n| n as u32);
240        let cost = self.pricing.compute(&model, usage);
241        Generation {
242            model,
243            provider,
244            model_parameters,
245            usage,
246            cost,
247            completion_start_time: None,
248            finish_reason,
249            prompt_name,
250            prompt_version,
251        }
252    }
253}
254
255impl CallbackHandler for TracingHandler {
256    fn name(&self) -> &str {
257        "cognis_trace::TracingHandler"
258    }
259
260    fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
261        self.start_span(SpanKind::Chain, runnable, Some(input.clone()), run_id);
262    }
263
264    fn on_chain_end(&self, _runnable: &str, output: &serde_json::Value, run_id: Uuid) {
265        self.finish_ok(run_id, Some(output.clone()));
266    }
267
268    fn on_chain_error(&self, _runnable: &str, error: &str, run_id: Uuid) {
269        self.finish_error(run_id, error);
270    }
271
272    fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
273        self.start_span(SpanKind::Tool, tool, Some(args.clone()), run_id);
274    }
275
276    fn on_tool_end(&self, _tool: &str, result: &serde_json::Value, run_id: Uuid) {
277        self.finish_ok(run_id, Some(result.clone()));
278    }
279
280    fn on_tool_error(&self, _tool: &str, error: &str, run_id: Uuid) {
281        self.finish_error(run_id, error);
282    }
283
284    fn on_node_start(&self, node: &str, _step: u64, run_id: Uuid) {
285        self.start_span(SpanKind::Span, node, None, run_id);
286    }
287
288    fn on_node_end(&self, _node: &str, _step: u64, output: &serde_json::Value, run_id: Uuid) {
289        self.finish_ok(run_id, Some(output.clone()));
290    }
291
292    fn on_checkpoint(&self, _step: u64, _run_id: Uuid) {
293        // No-op for trace tree shape; checkpoints are visible via the
294        // graph engine's own observers.
295    }
296
297    fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
298        // Emit a discrete EVENT span: open and immediately close.
299        let trace_id = crate::parent::peek().unwrap_or(run_id);
300        let now = std::time::SystemTime::now();
301        let mut b = SpanBuilder::open(
302            run_id,
303            crate::parent::peek(),
304            trace_id,
305            SpanKind::Event,
306            kind,
307            Some(payload.clone()),
308            now,
309        );
310        b.span
311            .metadata
312            .insert("kind".into(), serde_json::Value::String(kind.into()));
313        let span = b.finish_ok(None, now);
314        self.dispatch(span);
315    }
316
317    fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
318        self.start_span(SpanKind::Generation, model, Some(prompt.clone()), run_id);
319    }
320
321    fn on_llm_token(&self, _token: &str, _run_id: Uuid) {
322        // Streaming tokens are not buffered into the trace by default.
323    }
324
325    fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
326        let generation = self.parse_generation(model, output);
327        if let Some((_, b)) = self.inflight.remove(&run_id) {
328            let b = b.with_generation(generation);
329            // Use the payload's `content` field as `output` if present, else the whole payload.
330            let out = output
331                .as_object()
332                .and_then(|o| o.get("content").cloned())
333                .or_else(|| Some(output.clone()));
334            let span = b.finish_ok(out, std::time::SystemTime::now());
335            self.dispatch(span);
336        }
337        crate::parent::pop(run_id);
338    }
339
340    fn on_llm_error(&self, _model: &str, error: &str, run_id: Uuid) {
341        self.finish_error(run_id, error);
342    }
343}