Skip to main content

converge_provider/tools/
registry.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// SPDX-License-Identifier: MIT
3
4//! Tool registry for unified tool discovery and invocation.
5
6use super::{ToolCall, ToolDefinition, ToolError, ToolResult, ToolSource};
7use std::collections::HashMap;
8
9/// Registry for tool discovery and invocation.
10#[derive(Debug, Default)]
11pub struct ToolRegistry {
12    tools: HashMap<String, ToolDefinition>,
13}
14
15impl Clone for ToolRegistry {
16    fn clone(&self) -> Self {
17        Self { tools: self.tools.clone() }
18    }
19}
20
21impl ToolRegistry {
22    #[must_use]
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn register(&mut self, tool: ToolDefinition) {
28        self.tools.insert(tool.name.clone(), tool);
29    }
30
31    pub fn register_all(&mut self, tools: impl IntoIterator<Item = ToolDefinition>) {
32        for tool in tools {
33            self.register(tool);
34        }
35    }
36
37    #[must_use]
38    pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
39        self.tools.get(name)
40    }
41
42    #[must_use]
43    pub fn contains(&self, name: &str) -> bool {
44        self.tools.contains_key(name)
45    }
46
47    #[must_use]
48    pub fn list_tools(&self) -> Vec<&ToolDefinition> {
49        self.tools.values().collect()
50    }
51
52    #[must_use]
53    pub fn len(&self) -> usize {
54        self.tools.len()
55    }
56
57    #[must_use]
58    pub fn is_empty(&self) -> bool {
59        self.tools.is_empty()
60    }
61
62    #[must_use]
63    pub fn tools_by_source(&self, filter: SourceFilter) -> Vec<&ToolDefinition> {
64        self.tools.values().filter(|t| filter.matches(&t.source)).collect()
65    }
66
67    pub fn call_tool(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
68        let tool = self.get(&call.tool_name)
69            .ok_or_else(|| ToolError::not_found(&call.tool_name))?;
70
71        match &tool.source {
72            ToolSource::Inline => Err(ToolError::unsupported_source("inline")),
73            ToolSource::Mcp { .. } => Err(ToolError::unsupported_source("mcp (use McpClient)")),
74            ToolSource::OpenApi { .. } => Err(ToolError::unsupported_source("openapi")),
75            ToolSource::GraphQl { .. } => Err(ToolError::unsupported_source("graphql")),
76        }
77    }
78
79    #[must_use]
80    pub fn to_llm_tools(&self) -> Vec<serde_json::Value> {
81        self.tools.values().map(|tool| {
82            serde_json::json!({
83                "type": "function",
84                "function": {
85                    "name": tool.name,
86                    "description": tool.description,
87                    "parameters": tool.input_schema.schema
88                }
89            })
90        }).collect()
91    }
92
93    #[must_use]
94    pub fn to_anthropic_tools(&self) -> Vec<serde_json::Value> {
95        self.tools.values().map(|tool| {
96            serde_json::json!({
97                "name": tool.name,
98                "description": tool.description,
99                "input_schema": tool.input_schema.schema
100            })
101        }).collect()
102    }
103}
104
105/// Filter for selecting tools by source type.
106#[derive(Debug, Clone, Copy, Default)]
107pub enum SourceFilter {
108    #[default]
109    All,
110    Mcp,
111    OpenApi,
112    GraphQl,
113    Inline,
114}
115
116impl SourceFilter {
117    #[must_use]
118    pub fn matches(&self, source: &ToolSource) -> bool {
119        match self {
120            Self::All => true,
121            Self::Mcp => matches!(source, ToolSource::Mcp { .. }),
122            Self::OpenApi => matches!(source, ToolSource::OpenApi { .. }),
123            Self::GraphQl => matches!(source, ToolSource::GraphQl { .. }),
124            Self::Inline => matches!(source, ToolSource::Inline),
125        }
126    }
127}
128
129/// Trait for tool execution handlers.
130pub trait ToolHandler: std::fmt::Debug + Send + Sync {
131    fn can_handle(&self, tool: &ToolDefinition) -> bool;
132    fn execute(&self, tool: &ToolDefinition, call: &ToolCall) -> Result<ToolResult, ToolError>;
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::tools::InputSchema;
139
140    #[test]
141    fn test_registry_operations() {
142        let mut registry = ToolRegistry::new();
143        assert!(registry.is_empty());
144
145        registry.register(ToolDefinition::new("test", "Test", InputSchema::empty()));
146        assert_eq!(registry.len(), 1);
147        assert!(registry.contains("test"));
148    }
149}