use super::{McpMiddleware, MiddlewareContext, MiddlewareResult};
use crate::mcp::{CallToolRequest, McpError, McpResult};
pub struct ValidationMiddleware {
strict_mode: bool,
}
impl ValidationMiddleware {
#[must_use]
pub fn new(strict_mode: bool) -> Self {
Self { strict_mode }
}
#[must_use]
pub fn strict() -> Self {
Self::new(true)
}
#[must_use]
pub fn lenient() -> Self {
Self::new(false)
}
fn validate_request(&self, request: &CallToolRequest) -> McpResult<()> {
if request.name.is_empty() {
return Err(McpError::validation_error("Tool name cannot be empty"));
}
if !request
.name
.chars()
.all(|c| c.is_alphanumeric() || c == '_')
{
return Err(McpError::validation_error(
"Tool name must contain only alphanumeric characters and underscores",
));
}
if self.strict_mode {
if let Some(args) = &request.arguments {
if !args.is_object() {
return Err(McpError::validation_error(
"Arguments must be a JSON object",
));
}
}
}
Ok(())
}
}
#[async_trait::async_trait]
impl McpMiddleware for ValidationMiddleware {
fn name(&self) -> &'static str {
"validation"
}
fn priority(&self) -> i32 {
50 }
async fn before_request(
&self,
request: &CallToolRequest,
context: &mut MiddlewareContext,
) -> McpResult<MiddlewareResult> {
if let Err(error) = self.validate_request(request) {
context.set_metadata(
"validation_error".to_string(),
serde_json::Value::String(error.to_string()),
);
return Ok(MiddlewareResult::Error(error));
}
context.set_metadata("validated".to_string(), serde_json::Value::Bool(true));
Ok(MiddlewareResult::Continue)
}
}