1use crate::event_bus::EventBus;
2use crate::store::TracingStore;
3use crate::types::*;
4use async_trait::async_trait;
5use langgraph_checkpoint::config::RunnableConfig;
6use langgraph_prebuilt::traits::{BaseChatModel, BaseTool, ToolDef};
7use langgraph_prebuilt::types::Message;
8use serde_json::Value as JsonValue;
9use std::sync::Arc;
10use std::time::Instant;
11use uuid::Uuid;
12
13pub struct TracingChatModel<M: BaseChatModel> {
15 inner: M,
16 store: Arc<dyn TracingStore>,
17 event_bus: EventBus,
18 trace_id: String,
19 parent_span_id: Option<String>,
20}
21
22impl<M: BaseChatModel> TracingChatModel<M> {
23 pub fn new(
24 inner: M,
25 store: Arc<dyn TracingStore>,
26 event_bus: EventBus,
27 trace_id: String,
28 ) -> Self {
29 Self {
30 inner,
31 store,
32 event_bus,
33 trace_id,
34 parent_span_id: None,
35 }
36 }
37
38 pub fn with_parent_span(mut self, span_id: String) -> Self {
39 self.parent_span_id = Some(span_id);
40 self
41 }
42}
43
44fn record_llm_span(
45 store: &dyn TracingStore,
46 event_bus: &EventBus,
47 trace_id: &str,
48 parent_span_id: &Option<String>,
49 model_name: &str,
50 input_json: JsonValue,
51 result: &Result<Message, langgraph_prebuilt::traits::ModelError>,
52) {
53 let span_id = Uuid::new_v4().to_string();
54 match result {
55 Ok(response) => {
56 let output_json = serde_json::to_value(response).unwrap_or(JsonValue::Null);
57 let usage = response.usage();
58 let mut span = Span::new(
59 span_id,
60 trace_id.to_string(),
61 parent_span_id.clone(),
62 model_name.to_string(),
63 SpanType::LlmGeneration,
64 input_json,
65 );
66 span.finish(output_json, SpanStatus::Success);
67 span.metadata.model = Some(model_name.to_string());
68 span.metadata.tokens_in = usage.map(|u| u.prompt_tokens);
69 span.metadata.tokens_out = usage.map(|u| u.completion_tokens);
70 span.metadata.total_tokens = usage.map(|u| u.total_tokens);
71 store.add_span(span.clone());
72 event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
73 }
74 Err(e) => {
75 let mut span = Span::new(
76 span_id,
77 trace_id.to_string(),
78 parent_span_id.clone(),
79 model_name.to_string(),
80 SpanType::LlmGeneration,
81 input_json,
82 );
83 span.finish(
84 serde_json::json!({"error": e.to_string()}),
85 SpanStatus::Error,
86 );
87 span.metadata.model = Some(model_name.to_string());
88 store.add_span(span.clone());
89 event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
90 }
91 }
92}
93
94#[async_trait]
95impl<M: BaseChatModel + 'static> BaseChatModel for TracingChatModel<M> {
96 fn name(&self) -> &str {
97 self.inner.name()
98 }
99
100 fn invoke(
101 &self,
102 messages: &[Message],
103 config: &RunnableConfig,
104 ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
105 let start = Instant::now();
106 let result = self.inner.invoke(messages, config);
107 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
108 record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
109 let _ = start;
110 result
111 }
112
113 async fn ainvoke(
114 &self,
115 messages: &[Message],
116 config: &RunnableConfig,
117 ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
118 let start = Instant::now();
119 let result = self.inner.ainvoke(messages, config).await;
120 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
121 record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
122 let _ = start;
123 result
124 }
125
126 fn astream<'a>(
127 &'a self,
128 messages: &'a [Message],
129 config: &'a RunnableConfig,
130 ) -> langgraph_prebuilt::MessageStream<'a> {
131 let store = self.store.clone();
132 let event_bus = self.event_bus.clone();
133 let trace_id = self.trace_id.clone();
134 let parent_span_id = self.parent_span_id.clone();
135 let model_name = self.inner.name().to_string();
136 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
137
138 let mut stream = self.inner.astream(messages, config);
139
140 Box::pin(async_stream::stream! {
141 let mut accumulated_message: Option<Message> = None;
142
143 while let Some(result) = tokio_stream::StreamExt::next(&mut stream).await {
144 if let Ok(ref msg) = result {
145 match accumulated_message {
146 None => {
147 accumulated_message = Some(msg.clone());
148 }
149 Some(langgraph_prebuilt::types::Message::Ai {
150 content: langgraph_prebuilt::types::MessageContent::Text(ref mut acc_text),
151 ref mut tool_calls,
152 ref mut usage,
153 ..
154 }) => {
155 if let langgraph_prebuilt::types::Message::Ai {
156 content: langgraph_prebuilt::types::MessageContent::Text(ref msg_text),
157 tool_calls: ref msg_tools,
158 usage: ref msg_usage,
159 ..
160 } = msg {
161 acc_text.push_str(msg_text);
162 for tc in msg_tools {
163 if !tool_calls.iter().any(|existing| existing.id == tc.id && tc.id.is_some()) {
164 tool_calls.push(tc.clone());
165 }
166 }
167 if msg_usage.is_some() {
168 *usage = msg_usage.clone();
169 }
170 }
171 }
172 _ => {
173 accumulated_message = Some(msg.clone());
175 }
176 }
177 }
178 yield result;
179 }
180
181 if let Some(final_msg) = accumulated_message {
183 record_llm_span(
184 store.as_ref(),
185 &event_bus,
186 &trace_id,
187 &parent_span_id,
188 &model_name,
189 input_json,
190 &Ok(final_msg),
191 );
192 }
193 })
194 }
195
196 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
197 let inner = self.inner.bind_tools(tools);
202 Box::new(DynamicTracingChatModel {
203 inner,
204 store: self.store.clone(),
205 event_bus: self.event_bus.clone(),
206 trace_id: self.trace_id.clone(),
207 parent_span_id: self.parent_span_id.clone(),
208 })
209 }
210}
211
212struct DynamicTracingChatModel {
215 inner: Box<dyn BaseChatModel>,
216 store: Arc<dyn TracingStore>,
217 event_bus: EventBus,
218 trace_id: String,
219 parent_span_id: Option<String>,
220}
221
222#[async_trait]
223impl BaseChatModel for DynamicTracingChatModel {
224 fn name(&self) -> &str {
225 self.inner.name()
226 }
227
228 fn invoke(
229 &self,
230 messages: &[Message],
231 config: &RunnableConfig,
232 ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
233 let start = Instant::now();
234 let result = self.inner.invoke(messages, config);
235 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
236 record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
237 let _ = start;
238 result
239 }
240
241 async fn ainvoke(
242 &self,
243 messages: &[Message],
244 config: &RunnableConfig,
245 ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
246 let start = Instant::now();
247 let result = self.inner.ainvoke(messages, config).await;
248 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
249 record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
250 let _ = start;
251 result
252 }
253
254 fn astream<'a>(
255 &'a self,
256 messages: &'a [Message],
257 config: &'a RunnableConfig,
258 ) -> langgraph_prebuilt::MessageStream<'a> {
259 let store = self.store.clone();
260 let event_bus = self.event_bus.clone();
261 let trace_id = self.trace_id.clone();
262 let parent_span_id = self.parent_span_id.clone();
263 let model_name = self.inner.name().to_string();
264 let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
265
266 let mut stream = self.inner.astream(messages, config);
267
268 Box::pin(async_stream::stream! {
269 let mut accumulated_message: Option<Message> = None;
270
271 while let Some(result) = tokio_stream::StreamExt::next(&mut stream).await {
272 if let Ok(ref msg) = result {
273 match accumulated_message {
274 None => {
275 accumulated_message = Some(msg.clone());
276 }
277 Some(langgraph_prebuilt::types::Message::Ai {
278 content: langgraph_prebuilt::types::MessageContent::Text(ref mut acc_text),
279 ref mut tool_calls,
280 ref mut usage,
281 ..
282 }) => {
283 if let langgraph_prebuilt::types::Message::Ai {
284 content: langgraph_prebuilt::types::MessageContent::Text(ref msg_text),
285 tool_calls: ref msg_tools,
286 usage: ref msg_usage,
287 ..
288 } = msg {
289 acc_text.push_str(msg_text);
290 for tc in msg_tools {
291 if !tool_calls.iter().any(|existing| existing.id == tc.id && tc.id.is_some()) {
292 tool_calls.push(tc.clone());
293 }
294 }
295 if msg_usage.is_some() {
296 *usage = msg_usage.clone();
297 }
298 }
299 }
300 _ => {
301 accumulated_message = Some(msg.clone());
302 }
303 }
304 }
305 yield result;
306 }
307
308 if let Some(final_msg) = accumulated_message {
310 record_llm_span(
311 store.as_ref(),
312 &event_bus,
313 &trace_id,
314 &parent_span_id,
315 &model_name,
316 input_json,
317 &Ok(final_msg),
318 );
319 }
320 })
321 }
322
323 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
324 let inner = self.inner.bind_tools(tools);
325 Box::new(DynamicTracingChatModel {
326 inner,
327 store: self.store.clone(),
328 event_bus: self.event_bus.clone(),
329 trace_id: self.trace_id.clone(),
330 parent_span_id: self.parent_span_id.clone(),
331 })
332 }
333}
334
335pub struct TracingTool<T: BaseTool> {
337 inner: T,
338 store: Arc<dyn TracingStore>,
339 event_bus: EventBus,
340 trace_id: String,
341 parent_span_id: Option<String>,
342}
343
344impl<T: BaseTool> TracingTool<T> {
345 pub fn new(
346 inner: T,
347 store: Arc<dyn TracingStore>,
348 event_bus: EventBus,
349 trace_id: String,
350 ) -> Self {
351 Self {
352 inner,
353 store,
354 event_bus,
355 trace_id,
356 parent_span_id: None,
357 }
358 }
359
360 pub fn with_parent_span(mut self, span_id: String) -> Self {
361 self.parent_span_id = Some(span_id);
362 self
363 }
364}
365
366fn record_tool_span(
367 store: &dyn TracingStore,
368 event_bus: &EventBus,
369 trace_id: &str,
370 parent_span_id: &Option<String>,
371 tool_name: &str,
372 input: &JsonValue,
373 result: &Result<JsonValue, langgraph_prebuilt::traits::ToolError>,
374) {
375 let span_id = Uuid::new_v4().to_string();
376 let mut span = Span::new(
377 span_id,
378 trace_id.to_string(),
379 parent_span_id.clone(),
380 tool_name.to_string(),
381 SpanType::ToolCall,
382 input.clone(),
383 );
384 span.metadata.tool_name = Some(tool_name.to_string());
385
386 match result {
387 Ok(output) => {
388 span.finish(output.clone(), SpanStatus::Success);
389 }
390 Err(e) => {
391 span.finish(
392 serde_json::json!({"error": e.to_string()}),
393 SpanStatus::Error,
394 );
395 }
396 }
397
398 store.add_span(span.clone());
399 event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
400}
401
402#[async_trait]
403impl<T: BaseTool + 'static> BaseTool for TracingTool<T> {
404 fn name(&self) -> &str {
405 self.inner.name()
406 }
407
408 fn description(&self) -> &str {
409 self.inner.description()
410 }
411
412 fn parameters(&self) -> Option<&JsonValue> {
413 self.inner.parameters()
414 }
415
416 fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, langgraph_prebuilt::traits::ToolError> {
417 let start = Instant::now();
418 let result = self.inner.invoke(args, config);
419 record_tool_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), args, &result);
420 let _ = start;
421 result
422 }
423
424 async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, langgraph_prebuilt::traits::ToolError> {
425 let start = Instant::now();
426 let result = self.inner.ainvoke(args, config).await;
427 record_tool_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), args, &result);
428 let _ = start;
429 result
430 }
431
432 fn to_tool_def(&self) -> ToolDef {
433 self.inner.to_tool_def()
434 }
435}