1use std::pin::Pin;
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use langgraph_checkpoint::config::RunnableConfig;
5use langgraph::types::{GraphInterrupt, InterruptError};
6use crate::types::Message;
7
8pub type MessageStream<'a> = Pin<Box<dyn tokio_stream::Stream<Item = Result<Message, ModelError>> + Send + 'a>>;
13
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct LlmUsage {
17 pub prompt_tokens: u32,
18 pub completion_tokens: u32,
19 pub total_tokens: u32,
20 #[serde(default)]
21 pub cache_creation_tokens: Option<u32>,
22 #[serde(default)]
23 pub cache_read_tokens: Option<u32>,
24}
25
26#[derive(Debug, thiserror::Error)]
28pub enum ToolError {
29 #[error("tool execution error: {0}")]
30 Execution(String),
31
32 #[error("invalid arguments: {0}")]
33 InvalidArgs(String),
34
35 #[error("tool not found: {0}")]
36 NotFound(String),
37
38 #[error("graph interrupt")]
39 Interrupt(GraphInterrupt),
40
41 #[error(transparent)]
42 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
43}
44
45impl From<String> for ToolError {
46 fn from(s: String) -> Self {
47 ToolError::Execution(s)
48 }
49}
50
51impl From<GraphInterrupt> for ToolError {
52 fn from(interrupt: GraphInterrupt) -> Self {
53 ToolError::Interrupt(interrupt)
54 }
55}
56
57impl From<InterruptError> for ToolError {
58 fn from(e: InterruptError) -> Self {
59 ToolError::Interrupt(e.into())
60 }
61}
62
63#[derive(Debug, thiserror::Error)]
65pub enum ModelError {
66 #[error("model invocation error: {0}")]
67 Invocation(String),
68
69 #[error("model configuration error: {0}")]
70 Config(String),
71
72 #[error(transparent)]
73 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
74}
75
76#[async_trait]
80pub trait BaseTool: Send + Sync {
81 fn name(&self) -> &str;
83
84 fn description(&self) -> &str {
86 ""
87 }
88
89 fn parameters(&self) -> Option<&JsonValue> {
91 None
92 }
93
94 fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError>;
96
97 async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError> {
102 let args = args.clone();
103 let config = config.clone();
104 let current_runtime = langgraph::config::get_runtime();
106 let runtime = current_runtime.unwrap_or_else(|| {
108 std::sync::Arc::new(langgraph::runtime::Runtime {
109 context: (),
110 store: None,
111 stream_writer: None,
112 previous: None,
113 execution_info: None,
114 server_info: None,
115 })
116 });
117 tokio::task::block_in_place(|| {
118 langgraph::config::with_runtime_sync(config.clone(), runtime, || {
119 self.invoke(&args, &config)
120 })
121 })
122 }
123
124 fn to_tool_def(&self) -> ToolDef {
126 ToolDef {
127 name: self.name().to_string(),
128 description: self.description().to_string(),
129 parameters: self.parameters().cloned().unwrap_or(serde_json::json!({})),
130 }
131 }
132}
133
134#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
136pub struct ToolDef {
137 pub name: String,
138 pub description: String,
139 pub parameters: JsonValue,
140}
141
142#[async_trait]
146pub trait BaseChatModel: Send + Sync {
147 fn name(&self) -> &str;
149
150 fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError>;
152
153 async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
155 let messages = messages.to_vec();
156 let config = config.clone();
157 tokio::task::block_in_place(|| self.invoke(&messages, &config))
158 }
159
160 fn astream<'a>(
171 &'a self,
172 messages: &'a [Message],
173 config: &'a RunnableConfig,
174 ) -> MessageStream<'a> {
175 let messages = messages.to_vec();
176 let config = config.clone();
177 Box::pin(async_stream::stream! {
178 let msg = self.ainvoke(&messages, &config).await?;
179 yield Ok(msg);
180 })
181 }
182
183 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel>;
185}
186
187#[async_trait]
188impl BaseChatModel for Box<dyn BaseChatModel> {
189 fn name(&self) -> &str {
190 (**self).name()
191 }
192
193 fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
194 (**self).invoke(messages, config)
195 }
196
197 async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
198 (**self).ainvoke(messages, config).await
199 }
200
201 fn astream<'a>(
202 &'a self,
203 messages: &'a [Message],
204 config: &'a RunnableConfig,
205 ) -> MessageStream<'a> {
206 (**self).astream(messages, config)
207 }
208
209 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
210 (**self).bind_tools(tools)
211 }
212}
213
214pub struct ClosureTool {
216 tool_name: String,
217 tool_description: String,
218 tool_parameters: Option<JsonValue>,
219 func: Box<dyn Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync>,
220}
221
222impl ClosureTool {
223 pub fn new(
224 name: impl Into<String>,
225 description: impl Into<String>,
226 func: impl Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync + 'static,
227 ) -> Self {
228 Self {
229 tool_name: name.into(),
230 tool_description: description.into(),
231 tool_parameters: None,
232 func: Box::new(func),
233 }
234 }
235
236 pub fn with_parameters(mut self, params: JsonValue) -> Self {
237 self.tool_parameters = Some(params);
238 self
239 }
240}
241
242#[async_trait]
243impl BaseTool for ClosureTool {
244 fn name(&self) -> &str {
245 &self.tool_name
246 }
247
248 fn description(&self) -> &str {
249 &self.tool_description
250 }
251
252 fn parameters(&self) -> Option<&JsonValue> {
253 self.tool_parameters.as_ref()
254 }
255
256 fn invoke(&self, args: &JsonValue, _config: &RunnableConfig) -> Result<JsonValue, ToolError> {
257 (self.func)(args)
258 }
259}
260
261pub struct PreparedTools {
268 pub tool_defs: Vec<ToolDef>,
269 pub by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>>,
270 pub tools: Vec<std::sync::Arc<dyn BaseTool>>,
271}
272
273pub fn prepare_tools(tools: Vec<std::sync::Arc<dyn BaseTool>>) -> PreparedTools {
293 let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
294 let by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>> = tools
295 .iter()
296 .map(|t| (t.name().to_string(), t.clone()))
297 .collect();
298 PreparedTools {
299 tool_defs,
300 by_name,
301 tools,
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_closure_tool() {
311 let tool = ClosureTool::new("echo", "Echoes the input", |args| {
312 Ok(args.clone())
313 });
314
315 assert_eq!(tool.name(), "echo");
316 assert_eq!(tool.description(), "Echoes the input");
317
318 let result = tool.invoke(&serde_json::json!("hello"), &RunnableConfig::new()).unwrap();
319 assert_eq!(result, serde_json::json!("hello"));
320 }
321}