oxi/extensions/
loading.rs1use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::Extension;
35use crate::extensions::types::ExtensionError;
36
37const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45 "dylib"
46} else if cfg!(target_os = "windows") {
47 "dll"
48} else {
49 "so"
50};
51
52fn is_shared_library(path: &Path) -> bool {
54 path.extension()
55 .and_then(|e| e.to_str())
56 .map(|e| e == SHARED_LIB_EXTENSION)
57 .unwrap_or(false)
58}
59
60pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62 let mut paths = Vec::new();
63
64 if let Some(home) = dirs::home_dir() {
66 let ext_dir = home.join(".oxi").join("extensions");
67 if ext_dir.is_dir() {
68 discover_in_dir(&ext_dir, &mut paths);
69 }
70 }
71
72 let project_ext_dir = cwd.join(".oxi").join("extensions");
74 if project_ext_dir.is_dir() {
75 discover_in_dir(&project_ext_dir, &mut paths);
76 }
77
78 for extra in extra_paths {
80 if extra.is_dir() {
81 discover_in_dir(extra, &mut paths);
82 } else if is_shared_library(extra) && extra.exists() {
83 paths.push(extra.clone());
84 }
85 }
86
87 paths.sort();
88 paths.dedup();
89 paths
90}
91
92pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94 let mut paths = Vec::new();
95 discover_in_dir(dir, &mut paths);
96 paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100 let Ok(entries) = std::fs::read_dir(dir) else {
101 return;
102 };
103 for entry in entries.flatten() {
104 let path = entry.path();
105 if path.is_file() && is_shared_library(&path) {
106 out.push(path);
107 }
108 }
109}
110
111pub fn load_extension(
136 path: &Path,
137 expected_checksum: Option<&str>,
138) -> anyhow::Result<Arc<dyn Extension>> {
139 let path_display = path.display().to_string();
140 if std::env::var("OXI_NATIVE_EXTENSIONS").ok().as_deref() != Some("1") {
145 tracing::warn!(
146 path = %path_display,
147 "native extension skipped — set OXI_NATIVE_EXTENSIONS=1 to load unsandboxed extensions"
148 );
149 anyhow::bail!(
150 "Native extensions are disabled; set OXI_NATIVE_EXTENSIONS=1 to load '{}'",
151 path_display
152 );
153 }
154
155 if !path.exists() {
156 anyhow::bail!("Extension file not found: {}", path_display);
157 }
158
159 if !is_shared_library(path) {
160 anyhow::bail!(
161 "Not a shared library (expected .{}): {}",
162 SHARED_LIB_EXTENSION,
163 path_display
164 );
165 }
166
167 let validated = validate_extension(path).map_err(|e| {
175 anyhow::anyhow!(
176 "native extension pre-load validation failed for '{}': {}",
177 path_display,
178 e
179 )
180 })?;
181 if let Some(expected) = expected_checksum {
182 if !validated.checksum.eq_ignore_ascii_case(expected) {
183 anyhow::bail!(
184 "native extension checksum mismatch for '{}': expected sha256-{expected}, got sha256-{}",
185 path_display,
186 validated.checksum
187 );
188 }
189 tracing::debug!(
190 path = %path_display,
191 checksum = %validated.checksum,
192 "native extension integrity verified"
193 );
194 } else {
195 tracing::warn!(
196 path = %path_display,
197 "loading native extension WITHOUT integrity verification — caller passed None"
198 );
199 }
200
201 let library = unsafe { Library::new(path) }
206 .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
207
208 let create: libloading::Symbol<CreateFn> =
211 unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
212 anyhow::anyhow!(
213 "Symbol 'oxi_extension_create' not found in '{}': {}",
214 path_display,
215 e
216 )
217 })?;
218
219 let raw_ptr = unsafe { create() };
223 if raw_ptr.is_null() {
224 anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
225 }
226
227 let extension: Arc<dyn Extension> = unsafe {
231 let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
232 Arc::from(boxed)
233 };
234
235 tracing::info!(
236 name = %extension.name(),
237 path = %path_display,
238 "Extension loaded"
239 );
240
241 std::mem::forget(library);
246
247 Ok(extension)
248}
249
250pub fn load_extensions(
261 paths: &[&Path],
262 checksums: &[Option<&str>],
263) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
264 assert_eq!(
265 paths.len(),
266 checksums.len(),
267 "load_extensions: paths and checksums must be parallel slices"
268 );
269 let mut loaded = Vec::new();
270 let mut errors = Vec::new();
271
272 for (path, expected) in paths.iter().zip(checksums.iter()) {
273 match load_extension(path, *expected) {
274 Ok(ext) => loaded.push(ext),
275 Err(e) => {
276 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
277 errors.push(e);
278 }
279 }
280 }
281
282 (loaded, errors)
283}
284
285#[derive(Debug)]
287pub struct ValidatedExtension {
288 pub path: PathBuf,
290 pub checksum: String,
292}
293
294pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
298 if !path.exists() {
299 return Err(ExtensionError::LoadFailed {
300 name: path.display().to_string(),
301 reason: "File not found".into(),
302 });
303 }
304
305 let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
306 name: path.display().to_string(),
307 reason: format!("Cannot read file metadata: {e}"),
308 })?;
309
310 if metadata.len() == 0 {
311 return Err(ExtensionError::LoadFailed {
312 name: path.display().to_string(),
313 reason: "Empty file".into(),
314 });
315 }
316 if metadata.len() > 100 * 1024 * 1024 {
317 return Err(ExtensionError::LoadFailed {
318 name: path.display().to_string(),
319 reason: "File too large (>100MB)".into(),
320 });
321 }
322
323 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
324 let valid_ext = match std::env::consts::OS {
325 "linux" => ext == "so",
326 "macos" => ext == "dylib",
327 "windows" => ext == "dll",
328 _ => true,
329 };
330 if !valid_ext {
331 return Err(ExtensionError::LoadFailed {
332 name: path.display().to_string(),
333 reason: format!("Invalid extension: .{ext}"),
334 });
335 }
336
337 let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
338 name: path.display().to_string(),
339 reason: format!("Cannot read file: {e}"),
340 })?;
341 let checksum = format!("{:x}", sha2::Sha256::digest(&data));
342
343 Ok(ValidatedExtension {
344 path: path.to_path_buf(),
345 checksum,
346 })
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use std::io::Write;
353
354 fn write_fake_ext(path: &Path, payload: &[u8]) {
357 let mut f = std::fs::File::create(path).unwrap();
358 f.write_all(payload).unwrap();
359 }
360
361 #[test]
364 fn validate_extension_is_deterministic() {
365 let tmp = tempfile::tempdir().unwrap();
366 let ext_path = tmp.path().join(format!("lib.{}", SHARED_LIB_EXTENSION));
367 write_fake_ext(&ext_path, b"deterministic test payload");
368
369 let v1 = validate_extension(&ext_path).expect("validate should succeed");
370 let v2 = validate_extension(&ext_path).expect("validate should succeed");
371 assert_eq!(v1.checksum, v2.checksum);
372 assert_eq!(v1.checksum.len(), 64);
374 assert!(
375 v1.checksum
376 .chars()
377 .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
378 );
379 }
380
381 #[test]
383 fn validate_extension_distinguishes_content() {
384 let tmp = tempfile::tempdir().unwrap();
385 let ext_a = tmp.path().join(format!("a.{}", SHARED_LIB_EXTENSION));
386 let ext_b = tmp.path().join(format!("b.{}", SHARED_LIB_EXTENSION));
387 write_fake_ext(&ext_a, b"alpha");
388 write_fake_ext(&ext_b, b"beta");
389
390 let v_a = validate_extension(&ext_a).unwrap();
391 let v_b = validate_extension(&ext_b).unwrap();
392 assert_ne!(v_a.checksum, v_b.checksum);
393 }
394
395 #[test]
399 #[cfg(target_os = "macos")]
400 fn validate_extension_rejects_wrong_platform_ext_on_macos() {
401 let tmp = tempfile::tempdir().unwrap();
402 let wrong = tmp.path().join("lib.so");
404 write_fake_ext(&wrong, b"x");
405 let err = validate_extension(&wrong).expect_err("wrong platform ext must fail");
406 let msg = format!("{err}");
407 assert!(msg.contains("Invalid extension"), "unexpected err: {msg}");
408 }
409
410 #[test]
412 fn validate_extension_handles_missing_path() {
413 let tmp = tempfile::tempdir().unwrap();
414 let missing = tmp.path().join("does-not-exist.dylib");
415 let err = validate_extension(&missing).expect_err("missing path must fail");
416 assert!(format!("{err}").contains("File not found"));
417 }
418}