1use std::collections::HashMap;
6use std::ffi::CStr;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9
10use libloading::{Library, Symbol};
11
12use shape_abi_v1::{
13 ABI_VERSION, CAPABILITY_DATA_SOURCE, CAPABILITY_LANGUAGE_RUNTIME, CAPABILITY_MODULE,
14 CAPABILITY_OUTPUT_SINK, CapabilityKind, CapabilityManifest, DataSourceVTable, GetAbiVersionFn,
15 GetCapabilityManifestFn, GetCapabilityVTableFn, GetClaimedSectionsFn, GetPluginInfoFn,
16 LanguageRuntimeVTable, ModuleVTable, OutputSinkVTable, PluginType, SectionsManifest,
17};
18
19use shape_ast::error::{Result, ShapeError};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ClaimedSection {
24 pub name: String,
26 pub required: bool,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct PluginCapability {
33 pub kind: CapabilityKind,
35 pub contract: String,
37 pub version: String,
39 pub flags: u64,
41}
42
43#[derive(Debug, Clone)]
45pub struct LoadedPlugin {
46 pub name: String,
48 pub version: String,
50 pub plugin_type: PluginType,
52 pub description: String,
54 pub capabilities: Vec<PluginCapability>,
56 pub claimed_sections: Vec<ClaimedSection>,
58}
59
60impl LoadedPlugin {
61 pub fn has_capability_kind(&self, kind: CapabilityKind) -> bool {
63 self.capabilities.iter().any(|cap| cap.kind == kind)
64 }
65
66 pub fn claimed_section_names(&self) -> Vec<&str> {
68 self.claimed_sections
69 .iter()
70 .map(|s| s.name.as_str())
71 .collect()
72 }
73}
74
75pub struct PluginLoader {
80 loaded_libraries: HashMap<String, Library>,
82}
83
84impl PluginLoader {
85 pub fn new() -> Self {
87 Self {
88 loaded_libraries: HashMap::new(),
89 }
90 }
91
92 pub fn load(&mut self, path: &Path) -> Result<LoadedPlugin> {
103 let lib =
105 load_library_with_python_fallback(path).map_err(|e| ShapeError::RuntimeError {
106 message: format!("Failed to load plugin library '{}': {}", path.display(), e),
107 location: None,
108 })?;
109
110 if let Ok(get_version) = unsafe { lib.get::<GetAbiVersionFn>(b"shape_abi_version") } {
112 let version = unsafe { get_version() };
113 if version != ABI_VERSION {
114 return Err(ShapeError::RuntimeError {
115 message: format!(
116 "Plugin ABI version mismatch: expected {}, got {}",
117 ABI_VERSION, version
118 ),
119 location: None,
120 });
121 }
122 }
123
124 let get_info: Symbol<GetPluginInfoFn> = unsafe {
126 lib.get(b"shape_plugin_info")
127 .map_err(|e| ShapeError::RuntimeError {
128 message: format!("Plugin missing 'shape_plugin_info' export: {}", e),
129 location: None,
130 })?
131 };
132
133 let info_ptr = unsafe { get_info() };
134 if info_ptr.is_null() {
135 return Err(ShapeError::RuntimeError {
136 message: "Plugin returned null PluginInfo".to_string(),
137 location: None,
138 });
139 }
140
141 let info = unsafe { &*info_ptr };
142
143 let name = read_c_string(info.name, "PluginInfo.name")?;
145 let version = read_c_string(info.version, "PluginInfo.version")?;
146 let description = read_c_string(info.description, "PluginInfo.description")?;
147
148 let capabilities = self.load_capabilities(&lib)?;
149
150 let claimed_sections = if let Ok(get_sections) =
152 unsafe { lib.get::<GetClaimedSectionsFn>(b"shape_claimed_sections") }
153 {
154 let manifest_ptr = unsafe { get_sections() };
155 if manifest_ptr.is_null() {
156 vec![]
157 } else {
158 let manifest = unsafe { &*manifest_ptr };
159 parse_sections_manifest(manifest)?
160 }
161 } else {
162 vec![] };
164
165 self.loaded_libraries.insert(name.clone(), lib);
167
168 Ok(LoadedPlugin {
169 name,
170 version,
171 plugin_type: info.plugin_type,
172 description,
173 capabilities,
174 claimed_sections,
175 })
176 }
177
178 fn load_capabilities(&self, lib: &Library) -> Result<Vec<PluginCapability>> {
179 let get_manifest =
180 unsafe { lib.get::<GetCapabilityManifestFn>(b"shape_capability_manifest") }.map_err(
181 |e| ShapeError::RuntimeError {
182 message: format!(
183 "Plugin missing required 'shape_capability_manifest' export: {}",
184 e
185 ),
186 location: None,
187 },
188 )?;
189
190 let manifest_ptr = unsafe { get_manifest() };
191 if manifest_ptr.is_null() {
192 return Err(ShapeError::RuntimeError {
193 message: "Plugin returned null CapabilityManifest".to_string(),
194 location: None,
195 });
196 }
197 let manifest = unsafe { &*manifest_ptr };
198 parse_capability_manifest(manifest)
199 }
200
201 pub fn get_data_source_vtable(&self, name: &str) -> Result<&'static DataSourceVTable> {
209 let lib = self
210 .loaded_libraries
211 .get(name)
212 .ok_or_else(|| ShapeError::RuntimeError {
213 message: format!("Plugin '{}' not loaded", name),
214 location: None,
215 })?;
216
217 if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_DATA_SOURCE)? {
218 return Ok(unsafe { &*(vtable_ptr as *const DataSourceVTable) });
220 }
221
222 Err(ShapeError::RuntimeError {
223 message: format!(
224 "Plugin '{}' does not provide capability vtable for '{}'",
225 name, CAPABILITY_DATA_SOURCE
226 ),
227 location: None,
228 })
229 }
230
231 pub fn get_output_sink_vtable(&self, name: &str) -> Result<&'static OutputSinkVTable> {
239 let lib = self
240 .loaded_libraries
241 .get(name)
242 .ok_or_else(|| ShapeError::RuntimeError {
243 message: format!("Plugin '{}' not loaded", name),
244 location: None,
245 })?;
246
247 if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_OUTPUT_SINK)? {
248 return Ok(unsafe { &*(vtable_ptr as *const OutputSinkVTable) });
250 }
251
252 Err(ShapeError::RuntimeError {
253 message: format!(
254 "Plugin '{}' does not provide capability vtable for '{}'",
255 name, CAPABILITY_OUTPUT_SINK
256 ),
257 location: None,
258 })
259 }
260
261 pub fn get_module_vtable(&self, name: &str) -> Result<&'static ModuleVTable> {
263 let lib = self
264 .loaded_libraries
265 .get(name)
266 .ok_or_else(|| ShapeError::RuntimeError {
267 message: format!("Plugin '{}' not loaded", name),
268 location: None,
269 })?;
270
271 if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_MODULE)? {
272 return Ok(unsafe { &*(vtable_ptr as *const ModuleVTable) });
274 }
275
276 Err(ShapeError::RuntimeError {
277 message: format!(
278 "Plugin '{}' does not provide capability vtable for '{}'",
279 name, CAPABILITY_MODULE
280 ),
281 location: None,
282 })
283 }
284
285 pub fn get_language_runtime_vtable(
287 &self,
288 name: &str,
289 ) -> Result<&'static LanguageRuntimeVTable> {
290 let lib = self
291 .loaded_libraries
292 .get(name)
293 .ok_or_else(|| ShapeError::RuntimeError {
294 message: format!("Plugin '{}' not loaded", name),
295 location: None,
296 })?;
297
298 if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_LANGUAGE_RUNTIME)? {
299 return Ok(unsafe { &*(vtable_ptr as *const LanguageRuntimeVTable) });
300 }
301
302 Err(ShapeError::RuntimeError {
303 message: format!(
304 "Plugin '{}' does not provide capability vtable for '{}'",
305 name, CAPABILITY_LANGUAGE_RUNTIME
306 ),
307 location: None,
308 })
309 }
310
311 pub fn unload(&mut self, name: &str) -> bool {
316 self.loaded_libraries.remove(name).is_some()
317 }
318
319 pub fn loaded_plugins(&self) -> Vec<&str> {
321 self.loaded_libraries.keys().map(|s| s.as_str()).collect()
322 }
323
324 pub fn is_loaded(&self, name: &str) -> bool {
326 self.loaded_libraries.contains_key(name)
327 }
328
329 pub fn load_data_source(
341 &mut self,
342 path: &Path,
343 config: &serde_json::Value,
344 ) -> Result<super::PluginDataSource> {
345 let info = self.load(path)?;
347 let name = info.name.clone();
348
349 if !info.has_capability_kind(CapabilityKind::DataSource) {
350 return Err(ShapeError::RuntimeError {
351 message: format!(
352 "Plugin '{}' does not declare data source capability",
353 info.name
354 ),
355 location: None,
356 });
357 }
358
359 let vtable = self.get_data_source_vtable(&name)?;
361
362 super::PluginDataSource::new(name, vtable, config)
364 }
365}
366
367fn load_library_with_python_fallback(path: &Path) -> std::result::Result<Library, String> {
368 let initial = unsafe { Library::new(path) };
369 let initial_error = match initial {
370 Ok(lib) => return Ok(lib),
371 Err(err) => err,
372 };
373 let initial_msg = initial_error.to_string();
374
375 if !should_try_python_fallback(&initial_msg) {
376 return Err(initial_msg);
377 }
378
379 if !preload_python_shared_library() {
380 return Err(initial_msg);
381 }
382
383 match unsafe { Library::new(path) } {
384 Ok(lib) => Ok(lib),
385 Err(retry_err) => Err(format!(
386 "{} (retry after python preload failed: {})",
387 initial_msg, retry_err
388 )),
389 }
390}
391
392fn should_try_python_fallback(error_message: &str) -> bool {
393 let lowered = error_message.to_ascii_lowercase();
394 lowered.contains("libpython") || lowered.contains("python.framework")
395}
396
397fn preload_python_shared_library() -> bool {
398 let candidates = discover_python_shared_library_candidates();
399 for candidate in candidates {
400 match unsafe { Library::new(&candidate) } {
401 Ok(lib) => {
402 tracing::info!(
403 "preloaded python runtime library for extension loading fallback: {}",
404 candidate.display()
405 );
406 std::mem::forget(lib);
408 return true;
409 }
410 Err(err) => {
411 tracing::debug!(
412 "failed to preload python runtime candidate '{}': {}",
413 candidate.display(),
414 err
415 );
416 }
417 }
418 }
419 false
420}
421
422fn discover_python_shared_library_candidates() -> Vec<PathBuf> {
423 let python = std::env::var("PYO3_PYTHON").unwrap_or_else(|_| "python3".to_string());
424 let script = r#"import os, sys, sysconfig
425cands = []
426libdir = sysconfig.get_config_var("LIBDIR")
427ldlibrary = sysconfig.get_config_var("LDLIBRARY")
428if libdir and ldlibrary:
429 cands.append(os.path.join(libdir, ldlibrary))
430if libdir:
431 for name in ("libpython3.so", "libpython3.so.1.0", "libpython3.dylib"):
432 cands.append(os.path.join(libdir, name))
433for base in {sys.base_prefix, sys.prefix}:
434 if not base:
435 continue
436 for rel in ("lib", "lib64"):
437 d = os.path.join(base, rel)
438 if ldlibrary:
439 cands.append(os.path.join(d, ldlibrary))
440seen = set()
441for cand in cands:
442 if not cand:
443 continue
444 real = os.path.realpath(cand)
445 if real in seen:
446 continue
447 seen.add(real)
448 if os.path.exists(real):
449 print(real)
450"#;
451
452 let output = Command::new(&python).arg("-c").arg(script).output();
453 let Ok(output) = output else {
454 return Vec::new();
455 };
456 if !output.status.success() {
457 return Vec::new();
458 }
459
460 String::from_utf8_lossy(&output.stdout)
461 .lines()
462 .map(str::trim)
463 .filter(|line| !line.is_empty())
464 .map(PathBuf::from)
465 .collect()
466}
467
468impl Drop for PluginLoader {
469 fn drop(&mut self) {
470 for (_name, lib) in self.loaded_libraries.drain() {
475 if let Ok(get_manifest) =
476 unsafe { lib.get::<GetCapabilityManifestFn>(b"shape_capability_manifest") }
477 {
478 let manifest_ptr = unsafe { get_manifest() };
479 if !manifest_ptr.is_null() {
480 let manifest = unsafe { &*manifest_ptr };
481 if let Ok(caps) = parse_capability_manifest(manifest) {
482 if caps
483 .iter()
484 .any(|c| c.kind == CapabilityKind::LanguageRuntime)
485 {
486 std::mem::forget(lib);
488 continue;
489 }
490 }
491 }
492 }
493 drop(lib);
495 }
496 }
497}
498
499impl Default for PluginLoader {
500 fn default() -> Self {
501 Self::new()
502 }
503}
504
505fn try_capability_vtable(lib: &Library, contract: &str) -> Result<Option<*const std::ffi::c_void>> {
506 let get_vtable_fn = unsafe { lib.get::<GetCapabilityVTableFn>(b"shape_capability_vtable") };
507 let Ok(get_vtable_fn) = get_vtable_fn else {
508 return Ok(None);
509 };
510
511 let vtable_ptr = unsafe { get_vtable_fn(contract.as_ptr(), contract.len()) };
512 if vtable_ptr.is_null() {
513 return Ok(None);
514 }
515 Ok(Some(vtable_ptr))
516}
517
518fn parse_capability_manifest(manifest: &CapabilityManifest) -> Result<Vec<PluginCapability>> {
519 if manifest.capabilities_len == 0 {
520 return Err(ShapeError::RuntimeError {
521 message: "CapabilityManifest must contain at least one capability".to_string(),
522 location: None,
523 });
524 }
525 if manifest.capabilities.is_null() {
526 return Err(ShapeError::RuntimeError {
527 message: "CapabilityManifest.capabilities is null".to_string(),
528 location: None,
529 });
530 }
531
532 let caps =
533 unsafe { std::slice::from_raw_parts(manifest.capabilities, manifest.capabilities_len) };
534 let mut parsed = Vec::with_capacity(caps.len());
535 for cap in caps {
536 parsed.push(PluginCapability {
537 kind: cap.kind,
538 contract: read_c_string(cap.contract, "CapabilityDescriptor.contract")?,
539 version: read_c_string(cap.version, "CapabilityDescriptor.version")?,
540 flags: cap.flags,
541 });
542 }
543 Ok(parsed)
544}
545
546pub fn parse_sections_manifest(manifest: &SectionsManifest) -> Result<Vec<ClaimedSection>> {
547 if manifest.sections_len == 0 {
548 return Ok(vec![]);
549 }
550 if manifest.sections.is_null() {
551 return Err(ShapeError::RuntimeError {
552 message: "SectionsManifest.sections is null but sections_len > 0".to_string(),
553 location: None,
554 });
555 }
556
557 let claims = unsafe { std::slice::from_raw_parts(manifest.sections, manifest.sections_len) };
558 let mut parsed = Vec::with_capacity(claims.len());
559 for claim in claims {
560 parsed.push(ClaimedSection {
561 name: read_c_string(claim.name, "SectionClaim.name")?,
562 required: claim.required,
563 });
564 }
565 Ok(parsed)
566}
567
568fn read_c_string(ptr: *const std::ffi::c_char, field: &str) -> Result<String> {
569 if ptr.is_null() {
570 return Err(ShapeError::RuntimeError {
571 message: format!("{} is null", field),
572 location: None,
573 });
574 }
575
576 Ok(unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string())
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use shape_abi_v1::{CAPABILITY_MODULE, CapabilityDescriptor};
583
584 #[test]
585 fn test_plugin_loader_new() {
586 let loader = PluginLoader::new();
587 assert!(loader.loaded_plugins().is_empty());
588 }
589
590 #[test]
591 fn test_is_loaded_false() {
592 let loader = PluginLoader::new();
593 assert!(!loader.is_loaded("nonexistent"));
594 }
595
596 #[test]
597 fn test_should_try_python_fallback_matches_libpython_errors() {
598 assert!(should_try_python_fallback(
599 "libpython3.13.so.1.0: cannot open shared object file"
600 ));
601 assert!(should_try_python_fallback(
602 "Library not loaded: @rpath/Python.framework/Versions/3.12/Python"
603 ));
604 assert!(!should_try_python_fallback(
605 "undefined symbol: sqlite3_open"
606 ));
607 }
608
609 #[test]
610 fn test_parse_capability_manifest() {
611 static CAPS: [CapabilityDescriptor; 2] = [
612 CapabilityDescriptor {
613 kind: CapabilityKind::DataSource,
614 contract: c"shape.datasource".as_ptr(),
615 version: c"1".as_ptr(),
616 flags: 0,
617 },
618 CapabilityDescriptor {
619 kind: CapabilityKind::Compute,
620 contract: c"shape.compute".as_ptr(),
621 version: c"1".as_ptr(),
622 flags: 42,
623 },
624 ];
625 static MANIFEST: CapabilityManifest = CapabilityManifest {
626 capabilities: CAPS.as_ptr(),
627 capabilities_len: CAPS.len(),
628 };
629
630 let parsed = parse_capability_manifest(&MANIFEST).expect("manifest should parse");
631 assert_eq!(parsed.len(), 2);
632 assert_eq!(parsed[0].contract, "shape.datasource");
633 assert_eq!(parsed[1].kind, CapabilityKind::Compute);
634 assert_eq!(parsed[1].flags, 42);
635 }
636
637 #[test]
638 fn test_parse_capability_manifest_rejects_empty() {
639 static MANIFEST: CapabilityManifest = CapabilityManifest {
640 capabilities: std::ptr::null(),
641 capabilities_len: 0,
642 };
643 let result = parse_capability_manifest(&MANIFEST);
644 assert!(result.is_err());
645 }
646
647 #[test]
648 fn test_module_contract_constant_is_expected() {
649 assert_eq!(CAPABILITY_MODULE, "shape.module");
650 }
651
652 #[test]
653 fn test_parse_sections_manifest_valid() {
654 use shape_abi_v1::SectionClaim as AbiSectionClaim;
655
656 static CLAIMS: [AbiSectionClaim; 2] = [
657 AbiSectionClaim {
658 name: c"native-dependencies".as_ptr(),
659 required: false,
660 },
661 AbiSectionClaim {
662 name: c"custom-config".as_ptr(),
663 required: true,
664 },
665 ];
666 static MANIFEST: SectionsManifest = SectionsManifest {
667 sections: CLAIMS.as_ptr(),
668 sections_len: CLAIMS.len(),
669 };
670
671 let parsed = parse_sections_manifest(&MANIFEST).expect("should parse");
672 assert_eq!(parsed.len(), 2);
673 assert_eq!(parsed[0].name, "native-dependencies");
674 assert!(!parsed[0].required);
675 assert_eq!(parsed[1].name, "custom-config");
676 assert!(parsed[1].required);
677 }
678
679 #[test]
680 fn test_parse_sections_manifest_empty() {
681 static MANIFEST: SectionsManifest = SectionsManifest {
682 sections: std::ptr::null(),
683 sections_len: 0,
684 };
685 let parsed = parse_sections_manifest(&MANIFEST).expect("empty should parse");
686 assert!(parsed.is_empty());
687 }
688}