1use std::sync::Arc;
5
6use cognis_core::callbacks::CallbackHandler;
7use dashmap::DashMap;
8use uuid::Uuid;
9
10use crate::batch::{Batcher, BatcherConfig};
11use crate::cost::PriceTable;
12use crate::exporter::TraceExporter;
13use crate::span::{ScoreRecord, Span, SpanBuilder, SpanKind};
14
15pub struct TracingHandler {
17 exporters: Vec<Arc<dyn TraceExporter>>,
18 inflight: DashMap<Uuid, SpanBuilder>,
19 span_batchers: Vec<Batcher<Span>>,
20 score_batchers: Vec<Batcher<ScoreRecord>>,
21 pricing: Arc<PriceTable>,
22}
23
24impl TracingHandler {
25 pub fn builder() -> TracingHandlerBuilder {
27 TracingHandlerBuilder::default()
28 }
29
30 pub fn record_score(&self, score: ScoreRecord) {
32 for b in &self.score_batchers {
33 b.send(score.clone());
34 }
35 }
36
37 pub fn stats(&self, exporter_name: &str) -> Option<(usize, usize, usize)> {
39 for (i, e) in self.exporters.iter().enumerate() {
40 if e.name() == exporter_name {
41 return self.span_batchers.get(i).map(|b| b.stats().snapshot());
42 }
43 }
44 None
45 }
46
47 pub async fn shutdown(self) {
50 let Self {
51 exporters,
52 span_batchers,
53 score_batchers,
54 ..
55 } = self;
56 for b in span_batchers {
57 b.shutdown().await;
58 }
59 for b in score_batchers {
60 b.shutdown().await;
61 }
62 for e in exporters {
63 if let Err(err) = e.shutdown().await {
64 tracing::warn!(exporter = e.name(), error = %err, "exporter shutdown failed");
65 }
66 }
67 }
68}
69
70#[derive(Default)]
72pub struct TracingHandlerBuilder {
73 exporters: Vec<Arc<dyn TraceExporter>>,
74 pricing: Option<PriceTable>,
75 batcher_cfg: BatcherConfig,
76}
77
78impl TracingHandlerBuilder {
79 pub fn with_exporter<E: TraceExporter + 'static>(mut self, e: E) -> Self {
81 self.exporters.push(Arc::new(e));
82 self
83 }
84
85 pub fn with_default_pricing(mut self) -> Self {
87 self.pricing = Some(PriceTable::with_defaults());
88 self
89 }
90
91 pub fn with_pricing(mut self, p: PriceTable) -> Self {
93 self.pricing = Some(p);
94 self
95 }
96
97 pub fn override_price(mut self, model: impl Into<String>, p: crate::cost::ModelPrice) -> Self {
99 let mut t = self.pricing.unwrap_or_default();
100 t.insert(model, p);
101 self.pricing = Some(t);
102 self
103 }
104
105 pub fn with_batcher_config(mut self, cfg: BatcherConfig) -> Self {
107 self.batcher_cfg = cfg;
108 self
109 }
110
111 pub fn build(self) -> TracingHandler {
114 let cfg = self.batcher_cfg;
115 let pricing = Arc::new(self.pricing.unwrap_or_default());
116
117 let mut span_batchers = Vec::with_capacity(self.exporters.len());
118 let mut score_batchers = Vec::with_capacity(self.exporters.len());
119 for e in &self.exporters {
120 let e_for_spans = e.clone();
121 span_batchers.push(Batcher::spawn(cfg, move |batch: Vec<Span>| {
122 let e = e_for_spans.clone();
123 async move { e.export_spans(batch).await }
124 }));
125 let e_for_scores = e.clone();
126 score_batchers.push(Batcher::spawn(cfg, move |batch: Vec<ScoreRecord>| {
127 let e = e_for_scores.clone();
128 async move { e.export_scores(batch).await }
129 }));
130 }
131
132 TracingHandler {
133 exporters: self.exporters,
134 inflight: DashMap::new(),
135 span_batchers,
136 score_batchers,
137 pricing,
138 }
139 }
140}
141
142impl TracingHandler {
143 fn start_span(
144 &self,
145 kind: SpanKind,
146 name: &str,
147 input: Option<serde_json::Value>,
148 run_id: Uuid,
149 ) {
150 let parent = crate::parent::peek();
151 let trace_id = parent.unwrap_or(run_id);
152 let b = SpanBuilder::open(
153 run_id,
154 parent,
155 trace_id,
156 kind,
157 name.to_string(),
158 input,
159 std::time::SystemTime::now(),
160 );
161 self.inflight.insert(run_id, b);
162 crate::parent::push(run_id);
163 }
164
165 fn finish_ok(&self, run_id: Uuid, output: Option<serde_json::Value>) {
166 if let Some((_, b)) = self.inflight.remove(&run_id) {
167 let span = b.finish_ok(output, std::time::SystemTime::now());
168 self.dispatch(span);
169 }
170 crate::parent::pop(run_id);
171 }
172
173 fn finish_error(&self, run_id: Uuid, message: &str) {
174 if let Some((_, b)) = self.inflight.remove(&run_id) {
175 let span = b.finish_error(message, std::time::SystemTime::now());
176 self.dispatch(span);
177 }
178 crate::parent::pop(run_id);
179 }
180
181 fn dispatch(&self, span: Span) {
182 for b in &self.span_batchers {
183 b.send(span.clone());
184 }
185 }
186
187 fn parse_generation(
190 &self,
191 model_hint: &str,
192 payload: &serde_json::Value,
193 ) -> crate::span::Generation {
194 use crate::span::{Generation, TokenUsage};
195
196 let obj = payload.as_object();
197 let model = obj
198 .and_then(|o| o.get("model"))
199 .and_then(|v| v.as_str())
200 .unwrap_or(model_hint)
201 .to_string();
202 let provider = obj
203 .and_then(|o| o.get("provider"))
204 .and_then(|v| v.as_str())
205 .unwrap_or("")
206 .to_string();
207 let finish_reason = obj
208 .and_then(|o| o.get("finish_reason"))
209 .and_then(|v| v.as_str())
210 .map(String::from);
211 let model_parameters = obj
212 .and_then(|o| o.get("model_parameters"))
213 .and_then(|v| v.as_object())
214 .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
215 .unwrap_or_default();
216 let usage = obj
217 .and_then(|o| o.get("usage"))
218 .and_then(|v| v.as_object())
219 .map(|u| TokenUsage {
220 input: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
221 output: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
222 cache_read: u
223 .get("cache_read_tokens")
224 .and_then(|v| v.as_u64())
225 .unwrap_or(0) as u32,
226 cache_write: u
227 .get("cache_creation_tokens")
228 .and_then(|v| v.as_u64())
229 .unwrap_or(0) as u32,
230 })
231 .unwrap_or_default();
232 let prompt_name = obj
233 .and_then(|o| o.get("prompt_name"))
234 .and_then(|v| v.as_str())
235 .map(String::from);
236 let prompt_version = obj
237 .and_then(|o| o.get("prompt_version"))
238 .and_then(|v| v.as_u64())
239 .map(|n| n as u32);
240 let cost = self.pricing.compute(&model, usage);
241 Generation {
242 model,
243 provider,
244 model_parameters,
245 usage,
246 cost,
247 completion_start_time: None,
248 finish_reason,
249 prompt_name,
250 prompt_version,
251 }
252 }
253}
254
255impl CallbackHandler for TracingHandler {
256 fn name(&self) -> &str {
257 "cognis_trace::TracingHandler"
258 }
259
260 fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
261 self.start_span(SpanKind::Chain, runnable, Some(input.clone()), run_id);
262 }
263
264 fn on_chain_end(&self, _runnable: &str, output: &serde_json::Value, run_id: Uuid) {
265 self.finish_ok(run_id, Some(output.clone()));
266 }
267
268 fn on_chain_error(&self, _runnable: &str, error: &str, run_id: Uuid) {
269 self.finish_error(run_id, error);
270 }
271
272 fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
273 self.start_span(SpanKind::Tool, tool, Some(args.clone()), run_id);
274 }
275
276 fn on_tool_end(&self, _tool: &str, result: &serde_json::Value, run_id: Uuid) {
277 self.finish_ok(run_id, Some(result.clone()));
278 }
279
280 fn on_tool_error(&self, _tool: &str, error: &str, run_id: Uuid) {
281 self.finish_error(run_id, error);
282 }
283
284 fn on_node_start(&self, node: &str, _step: u64, run_id: Uuid) {
285 self.start_span(SpanKind::Span, node, None, run_id);
286 }
287
288 fn on_node_end(&self, _node: &str, _step: u64, output: &serde_json::Value, run_id: Uuid) {
289 self.finish_ok(run_id, Some(output.clone()));
290 }
291
292 fn on_checkpoint(&self, _step: u64, _run_id: Uuid) {
293 }
296
297 fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
298 let trace_id = crate::parent::peek().unwrap_or(run_id);
300 let now = std::time::SystemTime::now();
301 let mut b = SpanBuilder::open(
302 run_id,
303 crate::parent::peek(),
304 trace_id,
305 SpanKind::Event,
306 kind,
307 Some(payload.clone()),
308 now,
309 );
310 b.span
311 .metadata
312 .insert("kind".into(), serde_json::Value::String(kind.into()));
313 let span = b.finish_ok(None, now);
314 self.dispatch(span);
315 }
316
317 fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
318 self.start_span(SpanKind::Generation, model, Some(prompt.clone()), run_id);
319 }
320
321 fn on_llm_token(&self, _token: &str, _run_id: Uuid) {
322 }
324
325 fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
326 let generation = self.parse_generation(model, output);
327 if let Some((_, b)) = self.inflight.remove(&run_id) {
328 let b = b.with_generation(generation);
329 let out = output
331 .as_object()
332 .and_then(|o| o.get("content").cloned())
333 .or_else(|| Some(output.clone()));
334 let span = b.finish_ok(out, std::time::SystemTime::now());
335 self.dispatch(span);
336 }
337 crate::parent::pop(run_id);
338 }
339
340 fn on_llm_error(&self, _model: &str, error: &str, run_id: Uuid) {
341 self.finish_error(run_id, error);
342 }
343}