fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::fmt::{self, Debug};
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::error::{Error, Result};
use crate::protocol::ToolMetadata;

mod context;
pub use context::ToolContext;

mod validator;
pub use validator::ParameterValidator;

mod templates;
pub use templates::*;

/// 工具权限级别
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum PermissionLevel {
    /// 公共工具,所有用户都可以访问
    Public,

    /// 需要认证的工具,只有登录用户可以访问
    Authenticated,

    /// 需要特定权限的工具,只有有权限的用户可以访问
    Protected,

    /// 仅管理员可以访问的工具
    Admin,
}

impl Default for PermissionLevel {
    fn default() -> Self {
        Self::Public
    }
}

/// Tool trait for MCP tools
#[async_trait]
pub trait Tool: Send + Sync + Debug {
    /// Get the tool name
    fn name(&self) -> &str;

    /// Get the tool description
    fn description(&self) -> &str;

    /// Get the tool parameters as JSON Schema
    fn parameters(&self) -> Value;

    /// Get the required permission level for this tool
    fn permission_level(&self) -> PermissionLevel {
        PermissionLevel::Public
    }

    /// Check if the caller has permission to use this tool
    fn check_permission(&self, context: &ToolContext) -> Result<()> {
        let required = self.permission_level();

        // 如果是公共工具,直接返回成功
        if required == PermissionLevel::Public {
            return Ok(());
        }

        // 检查上下文中的用户权限
        let user_level = context.permission_level();

        if user_level >= required {
            Ok(())
        } else {
            Err(Error::ResourceAccess(format!(
                "权限不足:需要 {required:?} 权限,当前权限为 {user_level:?}"
            )))
        }
    }

    /// Get the tool return value schema (if any)
    fn return_schema(&self) -> Option<Value> {
        None
    }

    /// Get whether this tool supports streaming responses
    fn streaming(&self) -> bool {
        false
    }

    /// Get tool categories
    fn categories(&self) -> Vec<String> {
        Vec::new()
    }

    /// Get tool version
    fn version(&self) -> Option<String> {
        None
    }

    /// Get tool author
    fn author(&self) -> Option<String> {
        None
    }

    /// Get tool documentation URL
    fn documentation_url(&self) -> Option<String> {
        None
    }

    /// Get tool timeout (if different from server default)
    fn timeout(&self) -> Option<Duration> {
        None
    }

    /// Get whether this tool is deprecated
    fn deprecated(&self) -> bool {
        false
    }

    /// Get deprecation message if tool is deprecated
    fn deprecation_message(&self) -> Option<String> {
        None
    }

    /// Get complete tool metadata
    fn metadata(&self) -> ToolMetadata {
        ToolMetadata {
            name: self.name().to_string(),
            description: self.description().to_string(),
            parameters: self.parameters(),
            return_schema: self.return_schema(),
            streaming: self.streaming(),
            categories: self.categories(),
            version: self.version(),
            author: self.author(),
            documentation_url: self.documentation_url(),
            deprecated: self.deprecated(),
            deprecation_message: self.deprecation_message(),
        }
    }

    /// Execute the tool with the given parameters
    async fn execute(&self, params: Value, context: Arc<ToolContext>) -> Result<Value>;

    /// Send a partial result (for streaming tools)
    async fn send_partial_result(&self, _result: &Value, _context: &ToolContext) -> Result<()> {
        Ok(())
    }

    /// Validate tool parameters before execution
    fn validate_params(&self, _params: &Value) -> Result<()> {
        // 默认使用JSON Schema验证参数
        let params_schema = self.parameters();
        let validator = ParameterValidator::new(params_schema)?;
        validator.validate(_params)
    }

    /// Called before tool execution
    async fn before_execute(&self, _context: &ToolContext) -> Result<()> {
        Ok(())
    }

    /// Called after tool execution (regardless of success or failure)
    /// status: true if execution succeeded, false if it failed
    /// value: the result value if succeeded, or the error message if failed
    async fn after_execute(
        &self, _context: &ToolContext, _status: bool, _value: Option<&Value>,
        _error_message: Option<&str>,
    ) -> Result<()> {
        Ok(())
    }
}

/// A type-erased boxed Tool
pub type BoxedTool = Box<dyn Tool>;

