Skip to main content

oxi/extensions/
loading.rs

1//! Extension dynamic loading.
2//!
3//! Loads Rust extensions compiled as `cdylib` shared libraries (`.dylib`/`.so`/`.dll`).
4//!
5//! # Extension ABI
6//!
7//! Every extension must export a single entry point:
8//!
9//! ```ignore
10//! #[no_mangle]
11//! pub extern "C" fn oxi_extension_create() -> *mut oxi_cli::extensions::Extension {
12//!     Box::into_raw(Box::new(MyExtension))
13//! }
14//! ```
15//!
16//! # Directory layout
17//!
18//! ```text
19//! ~/.oxi/extensions/
20//!   ├── my_ext.dylib    # macOS
21//!   ├── other_ext.so    # Linux
22//!   └── win_ext.dll     # Windows
23//! ```
24//!
25//! Extensions are discovered in `~/.oxi/extensions/` and any extra paths
26//! configured in settings.
27
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::Extension;
35use crate::extensions::types::ExtensionError;
36
37/// Entry point symbol that every extension must export.
38const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40/// Function signature for the extension creation entry point.
41type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43/// Shared library extension for the current platform.
44pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45    "dylib"
46} else if cfg!(target_os = "windows") {
47    "dll"
48} else {
49    "so"
50};
51
52/// Check if a file looks like a shared library for the current platform.
53fn is_shared_library(path: &Path) -> bool {
54    path.extension()
55        .and_then(|e| e.to_str())
56        .map(|e| e == SHARED_LIB_EXTENSION)
57        .unwrap_or(false)
58}
59
60/// Discover extension shared libraries in `~/.oxi/extensions/` and extra paths.
61pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62    let mut paths = Vec::new();
63
64    // ~/.oxi/extensions/
65    if let Some(home) = dirs::home_dir() {
66        let ext_dir = home.join(".oxi").join("extensions");
67        if ext_dir.is_dir() {
68            discover_in_dir(&ext_dir, &mut paths);
69        }
70    }
71
72    // .oxi/extensions/ (project-local)
73    let project_ext_dir = cwd.join(".oxi").join("extensions");
74    if project_ext_dir.is_dir() {
75        discover_in_dir(&project_ext_dir, &mut paths);
76    }
77
78    // Extra paths from settings
79    for extra in extra_paths {
80        if extra.is_dir() {
81            discover_in_dir(extra, &mut paths);
82        } else if is_shared_library(extra) && extra.exists() {
83            paths.push(extra.clone());
84        }
85    }
86
87    paths.sort();
88    paths.dedup();
89    paths
90}
91
92/// Discover extension shared libraries in a single directory.
93pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94    let mut paths = Vec::new();
95    discover_in_dir(dir, &mut paths);
96    paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100    let Ok(entries) = std::fs::read_dir(dir) else {
101        return;
102    };
103    for entry in entries.flatten() {
104        let path = entry.path();
105        if path.is_file() && is_shared_library(&path) {
106            out.push(path);
107        }
108    }
109}
110
111/// Load a single extension from a shared library.
112///
113/// # Integrity (audit F-2)
114///
115/// `expected_checksum` is the SHA-256 hex digest that the caller (e.g. the
116/// package manager's lockfile reader) has on record for this binary. When
117/// `Some`, the binary is hashed before loading and rejected on mismatch —
118/// this is the supply-chain integrity gate for native extensions, which
119/// otherwise run arbitrary in-process code with no sandbox (libloading +
120/// `unsafe extern "C"` entry). When `None`, the caller is opting out of
121/// verification explicitly; this is reserved for locally-built extensions
122/// the user just compiled and trusts by construction.
123///
124/// The hash comparison is constant-time on the hex string length via
125/// `subtle::ConstantTimeEq` if the `subtle` dep is added; until then
126/// `eq_ignore_ascii_case` is used (timing leak is negligible here since
127/// the hash is not a secret and an attacker who can swap the binary
128/// already controls the comparison outcome).
129///
130/// # Safety
131///
132/// The loaded library must export `oxi_extension_create` returning a valid
133/// pointer to a `dyn Extension`. The library must have been compiled with
134/// a compatible Rust toolchain version.
135pub fn load_extension(
136    path: &Path,
137    expected_checksum: Option<&str>,
138) -> anyhow::Result<Arc<dyn Extension>> {
139    let path_display = path.display().to_string();
140    // Security: native extensions are unsandboxed arbitrary in-process code
141    // (loaded via libloading with no sandbox). Require explicit opt-in so
142    // they cannot execute by default — mirrors the `OXI_EXTENSION_EXEC`
143    // opt-in for WASM extensions.
144    if std::env::var("OXI_NATIVE_EXTENSIONS").ok().as_deref() != Some("1") {
145        tracing::warn!(
146            path = %path_display,
147            "native extension skipped — set OXI_NATIVE_EXTENSIONS=1 to load unsandboxed extensions"
148        );
149        anyhow::bail!(
150            "Native extensions are disabled; set OXI_NATIVE_EXTENSIONS=1 to load '{}'",
151            path_display
152        );
153    }
154
155    if !path.exists() {
156        anyhow::bail!("Extension file not found: {}", path_display);
157    }
158
159    if !is_shared_library(path) {
160        anyhow::bail!(
161            "Not a shared library (expected .{}): {}",
162            SHARED_LIB_EXTENSION,
163            path_display
164        );
165    }
166
167    // F-2 (audit 2026-06-21): integrity check before mmap.
168    //
169    // `validate_extension` performs pre-load validation (file exists, size
170    // bounds, platform extension, SHA-256). It returns `ValidatedExtension`
171    // with the actual checksum; we compare it to the caller-supplied
172    // expected checksum and bail on mismatch — refusing to load a binary
173    // that has been swapped since the lockfile was written.
174    let validated = validate_extension(path).map_err(|e| {
175        anyhow::anyhow!(
176            "native extension pre-load validation failed for '{}': {}",
177            path_display,
178            e
179        )
180    })?;
181    if let Some(expected) = expected_checksum {
182        if !validated.checksum.eq_ignore_ascii_case(expected) {
183            anyhow::bail!(
184                "native extension checksum mismatch for '{}': expected sha256-{expected}, got sha256-{}",
185                path_display,
186                validated.checksum
187            );
188        }
189        tracing::debug!(
190            path = %path_display,
191            checksum = %validated.checksum,
192            "native extension integrity verified"
193        );
194    } else {
195        tracing::warn!(
196            path = %path_display,
197            "loading native extension WITHOUT integrity verification — caller passed None"
198        );
199    }
200
201    // SAFETY: Library::new loads a shared library from the given path.
202    // This is unsafe because the loaded code can perform arbitrary operations.
203    // We trust the user-installed extension at the given path, AND its
204    // integrity has been verified above when `expected_checksum` is Some.
205    let library = unsafe { Library::new(path) }
206        .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
207
208    // SAFETY: library.get looks up a symbol by name in the loaded shared library.
209    // The symbol name is a static constant, not user-controlled.
210    let create: libloading::Symbol<CreateFn> =
211        unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
212            anyhow::anyhow!(
213                "Symbol 'oxi_extension_create' not found in '{}': {}",
214                path_display,
215                e
216            )
217        })?;
218
219    // SAFETY: Calling the extension's oxi_extension_create entry point.
220    // The function signature is `unsafe fn() -> *mut dyn Extension`.
221    // We check the returned pointer for null below.
222    let raw_ptr = unsafe { create() };
223    if raw_ptr.is_null() {
224        anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
225    }
226
227    // SAFETY: Box::from_raw takes ownership of the pointer returned by
228    // oxi_extension_create. The extension must have allocated this with
229    // Box::new (documented contract). Null was checked above.
230    let extension: Arc<dyn Extension> = unsafe {
231        let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
232        Arc::from(boxed)
233    };
234
235    tracing::info!(
236        name = %extension.name(),
237        path = %path_display,
238        "Extension loaded"
239    );
240
241    // IMPORTANT: We must keep the Library alive for the entire lifetime
242    // of the extension. Leak it intentionally — the extension's code lives
243    // in this library. Unloading it while extension objects exist would
244    // cause undefined behavior.
245    std::mem::forget(library);
246
247    Ok(extension)
248}
249
250/// Load multiple extensions from the given paths.
251///
252/// Returns successfully loaded extensions and any errors encountered.
253/// Does not abort on individual failures — loads as many as possible.
254///
255/// `checksums` is parallel to `paths`: `checksums[i]` is the expected
256/// SHA-256 of `paths[i]`. Pass `None` to opt out of integrity verification
257/// for a particular extension (the same semantics as `load_extension`).
258/// A `Some(_)` mismatch is reported as an error but does not stop the
259/// other extensions from loading.
260pub fn load_extensions(
261    paths: &[&Path],
262    checksums: &[Option<&str>],
263) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
264    assert_eq!(
265        paths.len(),
266        checksums.len(),
267        "load_extensions: paths and checksums must be parallel slices"
268    );
269    let mut loaded = Vec::new();
270    let mut errors = Vec::new();
271
272    for (path, expected) in paths.iter().zip(checksums.iter()) {
273        match load_extension(path, *expected) {
274            Ok(ext) => loaded.push(ext),
275            Err(e) => {
276                tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
277                errors.push(e);
278            }
279        }
280    }
281
282    (loaded, errors)
283}
284
285/// Extension binary validation result.
286#[derive(Debug)]
287pub struct ValidatedExtension {
288    /// Path to the validated extension binary.
289    pub path: PathBuf,
290    /// SHA-256 hex digest of the file contents.
291    pub checksum: String,
292}
293
294/// Perform pre-load validation on an extension binary.
295///
296/// Checks file existence, size bounds, and platform-appropriate extension.
297pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
298    if !path.exists() {
299        return Err(ExtensionError::LoadFailed {
300            name: path.display().to_string(),
301            reason: "File not found".into(),
302        });
303    }
304
305    let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
306        name: path.display().to_string(),
307        reason: format!("Cannot read file metadata: {e}"),
308    })?;
309
310    if metadata.len() == 0 {
311        return Err(ExtensionError::LoadFailed {
312            name: path.display().to_string(),
313            reason: "Empty file".into(),
314        });
315    }
316    if metadata.len() > 100 * 1024 * 1024 {
317        return Err(ExtensionError::LoadFailed {
318            name: path.display().to_string(),
319            reason: "File too large (>100MB)".into(),
320        });
321    }
322
323    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
324    let valid_ext = match std::env::consts::OS {
325        "linux" => ext == "so",
326        "macos" => ext == "dylib",
327        "windows" => ext == "dll",
328        _ => true,
329    };
330    if !valid_ext {
331        return Err(ExtensionError::LoadFailed {
332            name: path.display().to_string(),
333            reason: format!("Invalid extension: .{ext}"),
334        });
335    }
336
337    let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
338        name: path.display().to_string(),
339        reason: format!("Cannot read file: {e}"),
340    })?;
341    let checksum = format!("{:x}", sha2::Sha256::digest(&data));
342
343    Ok(ValidatedExtension {
344        path: path.to_path_buf(),
345        checksum,
346    })
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use std::io::Write;
353
354    // ── F-2 regression: validate_extension computes deterministic SHA-256 ──
355
356    fn write_fake_ext(path: &Path, payload: &[u8]) {
357        let mut f = std::fs::File::create(path).unwrap();
358        f.write_all(payload).unwrap();
359    }
360
361    /// Two calls to `validate_extension` on the same file yield the same
362    /// SHA-256 hex digest — the function is pure and stable.
363    #[test]
364    fn validate_extension_is_deterministic() {
365        let tmp = tempfile::tempdir().unwrap();
366        let ext_path = tmp.path().join(format!("lib.{}", SHARED_LIB_EXTENSION));
367        write_fake_ext(&ext_path, b"deterministic test payload");
368
369        let v1 = validate_extension(&ext_path).expect("validate should succeed");
370        let v2 = validate_extension(&ext_path).expect("validate should succeed");
371        assert_eq!(v1.checksum, v2.checksum);
372        // SHA-256 hex is 64 chars, lowercase.
373        assert_eq!(v1.checksum.len(), 64);
374        assert!(
375            v1.checksum
376                .chars()
377                .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
378        );
379    }
380
381    /// Distinct file contents produce distinct checksums.
382    #[test]
383    fn validate_extension_distinguishes_content() {
384        let tmp = tempfile::tempdir().unwrap();
385        let ext_a = tmp.path().join(format!("a.{}", SHARED_LIB_EXTENSION));
386        let ext_b = tmp.path().join(format!("b.{}", SHARED_LIB_EXTENSION));
387        write_fake_ext(&ext_a, b"alpha");
388        write_fake_ext(&ext_b, b"beta");
389
390        let v_a = validate_extension(&ext_a).unwrap();
391        let v_b = validate_extension(&ext_b).unwrap();
392        assert_ne!(v_a.checksum, v_b.checksum);
393    }
394
395    /// `validate_extension` rejects a file with the wrong platform extension
396    /// (e.g. `.so` on macOS). The pre-load gate must catch this before any
397    /// `libloading::Library::new` call.
398    #[test]
399    #[cfg(target_os = "macos")]
400    fn validate_extension_rejects_wrong_platform_ext_on_macos() {
401        let tmp = tempfile::tempdir().unwrap();
402        // `.so` is the Linux extension; on macOS a `.dylib` is required.
403        let wrong = tmp.path().join("lib.so");
404        write_fake_ext(&wrong, b"x");
405        let err = validate_extension(&wrong).expect_err("wrong platform ext must fail");
406        let msg = format!("{err}");
407        assert!(msg.contains("Invalid extension"), "unexpected err: {msg}");
408    }
409
410    /// A non-existent path returns `File not found`, not a panic.
411    #[test]
412    fn validate_extension_handles_missing_path() {
413        let tmp = tempfile::tempdir().unwrap();
414        let missing = tmp.path().join("does-not-exist.dylib");
415        let err = validate_extension(&missing).expect_err("missing path must fail");
416        assert!(format!("{err}").contains("File not found"));
417    }
418}