use std::path::{Path, PathBuf};
use std::sync::Arc;
use libloading::Library;
use sha2::Digest;
use crate::extensions::Extension;
use crate::extensions::types::ExtensionError;
const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
type CreateFn = unsafe fn() -> *mut dyn Extension;
pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
"dylib"
} else if cfg!(target_os = "windows") {
"dll"
} else {
"so"
};
fn is_shared_library(path: &Path) -> bool {
path.extension()
.and_then(|e| e.to_str())
.map(|e| e == SHARED_LIB_EXTENSION)
.unwrap_or(false)
}
pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(home) = dirs::home_dir() {
let ext_dir = home.join(".oxi").join("extensions");
if ext_dir.is_dir() {
discover_in_dir(&ext_dir, &mut paths);
}
}
let project_ext_dir = cwd.join(".oxi").join("extensions");
if project_ext_dir.is_dir() {
discover_in_dir(&project_ext_dir, &mut paths);
}
for extra in extra_paths {
if extra.is_dir() {
discover_in_dir(extra, &mut paths);
} else if is_shared_library(extra) && extra.exists() {
paths.push(extra.clone());
}
}
paths.sort();
paths.dedup();
paths
}
pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
let mut paths = Vec::new();
discover_in_dir(dir, &mut paths);
paths
}
fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() && is_shared_library(&path) {
out.push(path);
}
}
}
pub fn load_extension(
path: &Path,
expected_checksum: Option<&str>,
) -> anyhow::Result<Arc<dyn Extension>> {
let path_display = path.display().to_string();
if std::env::var("OXI_NATIVE_EXTENSIONS").ok().as_deref() != Some("1") {
tracing::warn!(
path = %path_display,
"native extension skipped — set OXI_NATIVE_EXTENSIONS=1 to load unsandboxed extensions"
);
anyhow::bail!(
"Native extensions are disabled; set OXI_NATIVE_EXTENSIONS=1 to load '{}'",
path_display
);
}
if !path.exists() {
anyhow::bail!("Extension file not found: {}", path_display);
}
if !is_shared_library(path) {
anyhow::bail!(
"Not a shared library (expected .{}): {}",
SHARED_LIB_EXTENSION,
path_display
);
}
let validated = validate_extension(path).map_err(|e| {
anyhow::anyhow!(
"native extension pre-load validation failed for '{}': {}",
path_display,
e
)
})?;
if let Some(expected) = expected_checksum {
if !validated.checksum.eq_ignore_ascii_case(expected) {
anyhow::bail!(
"native extension checksum mismatch for '{}': expected sha256-{expected}, got sha256-{}",
path_display,
validated.checksum
);
}
tracing::debug!(
path = %path_display,
checksum = %validated.checksum,
"native extension integrity verified"
);
} else {
tracing::warn!(
path = %path_display,
"loading native extension WITHOUT integrity verification — caller passed None"
);
}
let library = unsafe { Library::new(path) }
.map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
let create: libloading::Symbol<CreateFn> =
unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
anyhow::anyhow!(
"Symbol 'oxi_extension_create' not found in '{}': {}",
path_display,
e
)
})?;
let raw_ptr = unsafe { create() };
if raw_ptr.is_null() {
anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
}
let extension: Arc<dyn Extension> = unsafe {
let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
Arc::from(boxed)
};
tracing::info!(
name = %extension.name(),
path = %path_display,
"Extension loaded"
);
std::mem::forget(library);
Ok(extension)
}
pub fn load_extensions(
paths: &[&Path],
checksums: &[Option<&str>],
) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
assert_eq!(
paths.len(),
checksums.len(),
"load_extensions: paths and checksums must be parallel slices"
);
let mut loaded = Vec::new();
let mut errors = Vec::new();
for (path, expected) in paths.iter().zip(checksums.iter()) {
match load_extension(path, *expected) {
Ok(ext) => loaded.push(ext),
Err(e) => {
tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
errors.push(e);
}
}
}
(loaded, errors)
}
#[derive(Debug)]
pub struct ValidatedExtension {
pub path: PathBuf,
pub checksum: String,
}
pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
if !path.exists() {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "File not found".into(),
});
}
let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Cannot read file metadata: {e}"),
})?;
if metadata.len() == 0 {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "Empty file".into(),
});
}
if metadata.len() > 100 * 1024 * 1024 {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: "File too large (>100MB)".into(),
});
}
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
let valid_ext = match std::env::consts::OS {
"linux" => ext == "so",
"macos" => ext == "dylib",
"windows" => ext == "dll",
_ => true,
};
if !valid_ext {
return Err(ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Invalid extension: .{ext}"),
});
}
let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
name: path.display().to_string(),
reason: format!("Cannot read file: {e}"),
})?;
let checksum = format!("{:x}", sha2::Sha256::digest(&data));
Ok(ValidatedExtension {
path: path.to_path_buf(),
checksum,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn write_fake_ext(path: &Path, payload: &[u8]) {
let mut f = std::fs::File::create(path).unwrap();
f.write_all(payload).unwrap();
}
#[test]
fn validate_extension_is_deterministic() {
let tmp = tempfile::tempdir().unwrap();
let ext_path = tmp.path().join(format!("lib.{}", SHARED_LIB_EXTENSION));
write_fake_ext(&ext_path, b"deterministic test payload");
let v1 = validate_extension(&ext_path).expect("validate should succeed");
let v2 = validate_extension(&ext_path).expect("validate should succeed");
assert_eq!(v1.checksum, v2.checksum);
assert_eq!(v1.checksum.len(), 64);
assert!(
v1.checksum
.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
);
}
#[test]
fn validate_extension_distinguishes_content() {
let tmp = tempfile::tempdir().unwrap();
let ext_a = tmp.path().join(format!("a.{}", SHARED_LIB_EXTENSION));
let ext_b = tmp.path().join(format!("b.{}", SHARED_LIB_EXTENSION));
write_fake_ext(&ext_a, b"alpha");
write_fake_ext(&ext_b, b"beta");
let v_a = validate_extension(&ext_a).unwrap();
let v_b = validate_extension(&ext_b).unwrap();
assert_ne!(v_a.checksum, v_b.checksum);
}
#[test]
#[cfg(target_os = "macos")]
fn validate_extension_rejects_wrong_platform_ext_on_macos() {
let tmp = tempfile::tempdir().unwrap();
let wrong = tmp.path().join("lib.so");
write_fake_ext(&wrong, b"x");
let err = validate_extension(&wrong).expect_err("wrong platform ext must fail");
let msg = format!("{err}");
assert!(msg.contains("Invalid extension"), "unexpected err: {msg}");
}
#[test]
fn validate_extension_handles_missing_path() {
let tmp = tempfile::tempdir().unwrap();
let missing = tmp.path().join("does-not-exist.dylib");
let err = validate_extension(&missing).expect_err("missing path must fail");
assert!(format!("{err}").contains("File not found"));
}
}