Skip to main content

oxi/extensions/
loading.rs

1//! Extension dynamic loading.
2//!
3//! Loads Rust extensions compiled as `cdylib` shared libraries (`.dylib`/`.so`/`.dll`).
4//!
5//! # Extension ABI
6//!
7//! Every extension must export a single entry point:
8//!
9//! ```ignore
10//! #[no_mangle]
11//! pub extern "C" fn oxi_extension_create() -> *mut oxi_cli::extensions::Extension {
12//!     Box::into_raw(Box::new(MyExtension))
13//! }
14//! ```
15//!
16//! # Directory layout
17//!
18//! ```text
19//! ~/.oxi/extensions/
20//!   ├── my_ext.dylib    # macOS
21//!   ├── other_ext.so    # Linux
22//!   └── win_ext.dll     # Windows
23//! ```
24//!
25//! Extensions are discovered in `~/.oxi/extensions/` and any extra paths
26//! configured in settings.
27
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::types::ExtensionError;
35use crate::extensions::Extension;
36
37/// Entry point symbol that every extension must export.
38const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40/// Function signature for the extension creation entry point.
41type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43/// Shared library extension for the current platform.
44pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45    "dylib"
46} else if cfg!(target_os = "windows") {
47    "dll"
48} else {
49    "so"
50};
51
52/// Check if a file looks like a shared library for the current platform.
53fn is_shared_library(path: &Path) -> bool {
54    path.extension()
55        .and_then(|e| e.to_str())
56        .map(|e| e == SHARED_LIB_EXTENSION)
57        .unwrap_or(false)
58}
59
60/// Discover extension shared libraries in `~/.oxi/extensions/` and extra paths.
61pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62    let mut paths = Vec::new();
63
64    // ~/.oxi/extensions/
65    if let Some(home) = dirs::home_dir() {
66        let ext_dir = home.join(".oxi").join("extensions");
67        if ext_dir.is_dir() {
68            discover_in_dir(&ext_dir, &mut paths);
69        }
70    }
71
72    // .oxi/extensions/ (project-local)
73    let project_ext_dir = cwd.join(".oxi").join("extensions");
74    if project_ext_dir.is_dir() {
75        discover_in_dir(&project_ext_dir, &mut paths);
76    }
77
78    // Extra paths from settings
79    for extra in extra_paths {
80        if extra.is_dir() {
81            discover_in_dir(extra, &mut paths);
82        } else if is_shared_library(extra) && extra.exists() {
83            paths.push(extra.clone());
84        }
85    }
86
87    paths.sort();
88    paths.dedup();
89    paths
90}
91
92/// Discover extension shared libraries in a single directory.
93pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94    let mut paths = Vec::new();
95    discover_in_dir(dir, &mut paths);
96    paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100    let Ok(entries) = std::fs::read_dir(dir) else {
101        return;
102    };
103    for entry in entries.flatten() {
104        let path = entry.path();
105        if path.is_file() && is_shared_library(&path) {
106            out.push(path);
107        }
108    }
109}
110
111/// Load a single extension from a shared library.
112///
113/// # Safety
114///
115/// The loaded library must export `oxi_extension_create` returning a valid
116/// pointer to a `dyn Extension`. The library must have been compiled with
117/// a compatible Rust toolchain version.
118pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
119    let path_display = path.display().to_string();
120
121    if !path.exists() {
122        anyhow::bail!("Extension file not found: {}", path_display);
123    }
124
125    if !is_shared_library(path) {
126        anyhow::bail!(
127            "Not a shared library (expected .{}): {}",
128            SHARED_LIB_EXTENSION,
129            path_display
130        );
131    }
132
133    // Load the library
134    let library = unsafe { Library::new(path) }
135        .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
136
137    // Get the entry point symbol
138    let create: libloading::Symbol<CreateFn> =
139        unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
140            anyhow::anyhow!(
141                "Symbol 'oxi_extension_create' not found in '{}': {}",
142                path_display,
143                e
144            )
145        })?;
146
147    // Call the entry point
148    let raw_ptr = unsafe { create() };
149    if raw_ptr.is_null() {
150        anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
151    }
152
153    // Wrap in Arc (take ownership from the raw pointer)
154    let extension: Arc<dyn Extension> = unsafe {
155        let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
156        Arc::from(boxed)
157    };
158
159    tracing::info!(
160        name = %extension.name(),
161        path = %path_display,
162        "Extension loaded"
163    );
164
165    // IMPORTANT: We must keep the Library alive for the entire lifetime
166    // of the extension. Leak it intentionally — the extension's code lives
167    // in this library. Unloading it while extension objects exist would
168    // cause undefined behavior.
169    std::mem::forget(library);
170
171    Ok(extension)
172}
173
174/// Load multiple extensions from the given paths.
175///
176/// Returns successfully loaded extensions and any errors encountered.
177/// Does not abort on individual failures — loads as many as possible.
178pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
179    let mut loaded = Vec::new();
180    let mut errors = Vec::new();
181
182    for path in paths {
183        match load_extension(path) {
184            Ok(ext) => loaded.push(ext),
185            Err(e) => {
186                tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
187                errors.push(e);
188            }
189        }
190    }
191
192    (loaded, errors)
193}
194
195/// Extension binary validation result.
196pub struct ValidatedExtension {
197    /// Path to the validated extension binary.
198    pub path: PathBuf,
199    /// SHA-256 hex digest of the file contents.
200    pub checksum: String,
201}
202
203/// Perform pre-load validation on an extension binary.
204///
205/// Checks file existence, size bounds, and platform-appropriate extension.
206pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
207    if !path.exists() {
208        return Err(ExtensionError::LoadFailed {
209            name: path.display().to_string(),
210            reason: "File not found".into(),
211        });
212    }
213
214    let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
215        name: path.display().to_string(),
216        reason: format!("Cannot read file metadata: {e}"),
217    })?;
218
219    if metadata.len() == 0 {
220        return Err(ExtensionError::LoadFailed {
221            name: path.display().to_string(),
222            reason: "Empty file".into(),
223        });
224    }
225    if metadata.len() > 100 * 1024 * 1024 {
226        return Err(ExtensionError::LoadFailed {
227            name: path.display().to_string(),
228            reason: "File too large (>100MB)".into(),
229        });
230    }
231
232    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
233    let valid_ext = match std::env::consts::OS {
234        "linux" => ext == "so",
235        "macos" => ext == "dylib",
236        "windows" => ext == "dll",
237        _ => true,
238    };
239    if !valid_ext {
240        return Err(ExtensionError::LoadFailed {
241            name: path.display().to_string(),
242            reason: format!("Invalid extension: .{ext}"),
243        });
244    }
245
246    let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
247        name: path.display().to_string(),
248        reason: format!("Cannot read file: {e}"),
249    })?;
250    let checksum = format!("{:x}", sha2::Sha256::digest(&data));
251
252    Ok(ValidatedExtension {
253        path: path.to_path_buf(),
254        checksum,
255    })
256}