Skip to main content

ai_agents_tools/
provider.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::sync::Arc;
5
6use super::types::{ToolAliases, ToolMetadata, ToolProviderType};
7use super::{Tool, ToolResult};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ToolDescriptor {
11    pub id: String,
12    pub name: String,
13    pub description: String,
14    pub input_schema: Value,
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    pub aliases: Option<ToolAliases>,
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub metadata: Option<ToolMetadata>,
19}
20
21impl ToolDescriptor {
22    pub fn new(
23        id: impl Into<String>,
24        name: impl Into<String>,
25        description: impl Into<String>,
26        input_schema: Value,
27    ) -> Self {
28        Self {
29            id: id.into(),
30            name: name.into(),
31            description: description.into(),
32            input_schema,
33            aliases: None,
34            metadata: None,
35        }
36    }
37
38    pub fn with_aliases(mut self, aliases: ToolAliases) -> Self {
39        self.aliases = Some(aliases);
40        self
41    }
42
43    pub fn with_metadata(mut self, metadata: ToolMetadata) -> Self {
44        self.metadata = Some(metadata);
45        self
46    }
47
48    pub fn get_name(&self, lang: Option<&str>) -> &str {
49        if let Some(lang) = lang {
50            if let Some(ref aliases) = self.aliases {
51                if let Some(name) = aliases.get_name(lang) {
52                    return name;
53                }
54            }
55        }
56        &self.name
57    }
58
59    pub fn get_description(&self, lang: Option<&str>) -> &str {
60        if let Some(lang) = lang {
61            if let Some(ref aliases) = self.aliases {
62                if let Some(desc) = aliases.get_description(lang) {
63                    return desc;
64                }
65            }
66        }
67        &self.description
68    }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72#[serde(tag = "status", rename_all = "snake_case")]
73pub enum ProviderHealth {
74    #[default]
75    Healthy,
76    Degraded {
77        message: String,
78    },
79    Unavailable {
80        message: String,
81    },
82}
83
84impl ProviderHealth {
85    pub fn is_healthy(&self) -> bool {
86        matches!(self, ProviderHealth::Healthy)
87    }
88
89    pub fn is_available(&self) -> bool {
90        !matches!(self, ProviderHealth::Unavailable { .. })
91    }
92
93    pub fn degraded(message: impl Into<String>) -> Self {
94        ProviderHealth::Degraded {
95            message: message.into(),
96        }
97    }
98
99    pub fn unavailable(message: impl Into<String>) -> Self {
100        ProviderHealth::Unavailable {
101            message: message.into(),
102        }
103    }
104}
105
106#[derive(Debug, thiserror::Error)]
107pub enum ToolProviderError {
108    #[error("Tool not found: {0}")]
109    ToolNotFound(String),
110
111    #[error("Execution failed: {0}")]
112    ExecutionFailed(String),
113
114    #[error("Provider unavailable: {0}")]
115    Unavailable(String),
116
117    #[error("Connection error: {0}")]
118    ConnectionError(String),
119
120    #[error("Configuration error: {0}")]
121    ConfigError(String),
122
123    #[error("Timeout after {0}ms")]
124    Timeout(u64),
125
126    #[error("{0}")]
127    Other(String),
128}
129
130#[async_trait]
131pub trait ToolProvider: Send + Sync {
132    fn id(&self) -> &str;
133
134    fn name(&self) -> &str;
135
136    fn provider_type(&self) -> ToolProviderType;
137
138    async fn list_tools(&self) -> Vec<ToolDescriptor>;
139
140    async fn get_tool(&self, tool_id: &str) -> Option<Arc<dyn Tool>>;
141
142    async fn execute(&self, tool_id: &str, args: Value) -> Result<ToolResult, ToolProviderError> {
143        if let Some(tool) = self.get_tool(tool_id).await {
144            Ok(tool.execute(args).await)
145        } else {
146            Err(ToolProviderError::ToolNotFound(tool_id.to_string()))
147        }
148    }
149
150    fn supports_refresh(&self) -> bool {
151        false
152    }
153
154    async fn refresh(&self) -> Result<(), ToolProviderError> {
155        Ok(())
156    }
157
158    async fn health_check(&self) -> ProviderHealth {
159        ProviderHealth::Healthy
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_tool_descriptor() {
169        let desc = ToolDescriptor::new(
170            "search",
171            "Web Search",
172            "Search the web",
173            serde_json::json!({"type": "object"}),
174        );
175
176        assert_eq!(desc.id, "search");
177        assert_eq!(desc.get_name(None), "Web Search");
178        assert_eq!(desc.get_description(None), "Search the web");
179    }
180
181    #[test]
182    fn test_tool_descriptor_with_aliases() {
183        let aliases = ToolAliases::new()
184            .with_name("ko", "검색")
185            .with_description("ko", "웹 검색");
186
187        let desc = ToolDescriptor::new(
188            "search",
189            "Web Search",
190            "Search the web",
191            serde_json::json!({}),
192        )
193        .with_aliases(aliases);
194
195        assert_eq!(desc.get_name(Some("ko")), "검색");
196        assert_eq!(desc.get_name(Some("en")), "Web Search");
197        assert_eq!(desc.get_description(Some("ko")), "웹 검색");
198    }
199
200    #[test]
201    fn test_provider_health() {
202        let healthy = ProviderHealth::Healthy;
203        assert!(healthy.is_healthy());
204        assert!(healthy.is_available());
205
206        let degraded = ProviderHealth::degraded("Some tools failing");
207        assert!(!degraded.is_healthy());
208        assert!(degraded.is_available());
209
210        let unavailable = ProviderHealth::unavailable("Connection lost");
211        assert!(!unavailable.is_healthy());
212        assert!(!unavailable.is_available());
213    }
214}