ricecoder_tools/
provider.rs

1//! Provider trait and registry for tool implementations
2//!
3//! Implements the hybrid provider pattern with MCP → Built-in → Error priority chain.
4//! Supports hot-reload of MCP server availability without restart.
5
6use crate::error::ToolError;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12use tracing::{debug, warn, trace};
13use uuid::Uuid;
14
15/// Trait for tool providers
16#[async_trait]
17pub trait Provider: Send + Sync {
18    /// Execute a tool operation
19    async fn execute(&self, input: &str) -> Result<String, ToolError>;
20}
21
22/// MCP server availability cache entry
23#[derive(Clone)]
24struct AvailabilityEntry {
25    available: bool,
26    checked_at: Instant,
27}
28
29/// Provider selection result with trace information
30#[derive(Debug, Clone)]
31pub struct ProviderSelection {
32    /// Name of the selected provider ("mcp" or "builtin")
33    pub provider_name: String,
34    /// Trace ID for debugging
35    pub trace_id: String,
36    /// Whether this was a fallback selection
37    pub is_fallback: bool,
38}
39
40/// Registry for managing tool providers with fallback chain
41pub struct ProviderRegistry {
42    mcp_providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
43    builtin_providers: Arc<RwLock<HashMap<String, Arc<dyn Provider>>>>,
44    availability_cache: Arc<RwLock<HashMap<String, AvailabilityEntry>>>,
45    cache_ttl: Duration,
46    /// Callback for provider selection events
47    on_provider_selected: Arc<RwLock<Option<Arc<dyn Fn(ProviderSelection) + Send + Sync>>>>,
48}
49
50impl ProviderRegistry {
51    /// Create a new provider registry
52    pub fn new() -> Self {
53        Self {
54            mcp_providers: Arc::new(RwLock::new(HashMap::new())),
55            builtin_providers: Arc::new(RwLock::new(HashMap::new())),
56            availability_cache: Arc::new(RwLock::new(HashMap::new())),
57            cache_ttl: Duration::from_secs(5),
58            on_provider_selected: Arc::new(RwLock::new(None)),
59        }
60    }
61
62    /// Create a new provider registry with custom cache TTL
63    pub fn with_cache_ttl(cache_ttl: Duration) -> Self {
64        Self {
65            mcp_providers: Arc::new(RwLock::new(HashMap::new())),
66            builtin_providers: Arc::new(RwLock::new(HashMap::new())),
67            availability_cache: Arc::new(RwLock::new(HashMap::new())),
68            cache_ttl,
69            on_provider_selected: Arc::new(RwLock::new(None)),
70        }
71    }
72
73    /// Set a callback for provider selection events
74    pub async fn on_provider_selected<F>(&self, callback: F)
75    where
76        F: Fn(ProviderSelection) + Send + Sync + 'static,
77    {
78        *self.on_provider_selected.write().await = Some(Arc::new(callback));
79    }
80
81    /// Register an MCP provider for a tool
82    pub async fn register_mcp_provider(
83        &self,
84        tool_name: impl Into<String>,
85        provider: Arc<dyn Provider>,
86    ) {
87        let tool_name = tool_name.into();
88        debug!("Registering MCP provider for tool: {}", tool_name);
89        self.mcp_providers.write().await.insert(tool_name, provider);
90    }
91
92    /// Register a built-in provider for a tool
93    pub async fn register_builtin_provider(
94        &self,
95        tool_name: impl Into<String>,
96        provider: Arc<dyn Provider>,
97    ) {
98        let tool_name = tool_name.into();
99        debug!("Registering built-in provider for tool: {}", tool_name);
100        self.builtin_providers.write().await.insert(tool_name, provider);
101    }
102
103    /// Check if an MCP provider is available (with cache)
104    async fn is_mcp_available(&self, tool_name: &str) -> bool {
105        let cache = self.availability_cache.read().await;
106        if let Some(entry) = cache.get(tool_name) {
107            if entry.checked_at.elapsed() < self.cache_ttl {
108                trace!("MCP availability cache hit for tool: {}", tool_name);
109                return entry.available;
110            }
111        }
112        drop(cache);
113
114        // Check availability (in real implementation, would query ricecoder-mcp)
115        let available = self.mcp_providers.read().await.contains_key(tool_name);
116
117        // Update cache
118        self.availability_cache.write().await.insert(
119            tool_name.to_string(),
120            AvailabilityEntry {
121                available,
122                checked_at: Instant::now(),
123            },
124        );
125
126        trace!("MCP availability check for tool: {} = {}", tool_name, available);
127        available
128    }
129
130    /// Get a provider for a tool using the priority chain
131    ///
132    /// Returns the provider and selection information including trace ID for debugging.
133    pub async fn get_provider(
134        &self,
135        tool_name: &str,
136    ) -> Result<(Arc<dyn Provider>, ProviderSelection), ToolError> {
137        let trace_id = Uuid::new_v4().to_string();
138
139        // Try MCP first
140        if self.is_mcp_available(tool_name).await {
141            if let Some(provider) = self.mcp_providers.read().await.get(tool_name) {
142                let selection = ProviderSelection {
143                    provider_name: "mcp".to_string(),
144                    trace_id: trace_id.clone(),
145                    is_fallback: false,
146                };
147                debug!(
148                    trace_id = %trace_id,
149                    tool = tool_name,
150                    "Using MCP provider for tool"
151                );
152                self.notify_provider_selected(selection).await;
153                return Ok((provider.clone(), ProviderSelection {
154                    provider_name: "mcp".to_string(),
155                    trace_id,
156                    is_fallback: false,
157                }));
158            }
159        }
160
161        // Fall back to built-in
162        if let Some(provider) = self.builtin_providers.read().await.get(tool_name) {
163            let selection = ProviderSelection {
164                provider_name: "builtin".to_string(),
165                trace_id: trace_id.clone(),
166                is_fallback: true,
167            };
168            debug!(
169                trace_id = %trace_id,
170                tool = tool_name,
171                "Using built-in provider for tool (MCP unavailable)"
172            );
173            self.notify_provider_selected(selection.clone()).await;
174            return Ok((provider.clone(), selection));
175        }
176
177        // No provider available
178        warn!(
179            trace_id = %trace_id,
180            tool = tool_name,
181            "No provider available for tool"
182        );
183        Err(ToolError::new(
184            "PROVIDER_NOT_FOUND",
185            format!("No provider available for tool: {}", tool_name),
186        )
187        .with_details(format!("trace_id: {}", trace_id))
188        .with_suggestion("Ensure the tool is registered or MCP server is available"))
189    }
190
191    /// Get a provider without selection information (simplified API)
192    pub async fn get_provider_simple(&self, tool_name: &str) -> Result<Arc<dyn Provider>, ToolError> {
193        self.get_provider(tool_name)
194            .await
195            .map(|(provider, _)| provider)
196    }
197
198    /// Notify listeners of provider selection
199    async fn notify_provider_selected(&self, selection: ProviderSelection) {
200        if let Some(callback) = self.on_provider_selected.read().await.as_ref() {
201            callback(selection);
202        }
203    }
204
205    /// Invalidate availability cache for a tool
206    pub async fn invalidate_cache(&self, tool_name: &str) {
207        debug!("Invalidating availability cache for tool: {}", tool_name);
208        self.availability_cache.write().await.remove(tool_name);
209    }
210
211    /// Invalidate all availability caches
212    pub async fn invalidate_all_caches(&self) {
213        debug!("Invalidating all availability caches");
214        self.availability_cache.write().await.clear();
215    }
216}
217
218impl Default for ProviderRegistry {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    struct MockProvider {
229        name: String,
230    }
231
232    #[async_trait]
233    impl Provider for MockProvider {
234        async fn execute(&self, _input: &str) -> Result<String, ToolError> {
235            Ok(format!("Mock response from {}", self.name))
236        }
237    }
238
239    #[tokio::test]
240    async fn test_registry_creation() {
241        let registry = ProviderRegistry::new();
242        assert!(registry.mcp_providers.read().await.is_empty());
243        assert!(registry.builtin_providers.read().await.is_empty());
244    }
245
246    #[tokio::test]
247    async fn test_register_builtin_provider() {
248        let registry = ProviderRegistry::new();
249        let provider = Arc::new(MockProvider {
250            name: "test".to_string(),
251        });
252        registry
253            .register_builtin_provider("test_tool", provider.clone())
254            .await;
255
256        let (retrieved, selection) = registry.get_provider("test_tool").await.unwrap();
257        let result = retrieved.execute("test").await.unwrap();
258        assert_eq!(result, "Mock response from test");
259        assert_eq!(selection.provider_name, "builtin");
260        assert!(selection.is_fallback);
261    }
262
263    #[tokio::test]
264    async fn test_mcp_priority_over_builtin() {
265        let registry = ProviderRegistry::new();
266        let builtin = Arc::new(MockProvider {
267            name: "builtin".to_string(),
268        });
269        let mcp = Arc::new(MockProvider {
270            name: "mcp".to_string(),
271        });
272
273        registry
274            .register_builtin_provider("test_tool", builtin)
275            .await;
276        registry.register_mcp_provider("test_tool", mcp).await;
277
278        let (retrieved, selection) = registry.get_provider("test_tool").await.unwrap();
279        let result = retrieved.execute("test").await.unwrap();
280        assert_eq!(result, "Mock response from mcp");
281        assert_eq!(selection.provider_name, "mcp");
282        assert!(!selection.is_fallback);
283    }
284
285    #[tokio::test]
286    async fn test_provider_not_found() {
287        let registry = ProviderRegistry::new();
288        let result = registry.get_provider("nonexistent").await;
289        assert!(result.is_err());
290        if let Err(err) = result {
291            assert_eq!(err.code, "PROVIDER_NOT_FOUND");
292        }
293    }
294
295    #[tokio::test]
296    async fn test_provider_selection_callback() {
297        let registry = ProviderRegistry::new();
298        let provider = Arc::new(MockProvider {
299            name: "test".to_string(),
300        });
301        registry
302            .register_builtin_provider("test_tool", provider)
303            .await;
304
305        let selected = Arc::new(RwLock::new(None));
306        let selected_clone = selected.clone();
307        registry
308            .on_provider_selected(move |selection| {
309                let selected = selected_clone.clone();
310                let _ = std::thread::spawn(move || {
311                    let rt = tokio::runtime::Runtime::new().unwrap();
312                    rt.block_on(async {
313                        *selected.write().await = Some(selection);
314                    });
315                });
316            })
317            .await;
318
319        let _ = registry.get_provider("test_tool").await;
320        // Give callback time to execute
321        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
322    }
323
324    #[tokio::test]
325    async fn test_get_provider_simple() {
326        let registry = ProviderRegistry::new();
327        let provider = Arc::new(MockProvider {
328            name: "test".to_string(),
329        });
330        registry
331            .register_builtin_provider("test_tool", provider)
332            .await;
333
334        let retrieved = registry.get_provider_simple("test_tool").await.unwrap();
335        let result = retrieved.execute("test").await.unwrap();
336        assert_eq!(result, "Mock response from test");
337    }
338
339    #[tokio::test]
340    async fn test_cache_invalidation() {
341        let registry = ProviderRegistry::new();
342        let provider = Arc::new(MockProvider {
343            name: "test".to_string(),
344        });
345        registry
346            .register_builtin_provider("test_tool", provider)
347            .await;
348
349        // Check availability (populates cache)
350        let _ = registry.get_provider("test_tool").await;
351        assert!(!registry.availability_cache.read().await.is_empty());
352
353        // Invalidate cache
354        registry.invalidate_cache("test_tool").await;
355        assert!(registry.availability_cache.read().await.is_empty());
356    }
357}