Skip to main content

langgraph_prebuilt/
traits.rs

1use std::pin::Pin;
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use langgraph_checkpoint::config::RunnableConfig;
5use langgraph::types::{GraphInterrupt, InterruptError};
6use crate::types::Message;
7
8/// A stream of message chunks from a chat model.
9///
10/// Each item is a `Message` representing either a partial token chunk
11/// (for real-time display) or the final complete message.
12pub type MessageStream<'a> = Pin<Box<dyn tokio_stream::Stream<Item = Result<Message, ModelError>> + Send + 'a>>;
13
14/// Token usage information from an LLM API response.
15#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct LlmUsage {
17    pub prompt_tokens: u32,
18    pub completion_tokens: u32,
19    pub total_tokens: u32,
20}
21
22/// Error type for tool and model operations.
23#[derive(Debug, thiserror::Error)]
24pub enum ToolError {
25    #[error("tool execution error: {0}")]
26    Execution(String),
27
28    #[error("invalid arguments: {0}")]
29    InvalidArgs(String),
30
31    #[error("tool not found: {0}")]
32    NotFound(String),
33
34    #[error("graph interrupt")]
35    Interrupt(GraphInterrupt),
36
37    #[error(transparent)]
38    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
39}
40
41impl From<String> for ToolError {
42    fn from(s: String) -> Self {
43        ToolError::Execution(s)
44    }
45}
46
47impl From<GraphInterrupt> for ToolError {
48    fn from(interrupt: GraphInterrupt) -> Self {
49        ToolError::Interrupt(interrupt)
50    }
51}
52
53impl From<InterruptError> for ToolError {
54    fn from(e: InterruptError) -> Self {
55        ToolError::Interrupt(e.into())
56    }
57}
58
59/// Error type for chat model operations.
60#[derive(Debug, thiserror::Error)]
61pub enum ModelError {
62    #[error("model invocation error: {0}")]
63    Invocation(String),
64
65    #[error("model configuration error: {0}")]
66    Config(String),
67
68    #[error(transparent)]
69    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
70}
71
72/// A tool that can be invoked by an agent.
73///
74/// Mirrors langchain-core's BaseTool.
75#[async_trait]
76pub trait BaseTool: Send + Sync {
77    /// The name of the tool.
78    fn name(&self) -> &str;
79
80    /// A description of what the tool does.
81    fn description(&self) -> &str {
82        ""
83    }
84
85    /// The JSON schema for the tool's parameters.
86    fn parameters(&self) -> Option<&JsonValue> {
87        None
88    }
89
90    /// Invoke the tool synchronously with the given arguments.
91    fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError>;
92
93    /// Invoke the tool asynchronously. Default delegates to sync invoke via block_in_place.
94    ///
95    /// Sets up thread-local config/runtime context so that `get_config()` and
96    /// `get_runtime()` work inside sync tool code (needed by `interrupt()`).
97    async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError> {
98        let args = args.clone();
99        let config = config.clone();
100        // Capture runtime from async task-locals if available
101        let current_runtime = langgraph::config::get_runtime();
102        // Always use with_runtime_sync to set up thread-locals for get_config()/get_runtime()
103        let runtime = current_runtime.unwrap_or_else(|| {
104            std::sync::Arc::new(langgraph::runtime::Runtime {
105                context: (),
106                store: None,
107                stream_writer: None,
108                previous: None,
109                execution_info: None,
110                server_info: None,
111            })
112        });
113        tokio::task::block_in_place(|| {
114            langgraph::config::with_runtime_sync(config.clone(), runtime, || {
115                self.invoke(&args, &config)
116            })
117        })
118    }
119
120    /// Get the tool's schema as a ToolCall-compatible descriptor.
121    fn to_tool_def(&self) -> ToolDef {
122        ToolDef {
123            name: self.name().to_string(),
124            description: self.description().to_string(),
125            parameters: self.parameters().cloned().unwrap_or(serde_json::json!({})),
126        }
127    }
128}
129
130/// A tool definition that can be passed to a chat model.
131#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
132pub struct ToolDef {
133    pub name: String,
134    pub description: String,
135    pub parameters: JsonValue,
136}
137
138/// A chat model that can generate responses.
139///
140/// Mirrors langchain-core's BaseChatModel.
141#[async_trait]
142pub trait BaseChatModel: Send + Sync {
143    /// The name of the model.
144    fn name(&self) -> &str;
145
146    /// Invoke the model with a list of messages and get a response.
147    fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError>;
148
149    /// Invoke the model asynchronously. Default delegates to sync invoke via block_in_place.
150    async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
151        let messages = messages.to_vec();
152        let config = config.clone();
153        tokio::task::block_in_place(|| self.invoke(&messages, &config))
154    }
155
156    /// Stream tokens from the model. Returns a stream of partial Message chunks.
157    ///
158    /// Each yielded `Message` represents the accumulated content up to that point.
159    /// For example, if the model generates "Hello world", the stream might yield:
160    /// - `Message::ai("Hello")`
161    /// - `Message::ai("Hello world")`
162    ///
163    /// The final item in the stream is the complete response (including tool calls if any).
164    ///
165    /// Default implementation falls back to `ainvoke` (yields a single complete message).
166    fn astream<'a>(
167        &'a self,
168        messages: &'a [Message],
169        config: &'a RunnableConfig,
170    ) -> MessageStream<'a> {
171        let messages = messages.to_vec();
172        let config = config.clone();
173        Box::pin(async_stream::stream! {
174            let msg = self.ainvoke(&messages, &config).await?;
175            yield Ok(msg);
176        })
177    }
178
179    /// Bind tools to the model for tool-calling support.
180    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel>;
181}
182
183#[async_trait]
184impl BaseChatModel for Box<dyn BaseChatModel> {
185    fn name(&self) -> &str {
186        (**self).name()
187    }
188
189    fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
190        (**self).invoke(messages, config)
191    }
192
193    async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
194        (**self).ainvoke(messages, config).await
195    }
196
197    fn astream<'a>(
198        &'a self,
199        messages: &'a [Message],
200        config: &'a RunnableConfig,
201    ) -> MessageStream<'a> {
202        (**self).astream(messages, config)
203    }
204
205    fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
206        (**self).bind_tools(tools)
207    }
208}
209
210/// A simple tool implemented as a closure.
211pub struct ClosureTool {
212    tool_name: String,
213    tool_description: String,
214    tool_parameters: Option<JsonValue>,
215    func: Box<dyn Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync>,
216}
217
218impl ClosureTool {
219    pub fn new(
220        name: impl Into<String>,
221        description: impl Into<String>,
222        func: impl Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync + 'static,
223    ) -> Self {
224        Self {
225            tool_name: name.into(),
226            tool_description: description.into(),
227            tool_parameters: None,
228            func: Box::new(func),
229        }
230    }
231
232    pub fn with_parameters(mut self, params: JsonValue) -> Self {
233        self.tool_parameters = Some(params);
234        self
235    }
236}
237
238#[async_trait]
239impl BaseTool for ClosureTool {
240    fn name(&self) -> &str {
241        &self.tool_name
242    }
243
244    fn description(&self) -> &str {
245        &self.tool_description
246    }
247
248    fn parameters(&self) -> Option<&JsonValue> {
249        self.tool_parameters.as_ref()
250    }
251
252    fn invoke(&self, args: &JsonValue, _config: &RunnableConfig) -> Result<JsonValue, ToolError> {
253        (self.func)(args)
254    }
255}
256
257/// Result of `prepare_tools()`: contains everything you need to work with tools.
258///
259/// # Fields
260/// - `tool_defs`: Tool definitions for binding to a model (`model.bind_tools(prepared.tool_defs)`)
261/// - `by_name`: Name-to-tool lookup map for executing tool calls
262/// - `tools`: The original tools list (for passing to `ToolNode`, etc.)
263pub struct PreparedTools {
264    pub tool_defs: Vec<ToolDef>,
265    pub by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>>,
266    pub tools: Vec<std::sync::Arc<dyn BaseTool>>,
267}
268
269/// Prepare tools for use in a graph.
270///
271/// Takes a list of tools and returns everything needed:
272/// - `tool_defs`: for `model.bind_tools()`
273/// - `by_name`: for looking up tools by name when executing calls
274/// - `tools`: original list for `ToolNode` or other uses
275///
276/// # Example
277/// ```ignore
278/// use langgraph_prebuilt::prepare_tools;
279///
280/// let prepared = prepare_tools(vec![
281///     Arc::new(Multiply::new()),
282///     Arc::new(Add::new()),
283/// ]);
284///
285/// let model = model.bind_tools(prepared.tool_defs);
286/// // Use prepared.by_name in tool_node closure
287/// ```
288pub fn prepare_tools(tools: Vec<std::sync::Arc<dyn BaseTool>>) -> PreparedTools {
289    let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
290    let by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>> = tools
291        .iter()
292        .map(|t| (t.name().to_string(), t.clone()))
293        .collect();
294    PreparedTools {
295        tool_defs,
296        by_name,
297        tools,
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_closure_tool() {
307        let tool = ClosureTool::new("echo", "Echoes the input", |args| {
308            Ok(args.clone())
309        });
310
311        assert_eq!(tool.name(), "echo");
312        assert_eq!(tool.description(), "Echoes the input");
313
314        let result = tool.invoke(&serde_json::json!("hello"), &RunnableConfig::new()).unwrap();
315        assert_eq!(result, serde_json::json!("hello"));
316    }
317}