ai_agents_tools/
provider.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::sync::Arc;
5
6use super::types::{ToolAliases, ToolMetadata, ToolProviderType};
7use super::{Tool, ToolResult};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ToolDescriptor {
11 pub id: String,
12 pub name: String,
13 pub description: String,
14 pub input_schema: Value,
15 #[serde(default, skip_serializing_if = "Option::is_none")]
16 pub aliases: Option<ToolAliases>,
17 #[serde(default, skip_serializing_if = "Option::is_none")]
18 pub metadata: Option<ToolMetadata>,
19}
20
21impl ToolDescriptor {
22 pub fn new(
23 id: impl Into<String>,
24 name: impl Into<String>,
25 description: impl Into<String>,
26 input_schema: Value,
27 ) -> Self {
28 Self {
29 id: id.into(),
30 name: name.into(),
31 description: description.into(),
32 input_schema,
33 aliases: None,
34 metadata: None,
35 }
36 }
37
38 pub fn with_aliases(mut self, aliases: ToolAliases) -> Self {
39 self.aliases = Some(aliases);
40 self
41 }
42
43 pub fn with_metadata(mut self, metadata: ToolMetadata) -> Self {
44 self.metadata = Some(metadata);
45 self
46 }
47
48 pub fn get_name(&self, lang: Option<&str>) -> &str {
49 if let Some(lang) = lang {
50 if let Some(ref aliases) = self.aliases {
51 if let Some(name) = aliases.get_name(lang) {
52 return name;
53 }
54 }
55 }
56 &self.name
57 }
58
59 pub fn get_description(&self, lang: Option<&str>) -> &str {
60 if let Some(lang) = lang {
61 if let Some(ref aliases) = self.aliases {
62 if let Some(desc) = aliases.get_description(lang) {
63 return desc;
64 }
65 }
66 }
67 &self.description
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72#[serde(tag = "status", rename_all = "snake_case")]
73pub enum ProviderHealth {
74 #[default]
75 Healthy,
76 Degraded {
77 message: String,
78 },
79 Unavailable {
80 message: String,
81 },
82}
83
84impl ProviderHealth {
85 pub fn is_healthy(&self) -> bool {
86 matches!(self, ProviderHealth::Healthy)
87 }
88
89 pub fn is_available(&self) -> bool {
90 !matches!(self, ProviderHealth::Unavailable { .. })
91 }
92
93 pub fn degraded(message: impl Into<String>) -> Self {
94 ProviderHealth::Degraded {
95 message: message.into(),
96 }
97 }
98
99 pub fn unavailable(message: impl Into<String>) -> Self {
100 ProviderHealth::Unavailable {
101 message: message.into(),
102 }
103 }
104}
105
106#[derive(Debug, thiserror::Error)]
107pub enum ToolProviderError {
108 #[error("Tool not found: {0}")]
109 ToolNotFound(String),
110
111 #[error("Execution failed: {0}")]
112 ExecutionFailed(String),
113
114 #[error("Provider unavailable: {0}")]
115 Unavailable(String),
116
117 #[error("Connection error: {0}")]
118 ConnectionError(String),
119
120 #[error("Configuration error: {0}")]
121 ConfigError(String),
122
123 #[error("Timeout after {0}ms")]
124 Timeout(u64),
125
126 #[error("{0}")]
127 Other(String),
128}
129
130#[async_trait]
131pub trait ToolProvider: Send + Sync {
132 fn id(&self) -> &str;
133
134 fn name(&self) -> &str;
135
136 fn provider_type(&self) -> ToolProviderType;
137
138 async fn list_tools(&self) -> Vec<ToolDescriptor>;
139
140 async fn get_tool(&self, tool_id: &str) -> Option<Arc<dyn Tool>>;
141
142 async fn execute(&self, tool_id: &str, args: Value) -> Result<ToolResult, ToolProviderError> {
143 if let Some(tool) = self.get_tool(tool_id).await {
144 Ok(tool.execute(args).await)
145 } else {
146 Err(ToolProviderError::ToolNotFound(tool_id.to_string()))
147 }
148 }
149
150 fn supports_refresh(&self) -> bool {
151 false
152 }
153
154 async fn refresh(&self) -> Result<(), ToolProviderError> {
155 Ok(())
156 }
157
158 async fn health_check(&self) -> ProviderHealth {
159 ProviderHealth::Healthy
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_tool_descriptor() {
169 let desc = ToolDescriptor::new(
170 "search",
171 "Web Search",
172 "Search the web",
173 serde_json::json!({"type": "object"}),
174 );
175
176 assert_eq!(desc.id, "search");
177 assert_eq!(desc.get_name(None), "Web Search");
178 assert_eq!(desc.get_description(None), "Search the web");
179 }
180
181 #[test]
182 fn test_tool_descriptor_with_aliases() {
183 let aliases = ToolAliases::new()
184 .with_name("ko", "검색")
185 .with_description("ko", "웹 검색");
186
187 let desc = ToolDescriptor::new(
188 "search",
189 "Web Search",
190 "Search the web",
191 serde_json::json!({}),
192 )
193 .with_aliases(aliases);
194
195 assert_eq!(desc.get_name(Some("ko")), "검색");
196 assert_eq!(desc.get_name(Some("en")), "Web Search");
197 assert_eq!(desc.get_description(Some("ko")), "웹 검색");
198 }
199
200 #[test]
201 fn test_provider_health() {
202 let healthy = ProviderHealth::Healthy;
203 assert!(healthy.is_healthy());
204 assert!(healthy.is_available());
205
206 let degraded = ProviderHealth::degraded("Some tools failing");
207 assert!(!degraded.is_healthy());
208 assert!(degraded.is_available());
209
210 let unavailable = ProviderHealth::unavailable("Connection lost");
211 assert!(!unavailable.is_healthy());
212 assert!(!unavailable.is_available());
213 }
214}