Skip to main content

langgraph_tracing/
wrappers.rs

1use crate::event_bus::EventBus;
2use crate::store::TracingStore;
3use crate::types::*;
4use async_trait::async_trait;
5use langgraph_checkpoint::config::RunnableConfig;
6use langgraph_prebuilt::traits::{BaseChatModel, BaseTool, ToolDef};
7use langgraph_prebuilt::types::Message;
8use serde_json::Value as JsonValue;
9use std::sync::Arc;
10use std::time::Instant;
11use uuid::Uuid;
12
13/// Wrapper around any BaseChatModel that records LLM call traces.
14pub struct TracingChatModel<M: BaseChatModel> {
15    inner: M,
16    store: Arc<dyn TracingStore>,
17    event_bus: EventBus,
18    trace_id: String,
19    parent_span_id: Option<String>,
20}
21
22impl<M: BaseChatModel> TracingChatModel<M> {
23    pub fn new(
24        inner: M,
25        store: Arc<dyn TracingStore>,
26        event_bus: EventBus,
27        trace_id: String,
28    ) -> Self {
29        Self {
30            inner,
31            store,
32            event_bus,
33            trace_id,
34            parent_span_id: None,
35        }
36    }
37
38    pub fn with_parent_span(mut self, span_id: String) -> Self {
39        self.parent_span_id = Some(span_id);
40        self
41    }
42}
43
44fn record_llm_span(
45    store: &dyn TracingStore,
46    event_bus: &EventBus,
47    trace_id: &str,
48    parent_span_id: &Option<String>,
49    model_name: &str,
50    input_json: JsonValue,
51    result: &Result<Message, langgraph_prebuilt::traits::ModelError>,
52) {
53    let span_id = Uuid::new_v4().to_string();
54    match result {
55        Ok(response) => {
56            let output_json = serde_json::to_value(response).unwrap_or(JsonValue::Null);
57            let usage = response.usage();
58            let mut span = Span::new(
59                span_id,
60                trace_id.to_string(),
61                parent_span_id.clone(),
62                model_name.to_string(),
63                SpanType::LlmGeneration,
64                input_json,
65            );
66            span.finish(output_json, SpanStatus::Success);
67            span.metadata.model = Some(model_name.to_string());
68            span.metadata.tokens_in = usage.map(|u| u.prompt_tokens);
69            span.metadata.tokens_out = usage.map(|u| u.completion_tokens);
70            span.metadata.total_tokens = usage.map(|u| u.total_tokens);
71            store.add_span(span.clone());
72            event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
73        }
74        Err(e) => {
75            let mut span = Span::new(
76                span_id,
77                trace_id.to_string(),
78                parent_span_id.clone(),
79                model_name.to_string(),
80                SpanType::LlmGeneration,
81                input_json,
82            );
83            span.finish(
84                serde_json::json!({"error": e.to_string()}),
85                SpanStatus::Error,
86            );
87            span.metadata.model = Some(model_name.to_string());
88            store.add_span(span.clone());
89            event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
90        }
91    }
92}
93
94#[async_trait]
95impl<M: BaseChatModel + 'static> BaseChatModel for TracingChatModel<M> {
96    fn name(&self) -> &str {
97        self.inner.name()
98    }
99
100    fn invoke(
101        &self,
102        messages: &[Message],
103        config: &RunnableConfig,
104    ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
105        let start = Instant::now();
106        let result = self.inner.invoke(messages, config);
107        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
108        record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
109        let _ = start;
110        result
111    }
112
113    async fn ainvoke(
114        &self,
115        messages: &[Message],
116        config: &RunnableConfig,
117    ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
118        let start = Instant::now();
119        let result = self.inner.ainvoke(messages, config).await;
120        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
121        record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
122        let _ = start;
123        result
124    }
125
126    fn astream<'a>(
127        &'a self,
128        messages: &'a [Message],
129        config: &'a RunnableConfig,
130    ) -> langgraph_prebuilt::MessageStream<'a> {
131        let store = self.store.clone();
132        let event_bus = self.event_bus.clone();
133        let trace_id = self.trace_id.clone();
134        let parent_span_id = self.parent_span_id.clone();
135        let model_name = self.inner.name().to_string();
136        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
137
138        let mut stream = self.inner.astream(messages, config);
139
140        Box::pin(async_stream::stream! {
141            let mut accumulated_message: Option<Message> = None;
142            
143            while let Some(result) = tokio_stream::StreamExt::next(&mut stream).await {
144                if let Ok(ref msg) = result {
145                    match accumulated_message {
146                        None => {
147                            accumulated_message = Some(msg.clone());
148                        }
149                        Some(langgraph_prebuilt::types::Message::Ai { 
150                            content: langgraph_prebuilt::types::MessageContent::Text(ref mut acc_text),
151                            ref mut tool_calls,
152                            ref mut usage,
153                            ..
154                        }) => {
155                            if let langgraph_prebuilt::types::Message::Ai { 
156                                content: langgraph_prebuilt::types::MessageContent::Text(ref msg_text),
157                                tool_calls: ref msg_tools,
158                                usage: ref msg_usage,
159                                ..
160                            } = msg {
161                                acc_text.push_str(msg_text);
162                                for tc in msg_tools {
163                                    if !tool_calls.iter().any(|existing| existing.id == tc.id && tc.id.is_some()) {
164                                        tool_calls.push(tc.clone());
165                                    }
166                                }
167                                if msg_usage.is_some() {
168                                    *usage = msg_usage.clone();
169                                }
170                            }
171                        }
172                        _ => {
173                            // For other message types, just replace (though astream usually only yields AI messages)
174                            accumulated_message = Some(msg.clone());
175                        }
176                    }
177                }
178                yield result;
179            }
180
181            // Record span when stream ends
182            if let Some(final_msg) = accumulated_message {
183                record_llm_span(
184                    store.as_ref(),
185                    &event_bus,
186                    &trace_id,
187                    &parent_span_id,
188                    &model_name,
189                    input_json,
190                    &Ok(final_msg),
191                );
192            }
193        })
194    }
195
196    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
197        // We can't wrap Box<dyn BaseChatModel> in TracingChatModel because
198        // Box<dyn BaseChatModel> doesn't implement BaseChatModel.
199        // Instead, bind tools on the inner model and wrap the result.
200        // We need to create a dynamic wrapper.
201        let inner = self.inner.bind_tools(tools);
202        Box::new(DynamicTracingChatModel {
203            inner,
204            store: self.store.clone(),
205            event_bus: self.event_bus.clone(),
206            trace_id: self.trace_id.clone(),
207            parent_span_id: self.parent_span_id.clone(),
208        })
209    }
210}
211
212/// Dynamic wrapper that holds a Box<dyn BaseChatModel> instead of a generic type.
213/// This is needed for bind_tools which returns Box<dyn BaseChatModel>.
214struct DynamicTracingChatModel {
215    inner: Box<dyn BaseChatModel>,
216    store: Arc<dyn TracingStore>,
217    event_bus: EventBus,
218    trace_id: String,
219    parent_span_id: Option<String>,
220}
221
222#[async_trait]
223impl BaseChatModel for DynamicTracingChatModel {
224    fn name(&self) -> &str {
225        self.inner.name()
226    }
227
228    fn invoke(
229        &self,
230        messages: &[Message],
231        config: &RunnableConfig,
232    ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
233        let start = Instant::now();
234        let result = self.inner.invoke(messages, config);
235        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
236        record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
237        let _ = start;
238        result
239    }
240
241    async fn ainvoke(
242        &self,
243        messages: &[Message],
244        config: &RunnableConfig,
245    ) -> Result<Message, langgraph_prebuilt::traits::ModelError> {
246        let start = Instant::now();
247        let result = self.inner.ainvoke(messages, config).await;
248        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
249        record_llm_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), input_json, &result);
250        let _ = start;
251        result
252    }
253
254    fn astream<'a>(
255        &'a self,
256        messages: &'a [Message],
257        config: &'a RunnableConfig,
258    ) -> langgraph_prebuilt::MessageStream<'a> {
259        let store = self.store.clone();
260        let event_bus = self.event_bus.clone();
261        let trace_id = self.trace_id.clone();
262        let parent_span_id = self.parent_span_id.clone();
263        let model_name = self.inner.name().to_string();
264        let input_json = serde_json::to_value(messages).unwrap_or(JsonValue::Null);
265
266        let mut stream = self.inner.astream(messages, config);
267
268        Box::pin(async_stream::stream! {
269            let mut accumulated_message: Option<Message> = None;
270            
271            while let Some(result) = tokio_stream::StreamExt::next(&mut stream).await {
272                if let Ok(ref msg) = result {
273                    match accumulated_message {
274                        None => {
275                            accumulated_message = Some(msg.clone());
276                        }
277                        Some(langgraph_prebuilt::types::Message::Ai { 
278                            content: langgraph_prebuilt::types::MessageContent::Text(ref mut acc_text),
279                            ref mut tool_calls,
280                            ref mut usage,
281                            ..
282                        }) => {
283                            if let langgraph_prebuilt::types::Message::Ai { 
284                                content: langgraph_prebuilt::types::MessageContent::Text(ref msg_text),
285                                tool_calls: ref msg_tools,
286                                usage: ref msg_usage,
287                                ..
288                            } = msg {
289                                acc_text.push_str(msg_text);
290                                for tc in msg_tools {
291                                    if !tool_calls.iter().any(|existing| existing.id == tc.id && tc.id.is_some()) {
292                                        tool_calls.push(tc.clone());
293                                    }
294                                }
295                                if msg_usage.is_some() {
296                                    *usage = msg_usage.clone();
297                                }
298                            }
299                        }
300                        _ => {
301                            accumulated_message = Some(msg.clone());
302                        }
303                    }
304                }
305                yield result;
306            }
307
308            // Record span when stream ends
309            if let Some(final_msg) = accumulated_message {
310                record_llm_span(
311                    store.as_ref(),
312                    &event_bus,
313                    &trace_id,
314                    &parent_span_id,
315                    &model_name,
316                    input_json,
317                    &Ok(final_msg),
318                );
319            }
320        })
321    }
322
323    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
324        let inner = self.inner.bind_tools(tools);
325        Box::new(DynamicTracingChatModel {
326            inner,
327            store: self.store.clone(),
328            event_bus: self.event_bus.clone(),
329            trace_id: self.trace_id.clone(),
330            parent_span_id: self.parent_span_id.clone(),
331        })
332    }
333}
334
335/// Wrapper around any BaseTool that records tool call traces.
336pub struct TracingTool<T: BaseTool> {
337    inner: T,
338    store: Arc<dyn TracingStore>,
339    event_bus: EventBus,
340    trace_id: String,
341    parent_span_id: Option<String>,
342}
343
344impl<T: BaseTool> TracingTool<T> {
345    pub fn new(
346        inner: T,
347        store: Arc<dyn TracingStore>,
348        event_bus: EventBus,
349        trace_id: String,
350    ) -> Self {
351        Self {
352            inner,
353            store,
354            event_bus,
355            trace_id,
356            parent_span_id: None,
357        }
358    }
359
360    pub fn with_parent_span(mut self, span_id: String) -> Self {
361        self.parent_span_id = Some(span_id);
362        self
363    }
364}
365
366fn record_tool_span(
367    store: &dyn TracingStore,
368    event_bus: &EventBus,
369    trace_id: &str,
370    parent_span_id: &Option<String>,
371    tool_name: &str,
372    input: &JsonValue,
373    result: &Result<JsonValue, langgraph_prebuilt::traits::ToolError>,
374) {
375    let span_id = Uuid::new_v4().to_string();
376    let mut span = Span::new(
377        span_id,
378        trace_id.to_string(),
379        parent_span_id.clone(),
380        tool_name.to_string(),
381        SpanType::ToolCall,
382        input.clone(),
383    );
384    span.metadata.tool_name = Some(tool_name.to_string());
385
386    match result {
387        Ok(output) => {
388            span.finish(output.clone(), SpanStatus::Success);
389        }
390        Err(e) => {
391            span.finish(
392                serde_json::json!({"error": e.to_string()}),
393                SpanStatus::Error,
394            );
395        }
396    }
397
398    store.add_span(span.clone());
399    event_bus.publish(crate::event_bus::TracingEvent::SpanCreated { span });
400}
401
402#[async_trait]
403impl<T: BaseTool + 'static> BaseTool for TracingTool<T> {
404    fn name(&self) -> &str {
405        self.inner.name()
406    }
407
408    fn description(&self) -> &str {
409        self.inner.description()
410    }
411
412    fn parameters(&self) -> Option<&JsonValue> {
413        self.inner.parameters()
414    }
415
416    fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, langgraph_prebuilt::traits::ToolError> {
417        let start = Instant::now();
418        let result = self.inner.invoke(args, config);
419        record_tool_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), args, &result);
420        let _ = start;
421        result
422    }
423
424    async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, langgraph_prebuilt::traits::ToolError> {
425        let start = Instant::now();
426        let result = self.inner.ainvoke(args, config).await;
427        record_tool_span(self.store.as_ref(), &self.event_bus, &self.trace_id, &self.parent_span_id, self.inner.name(), args, &result);
428        let _ = start;
429        result
430    }
431
432    fn to_tool_def(&self) -> ToolDef {
433        self.inner.to_tool_def()
434    }
435}