use std::{
any::{Any, TypeId},
collections::HashMap,
sync::{Arc, RwLock},
};
use serde::{Deserialize, Serialize};
pub struct ToolContext {
conversation_id: Option<String>,
state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
extensions: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
is_idle: bool,
}
impl ToolContext {
#[must_use]
pub fn new(conversation_id: Option<String>) -> Self {
Self {
conversation_id,
state: Arc::new(RwLock::new(HashMap::new())),
extensions: Arc::new(RwLock::new(HashMap::new())),
is_idle: false,
}
}
#[must_use]
pub fn with_shared_state(
conversation_id: Option<String>,
state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
) -> Self {
Self {
conversation_id,
state,
extensions: Arc::new(RwLock::new(HashMap::new())),
is_idle: false,
}
}
#[must_use]
pub fn conversation_id(&self) -> Option<&str> {
self.conversation_id.as_deref()
}
#[must_use]
pub fn get_state(&self, key: &str, default: serde_json::Value) -> serde_json::Value {
match self.state.read() {
Ok(guard) => guard.get(key).cloned().unwrap_or(default),
Err(e) => {
tracing::warn!(key, error = %e, "ToolContext::get_state: lock poisoned, returning default");
default
}
}
}
pub fn set_state(&self, key: &str, value: serde_json::Value) -> Result<(), ToolError> {
match self.state.write() {
Ok(mut guard) => {
guard.insert(key.to_owned(), value);
Ok(())
}
Err(e) => {
let msg = format!("ToolContext::set_state: lock poisoned for key '{key}': {e}");
tracing::warn!("{msg}");
Err(ToolError::new(msg))
}
}
}
#[must_use]
pub const fn is_idle(&self) -> bool {
self.is_idle
}
pub fn set_ext<T: Send + Sync + 'static>(&self, value: T) {
let mut exts = self
.extensions
.write()
.expect("ToolContext extensions lock poisoned");
exts.insert(TypeId::of::<T>(), Box::new(value));
}
#[must_use]
pub fn get_ext<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
let exts = self
.extensions
.read()
.expect("ToolContext extensions lock poisoned");
exts.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
.cloned()
}
}
pub use llm_tool_macros::llm_tool;
pub use schemars::JsonSchema;
fn other_type_name(value: &serde_json::Value) -> &'static str {
match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "bool",
serde_json::Value::Number(_) => "number",
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolOutput {
content: String,
metadata: std::collections::HashMap<String, serde_json::Value>,
}
impl ToolOutput {
pub fn new(content: impl Into<String>) -> Self {
Self {
content: content.into(),
metadata: std::collections::HashMap::new(),
}
}
pub fn json<T: serde::Serialize>(value: &T) -> Result<Self, ToolError> {
serde_json::to_string(value)
.map(Self::new)
.map_err(|e| ToolError::new(format!("serialization failed: {e}")))
}
pub fn from_metadata<T: serde::Serialize>(value: &T) -> Result<Self, ToolError> {
let json_value = serde_json::to_value(value)
.map_err(|e| ToolError::new(format!("metadata serialization failed: {e}")))?;
let content = json_value.to_string();
match json_value {
serde_json::Value::Object(map) => Ok(Self {
content,
metadata: map.into_iter().collect(),
}),
other => Err(ToolError::new(format!(
"metadata must serialize to a JSON object, got {}",
other_type_name(&other),
))),
}
}
#[must_use]
pub fn with_meta(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_metadata<T: serde::Serialize>(mut self, value: &T) -> Result<Self, ToolError> {
let json = serde_json::to_value(value)
.map_err(|e| ToolError::new(format!("metadata serialization failed: {e}")))?;
match json {
serde_json::Value::Object(map) => {
self.metadata.extend(map);
Ok(self)
}
other => Err(ToolError::new(format!(
"metadata must serialize to a JSON object, got {}",
other_type_name(&other),
))),
}
}
#[must_use]
pub fn content(&self) -> &str {
&self.content
}
#[must_use]
pub fn into_content(self) -> String {
self.content
}
#[must_use]
pub fn metadata(&self) -> &std::collections::HashMap<String, serde_json::Value> {
&self.metadata
}
}
impl std::fmt::Display for ToolOutput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.content)
}
}
impl From<String> for ToolOutput {
fn from(content: String) -> Self {
Self::new(content)
}
}
impl From<&str> for ToolOutput {
fn from(content: &str) -> Self {
Self::new(content)
}
}
impl From<i64> for ToolOutput {
fn from(value: i64) -> Self {
Self::new(value.to_string())
}
}
impl From<f64> for ToolOutput {
fn from(value: f64) -> Self {
Self::new(value.to_string())
}
}
impl From<bool> for ToolOutput {
fn from(value: bool) -> Self {
Self::new(value.to_string())
}
}
impl From<serde_json::Value> for ToolOutput {
fn from(value: serde_json::Value) -> Self {
Self::new(value.to_string())
}
}
pub struct Json<T>(pub T);
impl<T: serde::Serialize> From<Json<T>> for ToolOutput {
fn from(json: Json<T>) -> Self {
Self::new(
serde_json::to_string(&json.0)
.expect("Json<T> serialization failed — this is a bug in the Serialize impl"),
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolError {
pub message: String,
metadata: std::collections::HashMap<String, serde_json::Value>,
}
impl ToolError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
metadata: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn with_meta(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_metadata<T: serde::Serialize>(mut self, value: &T) -> Result<Self, Self> {
let json = serde_json::to_value(value).map_err(|e| {
Self::new(format!(
"{} (metadata serialization also failed: {e})",
self.message
))
})?;
match json {
serde_json::Value::Object(map) => {
self.metadata.extend(map);
Ok(self)
}
other => Err(Self::new(format!(
"{} (metadata must serialize to a JSON object, got {})",
self.message,
other_type_name(&other),
))),
}
}
#[must_use]
pub fn metadata(&self) -> &std::collections::HashMap<String, serde_json::Value> {
&self.metadata
}
}
impl std::fmt::Display for ToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ToolError {}
impl From<String> for ToolError {
fn from(message: String) -> Self {
Self::new(message)
}
}
impl From<&str> for ToolError {
fn from(message: &str) -> Self {
Self::new(message)
}
}
impl From<std::io::Error> for ToolError {
fn from(e: std::io::Error) -> Self {
Self::new(e.to_string())
.with_meta("error_kind", serde_json::json!(format!("{:?}", e.kind())))
}
}
impl From<serde_json::Error> for ToolError {
fn from(e: serde_json::Error) -> Self {
Self::new(e.to_string())
.with_meta("category", serde_json::json!(format!("{:?}", e.classify())))
}
}
impl From<Box<dyn std::error::Error + Send + Sync>> for ToolError {
fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self {
Self::new(e.to_string())
}
}
impl From<std::convert::Infallible> for ToolError {
fn from(never: std::convert::Infallible) -> Self {
match never {}
}
}
#[doc(hidden)]
#[deprecated(since = "0.2.0", note = "Use ToolOutput::json() instead")]
pub fn __serialize_tool_result<T: serde::Serialize>(value: &T) -> Result<ToolOutput, ToolError> {
ToolOutput::json(value)
}
#[doc(hidden)]
pub mod __private {
use super::{Json, ToolError, ToolOutput};
pub struct Wrap<T>(pub T);
impl Wrap<ToolOutput> {
pub fn __convert(self) -> Result<ToolOutput, ToolError> {
Ok(self.0)
}
}
impl Wrap<String> {
pub fn __convert(self) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::new(self.0))
}
}
impl<T: serde::Serialize> Wrap<Json<T>> {
pub fn __convert(self) -> Result<ToolOutput, ToolError> {
Ok((self.0).into())
}
}
pub trait SerializeFallback {
fn __convert(self) -> Result<ToolOutput, ToolError>;
}
impl<T: serde::Serialize> SerializeFallback for Wrap<T> {
fn __convert(self) -> Result<ToolOutput, ToolError> {
ToolOutput::json(&self.0)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameter_schema: serde_json::Value,
}