Skip to main content

forge_runtime/mcp/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{ForgeMcpTool, McpToolContext, McpToolInfo, Result};
7use serde_json::Value;
8
9fn normalize_args(args: Value) -> Value {
10    let unwrapped = match args {
11        Value::Object(map) if map.len() == 1 => map
12            .get("args")
13            .or_else(|| map.get("input"))
14            .cloned()
15            .unwrap_or(Value::Object(map)),
16        other => other,
17    };
18
19    match unwrapped {
20        Value::Null => Value::Object(serde_json::Map::new()),
21        other => other,
22    }
23}
24
25pub type BoxedMcpToolFn = Arc<
26    dyn Fn(&McpToolContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
27        + Send
28        + Sync,
29>;
30
31/// A registered MCP tool entry.
32pub struct McpToolEntry {
33    pub info: McpToolInfo,
34    pub input_schema: Value,
35    pub output_schema: Option<Value>,
36    pub handler: BoxedMcpToolFn,
37}
38
39/// Registry of MCP tools.
40#[derive(Clone, Default)]
41pub struct McpToolRegistry {
42    tools: HashMap<String, Arc<McpToolEntry>>,
43}
44
45impl McpToolRegistry {
46    pub fn new() -> Self {
47        Self {
48            tools: HashMap::new(),
49        }
50    }
51
52    pub fn register<T: ForgeMcpTool>(&mut self)
53    where
54        T::Args: serde::de::DeserializeOwned + Send + 'static,
55        T::Output: serde::Serialize + Send + 'static,
56    {
57        let info = T::info();
58        if let Err(e) = info.validate() {
59            tracing::error!(error = %e, "Skipping invalid MCP tool registration");
60            return;
61        }
62
63        let name = info.name.to_string();
64        let input_schema = T::input_schema();
65        let output_schema = T::output_schema();
66
67        let handler: BoxedMcpToolFn = Arc::new(move |ctx, args| {
68            Box::pin(async move {
69                let parsed_args: T::Args = serde_json::from_value(normalize_args(args))
70                    .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
71                let result = T::execute(ctx, parsed_args).await?;
72                serde_json::to_value(result)
73                    .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
74            })
75        });
76
77        self.tools.insert(
78            name,
79            Arc::new(McpToolEntry {
80                info,
81                input_schema,
82                output_schema,
83                handler,
84            }),
85        );
86    }
87
88    pub fn get(&self, name: &str) -> Option<Arc<McpToolEntry>> {
89        self.tools.get(name).cloned()
90    }
91
92    pub fn list(&self) -> impl Iterator<Item = Arc<McpToolEntry>> + '_ {
93        self.tools.values().cloned()
94    }
95
96    pub fn len(&self) -> usize {
97        self.tools.len()
98    }
99
100    pub fn is_empty(&self) -> bool {
101        self.tools.is_empty()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_registry_defaults() {
111        let registry = McpToolRegistry::new();
112        assert!(registry.is_empty());
113        assert_eq!(registry.len(), 0);
114    }
115}