fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};

use serde_json::Value;
use tokio::sync::mpsc;
use tokio::time::timeout;

use crate::error::{Error, Result};
use crate::protocol::SessionId;
use crate::tool::PermissionLevel;

/// 资源使用限制
#[derive(Debug, Clone, Default)]
pub struct ResourceLimits {
    /// 最大内存使用量(字节)
    pub max_memory: Option<u64>,

    /// 最大CPU时间(毫秒)
    pub max_cpu_time: Option<u64>,

    /// 最大API调用次数
    pub max_api_calls: Option<u64>,

    /// 最大并发请求数
    pub max_concurrent_requests: Option<u32>,
}

/// 资源使用情况
#[derive(Debug)]
pub struct ResourceUsage {
    /// 已使用内存(字节)
    memory_used: AtomicU64,

    /// 已使用CPU时间(毫秒)
    cpu_time: AtomicU64,

    /// API调用次数
    api_calls: AtomicU64,

    /// 当前并发请求数
    current_concurrent_requests: AtomicU64,

    /// 启动时间
    start_time: Instant,
}

impl Default for ResourceUsage {
    fn default() -> Self {
        Self {
            memory_used: AtomicU64::new(0),
            cpu_time: AtomicU64::new(0),
            api_calls: AtomicU64::new(0),
            current_concurrent_requests: AtomicU64::new(0),
            start_time: Instant::now(),
        }
    }
}

impl ResourceUsage {
    /// 获取已使用内存
    pub fn memory_used(&self) -> u64 {
        self.memory_used.load(Ordering::Relaxed)
    }

    /// 获取已使用CPU时间
    pub fn cpu_time(&self) -> u64 {
        self.cpu_time.load(Ordering::Relaxed)
    }

    /// 获取API调用次数
    pub fn api_calls(&self) -> u64 {
        self.api_calls.load(Ordering::Relaxed)
    }

    /// 获取当前并发请求数
    pub fn current_concurrent_requests(&self) -> u64 {
        self.current_concurrent_requests.load(Ordering::Relaxed)
    }

    /// 获取已运行时间
    pub fn elapsed_time(&self) -> Duration {
        self.start_time.elapsed()
    }

    /// 增加内存使用量
    pub fn add_memory(&self, bytes: u64) {
        self.memory_used.fetch_add(bytes, Ordering::Relaxed);
    }

    /// 增加CPU时间
    pub fn add_cpu_time(&self, ms: u64) {
        self.cpu_time.fetch_add(ms, Ordering::Relaxed);
    }

    /// 增加API调用次数
    pub fn increment_api_calls(&self) {
        self.api_calls.fetch_add(1, Ordering::Relaxed);
    }

    /// 增加并发请求数
    pub fn increment_concurrent_requests(&self) -> u64 {
        self.current_concurrent_requests
            .fetch_add(1, Ordering::Relaxed)
    }

    /// 减少并发请求数
    pub fn decrement_concurrent_requests(&self) -> u64 {
        self.current_concurrent_requests
            .fetch_sub(1, Ordering::Relaxed)
    }
}

/// Context information for tool execution
#[derive(Debug)]
pub struct ToolContext {
    /// Session ID
    pub session_id: SessionId,

    /// Whether the execution has been cancelled
    cancelled: AtomicBool,

    /// Channel for sending partial results
    partial_results_tx: Option<mpsc::Sender<Value>>,

    /// Default timeout for operations
    timeout: Option<Duration>,

    /// 用户权限级别
    permission_level: PermissionLevel,

    /// 用户ID(如果已认证)
    user_id: Option<String>,

    /// 资源使用限制
    resource_limits: ResourceLimits,

    /// 资源使用情况
    resource_usage: ResourceUsage,
}

impl ToolContext {
    /// Create a new ToolContext
    pub fn new(session_id: SessionId) -> Self {
        Self {
            session_id,
            cancelled: AtomicBool::new(false),
            partial_results_tx: None,
            timeout: None,
            permission_level: PermissionLevel::Public,
            user_id: None,
            resource_limits: ResourceLimits::default(),
            resource_usage: ResourceUsage::default(),
        }
    }

    /// Create a new ToolContext with a timeout
    pub fn with_timeout(session_id: SessionId, timeout_duration: Duration) -> Self {
        Self {
            session_id,
            cancelled: AtomicBool::new(false),
            partial_results_tx: None,
            timeout: Some(timeout_duration),
            permission_level: PermissionLevel::Public,
            user_id: None,
            resource_limits: ResourceLimits::default(),
            resource_usage: ResourceUsage::default(),
        }
    }

    /// 设置用户权限级别
    pub fn with_permission_level(mut self, level: PermissionLevel) -> Self {
        self.permission_level = level;
        self
    }

