ricecoder_ide/
provider_chain.rs

1//! Provider chain implementation for IDE features
2//!
3//! This module implements the LSP-first provider priority chain:
4//! 1. External LSP Servers (rust-analyzer, typescript-language-server, pylsp, etc.)
5//! 2. Configured IDE Rules (YAML/JSON configuration)
6//! 3. Built-in Language Providers (Rust, TypeScript, Python)
7//! 4. Generic Text-based Features (fallback for any language)
8
9use crate::error::IdeResult;
10use crate::provider::{IdeProvider, ProviderChange};
11use crate::types::*;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tracing::{debug, info, warn};
15
16/// Type alias for provider availability change callback
17type ProviderAvailabilityCallback = Box<dyn Fn(ProviderChange) + Send + Sync>;
18
19/// Provider registry for managing multiple providers
20pub struct ProviderRegistry {
21    /// External LSP providers by language
22    lsp_providers: HashMap<String, Arc<dyn IdeProvider>>,
23    /// Configured rules providers by language
24    configured_providers: HashMap<String, Arc<dyn IdeProvider>>,
25    /// Built-in language providers by language
26    builtin_providers: HashMap<String, Arc<dyn IdeProvider>>,
27    /// Generic fallback provider
28    generic_provider: Arc<dyn IdeProvider>,
29}
30
31impl ProviderRegistry {
32    /// Create a new provider registry
33    pub fn new(generic_provider: Arc<dyn IdeProvider>) -> Self {
34        ProviderRegistry {
35            lsp_providers: HashMap::new(),
36            configured_providers: HashMap::new(),
37            builtin_providers: HashMap::new(),
38            generic_provider,
39        }
40    }
41
42    /// Register an external LSP provider for a language
43    pub fn register_lsp_provider(&mut self, language: String, provider: Arc<dyn IdeProvider>) {
44        debug!("Registering LSP provider for language: {}", language);
45        self.lsp_providers.insert(language, provider);
46    }
47
48    /// Register a configured rules provider for a language
49    pub fn register_configured_provider(
50        &mut self,
51        language: String,
52        provider: Arc<dyn IdeProvider>,
53    ) {
54        debug!("Registering configured rules provider for language: {}", language);
55        self.configured_providers.insert(language, provider);
56    }
57
58    /// Register a built-in provider for a language
59    pub fn register_builtin_provider(&mut self, language: String, provider: Arc<dyn IdeProvider>) {
60        debug!("Registering built-in provider for language: {}", language);
61        self.builtin_providers.insert(language, provider);
62    }
63
64    /// Get a provider for a language following the priority chain
65    pub fn get_provider(&self, language: &str) -> Arc<dyn IdeProvider> {
66        // Priority 1: External LSP
67        if let Some(provider) = self.lsp_providers.get(language) {
68            debug!("Using LSP provider for language: {}", language);
69            return provider.clone();
70        }
71
72        // Priority 2: Configured rules
73        if let Some(provider) = self.configured_providers.get(language) {
74            debug!("Using configured rules provider for language: {}", language);
75            return provider.clone();
76        }
77
78        // Priority 3: Built-in
79        if let Some(provider) = self.builtin_providers.get(language) {
80            debug!("Using built-in provider for language: {}", language);
81            return provider.clone();
82        }
83
84        // Priority 4: Generic fallback
85        debug!("Using generic fallback provider for language: {}", language);
86        self.generic_provider.clone()
87    }
88
89    /// Check if a provider is available for a language
90    pub fn is_provider_available(&self, language: &str) -> bool {
91        self.lsp_providers.contains_key(language)
92            || self.configured_providers.contains_key(language)
93            || self.builtin_providers.contains_key(language)
94    }
95
96    /// Get all available languages
97    pub fn available_languages(&self) -> Vec<String> {
98        let mut languages = Vec::new();
99        languages.extend(self.lsp_providers.keys().cloned());
100        languages.extend(self.configured_providers.keys().cloned());
101        languages.extend(self.builtin_providers.keys().cloned());
102        languages.sort();
103        languages.dedup();
104        languages
105    }
106
107    /// Unregister an LSP provider for a language
108    pub fn unregister_lsp_provider(&mut self, language: &str) {
109        debug!("Unregistering LSP provider for language: {}", language);
110        self.lsp_providers.remove(language);
111    }
112
113    /// Unregister a configured rules provider for a language
114    pub fn unregister_configured_provider(&mut self, language: &str) {
115        debug!("Unregistering configured rules provider for language: {}", language);
116        self.configured_providers.remove(language);
117    }
118
119    /// Unregister a built-in provider for a language
120    pub fn unregister_builtin_provider(&mut self, language: &str) {
121        debug!("Unregistering built-in provider for language: {}", language);
122        self.builtin_providers.remove(language);
123    }
124}
125
126/// Provider chain manager that orchestrates the provider priority chain
127pub struct ProviderChainManager {
128    registry: Arc<tokio::sync::RwLock<ProviderRegistry>>,
129    availability_callbacks: Arc<tokio::sync::RwLock<Vec<ProviderAvailabilityCallback>>>,
130}
131
132impl ProviderChainManager {
133    /// Create a new provider chain manager
134    pub fn new(registry: ProviderRegistry) -> Self {
135        ProviderChainManager {
136            registry: Arc::new(tokio::sync::RwLock::new(registry)),
137            availability_callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
138        }
139    }
140
141    /// Get completions through the provider chain
142    pub async fn get_completions(&self, params: &CompletionParams) -> IdeResult<Vec<CompletionItem>> {
143        debug!(
144            "Getting completions for language: {} through provider chain",
145            params.language
146        );
147
148        let registry = self.registry.read().await;
149        let provider = registry.get_provider(&params.language);
150
151        match provider.get_completions(params).await {
152            Ok(completions) => {
153                info!(
154                    "Successfully got {} completions for language: {}",
155                    completions.len(),
156                    params.language
157                );
158                Ok(completions)
159            }
160            Err(e) => {
161                warn!(
162                    "Failed to get completions for language: {}: {}",
163                    params.language, e
164                );
165                Err(e)
166            }
167        }
168    }
169
170    /// Get diagnostics through the provider chain
171    pub async fn get_diagnostics(&self, params: &DiagnosticsParams) -> IdeResult<Vec<Diagnostic>> {
172        debug!(
173            "Getting diagnostics for language: {} through provider chain",
174            params.language
175        );
176
177        let registry = self.registry.read().await;
178        let provider = registry.get_provider(&params.language);
179
180        match provider.get_diagnostics(params).await {
181            Ok(diagnostics) => {
182                info!(
183                    "Successfully got {} diagnostics for language: {}",
184                    diagnostics.len(),
185                    params.language
186                );
187                Ok(diagnostics)
188            }
189            Err(e) => {
190                warn!(
191                    "Failed to get diagnostics for language: {}: {}",
192                    params.language, e
193                );
194                Err(e)
195            }
196        }
197    }
198
199    /// Get hover information through the provider chain
200    pub async fn get_hover(&self, params: &HoverParams) -> IdeResult<Option<Hover>> {
201        debug!(
202            "Getting hover information for language: {} through provider chain",
203            params.language
204        );
205
206        let registry = self.registry.read().await;
207        let provider = registry.get_provider(&params.language);
208
209        match provider.get_hover(params).await {
210            Ok(hover) => {
211                if hover.is_some() {
212                    info!("Successfully got hover information for language: {}", params.language);
213                }
214                Ok(hover)
215            }
216            Err(e) => {
217                warn!(
218                    "Failed to get hover information for language: {}: {}",
219                    params.language, e
220                );
221                Err(e)
222            }
223        }
224    }
225
226    /// Get definition location through the provider chain
227    pub async fn get_definition(&self, params: &DefinitionParams) -> IdeResult<Option<Location>> {
228        debug!(
229            "Getting definition for language: {} through provider chain",
230            params.language
231        );
232
233        let registry = self.registry.read().await;
234        let provider = registry.get_provider(&params.language);
235
236        match provider.get_definition(params).await {
237            Ok(location) => {
238                if location.is_some() {
239                    info!("Successfully got definition for language: {}", params.language);
240                }
241                Ok(location)
242            }
243            Err(e) => {
244                warn!(
245                    "Failed to get definition for language: {}: {}",
246                    params.language, e
247                );
248                Err(e)
249            }
250        }
251    }
252
253    /// Register a provider availability change callback
254    pub async fn on_provider_availability_changed(
255        &self,
256        callback: Box<dyn Fn(ProviderChange) + Send + Sync>,
257    ) {
258        let mut callbacks = self.availability_callbacks.write().await;
259        callbacks.push(callback);
260    }
261
262    /// Notify all callbacks of a provider availability change
263    pub async fn notify_provider_change(&self, change: ProviderChange) {
264        let callbacks = self.availability_callbacks.read().await;
265        for callback in callbacks.iter() {
266            callback(change.clone());
267        }
268    }
269
270    /// Reload configuration without restart
271    pub async fn reload_configuration(&self) -> IdeResult<()> {
272        debug!("Reloading provider chain configuration");
273        // This will be implemented when configuration hot-reload is added
274        info!("Provider chain configuration reloaded");
275        Ok(())
276    }
277
278    /// Update configuration and refresh providers
279    pub async fn update_config(&self, config: IdeIntegrationConfig) -> IdeResult<()> {
280        debug!("Updating provider chain configuration");
281
282        // Update external LSP servers if configuration changed
283        if config.providers.external_lsp.enabled {
284            for language in config.providers.external_lsp.servers.keys() {
285                // Providers will be re-registered based on new configuration
286                debug!("Updated LSP configuration for language: {}", language);
287            }
288        }
289
290        info!("Provider chain configuration updated");
291        Ok(())
292    }
293
294    /// Get the provider registry
295    pub async fn registry(&self) -> tokio::sync::RwLockReadGuard<'_, ProviderRegistry> {
296        self.registry.read().await
297    }
298
299    /// Get mutable access to the provider registry
300    pub async fn registry_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, ProviderRegistry> {
301        self.registry.write().await
302    }
303
304    /// Check if a provider is available for a language
305    pub async fn is_provider_available(&self, language: &str) -> bool {
306        let registry = self.registry.read().await;
307        registry.is_provider_available(language)
308    }
309
310    /// Get all available languages
311    pub async fn available_languages(&self) -> Vec<String> {
312        let registry = self.registry.read().await;
313        registry.available_languages()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use async_trait::async_trait;
321
322    /// Mock provider for testing
323    struct MockProvider {
324        name: String,
325        language: String,
326    }
327
328    #[async_trait]
329    impl IdeProvider for MockProvider {
330        async fn get_completions(&self, _params: &CompletionParams) -> IdeResult<Vec<CompletionItem>> {
331            Ok(vec![CompletionItem {
332                label: "test".to_string(),
333                kind: CompletionItemKind::Function,
334                detail: None,
335                documentation: None,
336                insert_text: "test()".to_string(),
337            }])
338        }
339
340        async fn get_diagnostics(&self, _params: &DiagnosticsParams) -> IdeResult<Vec<Diagnostic>> {
341            Ok(vec![])
342        }
343
344        async fn get_hover(&self, _params: &HoverParams) -> IdeResult<Option<Hover>> {
345            Ok(None)
346        }
347
348        async fn get_definition(&self, _params: &DefinitionParams) -> IdeResult<Option<Location>> {
349            Ok(None)
350        }
351
352        fn is_available(&self, language: &str) -> bool {
353            language == self.language
354        }
355
356        fn name(&self) -> &str {
357            &self.name
358        }
359    }
360
361    #[test]
362    fn test_provider_registry_creation() {
363        let generic = Arc::new(MockProvider {
364            name: "generic".to_string(),
365            language: "generic".to_string(),
366        });
367        let registry = ProviderRegistry::new(generic);
368        assert_eq!(registry.available_languages().len(), 0);
369    }
370
371    #[test]
372    fn test_register_lsp_provider() {
373        let generic = Arc::new(MockProvider {
374            name: "generic".to_string(),
375            language: "generic".to_string(),
376        });
377        let mut registry = ProviderRegistry::new(generic);
378
379        let lsp = Arc::new(MockProvider {
380            name: "rust-analyzer".to_string(),
381            language: "rust".to_string(),
382        });
383        registry.register_lsp_provider("rust".to_string(), lsp);
384
385        assert!(registry.is_provider_available("rust"));
386        assert_eq!(registry.available_languages(), vec!["rust"]);
387    }
388
389    #[test]
390    fn test_provider_priority_chain() {
391        let generic = Arc::new(MockProvider {
392            name: "generic".to_string(),
393            language: "generic".to_string(),
394        });
395        let mut registry = ProviderRegistry::new(generic);
396
397        let lsp = Arc::new(MockProvider {
398            name: "rust-analyzer".to_string(),
399            language: "rust".to_string(),
400        });
401        let builtin = Arc::new(MockProvider {
402            name: "builtin".to_string(),
403            language: "rust".to_string(),
404        });
405
406        registry.register_lsp_provider("rust".to_string(), lsp.clone());
407        registry.register_builtin_provider("rust".to_string(), builtin);
408
409        // LSP should be selected (priority 1)
410        let provider = registry.get_provider("rust");
411        assert_eq!(provider.name(), "rust-analyzer");
412    }
413
414    #[test]
415    fn test_provider_fallback_to_builtin() {
416        let generic = Arc::new(MockProvider {
417            name: "generic".to_string(),
418            language: "generic".to_string(),
419        });
420        let mut registry = ProviderRegistry::new(generic);
421
422        let builtin = Arc::new(MockProvider {
423            name: "builtin".to_string(),
424            language: "rust".to_string(),
425        });
426
427        registry.register_builtin_provider("rust".to_string(), builtin);
428
429        // Built-in should be selected (LSP not available)
430        let provider = registry.get_provider("rust");
431        assert_eq!(provider.name(), "builtin");
432    }
433
434    #[test]
435    fn test_provider_fallback_to_generic() {
436        let generic = Arc::new(MockProvider {
437            name: "generic".to_string(),
438            language: "generic".to_string(),
439        });
440        let registry = ProviderRegistry::new(generic);
441
442        // Generic should be selected (no other providers)
443        let provider = registry.get_provider("unknown");
444        assert_eq!(provider.name(), "generic");
445    }
446
447    #[test]
448    fn test_unregister_lsp_provider() {
449        let generic = Arc::new(MockProvider {
450            name: "generic".to_string(),
451            language: "generic".to_string(),
452        });
453        let mut registry = ProviderRegistry::new(generic);
454
455        let lsp = Arc::new(MockProvider {
456            name: "rust-analyzer".to_string(),
457            language: "rust".to_string(),
458        });
459        registry.register_lsp_provider("rust".to_string(), lsp);
460        assert!(registry.is_provider_available("rust"));
461
462        registry.unregister_lsp_provider("rust");
463        assert!(!registry.is_provider_available("rust"));
464    }
465
466    #[tokio::test]
467    async fn test_provider_chain_manager_creation() {
468        let generic = Arc::new(MockProvider {
469            name: "generic".to_string(),
470            language: "generic".to_string(),
471        });
472        let registry = ProviderRegistry::new(generic);
473        let manager = ProviderChainManager::new(registry);
474
475        assert_eq!(manager.available_languages().await.len(), 0);
476    }
477
478    #[tokio::test]
479    async fn test_provider_chain_manager_get_completions() {
480        let generic = Arc::new(MockProvider {
481            name: "generic".to_string(),
482            language: "generic".to_string(),
483        });
484        let mut registry = ProviderRegistry::new(generic);
485
486        let lsp = Arc::new(MockProvider {
487            name: "rust-analyzer".to_string(),
488            language: "rust".to_string(),
489        });
490        registry.register_lsp_provider("rust".to_string(), lsp);
491
492        let manager = ProviderChainManager::new(registry);
493
494        let params = CompletionParams {
495            language: "rust".to_string(),
496            file_path: "src/main.rs".to_string(),
497            position: Position {
498                line: 10,
499                character: 5,
500            },
501            context: "fn test".to_string(),
502        };
503
504        let result = manager.get_completions(&params).await;
505        assert!(result.is_ok());
506        assert_eq!(result.unwrap().len(), 1);
507    }
508
509    #[tokio::test]
510    async fn test_provider_availability_callback() {
511        let generic = Arc::new(MockProvider {
512            name: "generic".to_string(),
513            language: "generic".to_string(),
514        });
515        let registry = ProviderRegistry::new(generic);
516        let manager = ProviderChainManager::new(registry);
517
518        let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
519        let called_clone = called.clone();
520
521        manager
522            .on_provider_availability_changed(Box::new(move |_change| {
523                called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
524            }))
525            .await;
526
527        let change = ProviderChange {
528            provider_name: "rust-analyzer".to_string(),
529            language: "rust".to_string(),
530            available: true,
531        };
532
533        manager.notify_provider_change(change).await;
534        assert!(called.load(std::sync::atomic::Ordering::SeqCst));
535    }
536}