1use crate::error::IdeResult;
10use crate::provider::{IdeProvider, ProviderChange};
11use crate::types::*;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tracing::{debug, info, warn};
15
16type ProviderAvailabilityCallback = Box<dyn Fn(ProviderChange) + Send + Sync>;
18
19pub struct ProviderRegistry {
21 lsp_providers: HashMap<String, Arc<dyn IdeProvider>>,
23 configured_providers: HashMap<String, Arc<dyn IdeProvider>>,
25 builtin_providers: HashMap<String, Arc<dyn IdeProvider>>,
27 generic_provider: Arc<dyn IdeProvider>,
29}
30
31impl ProviderRegistry {
32 pub fn new(generic_provider: Arc<dyn IdeProvider>) -> Self {
34 ProviderRegistry {
35 lsp_providers: HashMap::new(),
36 configured_providers: HashMap::new(),
37 builtin_providers: HashMap::new(),
38 generic_provider,
39 }
40 }
41
42 pub fn register_lsp_provider(&mut self, language: String, provider: Arc<dyn IdeProvider>) {
44 debug!("Registering LSP provider for language: {}", language);
45 self.lsp_providers.insert(language, provider);
46 }
47
48 pub fn register_configured_provider(
50 &mut self,
51 language: String,
52 provider: Arc<dyn IdeProvider>,
53 ) {
54 debug!("Registering configured rules provider for language: {}", language);
55 self.configured_providers.insert(language, provider);
56 }
57
58 pub fn register_builtin_provider(&mut self, language: String, provider: Arc<dyn IdeProvider>) {
60 debug!("Registering built-in provider for language: {}", language);
61 self.builtin_providers.insert(language, provider);
62 }
63
64 pub fn get_provider(&self, language: &str) -> Arc<dyn IdeProvider> {
66 if let Some(provider) = self.lsp_providers.get(language) {
68 debug!("Using LSP provider for language: {}", language);
69 return provider.clone();
70 }
71
72 if let Some(provider) = self.configured_providers.get(language) {
74 debug!("Using configured rules provider for language: {}", language);
75 return provider.clone();
76 }
77
78 if let Some(provider) = self.builtin_providers.get(language) {
80 debug!("Using built-in provider for language: {}", language);
81 return provider.clone();
82 }
83
84 debug!("Using generic fallback provider for language: {}", language);
86 self.generic_provider.clone()
87 }
88
89 pub fn is_provider_available(&self, language: &str) -> bool {
91 self.lsp_providers.contains_key(language)
92 || self.configured_providers.contains_key(language)
93 || self.builtin_providers.contains_key(language)
94 }
95
96 pub fn available_languages(&self) -> Vec<String> {
98 let mut languages = Vec::new();
99 languages.extend(self.lsp_providers.keys().cloned());
100 languages.extend(self.configured_providers.keys().cloned());
101 languages.extend(self.builtin_providers.keys().cloned());
102 languages.sort();
103 languages.dedup();
104 languages
105 }
106
107 pub fn unregister_lsp_provider(&mut self, language: &str) {
109 debug!("Unregistering LSP provider for language: {}", language);
110 self.lsp_providers.remove(language);
111 }
112
113 pub fn unregister_configured_provider(&mut self, language: &str) {
115 debug!("Unregistering configured rules provider for language: {}", language);
116 self.configured_providers.remove(language);
117 }
118
119 pub fn unregister_builtin_provider(&mut self, language: &str) {
121 debug!("Unregistering built-in provider for language: {}", language);
122 self.builtin_providers.remove(language);
123 }
124}
125
126pub struct ProviderChainManager {
128 registry: Arc<tokio::sync::RwLock<ProviderRegistry>>,
129 availability_callbacks: Arc<tokio::sync::RwLock<Vec<ProviderAvailabilityCallback>>>,
130}
131
132impl ProviderChainManager {
133 pub fn new(registry: ProviderRegistry) -> Self {
135 ProviderChainManager {
136 registry: Arc::new(tokio::sync::RwLock::new(registry)),
137 availability_callbacks: Arc::new(tokio::sync::RwLock::new(Vec::new())),
138 }
139 }
140
141 pub async fn get_completions(&self, params: &CompletionParams) -> IdeResult<Vec<CompletionItem>> {
143 debug!(
144 "Getting completions for language: {} through provider chain",
145 params.language
146 );
147
148 let registry = self.registry.read().await;
149 let provider = registry.get_provider(¶ms.language);
150
151 match provider.get_completions(params).await {
152 Ok(completions) => {
153 info!(
154 "Successfully got {} completions for language: {}",
155 completions.len(),
156 params.language
157 );
158 Ok(completions)
159 }
160 Err(e) => {
161 warn!(
162 "Failed to get completions for language: {}: {}",
163 params.language, e
164 );
165 Err(e)
166 }
167 }
168 }
169
170 pub async fn get_diagnostics(&self, params: &DiagnosticsParams) -> IdeResult<Vec<Diagnostic>> {
172 debug!(
173 "Getting diagnostics for language: {} through provider chain",
174 params.language
175 );
176
177 let registry = self.registry.read().await;
178 let provider = registry.get_provider(¶ms.language);
179
180 match provider.get_diagnostics(params).await {
181 Ok(diagnostics) => {
182 info!(
183 "Successfully got {} diagnostics for language: {}",
184 diagnostics.len(),
185 params.language
186 );
187 Ok(diagnostics)
188 }
189 Err(e) => {
190 warn!(
191 "Failed to get diagnostics for language: {}: {}",
192 params.language, e
193 );
194 Err(e)
195 }
196 }
197 }
198
199 pub async fn get_hover(&self, params: &HoverParams) -> IdeResult<Option<Hover>> {
201 debug!(
202 "Getting hover information for language: {} through provider chain",
203 params.language
204 );
205
206 let registry = self.registry.read().await;
207 let provider = registry.get_provider(¶ms.language);
208
209 match provider.get_hover(params).await {
210 Ok(hover) => {
211 if hover.is_some() {
212 info!("Successfully got hover information for language: {}", params.language);
213 }
214 Ok(hover)
215 }
216 Err(e) => {
217 warn!(
218 "Failed to get hover information for language: {}: {}",
219 params.language, e
220 );
221 Err(e)
222 }
223 }
224 }
225
226 pub async fn get_definition(&self, params: &DefinitionParams) -> IdeResult<Option<Location>> {
228 debug!(
229 "Getting definition for language: {} through provider chain",
230 params.language
231 );
232
233 let registry = self.registry.read().await;
234 let provider = registry.get_provider(¶ms.language);
235
236 match provider.get_definition(params).await {
237 Ok(location) => {
238 if location.is_some() {
239 info!("Successfully got definition for language: {}", params.language);
240 }
241 Ok(location)
242 }
243 Err(e) => {
244 warn!(
245 "Failed to get definition for language: {}: {}",
246 params.language, e
247 );
248 Err(e)
249 }
250 }
251 }
252
253 pub async fn on_provider_availability_changed(
255 &self,
256 callback: Box<dyn Fn(ProviderChange) + Send + Sync>,
257 ) {
258 let mut callbacks = self.availability_callbacks.write().await;
259 callbacks.push(callback);
260 }
261
262 pub async fn notify_provider_change(&self, change: ProviderChange) {
264 let callbacks = self.availability_callbacks.read().await;
265 for callback in callbacks.iter() {
266 callback(change.clone());
267 }
268 }
269
270 pub async fn reload_configuration(&self) -> IdeResult<()> {
272 debug!("Reloading provider chain configuration");
273 info!("Provider chain configuration reloaded");
275 Ok(())
276 }
277
278 pub async fn update_config(&self, config: IdeIntegrationConfig) -> IdeResult<()> {
280 debug!("Updating provider chain configuration");
281
282 if config.providers.external_lsp.enabled {
284 for language in config.providers.external_lsp.servers.keys() {
285 debug!("Updated LSP configuration for language: {}", language);
287 }
288 }
289
290 info!("Provider chain configuration updated");
291 Ok(())
292 }
293
294 pub async fn registry(&self) -> tokio::sync::RwLockReadGuard<'_, ProviderRegistry> {
296 self.registry.read().await
297 }
298
299 pub async fn registry_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, ProviderRegistry> {
301 self.registry.write().await
302 }
303
304 pub async fn is_provider_available(&self, language: &str) -> bool {
306 let registry = self.registry.read().await;
307 registry.is_provider_available(language)
308 }
309
310 pub async fn available_languages(&self) -> Vec<String> {
312 let registry = self.registry.read().await;
313 registry.available_languages()
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use async_trait::async_trait;
321
322 struct MockProvider {
324 name: String,
325 language: String,
326 }
327
328 #[async_trait]
329 impl IdeProvider for MockProvider {
330 async fn get_completions(&self, _params: &CompletionParams) -> IdeResult<Vec<CompletionItem>> {
331 Ok(vec![CompletionItem {
332 label: "test".to_string(),
333 kind: CompletionItemKind::Function,
334 detail: None,
335 documentation: None,
336 insert_text: "test()".to_string(),
337 }])
338 }
339
340 async fn get_diagnostics(&self, _params: &DiagnosticsParams) -> IdeResult<Vec<Diagnostic>> {
341 Ok(vec![])
342 }
343
344 async fn get_hover(&self, _params: &HoverParams) -> IdeResult<Option<Hover>> {
345 Ok(None)
346 }
347
348 async fn get_definition(&self, _params: &DefinitionParams) -> IdeResult<Option<Location>> {
349 Ok(None)
350 }
351
352 fn is_available(&self, language: &str) -> bool {
353 language == self.language
354 }
355
356 fn name(&self) -> &str {
357 &self.name
358 }
359 }
360
361 #[test]
362 fn test_provider_registry_creation() {
363 let generic = Arc::new(MockProvider {
364 name: "generic".to_string(),
365 language: "generic".to_string(),
366 });
367 let registry = ProviderRegistry::new(generic);
368 assert_eq!(registry.available_languages().len(), 0);
369 }
370
371 #[test]
372 fn test_register_lsp_provider() {
373 let generic = Arc::new(MockProvider {
374 name: "generic".to_string(),
375 language: "generic".to_string(),
376 });
377 let mut registry = ProviderRegistry::new(generic);
378
379 let lsp = Arc::new(MockProvider {
380 name: "rust-analyzer".to_string(),
381 language: "rust".to_string(),
382 });
383 registry.register_lsp_provider("rust".to_string(), lsp);
384
385 assert!(registry.is_provider_available("rust"));
386 assert_eq!(registry.available_languages(), vec!["rust"]);
387 }
388
389 #[test]
390 fn test_provider_priority_chain() {
391 let generic = Arc::new(MockProvider {
392 name: "generic".to_string(),
393 language: "generic".to_string(),
394 });
395 let mut registry = ProviderRegistry::new(generic);
396
397 let lsp = Arc::new(MockProvider {
398 name: "rust-analyzer".to_string(),
399 language: "rust".to_string(),
400 });
401 let builtin = Arc::new(MockProvider {
402 name: "builtin".to_string(),
403 language: "rust".to_string(),
404 });
405
406 registry.register_lsp_provider("rust".to_string(), lsp.clone());
407 registry.register_builtin_provider("rust".to_string(), builtin);
408
409 let provider = registry.get_provider("rust");
411 assert_eq!(provider.name(), "rust-analyzer");
412 }
413
414 #[test]
415 fn test_provider_fallback_to_builtin() {
416 let generic = Arc::new(MockProvider {
417 name: "generic".to_string(),
418 language: "generic".to_string(),
419 });
420 let mut registry = ProviderRegistry::new(generic);
421
422 let builtin = Arc::new(MockProvider {
423 name: "builtin".to_string(),
424 language: "rust".to_string(),
425 });
426
427 registry.register_builtin_provider("rust".to_string(), builtin);
428
429 let provider = registry.get_provider("rust");
431 assert_eq!(provider.name(), "builtin");
432 }
433
434 #[test]
435 fn test_provider_fallback_to_generic() {
436 let generic = Arc::new(MockProvider {
437 name: "generic".to_string(),
438 language: "generic".to_string(),
439 });
440 let registry = ProviderRegistry::new(generic);
441
442 let provider = registry.get_provider("unknown");
444 assert_eq!(provider.name(), "generic");
445 }
446
447 #[test]
448 fn test_unregister_lsp_provider() {
449 let generic = Arc::new(MockProvider {
450 name: "generic".to_string(),
451 language: "generic".to_string(),
452 });
453 let mut registry = ProviderRegistry::new(generic);
454
455 let lsp = Arc::new(MockProvider {
456 name: "rust-analyzer".to_string(),
457 language: "rust".to_string(),
458 });
459 registry.register_lsp_provider("rust".to_string(), lsp);
460 assert!(registry.is_provider_available("rust"));
461
462 registry.unregister_lsp_provider("rust");
463 assert!(!registry.is_provider_available("rust"));
464 }
465
466 #[tokio::test]
467 async fn test_provider_chain_manager_creation() {
468 let generic = Arc::new(MockProvider {
469 name: "generic".to_string(),
470 language: "generic".to_string(),
471 });
472 let registry = ProviderRegistry::new(generic);
473 let manager = ProviderChainManager::new(registry);
474
475 assert_eq!(manager.available_languages().await.len(), 0);
476 }
477
478 #[tokio::test]
479 async fn test_provider_chain_manager_get_completions() {
480 let generic = Arc::new(MockProvider {
481 name: "generic".to_string(),
482 language: "generic".to_string(),
483 });
484 let mut registry = ProviderRegistry::new(generic);
485
486 let lsp = Arc::new(MockProvider {
487 name: "rust-analyzer".to_string(),
488 language: "rust".to_string(),
489 });
490 registry.register_lsp_provider("rust".to_string(), lsp);
491
492 let manager = ProviderChainManager::new(registry);
493
494 let params = CompletionParams {
495 language: "rust".to_string(),
496 file_path: "src/main.rs".to_string(),
497 position: Position {
498 line: 10,
499 character: 5,
500 },
501 context: "fn test".to_string(),
502 };
503
504 let result = manager.get_completions(¶ms).await;
505 assert!(result.is_ok());
506 assert_eq!(result.unwrap().len(), 1);
507 }
508
509 #[tokio::test]
510 async fn test_provider_availability_callback() {
511 let generic = Arc::new(MockProvider {
512 name: "generic".to_string(),
513 language: "generic".to_string(),
514 });
515 let registry = ProviderRegistry::new(generic);
516 let manager = ProviderChainManager::new(registry);
517
518 let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
519 let called_clone = called.clone();
520
521 manager
522 .on_provider_availability_changed(Box::new(move |_change| {
523 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
524 }))
525 .await;
526
527 let change = ProviderChange {
528 provider_name: "rust-analyzer".to_string(),
529 language: "rust".to_string(),
530 available: true,
531 };
532
533 manager.notify_provider_change(change).await;
534 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
535 }
536}