1use std::pin::Pin;
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use langgraph_checkpoint::config::RunnableConfig;
5use langgraph::types::{GraphInterrupt, InterruptError};
6use crate::types::Message;
7
8pub type MessageStream<'a> = Pin<Box<dyn tokio_stream::Stream<Item = Result<Message, ModelError>> + Send + 'a>>;
13
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct LlmUsage {
17 pub prompt_tokens: u32,
18 pub completion_tokens: u32,
19 pub total_tokens: u32,
20}
21
22#[derive(Debug, thiserror::Error)]
24pub enum ToolError {
25 #[error("tool execution error: {0}")]
26 Execution(String),
27
28 #[error("invalid arguments: {0}")]
29 InvalidArgs(String),
30
31 #[error("tool not found: {0}")]
32 NotFound(String),
33
34 #[error("graph interrupt")]
35 Interrupt(GraphInterrupt),
36
37 #[error(transparent)]
38 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
39}
40
41impl From<String> for ToolError {
42 fn from(s: String) -> Self {
43 ToolError::Execution(s)
44 }
45}
46
47impl From<GraphInterrupt> for ToolError {
48 fn from(interrupt: GraphInterrupt) -> Self {
49 ToolError::Interrupt(interrupt)
50 }
51}
52
53impl From<InterruptError> for ToolError {
54 fn from(e: InterruptError) -> Self {
55 ToolError::Interrupt(e.into())
56 }
57}
58
59#[derive(Debug, thiserror::Error)]
61pub enum ModelError {
62 #[error("model invocation error: {0}")]
63 Invocation(String),
64
65 #[error("model configuration error: {0}")]
66 Config(String),
67
68 #[error(transparent)]
69 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
70}
71
72#[async_trait]
76pub trait BaseTool: Send + Sync {
77 fn name(&self) -> &str;
79
80 fn description(&self) -> &str {
82 ""
83 }
84
85 fn parameters(&self) -> Option<&JsonValue> {
87 None
88 }
89
90 fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError>;
92
93 async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError> {
98 let args = args.clone();
99 let config = config.clone();
100 let current_runtime = langgraph::config::get_runtime();
102 let runtime = current_runtime.unwrap_or_else(|| {
104 std::sync::Arc::new(langgraph::runtime::Runtime {
105 context: (),
106 store: None,
107 stream_writer: None,
108 previous: None,
109 execution_info: None,
110 server_info: None,
111 })
112 });
113 tokio::task::block_in_place(|| {
114 langgraph::config::with_runtime_sync(config.clone(), runtime, || {
115 self.invoke(&args, &config)
116 })
117 })
118 }
119
120 fn to_tool_def(&self) -> ToolDef {
122 ToolDef {
123 name: self.name().to_string(),
124 description: self.description().to_string(),
125 parameters: self.parameters().cloned().unwrap_or(serde_json::json!({})),
126 }
127 }
128}
129
130#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
132pub struct ToolDef {
133 pub name: String,
134 pub description: String,
135 pub parameters: JsonValue,
136}
137
138#[async_trait]
142pub trait BaseChatModel: Send + Sync {
143 fn name(&self) -> &str;
145
146 fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError>;
148
149 async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
151 let messages = messages.to_vec();
152 let config = config.clone();
153 tokio::task::block_in_place(|| self.invoke(&messages, &config))
154 }
155
156 fn astream<'a>(
167 &'a self,
168 messages: &'a [Message],
169 config: &'a RunnableConfig,
170 ) -> MessageStream<'a> {
171 let messages = messages.to_vec();
172 let config = config.clone();
173 Box::pin(async_stream::stream! {
174 let msg = self.ainvoke(&messages, &config).await?;
175 yield Ok(msg);
176 })
177 }
178
179 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel>;
181}
182
183#[async_trait]
184impl BaseChatModel for Box<dyn BaseChatModel> {
185 fn name(&self) -> &str {
186 (**self).name()
187 }
188
189 fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
190 (**self).invoke(messages, config)
191 }
192
193 async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
194 (**self).ainvoke(messages, config).await
195 }
196
197 fn astream<'a>(
198 &'a self,
199 messages: &'a [Message],
200 config: &'a RunnableConfig,
201 ) -> MessageStream<'a> {
202 (**self).astream(messages, config)
203 }
204
205 fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
206 (**self).bind_tools(tools)
207 }
208}
209
210pub struct ClosureTool {
212 tool_name: String,
213 tool_description: String,
214 tool_parameters: Option<JsonValue>,
215 func: Box<dyn Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync>,
216}
217
218impl ClosureTool {
219 pub fn new(
220 name: impl Into<String>,
221 description: impl Into<String>,
222 func: impl Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync + 'static,
223 ) -> Self {
224 Self {
225 tool_name: name.into(),
226 tool_description: description.into(),
227 tool_parameters: None,
228 func: Box::new(func),
229 }
230 }
231
232 pub fn with_parameters(mut self, params: JsonValue) -> Self {
233 self.tool_parameters = Some(params);
234 self
235 }
236}
237
238#[async_trait]
239impl BaseTool for ClosureTool {
240 fn name(&self) -> &str {
241 &self.tool_name
242 }
243
244 fn description(&self) -> &str {
245 &self.tool_description
246 }
247
248 fn parameters(&self) -> Option<&JsonValue> {
249 self.tool_parameters.as_ref()
250 }
251
252 fn invoke(&self, args: &JsonValue, _config: &RunnableConfig) -> Result<JsonValue, ToolError> {
253 (self.func)(args)
254 }
255}
256
257pub struct PreparedTools {
264 pub tool_defs: Vec<ToolDef>,
265 pub by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>>,
266 pub tools: Vec<std::sync::Arc<dyn BaseTool>>,
267}
268
269pub fn prepare_tools(tools: Vec<std::sync::Arc<dyn BaseTool>>) -> PreparedTools {
289 let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
290 let by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>> = tools
291 .iter()
292 .map(|t| (t.name().to_string(), t.clone()))
293 .collect();
294 PreparedTools {
295 tool_defs,
296 by_name,
297 tools,
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_closure_tool() {
307 let tool = ClosureTool::new("echo", "Echoes the input", |args| {
308 Ok(args.clone())
309 });
310
311 assert_eq!(tool.name(), "echo");
312 assert_eq!(tool.description(), "Echoes the input");
313
314 let result = tool.invoke(&serde_json::json!("hello"), &RunnableConfig::new()).unwrap();
315 assert_eq!(result, serde_json::json!("hello"));
316 }
317}