forge_runtime/mcp/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::{ForgeMcpTool, McpToolContext, McpToolInfo, Result};
7use serde_json::Value;
8
9fn normalize_args(args: Value) -> Value {
10 let unwrapped = match args {
11 Value::Object(map) if map.len() == 1 => map
12 .get("args")
13 .or_else(|| map.get("input"))
14 .cloned()
15 .unwrap_or(Value::Object(map)),
16 other => other,
17 };
18
19 match unwrapped {
20 Value::Null => Value::Object(serde_json::Map::new()),
21 other => other,
22 }
23}
24
25pub type BoxedMcpToolFn = Arc<
26 dyn Fn(&McpToolContext, Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>>
27 + Send
28 + Sync,
29>;
30
31pub struct McpToolEntry {
33 pub info: McpToolInfo,
34 pub input_schema: Value,
35 pub output_schema: Option<Value>,
36 pub handler: BoxedMcpToolFn,
37}
38
39#[derive(Clone, Default)]
41pub struct McpToolRegistry {
42 tools: HashMap<String, Arc<McpToolEntry>>,
43}
44
45impl McpToolRegistry {
46 pub fn new() -> Self {
47 Self {
48 tools: HashMap::new(),
49 }
50 }
51
52 pub fn register<T: ForgeMcpTool>(&mut self)
53 where
54 T::Args: serde::de::DeserializeOwned + Send + 'static,
55 T::Output: serde::Serialize + Send + 'static,
56 {
57 let info = T::info();
58 if let Err(e) = info.validate() {
59 tracing::error!(error = %e, "Skipping invalid MCP tool registration");
60 return;
61 }
62
63 let name = info.name.to_string();
64 let input_schema = T::input_schema();
65 let output_schema = T::output_schema();
66
67 let handler: BoxedMcpToolFn = Arc::new(move |ctx, args| {
68 Box::pin(async move {
69 let parsed_args: T::Args = serde_json::from_value(normalize_args(args))
70 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
71 let result = T::execute(ctx, parsed_args).await?;
72 serde_json::to_value(result)
73 .map_err(|e| forge_core::ForgeError::Internal(e.to_string()))
74 })
75 });
76
77 self.tools.insert(
78 name,
79 Arc::new(McpToolEntry {
80 info,
81 input_schema,
82 output_schema,
83 handler,
84 }),
85 );
86 }
87
88 pub fn get(&self, name: &str) -> Option<Arc<McpToolEntry>> {
89 self.tools.get(name).cloned()
90 }
91
92 pub fn list(&self) -> impl Iterator<Item = Arc<McpToolEntry>> + '_ {
93 self.tools.values().cloned()
94 }
95
96 pub fn len(&self) -> usize {
97 self.tools.len()
98 }
99
100 pub fn is_empty(&self) -> bool {
101 self.tools.is_empty()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_registry_defaults() {
111 let registry = McpToolRegistry::new();
112 assert!(registry.is_empty());
113 assert_eq!(registry.len(), 0);
114 }
115}