use std::path::{Path, PathBuf};
use std::sync::Arc;
use libloading::Library;
use sha2::Digest;
use crate::extensions::types::ExtensionError;
use crate::extensions::Extension;
const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
type CreateFn = unsafe fn() -> *mut dyn Extension;
pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
"dylib"
} else if cfg!(target_os = "windows") {
"dll"
} else {
"so"
};
fn is_shared_library(path: &Path) -> bool {
path.extension()
.and_then(|e| e.to_str())
.map(|e| e == SHARED_LIB_EXTENSION)
.unwrap_or(false)
}
pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(home) = dirs::home_dir() {
let ext_dir = home.join(".oxi").join("extensions");
if ext_dir.is_dir() {
discover_in_dir(&ext_dir, &mut paths);
}
}
let project_ext_dir = cwd.join(".oxi").join("extensions");
if project_ext_dir.is_dir() {
discover_in_dir(&project_ext_dir, &mut paths);
}
for extra in extra_paths {
if extra.is_dir() {
discover_in_dir(extra, &mut paths);
} else if is_shared_library(extra) && extra.exists() {
paths.push(extra.clone());
}
}
paths.sort();
paths.dedup();
paths
}
pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
let mut paths = Vec::new();
discover_in_dir(dir, &mut paths);
paths
}
fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() && is_shared_library(&path) {
out.push(path);
}
}
}
pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
let path_display = path.display().to_string();
if !path.exists() {
anyhow::bail!("Extension file not found: {}", path_display);
}
if !is_shared_library(path) {
anyhow::bail!(
"Not a shared library (expected .{}): {}",
SHARED_LIB_EXTENSION,
path_display
);
}
let library = unsafe { Library::new(path) }
.map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
let create: libloading::Symbol<CreateFn> =
unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
anyhow::anyhow!(
"Symbol 'oxi_extension_create' not found in '{}': {}",
path_display,
e
)
})?;
let raw_ptr = unsafe { create() };
if raw_ptr.is_null() {
anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
}
let extension: Arc<dyn Extension> = unsafe {
let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
Arc::from(boxed)
};
tracing::info!(
name = %extension.name(),
path = %path_display,
"Extension loaded"
);
std::mem::forget(library);
Ok(extension)
}
pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
let mut loaded = Vec::new();
let mut errors = Vec::new();
for path in paths {
match load_extension(path) {
Ok(ext) => loaded.push(ext),
Err(e) => {
tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
errors.push(e);
}
}
}
(loaded, errors)
}
pub struct ValidatedExtension {
pub path: PathBuf,
pub checksum: String,
}
pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
if !path.exists() {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "File not found".into(),
});
}
let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Cannot read file metadata: {e}"),
})?;
if metadata.len() == 0 {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "Empty file".into(),
});
}
if metadata.len() > 100 * 1024 * 1024 {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "File too large (>100MB)".into(),
});
}
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
let valid_ext = match std::env::consts::OS {
"linux" => ext == "so",
"macos" => ext == "dylib",
"windows" => ext == "dll",
_ => true,
};
if !valid_ext {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Invalid extension: .{ext}"),
});
}
let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Cannot read file: {e}"),
})?;
let checksum = format!("{:x}", sha2::Sha256::digest(&data));
Ok(ValidatedExtension {
path: path.to_path_buf(),
checksum,
})
}