oxi/extensions/
loading.rs1use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::Extension;
35use crate::extensions::types::ExtensionError;
36
37const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45 "dylib"
46} else if cfg!(target_os = "windows") {
47 "dll"
48} else {
49 "so"
50};
51
52fn is_shared_library(path: &Path) -> bool {
54 path.extension()
55 .and_then(|e| e.to_str())
56 .map(|e| e == SHARED_LIB_EXTENSION)
57 .unwrap_or(false)
58}
59
60pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62 let mut paths = Vec::new();
63
64 if let Some(home) = dirs::home_dir() {
66 let ext_dir = home.join(".oxi").join("extensions");
67 if ext_dir.is_dir() {
68 discover_in_dir(&ext_dir, &mut paths);
69 }
70 }
71
72 let project_ext_dir = cwd.join(".oxi").join("extensions");
74 if project_ext_dir.is_dir() {
75 discover_in_dir(&project_ext_dir, &mut paths);
76 }
77
78 for extra in extra_paths {
80 if extra.is_dir() {
81 discover_in_dir(extra, &mut paths);
82 } else if is_shared_library(extra) && extra.exists() {
83 paths.push(extra.clone());
84 }
85 }
86
87 paths.sort();
88 paths.dedup();
89 paths
90}
91
92pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94 let mut paths = Vec::new();
95 discover_in_dir(dir, &mut paths);
96 paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100 let Ok(entries) = std::fs::read_dir(dir) else {
101 return;
102 };
103 for entry in entries.flatten() {
104 let path = entry.path();
105 if path.is_file() && is_shared_library(&path) {
106 out.push(path);
107 }
108 }
109}
110
111pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
119 let path_display = path.display().to_string();
120 if std::env::var("OXI_NATIVE_EXTENSIONS").ok().as_deref() != Some("1") {
125 tracing::warn!(
126 path = %path_display,
127 "native extension skipped — set OXI_NATIVE_EXTENSIONS=1 to load unsandboxed extensions"
128 );
129 anyhow::bail!(
130 "Native extensions are disabled; set OXI_NATIVE_EXTENSIONS=1 to load '{}'",
131 path_display
132 );
133 }
134
135 if !path.exists() {
136 anyhow::bail!("Extension file not found: {}", path_display);
137 }
138
139 if !is_shared_library(path) {
140 anyhow::bail!(
141 "Not a shared library (expected .{}): {}",
142 SHARED_LIB_EXTENSION,
143 path_display
144 );
145 }
146
147 let library = unsafe { Library::new(path) }
151 .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
152
153 let create: libloading::Symbol<CreateFn> =
156 unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
157 anyhow::anyhow!(
158 "Symbol 'oxi_extension_create' not found in '{}': {}",
159 path_display,
160 e
161 )
162 })?;
163
164 let raw_ptr = unsafe { create() };
168 if raw_ptr.is_null() {
169 anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
170 }
171
172 let extension: Arc<dyn Extension> = unsafe {
176 let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
177 Arc::from(boxed)
178 };
179
180 tracing::info!(
181 name = %extension.name(),
182 path = %path_display,
183 "Extension loaded"
184 );
185
186 std::mem::forget(library);
191
192 Ok(extension)
193}
194
195pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
200 let mut loaded = Vec::new();
201 let mut errors = Vec::new();
202
203 for path in paths {
204 match load_extension(path) {
205 Ok(ext) => loaded.push(ext),
206 Err(e) => {
207 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
208 errors.push(e);
209 }
210 }
211 }
212
213 (loaded, errors)
214}
215
216pub struct ValidatedExtension {
218 pub path: PathBuf,
220 pub checksum: String,
222}
223
224pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
228 if !path.exists() {
229 return Err(ExtensionError::LoadFailed {
230 name: path.display().to_string(),
231 reason: "File not found".into(),
232 });
233 }
234
235 let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
236 name: path.display().to_string(),
237 reason: format!("Cannot read file metadata: {e}"),
238 })?;
239
240 if metadata.len() == 0 {
241 return Err(ExtensionError::LoadFailed {
242 name: path.display().to_string(),
243 reason: "Empty file".into(),
244 });
245 }
246 if metadata.len() > 100 * 1024 * 1024 {
247 return Err(ExtensionError::LoadFailed {
248 name: path.display().to_string(),
249 reason: "File too large (>100MB)".into(),
250 });
251 }
252
253 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
254 let valid_ext = match std::env::consts::OS {
255 "linux" => ext == "so",
256 "macos" => ext == "dylib",
257 "windows" => ext == "dll",
258 _ => true,
259 };
260 if !valid_ext {
261 return Err(ExtensionError::LoadFailed {
262 name: path.display().to_string(),
263 reason: format!("Invalid extension: .{ext}"),
264 });
265 }
266
267 let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
268 name: path.display().to_string(),
269 reason: format!("Cannot read file: {e}"),
270 })?;
271 let checksum = format!("{:x}", sha2::Sha256::digest(&data));
272
273 Ok(ValidatedExtension {
274 path: path.to_path_buf(),
275 checksum,
276 })
277}