1use crate::error::ToolError;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12use tracing::{debug, warn, trace};
13use uuid::Uuid;
14
15#[async_trait]
17pub trait Provider: Send + Sync {
18 async fn execute(&self, input: &str) -> Result<String, ToolError>;
20}
21
22#[derive(Clone)]
24struct AvailabilityEntry {
25 available: bool,
26 checked_at: Instant,
27}
28
29#[derive(Debug, Clone)]
31pub struct ProviderSelection {
32 pub provider_name: String,
34 pub trace_id: String,
36 pub is_fallback: bool,
38}
39
40pub struct ProviderRegistry {
42 mcp_providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
43 builtin_providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
44 availability_cache: Arc<RwLock<HashMap<String, AvailabilityEntry>>>,
45 cache_ttl: Duration,
46 on_provider_selected: Arc<RwLock<Option<Arc<dyn Fn(ProviderSelection) + Send + Sync>>>>,
48}
49
50impl ProviderRegistry {
51 pub fn new() -> Self {
53 Self {
54 mcp_providers: Arc::new(RwLock::new(HashMap::new())),
55 builtin_providers: Arc::new(RwLock::new(HashMap::new())),
56 availability_cache: Arc::new(RwLock::new(HashMap::new())),
57 cache_ttl: Duration::from_secs(5),
58 on_provider_selected: Arc::new(RwLock::new(None)),
59 }
60 }
61
62 pub fn with_cache_ttl(cache_ttl: Duration) -> Self {
64 Self {
65 mcp_providers: Arc::new(RwLock::new(HashMap::new())),
66 builtin_providers: Arc::new(RwLock::new(HashMap::new())),
67 availability_cache: Arc::new(RwLock::new(HashMap::new())),
68 cache_ttl,
69 on_provider_selected: Arc::new(RwLock::new(None)),
70 }
71 }
72
73 pub async fn on_provider_selected<F>(&self, callback: F)
75 where
76 F: Fn(ProviderSelection) + Send + Sync + 'static,
77 {
78 *self.on_provider_selected.write().await = Some(Arc::new(callback));
79 }
80
81 pub async fn register_mcp_provider(
83 &self,
84 tool_name: impl Into<String>,
85 provider: Arc<dyn Provider>,
86 ) {
87 let tool_name = tool_name.into();
88 debug!("Registering MCP provider for tool: {}", tool_name);
89 self.mcp_providers.write().await.insert(tool_name, provider);
90 }
91
92 pub async fn register_builtin_provider(
94 &self,
95 tool_name: impl Into<String>,
96 provider: Arc<dyn Provider>,
97 ) {
98 let tool_name = tool_name.into();
99 debug!("Registering built-in provider for tool: {}", tool_name);
100 self.builtin_providers.write().await.insert(tool_name, provider);
101 }
102
103 async fn is_mcp_available(&self, tool_name: &str) -> bool {
105 let cache = self.availability_cache.read().await;
106 if let Some(entry) = cache.get(tool_name) {
107 if entry.checked_at.elapsed() < self.cache_ttl {
108 trace!("MCP availability cache hit for tool: {}", tool_name);
109 return entry.available;
110 }
111 }
112 drop(cache);
113
114 let available = self.mcp_providers.read().await.contains_key(tool_name);
116
117 self.availability_cache.write().await.insert(
119 tool_name.to_string(),
120 AvailabilityEntry {
121 available,
122 checked_at: Instant::now(),
123 },
124 );
125
126 trace!("MCP availability check for tool: {} = {}", tool_name, available);
127 available
128 }
129
130 pub async fn get_provider(
134 &self,
135 tool_name: &str,
136 ) -> Result<(Arc<dyn Provider>, ProviderSelection), ToolError> {
137 let trace_id = Uuid::new_v4().to_string();
138
139 if self.is_mcp_available(tool_name).await {
141 if let Some(provider) = self.mcp_providers.read().await.get(tool_name) {
142 let selection = ProviderSelection {
143 provider_name: "mcp".to_string(),
144 trace_id: trace_id.clone(),
145 is_fallback: false,
146 };
147 debug!(
148 trace_id = %trace_id,
149 tool = tool_name,
150 "Using MCP provider for tool"
151 );
152 self.notify_provider_selected(selection).await;
153 return Ok((provider.clone(), ProviderSelection {
154 provider_name: "mcp".to_string(),
155 trace_id,
156 is_fallback: false,
157 }));
158 }
159 }
160
161 if let Some(provider) = self.builtin_providers.read().await.get(tool_name) {
163 let selection = ProviderSelection {
164 provider_name: "builtin".to_string(),
165 trace_id: trace_id.clone(),
166 is_fallback: true,
167 };
168 debug!(
169 trace_id = %trace_id,
170 tool = tool_name,
171 "Using built-in provider for tool (MCP unavailable)"
172 );
173 self.notify_provider_selected(selection.clone()).await;
174 return Ok((provider.clone(), selection));
175 }
176
177 warn!(
179 trace_id = %trace_id,
180 tool = tool_name,
181 "No provider available for tool"
182 );
183 Err(ToolError::new(
184 "PROVIDER_NOT_FOUND",
185 format!("No provider available for tool: {}", tool_name),
186 )
187 .with_details(format!("trace_id: {}", trace_id))
188 .with_suggestion("Ensure the tool is registered or MCP server is available"))
189 }
190
191 pub async fn get_provider_simple(&self, tool_name: &str) -> Result<Arc<dyn Provider>, ToolError> {
193 self.get_provider(tool_name)
194 .await
195 .map(|(provider, _)| provider)
196 }
197
198 async fn notify_provider_selected(&self, selection: ProviderSelection) {
200 if let Some(callback) = self.on_provider_selected.read().await.as_ref() {
201 callback(selection);
202 }
203 }
204
205 pub async fn invalidate_cache(&self, tool_name: &str) {
207 debug!("Invalidating availability cache for tool: {}", tool_name);
208 self.availability_cache.write().await.remove(tool_name);
209 }
210
211 pub async fn invalidate_all_caches(&self) {
213 debug!("Invalidating all availability caches");
214 self.availability_cache.write().await.clear();
215 }
216}
217
218impl Default for ProviderRegistry {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 struct MockProvider {
229 name: String,
230 }
231
232 #[async_trait]
233 impl Provider for MockProvider {
234 async fn execute(&self, _input: &str) -> Result<String, ToolError> {
235 Ok(format!("Mock response from {}", self.name))
236 }
237 }
238
239 #[tokio::test]
240 async fn test_registry_creation() {
241 let registry = ProviderRegistry::new();
242 assert!(registry.mcp_providers.read().await.is_empty());
243 assert!(registry.builtin_providers.read().await.is_empty());
244 }
245
246 #[tokio::test]
247 async fn test_register_builtin_provider() {
248 let registry = ProviderRegistry::new();
249 let provider = Arc::new(MockProvider {
250 name: "test".to_string(),
251 });
252 registry
253 .register_builtin_provider("test_tool", provider.clone())
254 .await;
255
256 let (retrieved, selection) = registry.get_provider("test_tool").await.unwrap();
257 let result = retrieved.execute("test").await.unwrap();
258 assert_eq!(result, "Mock response from test");
259 assert_eq!(selection.provider_name, "builtin");
260 assert!(selection.is_fallback);
261 }
262
263 #[tokio::test]
264 async fn test_mcp_priority_over_builtin() {
265 let registry = ProviderRegistry::new();
266 let builtin = Arc::new(MockProvider {
267 name: "builtin".to_string(),
268 });
269 let mcp = Arc::new(MockProvider {
270 name: "mcp".to_string(),
271 });
272
273 registry
274 .register_builtin_provider("test_tool", builtin)
275 .await;
276 registry.register_mcp_provider("test_tool", mcp).await;
277
278 let (retrieved, selection) = registry.get_provider("test_tool").await.unwrap();
279 let result = retrieved.execute("test").await.unwrap();
280 assert_eq!(result, "Mock response from mcp");
281 assert_eq!(selection.provider_name, "mcp");
282 assert!(!selection.is_fallback);
283 }
284
285 #[tokio::test]
286 async fn test_provider_not_found() {
287 let registry = ProviderRegistry::new();
288 let result = registry.get_provider("nonexistent").await;
289 assert!(result.is_err());
290 if let Err(err) = result {
291 assert_eq!(err.code, "PROVIDER_NOT_FOUND");
292 }
293 }
294
295 #[tokio::test]
296 async fn test_provider_selection_callback() {
297 let registry = ProviderRegistry::new();
298 let provider = Arc::new(MockProvider {
299 name: "test".to_string(),
300 });
301 registry
302 .register_builtin_provider("test_tool", provider)
303 .await;
304
305 let selected = Arc::new(RwLock::new(None));
306 let selected_clone = selected.clone();
307 registry
308 .on_provider_selected(move |selection| {
309 let selected = selected_clone.clone();
310 let _ = std::thread::spawn(move || {
311 let rt = tokio::runtime::Runtime::new().unwrap();
312 rt.block_on(async {
313 *selected.write().await = Some(selection);
314 });
315 });
316 })
317 .await;
318
319 let _ = registry.get_provider("test_tool").await;
320 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
322 }
323
324 #[tokio::test]
325 async fn test_get_provider_simple() {
326 let registry = ProviderRegistry::new();
327 let provider = Arc::new(MockProvider {
328 name: "test".to_string(),
329 });
330 registry
331 .register_builtin_provider("test_tool", provider)
332 .await;
333
334 let retrieved = registry.get_provider_simple("test_tool").await.unwrap();
335 let result = retrieved.execute("test").await.unwrap();
336 assert_eq!(result, "Mock response from test");
337 }
338
339 #[tokio::test]
340 async fn test_cache_invalidation() {
341 let registry = ProviderRegistry::new();
342 let provider = Arc::new(MockProvider {
343 name: "test".to_string(),
344 });
345 registry
346 .register_builtin_provider("test_tool", provider)
347 .await;
348
349 let _ = registry.get_provider("test_tool").await;
351 assert!(!registry.availability_cache.read().await.is_empty());
352
353 registry.invalidate_cache("test_tool").await;
355 assert!(registry.availability_cache.read().await.is_empty());
356 }
357}