converge_provider/tools/
registry.rs1use super::{ToolCall, ToolDefinition, ToolError, ToolResult, ToolSource};
7use std::collections::HashMap;
8
9#[derive(Debug, Default)]
11pub struct ToolRegistry {
12 tools: HashMap<String, ToolDefinition>,
13}
14
15impl Clone for ToolRegistry {
16 fn clone(&self) -> Self {
17 Self { tools: self.tools.clone() }
18 }
19}
20
21impl ToolRegistry {
22 #[must_use]
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn register(&mut self, tool: ToolDefinition) {
28 self.tools.insert(tool.name.clone(), tool);
29 }
30
31 pub fn register_all(&mut self, tools: impl IntoIterator<Item = ToolDefinition>) {
32 for tool in tools {
33 self.register(tool);
34 }
35 }
36
37 #[must_use]
38 pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
39 self.tools.get(name)
40 }
41
42 #[must_use]
43 pub fn contains(&self, name: &str) -> bool {
44 self.tools.contains_key(name)
45 }
46
47 #[must_use]
48 pub fn list_tools(&self) -> Vec<&ToolDefinition> {
49 self.tools.values().collect()
50 }
51
52 #[must_use]
53 pub fn len(&self) -> usize {
54 self.tools.len()
55 }
56
57 #[must_use]
58 pub fn is_empty(&self) -> bool {
59 self.tools.is_empty()
60 }
61
62 #[must_use]
63 pub fn tools_by_source(&self, filter: SourceFilter) -> Vec<&ToolDefinition> {
64 self.tools.values().filter(|t| filter.matches(&t.source)).collect()
65 }
66
67 pub fn call_tool(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
68 let tool = self.get(&call.tool_name)
69 .ok_or_else(|| ToolError::not_found(&call.tool_name))?;
70
71 match &tool.source {
72 ToolSource::Inline => Err(ToolError::unsupported_source("inline")),
73 ToolSource::Mcp { .. } => Err(ToolError::unsupported_source("mcp (use McpClient)")),
74 ToolSource::OpenApi { .. } => Err(ToolError::unsupported_source("openapi")),
75 ToolSource::GraphQl { .. } => Err(ToolError::unsupported_source("graphql")),
76 }
77 }
78
79 #[must_use]
80 pub fn to_llm_tools(&self) -> Vec<serde_json::Value> {
81 self.tools.values().map(|tool| {
82 serde_json::json!({
83 "type": "function",
84 "function": {
85 "name": tool.name,
86 "description": tool.description,
87 "parameters": tool.input_schema.schema
88 }
89 })
90 }).collect()
91 }
92
93 #[must_use]
94 pub fn to_anthropic_tools(&self) -> Vec<serde_json::Value> {
95 self.tools.values().map(|tool| {
96 serde_json::json!({
97 "name": tool.name,
98 "description": tool.description,
99 "input_schema": tool.input_schema.schema
100 })
101 }).collect()
102 }
103}
104
105#[derive(Debug, Clone, Copy, Default)]
107pub enum SourceFilter {
108 #[default]
109 All,
110 Mcp,
111 OpenApi,
112 GraphQl,
113 Inline,
114}
115
116impl SourceFilter {
117 #[must_use]
118 pub fn matches(&self, source: &ToolSource) -> bool {
119 match self {
120 Self::All => true,
121 Self::Mcp => matches!(source, ToolSource::Mcp { .. }),
122 Self::OpenApi => matches!(source, ToolSource::OpenApi { .. }),
123 Self::GraphQl => matches!(source, ToolSource::GraphQl { .. }),
124 Self::Inline => matches!(source, ToolSource::Inline),
125 }
126 }
127}
128
129pub trait ToolHandler: std::fmt::Debug + Send + Sync {
131 fn can_handle(&self, tool: &ToolDefinition) -> bool;
132 fn execute(&self, tool: &ToolDefinition, call: &ToolCall) -> Result<ToolResult, ToolError>;
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::tools::InputSchema;
139
140 #[test]
141 fn test_registry_operations() {
142 let mut registry = ToolRegistry::new();
143 assert!(registry.is_empty());
144
145 registry.register(ToolDefinition::new("test", "Test", InputSchema::empty()));
146 assert_eq!(registry.len(), 1);
147 assert!(registry.contains("test"));
148 }
149}