fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use jsonschema::{Draft, JSONSchema};
use serde_json::Value;

use crate::error::{Error, Result};

/// 参数验证器
#[derive(Debug, Clone)]
pub struct ParameterValidator {
    /// JSON Schema 验证器实例
    schema: Arc<JSONSchema>,

    /// 缓存的验证器实例
    static_validators: HashMap<String, Arc<JSONSchema>>,
}

impl ParameterValidator {
    /// 从JSON Schema创建新的参数验证器
    pub fn new(schema: Value) -> Result<Self> {
        let compiled = JSONSchema::options()
            .with_draft(Draft::Draft7)
            .compile(&schema)
            .map_err(|e| Error::InvalidInput(format!("无效的JSON Schema: {e}")))?;

        Ok(Self {
            schema: Arc::new(compiled),
            static_validators: HashMap::new(),
        })
    }

    /// 添加静态验证器(用于特定字段验证)
    pub fn with_validator(mut self, name: &str, schema: Value) -> Result<Self> {
        let compiled = JSONSchema::options()
            .with_draft(Draft::Draft7)
            .compile(&schema)
            .map_err(|e| Error::InvalidInput(format!("无效的字段验证器: {e}")))?;

        self.static_validators
            .insert(name.to_string(), Arc::new(compiled));
        Ok(self)
    }

    /// 验证参数
    pub fn validate(&self, params: &Value) -> Result<()> {
        // 首先验证整体结构
        let result = self.schema.validate(params);

        if let Err(errors) = result {
            let error_messages: Vec<String> = errors
                .map(|e| format!("Path '{}': {}", e.instance_path, e))
                .collect();

            return Err(Error::InvalidInput(format!(
                "参数验证失败: {}",
                error_messages.join("; ")
            )));
        }

        // 针对特定字段进行进一步验证
        for (field, validator) in &self.static_validators {
            if let Some(value) = params.get(field) {
                let result = validator.validate(value);

                if let Err(errors) = result {
                    let error_messages: Vec<String> = errors.map(|e| format!("{e}")).collect();

                    return Err(Error::InvalidInput(format!(
                        "字段 '{}' 验证失败: {}",
                        field,
                        error_messages.join("; ")
                    )));
                }
            }
        }

        Ok(())
    }
}