Skip to main content

converge_provider/tools/
registry.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Tool registry for unified tool discovery and invocation.
5
6use super::{ToolCall, ToolDefinition, ToolError, ToolResult, ToolSource};
7use std::collections::HashMap;
8
9/// Registry for tool discovery and invocation.
10#[derive(Debug, Default)]
11pub struct ToolRegistry {
12    tools: HashMap<String, ToolDefinition>,
13}
14
15impl Clone for ToolRegistry {
16    fn clone(&self) -> Self {
17        Self {
18            tools: self.tools.clone(),
19        }
20    }
21}
22
23impl ToolRegistry {
24    #[must_use]
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    pub fn register(&mut self, tool: ToolDefinition) {
30        self.tools.insert(tool.name.clone(), tool);
31    }
32
33    pub fn register_all(&mut self, tools: impl IntoIterator<Item = ToolDefinition>) {
34        for tool in tools {
35            self.register(tool);
36        }
37    }
38
39    #[must_use]
40    pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
41        self.tools.get(name)
42    }
43
44    #[must_use]
45    pub fn contains(&self, name: &str) -> bool {
46        self.tools.contains_key(name)
47    }
48
49    #[must_use]
50    pub fn list_tools(&self) -> Vec<&ToolDefinition> {
51        self.tools.values().collect()
52    }
53
54    #[must_use]
55    pub fn len(&self) -> usize {
56        self.tools.len()
57    }
58
59    #[must_use]
60    pub fn is_empty(&self) -> bool {
61        self.tools.is_empty()
62    }
63
64    #[must_use]
65    pub fn tools_by_source(&self, filter: SourceFilter) -> Vec<&ToolDefinition> {
66        self.tools
67            .values()
68            .filter(|t| filter.matches(&t.source))
69            .collect()
70    }
71
72    pub fn call_tool(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
73        let tool = self
74            .get(&call.tool_name)
75            .ok_or_else(|| ToolError::not_found(&call.tool_name))?;
76
77        match &tool.source {
78            ToolSource::Inline => Err(ToolError::unsupported_source("inline")),
79            ToolSource::Mcp { .. } => Err(ToolError::unsupported_source("mcp (use McpClient)")),
80            ToolSource::OpenApi { .. } => Err(ToolError::unsupported_source("openapi")),
81            ToolSource::GraphQl { .. } => Err(ToolError::unsupported_source("graphql")),
82        }
83    }
84
85    #[must_use]
86    pub fn to_llm_tools(&self) -> Vec<serde_json::Value> {
87        self.tools
88            .values()
89            .map(|tool| {
90                serde_json::json!({
91                    "type": "function",
92                    "function": {
93                        "name": tool.name,
94                        "description": tool.description,
95                        "parameters": tool.input_schema.schema
96                    }
97                })
98            })
99            .collect()
100    }
101
102    #[must_use]
103    pub fn to_anthropic_tools(&self) -> Vec<serde_json::Value> {
104        self.tools
105            .values()
106            .map(|tool| {
107                serde_json::json!({
108                    "name": tool.name,
109                    "description": tool.description,
110                    "input_schema": tool.input_schema.schema
111                })
112            })
113            .collect()
114    }
115}
116
117/// Filter for selecting tools by source type.
118#[derive(Debug, Clone, Copy, Default)]
119pub enum SourceFilter {
120    #[default]
121    All,
122    Mcp,
123    OpenApi,
124    GraphQl,
125    Inline,
126}
127
128impl SourceFilter {
129    #[must_use]
130    pub fn matches(&self, source: &ToolSource) -> bool {
131        match self {
132            Self::All => true,
133            Self::Mcp => matches!(source, ToolSource::Mcp { .. }),
134            Self::OpenApi => matches!(source, ToolSource::OpenApi { .. }),
135            Self::GraphQl => matches!(source, ToolSource::GraphQl { .. }),
136            Self::Inline => matches!(source, ToolSource::Inline),
137        }
138    }
139}
140
141/// Trait for tool execution handlers.
142pub trait ToolHandler: std::fmt::Debug + Send + Sync {
143    fn can_handle(&self, tool: &ToolDefinition) -> bool;
144    fn execute(&self, tool: &ToolDefinition, call: &ToolCall) -> Result<ToolResult, ToolError>;
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::tools::InputSchema;
151
152    #[test]
153    fn test_registry_operations() {
154        let mut registry = ToolRegistry::new();
155        assert!(registry.is_empty());
156
157        registry.register(ToolDefinition::new("test", "Test", InputSchema::empty()));
158        assert_eq!(registry.len(), 1);
159        assert!(registry.contains("test"));
160    }
161}