use std::pin::Pin;
use async_trait::async_trait;
use serde_json::Value as JsonValue;
use langgraph_checkpoint::config::RunnableConfig;
use langgraph::types::{GraphInterrupt, InterruptError};
use crate::types::Message;
pub type MessageStream<'a> = Pin<Box<dyn tokio_stream::Stream<Item = Result<Message, ModelError>> + Send + 'a>>;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct LlmUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(default)]
pub cache_creation_tokens: Option<u32>,
#[serde(default)]
pub cache_read_tokens: Option<u32>,
}
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("tool execution error: {0}")]
Execution(String),
#[error("invalid arguments: {0}")]
InvalidArgs(String),
#[error("tool not found: {0}")]
NotFound(String),
#[error("graph interrupt")]
Interrupt(GraphInterrupt),
#[error(transparent)]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
impl From<String> for ToolError {
fn from(s: String) -> Self {
ToolError::Execution(s)
}
}
impl From<GraphInterrupt> for ToolError {
fn from(interrupt: GraphInterrupt) -> Self {
ToolError::Interrupt(interrupt)
}
}
impl From<InterruptError> for ToolError {
fn from(e: InterruptError) -> Self {
ToolError::Interrupt(e.into())
}
}
#[derive(Debug, thiserror::Error)]
pub enum ModelError {
#[error("model invocation error: {0}")]
Invocation(String),
#[error("model configuration error: {0}")]
Config(String),
#[error(transparent)]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
#[async_trait]
pub trait BaseTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str {
""
}
fn parameters(&self) -> Option<&JsonValue> {
None
}
fn invoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError>;
async fn ainvoke(&self, args: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, ToolError> {
let args = args.clone();
let config = config.clone();
let current_runtime = langgraph::config::get_runtime();
let runtime = current_runtime.unwrap_or_else(|| {
std::sync::Arc::new(langgraph::runtime::Runtime {
context: (),
store: None,
stream_writer: None,
previous: None,
execution_info: None,
server_info: None,
})
});
tokio::task::block_in_place(|| {
langgraph::config::with_runtime_sync(config.clone(), runtime, || {
self.invoke(&args, &config)
})
})
}
fn to_tool_def(&self) -> ToolDef {
ToolDef {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters().cloned().unwrap_or(serde_json::json!({})),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub parameters: JsonValue,
}
#[async_trait]
pub trait BaseChatModel: Send + Sync {
fn name(&self) -> &str;
fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError>;
async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
let messages = messages.to_vec();
let config = config.clone();
tokio::task::block_in_place(|| self.invoke(&messages, &config))
}
fn astream<'a>(
&'a self,
messages: &'a [Message],
config: &'a RunnableConfig,
) -> MessageStream<'a> {
let messages = messages.to_vec();
let config = config.clone();
Box::pin(async_stream::stream! {
let msg = self.ainvoke(&messages, &config).await?;
yield Ok(msg);
})
}
fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel>;
}
#[async_trait]
impl BaseChatModel for Box<dyn BaseChatModel> {
fn name(&self) -> &str {
(**self).name()
}
fn invoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
(**self).invoke(messages, config)
}
async fn ainvoke(&self, messages: &[Message], config: &RunnableConfig) -> Result<Message, ModelError> {
(**self).ainvoke(messages, config).await
}
fn astream<'a>(
&'a self,
messages: &'a [Message],
config: &'a RunnableConfig,
) -> MessageStream<'a> {
(**self).astream(messages, config)
}
fn bind_tools(&self, tools: Vec<ToolDef>) -> Box<dyn BaseChatModel> {
(**self).bind_tools(tools)
}
}
pub struct ClosureTool {
tool_name: String,
tool_description: String,
tool_parameters: Option<JsonValue>,
func: Box<dyn Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync>,
}
impl ClosureTool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
func: impl Fn(&JsonValue) -> Result<JsonValue, ToolError> + Send + Sync + 'static,
) -> Self {
Self {
tool_name: name.into(),
tool_description: description.into(),
tool_parameters: None,
func: Box::new(func),
}
}
pub fn with_parameters(mut self, params: JsonValue) -> Self {
self.tool_parameters = Some(params);
self
}
}
#[async_trait]
impl BaseTool for ClosureTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn parameters(&self) -> Option<&JsonValue> {
self.tool_parameters.as_ref()
}
fn invoke(&self, args: &JsonValue, _config: &RunnableConfig) -> Result<JsonValue, ToolError> {
(self.func)(args)
}
}
pub struct PreparedTools {
pub tool_defs: Vec<ToolDef>,
pub by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>>,
pub tools: Vec<std::sync::Arc<dyn BaseTool>>,
}
pub fn prepare_tools(tools: Vec<std::sync::Arc<dyn BaseTool>>) -> PreparedTools {
let tool_defs: Vec<ToolDef> = tools.iter().map(|t| t.to_tool_def()).collect();
let by_name: std::collections::HashMap<String, std::sync::Arc<dyn BaseTool>> = tools
.iter()
.map(|t| (t.name().to_string(), t.clone()))
.collect();
PreparedTools {
tool_defs,
by_name,
tools,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_closure_tool() {
let tool = ClosureTool::new("echo", "Echoes the input", |args| {
Ok(args.clone())
});
assert_eq!(tool.name(), "echo");
assert_eq!(tool.description(), "Echoes the input");
let result = tool.invoke(&serde_json::json!("hello"), &RunnableConfig::new()).unwrap();
assert_eq!(result, serde_json::json!("hello"));
}
}