use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use indexmap::IndexMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tracing::{trace, warn};
use uuid::Uuid;
use crate::error::{FastMcpError, Result};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub enum DuplicateBehavior {
#[default]
Error,
Replace,
Ignore,
Warn,
}
#[derive(Clone, Debug)]
pub struct InvocationContext {
pub tool_name: String,
pub request_id: Uuid,
pub timestamp: DateTime<Utc>,
pub metadata: Map<String, Value>,
}
impl InvocationContext {
pub fn new(tool_name: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
request_id: Uuid::new_v4(),
timestamp: Utc::now(),
metadata: Map::new(),
}
}
}
pub type ToolAnnotations = Map<String, Value>;
fn annotations_is_empty(annotations: &ToolAnnotations) -> bool {
annotations.is_empty()
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolResponse {
pub content: Vec<Value>,
#[serde(default, skip_serializing_if = "annotations_is_empty")]
pub annotations: ToolAnnotations,
}
impl ToolResponse {
pub fn new(content: Vec<Value>) -> Self {
Self {
content,
annotations: ToolAnnotations::default(),
}
}
pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self {
self.annotations = annotations;
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolDefinitionMetadata {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
#[serde(default, skip_serializing_if = "annotations_is_empty")]
pub annotations: ToolAnnotations,
}
#[async_trait]
pub trait ToolInvocation: Send + Sync {
async fn invoke(&self, ctx: InvocationContext, arguments: Value) -> Result<ToolResponse>;
}
#[async_trait]
impl<F, Fut> ToolInvocation for F
where
F: Send + Sync + Fn(InvocationContext, Value) -> Fut,
Fut: std::future::Future<Output = Result<ToolResponse>> + Send,
{
async fn invoke(&self, ctx: InvocationContext, arguments: Value) -> Result<ToolResponse> {
(self)(ctx, arguments).await
}
}
pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub summary: Option<String>,
pub parameters: Option<Value>,
pub annotations: ToolAnnotations,
handler: Arc<dyn ToolInvocation>,
}
impl ToolDefinition {
pub fn new(name: impl Into<String>, handler: impl ToolInvocation + 'static) -> Self {
Self {
name: name.into(),
description: None,
summary: None,
parameters: None,
annotations: ToolAnnotations::default(),
handler: Arc::new(handler),
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
self.summary = Some(summary.into());
self
}
pub fn with_parameters(mut self, parameters: Value) -> Self {
self.parameters = Some(parameters);
self
}
pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self {
self.annotations = annotations;
self
}
pub(crate) fn metadata(&self) -> ToolDefinitionMetadata {
ToolDefinitionMetadata {
name: self.name.clone(),
description: self.description.clone(),
summary: self.summary.clone(),
parameters: self.parameters.clone(),
annotations: self.annotations.clone(),
}
}
pub(crate) fn handler(&self) -> Arc<dyn ToolInvocation> {
Arc::clone(&self.handler)
}
}
pub struct ToolManager {
duplicate_behavior: DuplicateBehavior,
tools: RwLock<IndexMap<String, Arc<ToolDefinition>>>,
}
impl ToolManager {
pub fn new(duplicate_behavior: DuplicateBehavior) -> Self {
Self {
duplicate_behavior,
tools: RwLock::new(IndexMap::new()),
}
}
pub fn len(&self) -> usize {
self.tools.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn register(&self, tool: ToolDefinition) -> Result<()> {
let mut guard = self.tools.write();
match guard.get_mut(&tool.name) {
Some(existing) => match self.duplicate_behavior {
DuplicateBehavior::Error => {
return Err(FastMcpError::DuplicateTool(tool.name));
}
DuplicateBehavior::Ignore => {
trace!("Ignoring duplicate registration for tool {}", tool.name);
}
DuplicateBehavior::Replace => {
trace!("Replacing tool {}", tool.name);
*existing = Arc::new(tool);
}
DuplicateBehavior::Warn => {
warn!("Replacing duplicate tool {}", tool.name);
*existing = Arc::new(tool);
}
},
None => {
guard.insert(tool.name.clone(), Arc::new(tool));
}
}
Ok(())
}
pub fn list(&self) -> Vec<ToolDefinitionMetadata> {
self.tools
.read()
.values()
.map(|tool| tool.metadata())
.collect()
}
pub fn get(&self, name: &str) -> Option<ToolDefinitionMetadata> {
self.tools.read().get(name).map(|tool| tool.metadata())
}
pub fn contains(&self, name: &str) -> bool {
self.tools.read().contains_key(name)
}
pub async fn call(&self, name: &str, arguments: Value) -> Result<ToolResponse> {
let tool = {
let guard = self.tools.read();
guard
.get(name)
.cloned()
.ok_or_else(|| FastMcpError::ToolNotFound(name.to_string()))?
};
let ctx = InvocationContext::new(name.to_string());
tool.handler().invoke(ctx, arguments).await
}
}
#[cfg(feature = "auto-register")]
pub type ToolFactory = fn() -> ToolDefinition;
#[cfg(feature = "auto-register")]
#[linkme::distributed_slice]
pub static MCP_TOOL_FACTORIES: [ToolFactory];
#[cfg(feature = "auto-register")]
pub fn register_discovered_tools(server: &crate::server::FastMcpServer) {
for factory in MCP_TOOL_FACTORIES {
if let Err(e) = server.register_tool(factory()) {
warn!("Auto-register tool failed: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[tokio::test]
async fn registers_and_invokes_tool() {
let manager = ToolManager::new(DuplicateBehavior::Error);
manager
.register(
ToolDefinition::new("greet", |_, payload: Value| async move {
let name = payload
.get("name")
.and_then(Value::as_str)
.unwrap_or("world");
Ok(ToolResponse::new(vec![json!({
"type": "text",
"text": format!("Hello, {name}!"),
})]))
})
.with_description("Greets a user"),
)
.unwrap();
let response = manager
.call("greet", json!({ "name": "FastMCP" }))
.await
.unwrap();
assert_eq!(response.content.len(), 1);
assert_eq!(
response.content[0]["text"].as_str(),
Some("Hello, FastMCP!")
);
}
}