/// A tool function wrapper
pub struct ToolFn<F> {
    name: String,
    description: String,
    parameters: Value,
    return_schema: Option<Value>,
    streaming: bool,
    categories: Vec<String>,
    version: Option<String>,
    author: Option<String>,
    documentation_url: Option<String>,
    timeout: Option<Duration>,
    deprecated: bool,
    deprecation_message: Option<String>,
    permission_level: PermissionLevel,
    function: F,
}

// Implement Debug manually for ToolFn without requiring F to implement Debug
impl<F> Debug for ToolFn<F> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ToolFn")
            .field("name", &self.name)
            .field("description", &self.description)
            .field("parameters", &self.parameters)
            .field("return_schema", &self.return_schema)
            .field("streaming", &self.streaming)
            .field("categories", &self.categories)
            .field("version", &self.version)
            .field("author", &self.author)
            .field("documentation_url", &self.documentation_url)
            .field("timeout", &self.timeout)
            .field("deprecated", &self.deprecated)
            .field("deprecation_message", &self.deprecation_message)
            .field("permission_level", &self.permission_level)
            .field("function", &"<function>")
            .finish()
    }
}

impl<F> ToolFn<F> {
    /// Create a new ToolFn
    pub fn new(
        name: impl Into<String>, description: impl Into<String>, parameters: Value, function: F,
    ) -> Self {
        Self {
            name: name.into(),
            description: description.into(),
            parameters,
            return_schema: None,
            streaming: false,
            categories: Vec::new(),
            version: None,
            author: None,
            documentation_url: None,
            timeout: None,
            deprecated: false,
            deprecation_message: None,
            permission_level: PermissionLevel::Public,
            function,
        }
    }

    /// Set the return schema
    pub fn with_return_schema(mut self, schema: Value) -> Self {
        self.return_schema = Some(schema);
        self
    }

    /// Enable streaming for this tool
    pub fn with_streaming(mut self, streaming: bool) -> Self {
        self.streaming = streaming;
        self
    }

    /// Add categories to this tool
    pub fn with_categories(mut self, categories: Vec<String>) -> Self {
        self.categories = categories;
        self
    }

    /// Set tool version
    pub fn with_version(mut self, version: impl Into<String>) -> Self {
        self.version = Some(version.into());
        self
    }

    /// Set tool author
    pub fn with_author(mut self, author: impl Into<String>) -> Self {
        self.author = Some(author.into());
        self
    }

    /// Set tool documentation URL
    pub fn with_documentation_url(mut self, url: impl Into<String>) -> Self {
        self.documentation_url = Some(url.into());
        self
    }

    /// Set tool timeout
    pub fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = Some(timeout);
        self
    }

    /// Mark tool as deprecated
    pub fn deprecated(mut self, message: Option<impl Into<String>>) -> Self {
        self.deprecated = true;
        self.deprecation_message = message.map(|m| m.into());
        self
    }

    /// Set permission level for this tool
    pub fn with_permission_level(mut self, level: PermissionLevel) -> Self {
        self.permission_level = level;
        self
    }
}

#[async_trait]
impl<F, Fut> Tool for ToolFn<F>
where
    F: Fn(Value, Arc<ToolContext>) -> Fut + Send + Sync,
    Fut: std::future::Future<Output = Result<Value>> + Send,
{
    fn name(&self) -> &str {
        &self.name
    }

    fn description(&self) -> &str {
        &self.description
    }

    fn parameters(&self) -> Value {
        self.parameters.clone()
    }

    fn permission_level(&self) -> PermissionLevel {
        self.permission_level
    }

    fn return_schema(&self) -> Option<Value> {
        self.return_schema.clone()
    }

    fn streaming(&self) -> bool {
        self.streaming
    }

    fn categories(&self) -> Vec<String> {
        self.categories.clone()
    }

    fn version(&self) -> Option<String> {
        self.version.clone()
    }

    fn author(&self) -> Option<String> {
        self.author.clone()
    }

    fn documentation_url(&self) -> Option<String> {
        self.documentation_url.clone()
    }

    fn timeout(&self) -> Option<Duration> {
        self.timeout
    }

    fn deprecated(&self) -> bool {
        self.deprecated
    }

    fn deprecation_message(&self) -> Option<String> {
        self.deprecation_message.clone()
    }

    async fn execute(&self, params: Value, context: Arc<ToolContext>) -> Result<Value> {
        (self.function)(params, context).await
    }
}