use std::fmt::{self, Debug};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{Error, Result};
use crate::protocol::ToolMetadata;
mod context;
pub use context::ToolContext;
mod validator;
pub use validator::ParameterValidator;
mod templates;
pub use templates::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum PermissionLevel {
Public,
Authenticated,
Protected,
Admin,
}
impl Default for PermissionLevel {
fn default() -> Self {
Self::Public
}
}
#[async_trait]
pub trait Tool: Send + Sync + Debug {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Value;
fn permission_level(&self) -> PermissionLevel {
PermissionLevel::Public
}
fn check_permission(&self, context: &ToolContext) -> Result<()> {
let required = self.permission_level();
if required == PermissionLevel::Public {
return Ok(());
}
let user_level = context.permission_level();
if user_level >= required {
Ok(())
} else {
Err(Error::ResourceAccess(format!(
"权限不足:需要 {required:?} 权限,当前权限为 {user_level:?}"
)))
}
}
fn return_schema(&self) -> Option<Value> {
None
}
fn streaming(&self) -> bool {
false
}
fn categories(&self) -> Vec<String> {
Vec::new()
}
fn version(&self) -> Option<String> {
None
}
fn author(&self) -> Option<String> {
None
}
fn documentation_url(&self) -> Option<String> {
None
}
fn timeout(&self) -> Option<Duration> {
None
}
fn deprecated(&self) -> bool {
false
}
fn deprecation_message(&self) -> Option<String> {
None
}
fn metadata(&self) -> ToolMetadata {
ToolMetadata {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters(),
return_schema: self.return_schema(),
streaming: self.streaming(),
categories: self.categories(),
version: self.version(),
author: self.author(),
documentation_url: self.documentation_url(),
deprecated: self.deprecated(),
deprecation_message: self.deprecation_message(),
}
}
async fn execute(&self, params: Value, context: Arc<ToolContext>) -> Result<Value>;
async fn send_partial_result(&self, _result: &Value, _context: &ToolContext) -> Result<()> {
Ok(())
}
fn validate_params(&self, _params: &Value) -> Result<()> {
let params_schema = self.parameters();
let validator = ParameterValidator::new(params_schema)?;
validator.validate(_params)
}
async fn before_execute(&self, _context: &ToolContext) -> Result<()> {
Ok(())
}
async fn after_execute(
&self, _context: &ToolContext, _status: bool, _value: Option<&Value>,
_error_message: Option<&str>,
) -> Result<()> {
Ok(())
}
}
pub type BoxedTool = Box<dyn Tool>;
pub struct ToolFn<F> {
name: String,
description: String,
parameters: Value,
return_schema: Option<Value>,
streaming: bool,
categories: Vec<String>,
version: Option<String>,
author: Option<String>,
documentation_url: Option<String>,
timeout: Option<Duration>,
deprecated: bool,
deprecation_message: Option<String>,
permission_level: PermissionLevel,
function: F,
}
impl<F> Debug for ToolFn<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ToolFn")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &self.parameters)
.field("return_schema", &self.return_schema)
.field("streaming", &self.streaming)
.field("categories", &self.categories)
.field("version", &self.version)
.field("author", &self.author)
.field("documentation_url", &self.documentation_url)
.field("timeout", &self.timeout)
.field("deprecated", &self.deprecated)
.field("deprecation_message", &self.deprecation_message)
.field("permission_level", &self.permission_level)
.field("function", &"<function>")
.finish()
}
}
impl<F> ToolFn<F> {
pub fn new(
name: impl Into<String>, description: impl Into<String>, parameters: Value, function: F,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
return_schema: None,
streaming: false,
categories: Vec::new(),
version: None,
author: None,
documentation_url: None,
timeout: None,
deprecated: false,
deprecation_message: None,
permission_level: PermissionLevel::Public,
function,
}
}
pub fn with_return_schema(mut self, schema: Value) -> Self {
self.return_schema = Some(schema);
self
}
pub fn with_streaming(mut self, streaming: bool) -> Self {
self.streaming = streaming;
self
}
pub fn with_categories(mut self, categories: Vec<String>) -> Self {
self.categories = categories;
self
}
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn with_author(mut self, author: impl Into<String>) -> Self {
self.author = Some(author.into());
self
}
pub fn with_documentation_url(mut self, url: impl Into<String>) -> Self {
self.documentation_url = Some(url.into());
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn deprecated(mut self, message: Option<impl Into<String>>) -> Self {
self.deprecated = true;
self.deprecation_message = message.map(|m| m.into());
self
}
pub fn with_permission_level(mut self, level: PermissionLevel) -> Self {
self.permission_level = level;
self
}
}
#[async_trait]
impl<F, Fut> Tool for ToolFn<F>
where
F: Fn(Value, Arc<ToolContext>) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Value>> + Send,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters(&self) -> Value {
self.parameters.clone()
}
fn permission_level(&self) -> PermissionLevel {
self.permission_level
}
fn return_schema(&self) -> Option<Value> {
self.return_schema.clone()
}
fn streaming(&self) -> bool {
self.streaming
}
fn categories(&self) -> Vec<String> {
self.categories.clone()
}
fn version(&self) -> Option<String> {
self.version.clone()
}
fn author(&self) -> Option<String> {
self.author.clone()
}
fn documentation_url(&self) -> Option<String> {
self.documentation_url.clone()
}
fn timeout(&self) -> Option<Duration> {
self.timeout
}
fn deprecated(&self) -> bool {
self.deprecated
}
fn deprecation_message(&self) -> Option<String> {
self.deprecation_message.clone()
}
async fn execute(&self, params: Value, context: Arc<ToolContext>) -> Result<Value> {
(self.function)(params, context).await
}
}