1use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12 CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13 PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_wire::WireValue;
17use std::collections::HashMap;
18use std::path::Path;
19use std::sync::{Arc, RwLock};
20
21#[derive(Clone)]
46pub struct ProviderRegistry {
47 providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
49 default_provider: Arc<RwLock<Option<String>>>,
51 extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
53 extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
55 loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
57 extension_loader: Arc<RwLock<PluginLoader>>,
59 language_runtimes:
61 Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
62}
63
64impl ProviderRegistry {
65 pub fn new() -> Self {
67 Self {
68 providers: Arc::new(RwLock::new(HashMap::new())),
69 default_provider: Arc::new(RwLock::new(None)),
70 extension_sources: Arc::new(RwLock::new(HashMap::new())),
71 extension_modules: Arc::new(RwLock::new(HashMap::new())),
72 loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
73 extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
74 language_runtimes: Arc::new(RwLock::new(HashMap::new())),
75 }
76 }
77
78 pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
91 let mut providers = self.providers.write().unwrap();
92 providers.insert(name.to_string(), provider);
93 }
94
95 pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
105 let providers = self.providers.read().unwrap();
106 providers.get(name).cloned()
107 }
108
109 pub fn set_default(&self, name: &str) -> Result<()> {
119 let providers = self.providers.read().unwrap();
120 if !providers.contains_key(name) {
121 return Err(ShapeError::RuntimeError {
122 message: format!("Cannot set default provider: '{}' is not registered", name),
123 location: None,
124 });
125 }
126 drop(providers);
127
128 let mut default = self.default_provider.write().unwrap();
129 *default = Some(name.to_string());
130 Ok(())
131 }
132
133 pub fn get_default(&self) -> Option<SharedAsyncProvider> {
139 let default = self.default_provider.read().unwrap();
140 let name = default.as_ref().cloned();
141 drop(default);
142
143 name.and_then(|n| self.get(&n))
144 }
145
146 pub fn default_name(&self) -> Option<String> {
148 let default = self.default_provider.read().unwrap();
149 default.clone()
150 }
151
152 pub fn list_providers(&self) -> Vec<String> {
158 let providers = self.providers.read().unwrap();
159 providers.keys().cloned().collect()
160 }
161
162 pub fn has_provider(&self, name: &str) -> bool {
164 let providers = self.providers.read().unwrap();
165 providers.contains_key(name)
166 }
167
168 pub fn unregister(&self, name: &str) -> bool {
178 let mut providers = self.providers.write().unwrap();
179 let removed = providers.remove(name).is_some();
180
181 if removed {
183 let mut default = self.default_provider.write().unwrap();
184 if default.as_ref().map(|s| s == name).unwrap_or(false) {
185 *default = None;
186 }
187 }
188
189 removed
190 }
191
192 pub fn clear(&self) {
194 let mut providers = self.providers.write().unwrap();
195 providers.clear();
196
197 let mut default = self.default_provider.write().unwrap();
198 *default = None;
199
200 let mut extension_sources = self.extension_sources.write().unwrap();
201 extension_sources.clear();
202
203 let mut extension_modules = self.extension_modules.write().unwrap();
204 extension_modules.clear();
205
206 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
207 loaded_extensions.clear();
208
209 let mut runtimes = self.language_runtimes.write().unwrap();
210 runtimes.clear();
211 }
212
213 pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
232 let mut loader = self.extension_loader.write().unwrap();
234 let loaded_info = loader.load(path)?;
235 let name = loaded_info.name.clone();
236
237 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
240 let vtable = loader.get_data_source_vtable(&name)?;
241 let source = PluginDataSource::new(name.clone(), vtable, config)?;
242
243 let mut sources = self.extension_sources.write().unwrap();
244 sources.insert(name.clone(), Arc::new(source));
245 } else {
246 let mut sources = self.extension_sources.write().unwrap();
249 sources.remove(&name);
250 }
251
252 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
256 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
257 let mut modules = self.extension_modules.write().unwrap();
258 modules.insert(name.clone(), Arc::new(module));
259 }
260 }
261
262 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
264 let vtable = loader.get_language_runtime_vtable(&name)?;
265 let runtime =
266 crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
267 let lang_id = runtime.language_id().to_string();
268 let mut runtimes = self.language_runtimes.write().unwrap();
269 runtimes.insert(lang_id, Arc::new(runtime));
270 }
271
272 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
273 loaded_extensions.insert(name, loaded_info.clone());
274
275 Ok(loaded_info)
276 }
277
278 pub fn load_extension_with_sections(
284 &self,
285 path: &Path,
286 config: &serde_json::Value,
287 extension_sections: &std::collections::HashMap<String, toml::Value>,
288 all_claimed: &mut std::collections::HashSet<String>,
289 ) -> Result<LoadedPlugin> {
290 let mut loader = self.extension_loader.write().unwrap();
292 let loaded_info = loader.load(path)?;
293 let name = loaded_info.name.clone();
294
295 for claim in &loaded_info.claimed_sections {
297 if !all_claimed.insert(claim.name.clone()) {
298 return Err(ShapeError::RuntimeError {
299 message: format!(
300 "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
301 claim.name, name
302 ),
303 location: None,
304 });
305 }
306 }
307
308 let mut merged_config = config.clone();
311 if let serde_json::Value::Object(ref mut map) = merged_config {
312 for claim in &loaded_info.claimed_sections {
313 if let Some(section_value) = extension_sections.get(&claim.name) {
314 let json_value = crate::project::toml_to_json(section_value);
315 map.insert(claim.name.clone(), json_value);
316 } else if claim.required {
317 return Err(ShapeError::RuntimeError {
318 message: format!(
319 "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
320 name, claim.name
321 ),
322 location: None,
323 });
324 }
325 }
326 }
327
328 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
330 let vtable = loader.get_data_source_vtable(&name)?;
331 let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
332 let mut sources = self.extension_sources.write().unwrap();
333 sources.insert(name.clone(), Arc::new(source));
334 } else {
335 let mut sources = self.extension_sources.write().unwrap();
336 sources.remove(&name);
337 }
338
339 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
340 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
341 let mut modules = self.extension_modules.write().unwrap();
342 modules.insert(name.clone(), Arc::new(module));
343 }
344 }
345
346 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
347 let vtable = loader.get_language_runtime_vtable(&name)?;
348 let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
349 vtable,
350 &merged_config,
351 )?;
352 let lang_id = runtime.language_id().to_string();
353 let mut runtimes = self.language_runtimes.write().unwrap();
354 runtimes.insert(lang_id, Arc::new(runtime));
355 }
356
357 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
358 loaded_extensions.insert(name, loaded_info.clone());
359
360 Ok(loaded_info)
361 }
362
363 pub fn get_language_runtime(
365 &self,
366 language_id: &str,
367 ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
368 let runtimes = self.language_runtimes.read().unwrap();
369 runtimes.get(language_id).cloned()
370 }
371
372 pub fn language_runtimes(
374 &self,
375 ) -> std::collections::HashMap<
376 String,
377 Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
378 > {
379 let runtimes = self.language_runtimes.read().unwrap();
380 runtimes.clone()
381 }
382
383 pub fn language_runtime_lsp_configs(
385 &self,
386 ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
387 let runtimes = self.language_runtimes.read().unwrap();
388 let mut configs = Vec::new();
389
390 for runtime in runtimes.values() {
391 match runtime.lsp_config() {
392 Ok(Some(config)) => configs.push(config),
393 Ok(None) => {}
394 Err(err) => {
395 tracing::warn!("failed to query language runtime LSP config: {}", err);
396 }
397 }
398 }
399
400 configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
401 configs
402 }
403
404 pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
414 let sources = self.extension_sources.read().unwrap();
415 sources.get(name).cloned()
416 }
417
418 pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
420 let modules = self.extension_modules.read().unwrap();
421 modules
422 .values()
423 .find(|m| m.schema().module_name == module_name)
424 .map(|m| m.schema().clone())
425 }
426
427 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
434 Vec::new()
435 }
436
437 pub fn invoke_extension_module_wire(
439 &self,
440 module_name: &str,
441 function: &str,
442 args: &[WireValue],
443 ) -> Result<WireValue> {
444 let modules = self.extension_modules.read().unwrap();
445 let module = modules
446 .values()
447 .find(|m| m.schema().module_name == module_name)
448 .ok_or_else(|| ShapeError::RuntimeError {
449 message: format!("Module namespace '{}' is not loaded", module_name),
450 location: None,
451 })?;
452 module.invoke_wire(function, args)
453 }
454
455 pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
465 let sources = self.extension_sources.read().unwrap();
466 sources.get(name).map(|s| s.get_query_schema().clone())
467 }
468
469 pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
479 let sources = self.extension_sources.read().unwrap();
480 sources.get(name).map(|s| s.get_output_schema().clone())
481 }
482
483 pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
489 let sources = self.extension_sources.read().unwrap();
490 sources
491 .iter()
492 .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
493 .collect()
494 }
495
496 pub fn list_extensions(&self) -> Vec<String> {
498 let loaded = self.loaded_extensions.read().unwrap();
499 loaded.keys().cloned().collect()
500 }
501
502 pub fn has_extension(&self, name: &str) -> bool {
504 let loaded = self.loaded_extensions.read().unwrap();
505 loaded.contains_key(name)
506 }
507
508 pub fn unload_extension(&self, name: &str) -> bool {
518 let mut sources = self.extension_sources.write().unwrap();
519 let removed_source = sources.remove(name).is_some();
520 drop(sources);
521
522 let mut modules = self.extension_modules.write().unwrap();
523 let removed_module = modules.remove(name).is_some();
524 drop(modules);
525
526 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
527 let removed_plugin = loaded_extensions.remove(name).is_some();
528 drop(loaded_extensions);
529
530 if removed_plugin {
531 let mut loader = self.extension_loader.write().unwrap();
532 loader.unload(name);
533 }
534
535 removed_plugin || removed_source || removed_module
536 }
537}
538
539impl Default for ProviderRegistry {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::data::async_provider::NullAsyncProvider;
549
550 #[test]
551 fn test_register_and_get() {
552 let registry = ProviderRegistry::new();
553 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
554
555 registry.register("test", provider.clone());
556
557 assert!(registry.has_provider("test"));
558 assert!(!registry.has_provider("nonexistent"));
559 assert!(registry.get("test").is_some());
560 }
561
562 #[test]
563 fn test_default_provider() {
564 let registry = ProviderRegistry::new();
565 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
566
567 registry.register("test", provider);
568
569 assert!(registry.set_default("test").is_ok());
570 assert!(registry.get_default().is_some());
571 assert_eq!(registry.default_name(), Some("test".to_string()));
572 }
573
574 #[test]
575 fn test_set_default_nonexistent() {
576 let registry = ProviderRegistry::new();
577 assert!(registry.set_default("nonexistent").is_err());
578 }
579
580 #[test]
581 fn test_list_providers() {
582 let registry = ProviderRegistry::new();
583 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
584
585 registry.register("test1", provider.clone());
586 registry.register("test2", provider);
587
588 let mut names = registry.list_providers();
589 names.sort();
590 assert_eq!(names, vec!["test1", "test2"]);
591 }
592
593 #[test]
594 fn test_unregister() {
595 let registry = ProviderRegistry::new();
596 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
597
598 registry.register("test", provider);
599 registry.set_default("test").unwrap();
600
601 assert!(registry.unregister("test"));
602 assert!(!registry.has_provider("test"));
603 assert!(registry.get_default().is_none());
604 }
605
606 #[test]
607 fn test_clear() {
608 let registry = ProviderRegistry::new();
609 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
610
611 registry.register("test1", provider.clone());
612 registry.register("test2", provider);
613 registry.set_default("test1").unwrap();
614
615 registry.clear();
616
617 assert_eq!(registry.list_providers().len(), 0);
618 assert!(registry.get_default().is_none());
619 }
620
621 #[test]
624 fn test_plugin_not_loaded_by_default() {
625 let registry = ProviderRegistry::new();
626
627 assert!(!registry.has_extension("nonexistent"));
628 assert!(registry.get_extension("nonexistent").is_none());
629 }
630
631 #[test]
632 fn test_list_extensions_empty() {
633 let registry = ProviderRegistry::new();
634
635 let plugins = registry.list_extensions();
636 assert!(plugins.is_empty());
637 }
638
639 #[test]
640 fn test_list_extensions_with_schemas_empty() {
641 let registry = ProviderRegistry::new();
642
643 let schemas = registry.list_extensions_with_schemas();
644 assert!(schemas.is_empty());
645 }
646
647 #[test]
648 fn test_get_extension_query_schema_not_found() {
649 let registry = ProviderRegistry::new();
650
651 let schema = registry.get_extension_query_schema("nonexistent");
652 assert!(schema.is_none());
653 }
654
655 #[test]
656 fn test_get_extension_output_schema_not_found() {
657 let registry = ProviderRegistry::new();
658
659 let schema = registry.get_extension_output_schema("nonexistent");
660 assert!(schema.is_none());
661 }
662
663 #[test]
664 fn test_unload_plugin_not_loaded() {
665 let registry = ProviderRegistry::new();
666
667 assert!(!registry.unload_extension("nonexistent"));
669 }
670
671 #[test]
672 fn test_clear_removes_plugins() {
673 let registry = ProviderRegistry::new();
674
675 registry.clear();
677
678 assert!(registry.list_extensions().is_empty());
679 }
680}