Skip to main content

langgraph_prebuilt/
traits.rs

1use std::pin::Pin;
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use langgraph_checkpoint::config::RunnableConfig;
5use langgraph::types::{GraphInterrupt, InterruptError};
6use crate::types::Message;
7
8/// A stream of message chunks from a chat model.
9///
10/// Each item is a `Message` representing either a partial token chunk
11/// (for real-time display) or the final complete message.
12pub type MessageStream<'a> = Pin<Box<dyn tokio_stream::Stream<Item = Result<Message, ModelError>> + Send + 'a>>;
13
14/// Token usage information from an LLM API response.
15#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct LlmUsage {
17    pub prompt_tokens: u32,
18    pub completion_tokens: u32,
19    pub total_tokens: u32,
20    #[serde(default)]
21    pub cache_creation_tokens: Option<u32>,
22    #[serde(default)]
23    pub cache_read_tokens: Option<u32>,
24}
25
26/// Error type for tool and model operations.
27#[derive(Debug, thiserror::Error)]
28pub enum ToolError {
29    #[error("tool execution error: {0}")]
30    Execution(String),
31
32    #[error("invalid arguments: {0}")]
33    InvalidArgs(String),
34
35    #[error("tool not found: {0}")]
36    NotFound(String),
37
38    #[error("graph interrupt")]
39    Interrupt(GraphInterrupt),
40
41    #[error(transparent)]
42    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
43}
44
45impl From<String> for ToolError {
46    fn from(s: String) -> Self {
47        ToolError::Execution(s)
48    }
49}
50
51impl From<GraphInterrupt> for ToolError {
52    fn from(interrupt: GraphInterrupt) -> Self {
53        ToolError::Interrupt(interrupt)
54    }
55}
56
57impl From<InterruptError> for ToolError {
58    fn from(e: InterruptError) -> Self {
59        ToolError::Interrupt(e.into())
60    }
61}
62
63/// Error type for chat model operations.
64#[derive(Debug, thiserror::Error)]
65pub enum ModelError {
66    #[error("model invocation error: {0}")]
67    Invocation(String),
68
69    #[error("model configuration error: {0}")]
70    Config(String),
71
72    #[error(transparent)]
73    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
74}
75
76/// A tool that can be invoked by an agent.
77///
78/// Mirrors langchain-core's BaseTool.
79#[async_trait]
80pub trait BaseTool: Send + Sync {
81    /// The name of the tool.
82    fn name(&self) -> &str;
83
84    /// A description of what the tool does.
85    fn description(&self) -> &str {
86        ""
87    }
88
89    /// The JSON schema for the tool's parameters.
90    fn parameters(&self) -> Option<&JsonValue> {
91        None
92    }
93
94    /// Invoke the tool synchronously with the given arguments.
95    fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError>;
96
97    /// Invoke the tool asynchronously. Default delegates to sync invoke via block_in_place.
98    ///
99    /// Sets up thread-local config/runtime context so that `get_config()` and
100    /// `get_runtime()` work inside sync tool code (needed by `interrupt()`).
101    async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError> {
102        let args = args.clone();
103        let config = config.clone();
104        // Capture runtime from async task-locals if available
105        let current_runtime = langgraph::config::get_runtime();
106        // Always use with_runtime_sync to set up thread-locals for get_config()/get_runtime()
107        let runtime = current_runtime.unwrap_or_else(|| {
108            std::sync::Arc::new(langgraph::runtime::Runtime {
109                context: (),
110                store: None,
111                stream_writer: None,
112                previous: None,
113                execution_info: None,
114                server_info: None,
115            })
116        });
117        tokio::task::block_in_place(|| {
118            langgraph::config::with_runtime_sync(config.clone(), runtime, || {
119                self.invoke(&args, &config)
120            })
121        })
122    }
123
124    /// Get the tool's schema as a ToolCall-compatible descriptor.
125    fn to_tool_def(&self) -> ToolDef {
126        ToolDef {
127            name: self.name().to_string(),
128            description: self.description().to_string(),
129            parameters: self.parameters().cloned().unwrap_or(serde_json::json!({})),
130        }
131    }
132}
133
134/// A tool definition that can be passed to a chat model.
135#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
136pub struct ToolDef {
137    pub name: String,
138    pub description: String,
139    pub parameters: JsonValue,
140}
141
142/// A chat model that can generate responses.
143///
144/// Mirrors langchain-core's BaseChatModel.
145#[async_trait]
146pub trait BaseChatModel: Send + Sync {
147    /// The name of the model.
148    fn name(&self) -> &str;
149
150    /// Invoke the model with a list of messages and get a response.
151    fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError>;
152
153    /// Invoke the model asynchronously. Default delegates to sync invoke via block_in_place.
154    async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
155        let messages = messages.to_vec();
156        let config = config.clone();
157        tokio::task::block_in_place(|| self.invoke(&messages, &config))
158    }
159
160    /// Stream tokens from the model. Returns a stream of partial Message chunks.
161    ///
162    /// Each yielded `Message` represents the accumulated content up to that point.
163    /// For example, if the model generates "Hello world", the stream might yield:
164    /// - `Message::ai("Hello")`
165    /// - `Message::ai("Hello world")`
166    ///
167    /// The final item in the stream is the complete response (including tool calls if any).
168    ///
169    /// Default implementation falls back to `ainvoke` (yields a single complete message).
170    fn astream<'a>(
171        &'a self,
172        messages: &'a [Message],
173        config: &'a RunnableConfig,
174    ) -> MessageStream<'a> {
175        let messages = messages.to_vec();
176        let config = config.clone();
177        Box::pin(async_stream::stream! {
178            let msg = self.ainvoke(&messages, &config).await?;
179            yield Ok(msg);
180        })
181    }
182
183    /// Bind tools to the model for tool-calling support.
184    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel>;
185}
186
187#[async_trait]
188impl BaseChatModel for Box<dyn BaseChatModel> {
189    fn name(&self) -> &str {
190        (**self).name()
191    }
192
193    fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
194        (**self).invoke(messages, config)
195    }
196
197    async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
198        (**self).ainvoke(messages, config).await
199    }
200
201    fn astream<'a>(
202        &'a self,
203        messages: &'a [Message],
204        config: &'a RunnableConfig,
205    ) -> MessageStream<'a> {
206        (**self).astream(messages, config)
207    }
208
209    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
210        (**self).bind_tools(tools)
211    }
212}
213
214/// A simple tool implemented as a closure.
215pub struct ClosureTool {
216    tool_name: String,
217    tool_description: String,
218    tool_parameters: Option<JsonValue>,
219    func: Box<dyn Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync>,
220}
221
222impl ClosureTool {
223    pub fn new(
224        name: impl Into<String>,
225        description: impl Into<String>,
226        func: impl Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync + 'static,
227    ) -> Self {
228        Self {
229            tool_name: name.into(),
230            tool_description: description.into(),
231            tool_parameters: None,
232            func: Box::new(func),
233        }
234    }
235
236    pub fn with_parameters(mut self, params: JsonValue) -> Self {
237        self.tool_parameters = Some(params);
238        self
239    }
240}
241
242#[async_trait]
243impl BaseTool for ClosureTool {
244    fn name(&self) -> &str {
245        &self.tool_name
246    }
247
248    fn description(&self) -> &str {
249        &self.tool_description
250    }
251
252    fn parameters(&self) -> Option<&JsonValue> {
253        self.tool_parameters.as_ref()
254    }
255
256    fn invoke(&self, args: &JsonValue, _config: &RunnableConfig) -> Result<JsonValue, ToolError> {
257        (self.func)(args)
258    }
259}
260
261/// Result of `prepare_tools()`: contains everything you need to work with tools.
262///
263/// # Fields
264/// - `tool_defs`: Tool definitions for binding to a model (`model.bind_tools(prepared.tool_defs)`)
265/// - `by_name`: Name-to-tool lookup map for executing tool calls
266/// - `tools`: The original tools list (for passing to `ToolNode`, etc.)
267pub struct PreparedTools {
268    pub tool_defs: Vec<ToolDef>,
269    pub by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>>,
270    pub tools: Vec<std::sync::Arc<dyn BaseTool>>,
271}
272
273/// Prepare tools for use in a graph.
274///
275/// Takes a list of tools and returns everything needed:
276/// - `tool_defs`: for `model.bind_tools()`
277/// - `by_name`: for looking up tools by name when executing calls
278/// - `tools`: original list for `ToolNode` or other uses
279///
280/// # Example
281/// ```ignore
282/// use langgraph_prebuilt::prepare_tools;
283///
284/// let prepared = prepare_tools(vec![
285///     Arc::new(Multiply::new()),
286///     Arc::new(Add::new()),
287/// ]);
288///
289/// let model = model.bind_tools(prepared.tool_defs);
290/// // Use prepared.by_name in tool_node closure
291/// ```
292pub fn prepare_tools(tools: Vec<std::sync::Arc<dyn BaseTool>>) -> PreparedTools {
293    let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
294    let by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>> = tools
295        .iter()
296        .map(|t| (t.name().to_string(), t.clone()))
297        .collect();
298    PreparedTools {
299        tool_defs,
300        by_name,
301        tools,
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_closure_tool() {
311        let tool = ClosureTool::new("echo", "Echoes the input", |args| {
312            Ok(args.clone())
313        });
314
315        assert_eq!(tool.name(), "echo");
316        assert_eq!(tool.description(), "Echoes the input");
317
318        let result = tool.invoke(&serde_json::json!("hello"), &RunnableConfig::new()).unwrap();
319        assert_eq!(result, serde_json::json!("hello"));
320    }
321}