use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum ToolError {
#[error("Execution error: {0}")]
Execution(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Tool not found: {0}")]
NotFound(String),
#[error("Tool '{0}' is forbidden by policy")]
Forbidden(String),
#[error("Tool '{0}' execution denied by confirmation")]
ConfirmationDenied(String),
#[error("Tool error: {0}")]
Other(String),
}
impl From<String> for ToolError {
fn from(s: String) -> Self {
Self::Other(s)
}
}
impl From<&str> for ToolError {
fn from(s: &str) -> Self {
Self::Other(s.to_string())
}
}
impl From<serde_json::Error> for ToolError {
fn from(err: serde_json::Error) -> Self {
Self::InvalidArguments(err.to_string())
}
}
pub type ToolResult<T> = Result<T, ToolError>;
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
impl ToolDefinition {
#[must_use]
pub fn new(name: impl Into<String>, description: impl Into<String>, parameters: Value) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
strict: None,
}
}
#[must_use]
pub fn new_strict(
name: impl Into<String>,
description: impl Into<String>,
parameters: Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
strict: Some(true),
}
}
#[must_use]
pub fn with_strict(mut self, strict: bool) -> Self {
self.strict = Some(strict);
if strict {
if let Some(obj) = self.parameters.as_object_mut()
&& !obj.contains_key("additionalProperties")
{
obj.insert("additionalProperties".to_owned(), Value::Bool(false));
}
}
self
}
#[must_use]
pub const fn is_strict(&self) -> bool {
matches!(self.strict, Some(true))
}
#[inline]
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[inline]
#[must_use]
pub fn description(&self) -> &str {
&self.description
}
}
impl Serialize for ToolDefinition {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let mut function = serde_json::Map::new();
function.insert("name".to_owned(), Value::String(self.name.clone()));
function.insert(
"description".to_owned(),
Value::String(self.description.clone()),
);
function.insert("parameters".to_owned(), self.parameters.clone());
if let Some(strict) = self.strict {
function.insert("strict".to_owned(), Value::Bool(strict));
}
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("type", "function")?;
map.serialize_entry("function", &function)?;
map.end()
}
}
#[async_trait]
pub trait Tool: Send + Sync {
const NAME: &'static str;
type Args: for<'de> Deserialize<'de> + Send;
type Output: Serialize + Send;
type Error: Into<ToolError> + Send;
fn name(&self) -> &'static str {
Self::NAME
}
fn description(&self) -> String;
fn parameters_schema(&self) -> Value;
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error>;
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: self.name().to_owned(),
description: self.description(),
parameters: self.parameters_schema(),
strict: None,
}
}
async fn call_json(&self, args: Value) -> Result<Value, ToolError>
where
Self::Output: 'static,
{
let typed_args: Self::Args = match &args {
Value::String(s) => {
serde_json::from_str(s).map_err(|e| ToolError::InvalidArguments(e.to_string()))?
}
_ => serde_json::from_value(args)
.map_err(|e| ToolError::InvalidArguments(e.to_string()))?,
};
let result = self.call(typed_args).await.map_err(Into::into)?;
serde_json::to_value(result).map_err(|e| ToolError::Execution(e.to_string()))
}
}
pub type BoxedTool = Box<dyn DynTool>;
#[async_trait]
pub trait DynTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> String;
fn definition(&self) -> ToolDefinition;
async fn call_json(&self, args: Value) -> Result<Value, ToolError>;
}
#[async_trait]
impl<T: Tool + 'static> DynTool for T
where
T::Output: 'static,
{
fn name(&self) -> &str {
Tool::name(self)
}
fn description(&self) -> String {
Tool::description(self)
}
fn definition(&self) -> ToolDefinition {
Tool::definition(self)
}
async fn call_json(&self, args: Value) -> Result<Value, ToolError> {
Tool::call_json(self, args).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ToolExecutionPolicy {
#[default]
Auto,
RequireConfirmation,
Forbidden,
}
impl fmt::Display for ToolExecutionPolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Auto => write!(f, "auto"),
Self::RequireConfirmation => write!(f, "require_confirmation"),
Self::Forbidden => write!(f, "forbidden"),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolConfirmationRequest {
pub id: String,
pub name: String,
pub arguments: Value,
pub description: String,
}
impl ToolConfirmationRequest {
#[must_use]
pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
let name = name.into();
let description = format!(
"Tool '{}' wants to execute with arguments: {}",
name,
serde_json::to_string_pretty(&arguments).unwrap_or_else(|_| arguments.to_string())
);
Self {
id: id.into(),
name,
arguments,
description,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolConfirmationResponse {
Approved,
Denied,
ApproveAll,
}
impl ToolConfirmationResponse {
#[must_use]
pub const fn is_approved(&self) -> bool {
matches!(self, Self::Approved | Self::ApproveAll)
}
}
#[async_trait]
pub trait ConfirmationHandler: Send + Sync {
async fn confirm(&self, request: &ToolConfirmationRequest) -> ToolConfirmationResponse;
}
pub type SharedConfirmationHandler = std::sync::Arc<dyn ConfirmationHandler>;
#[derive(Debug, Clone, Copy, Default)]
pub struct AutoApproveHandler;
#[async_trait]
impl ConfirmationHandler for AutoApproveHandler {
async fn confirm(&self, _request: &ToolConfirmationRequest) -> ToolConfirmationResponse {
ToolConfirmationResponse::Approved
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AlwaysDenyHandler;
#[async_trait]
impl ConfirmationHandler for AlwaysDenyHandler {
async fn confirm(&self, _request: &ToolConfirmationRequest) -> ToolConfirmationResponse {
ToolConfirmationResponse::Denied
}
}