Skip to main content

oxi/extensions/
loading.rs

1//! Extension dynamic loading.
2//!
3//! Loads Rust extensions compiled as `cdylib` shared libraries (`.dylib`/`.so`/`.dll`).
4//!
5//! # Extension ABI
6//!
7//! Every extension must export a single entry point:
8//!
9//! ```ignore
10//! #[no_mangle]
11//! pub extern "C" fn oxi_extension_create() -> *mut oxi_cli::extensions::Extension {
12//!     Box::into_raw(Box::new(MyExtension))
13//! }
14//! ```
15//!
16//! # Directory layout
17//!
18//! ```text
19//! ~/.oxi/extensions/
20//!   ├── my_ext.dylib    # macOS
21//!   ├── other_ext.so    # Linux
22//!   └── win_ext.dll     # Windows
23//! ```
24//!
25//! Extensions are discovered in `~/.oxi/extensions/` and any extra paths
26//! configured in settings.
27
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::types::ExtensionError;
35use crate::extensions::Extension;
36
37/// Entry point symbol that every extension must export.
38const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40/// Function signature for the extension creation entry point.
41type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43/// Shared library extension for the current platform.
44pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45    "dylib"
46} else if cfg!(target_os = "windows") {
47    "dll"
48} else {
49    "so"
50};
51
52/// Check if a file looks like a shared library for the current platform.
53fn is_shared_library(path: &Path) -> bool {
54    path.extension()
55        .and_then(|e| e.to_str())
56        .map(|e| e == SHARED_LIB_EXTENSION)
57        .unwrap_or(false)
58}
59
60/// Discover extension shared libraries in `~/.oxi/extensions/` and extra paths.
61pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62    let mut paths = Vec::new();
63
64    // ~/.oxi/extensions/
65    if let Some(home) = dirs::home_dir() {
66        let ext_dir = home.join(".oxi").join("extensions");
67        if ext_dir.is_dir() {
68            discover_in_dir(&ext_dir, &mut paths);
69        }
70    }
71
72    // .oxi/extensions/ (project-local)
73    let project_ext_dir = cwd.join(".oxi").join("extensions");
74    if project_ext_dir.is_dir() {
75        discover_in_dir(&project_ext_dir, &mut paths);
76    }
77
78    // Extra paths from settings
79    for extra in extra_paths {
80        if extra.is_dir() {
81            discover_in_dir(extra, &mut paths);
82        } else if is_shared_library(extra) && extra.exists() {
83            paths.push(extra.clone());
84        }
85    }
86
87    paths.sort();
88    paths.dedup();
89    paths
90}
91
92/// Discover extension shared libraries in a single directory.
93pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94    let mut paths = Vec::new();
95    discover_in_dir(dir, &mut paths);
96    paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100    let Ok(entries) = std::fs::read_dir(dir) else {
101        return;
102    };
103    for entry in entries.flatten() {
104        let path = entry.path();
105        if path.is_file() && is_shared_library(&path) {
106            out.push(path);
107        }
108    }
109}
110
111/// Load a single extension from a shared library.
112///
113/// # Safety
114///
115/// The loaded library must export `oxi_extension_create` returning a valid
116/// pointer to a `dyn Extension`. The library must have been compiled with
117/// a compatible Rust toolchain version.
118pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
119    let path_display = path.display().to_string();
120
121    if !path.exists() {
122        anyhow::bail!("Extension file not found: {}", path_display);
123    }
124
125    if !is_shared_library(path) {
126        anyhow::bail!(
127            "Not a shared library (expected .{}): {}",
128            SHARED_LIB_EXTENSION,
129            path_display
130        );
131    }
132
133    // SAFETY: Library::new loads a shared library from the given path.
134    // This is unsafe because the loaded code can perform arbitrary operations.
135    // We trust the user-installed extension at the given path.
136    let library = unsafe { Library::new(path) }
137        .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
138
139    // SAFETY: library.get looks up a symbol by name in the loaded shared library.
140    // The symbol name is a static constant, not user-controlled.
141    let create: libloading::Symbol<CreateFn> =
142        unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
143            anyhow::anyhow!(
144                "Symbol 'oxi_extension_create' not found in '{}': {}",
145                path_display,
146                e
147            )
148        })?;
149
150    // SAFETY: Calling the extension's oxi_extension_create entry point.
151    // The function signature is `unsafe fn() -> *mut dyn Extension`.
152    // We check the returned pointer for null below.
153    let raw_ptr = unsafe { create() };
154    if raw_ptr.is_null() {
155        anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
156    }
157
158    // SAFETY: Box::from_raw takes ownership of the pointer returned by
159    // oxi_extension_create. The extension must have allocated this with
160    // Box::new (documented contract). Null was checked above.
161    let extension: Arc<dyn Extension> = unsafe {
162        let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
163        Arc::from(boxed)
164    };
165
166    tracing::info!(
167        name = %extension.name(),
168        path = %path_display,
169        "Extension loaded"
170    );
171
172    // IMPORTANT: We must keep the Library alive for the entire lifetime
173    // of the extension. Leak it intentionally — the extension's code lives
174    // in this library. Unloading it while extension objects exist would
175    // cause undefined behavior.
176    std::mem::forget(library);
177
178    Ok(extension)
179}
180
181/// Load multiple extensions from the given paths.
182///
183/// Returns successfully loaded extensions and any errors encountered.
184/// Does not abort on individual failures — loads as many as possible.
185pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
186    let mut loaded = Vec::new();
187    let mut errors = Vec::new();
188
189    for path in paths {
190        match load_extension(path) {
191            Ok(ext) => loaded.push(ext),
192            Err(e) => {
193                tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
194                errors.push(e);
195            }
196        }
197    }
198
199    (loaded, errors)
200}
201
202/// Extension binary validation result.
203pub struct ValidatedExtension {
204    /// Path to the validated extension binary.
205    pub path: PathBuf,
206    /// SHA-256 hex digest of the file contents.
207    pub checksum: String,
208}
209
210/// Perform pre-load validation on an extension binary.
211///
212/// Checks file existence, size bounds, and platform-appropriate extension.
213pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
214    if !path.exists() {
215        return Err(ExtensionError::LoadFailed {
216            name: path.display().to_string(),
217            reason: "File not found".into(),
218        });
219    }
220
221    let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
222        name: path.display().to_string(),
223        reason: format!("Cannot read file metadata: {e}"),
224    })?;
225
226    if metadata.len() == 0 {
227        return Err(ExtensionError::LoadFailed {
228            name: path.display().to_string(),
229            reason: "Empty file".into(),
230        });
231    }
232    if metadata.len() > 100 * 1024 * 1024 {
233        return Err(ExtensionError::LoadFailed {
234            name: path.display().to_string(),
235            reason: "File too large (>100MB)".into(),
236        });
237    }
238
239    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
240    let valid_ext = match std::env::consts::OS {
241        "linux" => ext == "so",
242        "macos" => ext == "dylib",
243        "windows" => ext == "dll",
244        _ => true,
245    };
246    if !valid_ext {
247        return Err(ExtensionError::LoadFailed {
248            name: path.display().to_string(),
249            reason: format!("Invalid extension: .{ext}"),
250        });
251    }
252
253    let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
254        name: path.display().to_string(),
255        reason: format!("Cannot read file: {e}"),
256    })?;
257    let checksum = format!("{:x}", sha2::Sha256::digest(&data));
258
259    Ok(ValidatedExtension {
260        path: path.to_path_buf(),
261        checksum,
262    })
263}