    /// 设置用户ID
    pub fn with_user_id(mut self, user_id: String) -> Self {
        self.user_id = Some(user_id);
        self.permission_level = if self.permission_level == PermissionLevel::Public {
            PermissionLevel::Authenticated
        } else {
            self.permission_level
        };
        self
    }

    /// 获取用户权限级别
    pub fn permission_level(&self) -> PermissionLevel {
        self.permission_level
    }

    /// 获取用户ID
    pub fn user_id(&self) -> Option<&str> {
        self.user_id.as_deref()
    }

    /// 提升为管理员权限
    pub fn elevate_to_admin(&mut self) {
        self.permission_level = PermissionLevel::Admin;
    }

    /// Set the partial results channel
    pub fn with_partial_results(mut self, tx: mpsc::Sender<Value>) -> Self {
        self.partial_results_tx = Some(tx);
        self
    }

    /// Cancel the execution
    pub fn cancel(&self) {
        self.cancelled.store(true, Ordering::SeqCst);
    }

    /// Check if the execution has been cancelled
    pub fn is_cancelled(&self) -> bool {
        self.cancelled.load(Ordering::SeqCst)
    }

    /// Check if the execution has been cancelled and return an error if it has
    pub fn check_cancelled(&self) -> Result<()> {
        if self.is_cancelled() {
            return Err(Error::Cancelled("Tool execution cancelled".to_string()));
        }
        Ok(())
    }

    /// Run a future with the context's timeout
    pub async fn with_timeout_future<F, T, E>(&self, future: F) -> Result<T>
    where
        F: Future<Output = std::result::Result<T, E>>,
        E: Into<Error>,
    {
        self.check_cancelled()?;

        match self.timeout {
            Some(duration) => {
                match timeout(duration, future).await {
                    Ok(result) => result.map_err(|e| e.into()),
                    Err(_) => Err(Error::Timeout("Tool execution timed out".to_string())),
                }
            }
            None => future.await.map_err(|e| e.into()),
        }
    }

    /// Send a partial result
    pub async fn send_partial_result(&self, result: Value) -> Result<()> {
        if let Some(tx) = &self.partial_results_tx {
            tx.send(result)
                .await
                .map_err(|_| Error::Transport("Failed to send partial result".to_string()))?;
        }
        Ok(())
    }

    /// 设置资源使用限制
    pub fn with_resource_limits(mut self, limits: ResourceLimits) -> Self {
        self.resource_limits = limits;
        self
    }

    /// 检查资源使用是否超出限制
    pub fn check_resource_limits(&self) -> Result<()> {
        // 检查内存使用
        if let Some(max_memory) = self.resource_limits.max_memory {
            let memory_used = self.resource_usage.memory_used();
            if memory_used > max_memory {
                return Err(Error::ResourceAccess(format!(
                    "内存使用超出限制:已使用 {memory_used} 字节,最大限制 {max_memory} 字节"
                )));
            }
        }

        // 检查CPU时间
        if let Some(max_cpu_time) = self.resource_limits.max_cpu_time {
            let cpu_time = self.resource_usage.cpu_time();
            if cpu_time > max_cpu_time {
                return Err(Error::ResourceAccess(format!(
                    "CPU时间超出限制:已使用 {cpu_time} 毫秒,最大限制 {max_cpu_time} 毫秒"
                )));
            }
        }

        // 检查API调用次数
        if let Some(max_api_calls) = self.resource_limits.max_api_calls {
            let api_calls = self.resource_usage.api_calls();
            if api_calls > max_api_calls {
                return Err(Error::ResourceAccess(format!(
                    "API调用次数超出限制:已调用 {api_calls} 次,最大限制 {max_api_calls}"
                )));
            }
        }

        // 检查并发请求数
        if let Some(max_concurrent_requests) = self.resource_limits.max_concurrent_requests {
            let concurrent_requests = self.resource_usage.current_concurrent_requests();
            if concurrent_requests > max_concurrent_requests as u64 {
                return Err(Error::ResourceAccess(format!(
                    "并发请求数超出限制:当前 {concurrent_requests} 个,最大限制 {max_concurrent_requests}"
                )));
            }
        }

        Ok(())
    }

    /// 记录API调用
    pub fn record_api_call(&self) -> Result<()> {
        self.resource_usage.increment_api_calls();
        self.check_resource_limits()
    }

    /// 记录内存使用
    pub fn record_memory_usage(&self, bytes: u64) -> Result<()> {
        self.resource_usage.add_memory(bytes);
        self.check_resource_limits()
    }

    /// 获取资源使用情况
    pub fn resource_usage(&self) -> &ResourceUsage {
        &self.resource_usage
    }

    /// 获取资源限制
    pub fn resource_limits(&self) -> &ResourceLimits {
        &self.resource_limits
    }
}