oxi/extensions/
loading.rs1use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::types::ExtensionError;
35use crate::extensions::Extension;
36
37const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45 "dylib"
46} else if cfg!(target_os = "windows") {
47 "dll"
48} else {
49 "so"
50};
51
52fn is_shared_library(path: &Path) -> bool {
54 path.extension()
55 .and_then(|e| e.to_str())
56 .map(|e| e == SHARED_LIB_EXTENSION)
57 .unwrap_or(false)
58}
59
60pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62 let mut paths = Vec::new();
63
64 if let Some(home) = dirs::home_dir() {
66 let ext_dir = home.join(".oxi").join("extensions");
67 if ext_dir.is_dir() {
68 discover_in_dir(&ext_dir, &mut paths);
69 }
70 }
71
72 let project_ext_dir = cwd.join(".oxi").join("extensions");
74 if project_ext_dir.is_dir() {
75 discover_in_dir(&project_ext_dir, &mut paths);
76 }
77
78 for extra in extra_paths {
80 if extra.is_dir() {
81 discover_in_dir(extra, &mut paths);
82 } else if is_shared_library(extra) && extra.exists() {
83 paths.push(extra.clone());
84 }
85 }
86
87 paths.sort();
88 paths.dedup();
89 paths
90}
91
92pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94 let mut paths = Vec::new();
95 discover_in_dir(dir, &mut paths);
96 paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100 let Ok(entries) = std::fs::read_dir(dir) else {
101 return;
102 };
103 for entry in entries.flatten() {
104 let path = entry.path();
105 if path.is_file() && is_shared_library(&path) {
106 out.push(path);
107 }
108 }
109}
110
111pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
119 let path_display = path.display().to_string();
120
121 if !path.exists() {
122 anyhow::bail!("Extension file not found: {}", path_display);
123 }
124
125 if !is_shared_library(path) {
126 anyhow::bail!(
127 "Not a shared library (expected .{}): {}",
128 SHARED_LIB_EXTENSION,
129 path_display
130 );
131 }
132
133 let library = unsafe { Library::new(path) }
135 .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
136
137 let create: libloading::Symbol<CreateFn> =
139 unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
140 anyhow::anyhow!(
141 "Symbol 'oxi_extension_create' not found in '{}': {}",
142 path_display,
143 e
144 )
145 })?;
146
147 let raw_ptr = unsafe { create() };
149 if raw_ptr.is_null() {
150 anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
151 }
152
153 let extension: Arc<dyn Extension> = unsafe {
155 let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
156 Arc::from(boxed)
157 };
158
159 tracing::info!(
160 name = %extension.name(),
161 path = %path_display,
162 "Extension loaded"
163 );
164
165 std::mem::forget(library);
170
171 Ok(extension)
172}
173
174pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
179 let mut loaded = Vec::new();
180 let mut errors = Vec::new();
181
182 for path in paths {
183 match load_extension(path) {
184 Ok(ext) => loaded.push(ext),
185 Err(e) => {
186 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
187 errors.push(e);
188 }
189 }
190 }
191
192 (loaded, errors)
193}
194
195pub struct ValidatedExtension {
197 pub path: PathBuf,
199 pub checksum: String,
201}
202
203pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
207 if !path.exists() {
208 return Err(ExtensionError::LoadFailed {
209 name: path.display().to_string(),
210 reason: "File not found".into(),
211 });
212 }
213
214 let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
215 name: path.display().to_string(),
216 reason: format!("Cannot read file metadata: {e}"),
217 })?;
218
219 if metadata.len() == 0 {
220 return Err(ExtensionError::LoadFailed {
221 name: path.display().to_string(),
222 reason: "Empty file".into(),
223 });
224 }
225 if metadata.len() > 100 * 1024 * 1024 {
226 return Err(ExtensionError::LoadFailed {
227 name: path.display().to_string(),
228 reason: "File too large (>100MB)".into(),
229 });
230 }
231
232 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
233 let valid_ext = match std::env::consts::OS {
234 "linux" => ext == "so",
235 "macos" => ext == "dylib",
236 "windows" => ext == "dll",
237 _ => true,
238 };
239 if !valid_ext {
240 return Err(ExtensionError::LoadFailed {
241 name: path.display().to_string(),
242 reason: format!("Invalid extension: .{ext}"),
243 });
244 }
245
246 let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
247 name: path.display().to_string(),
248 reason: format!("Cannot read file: {e}"),
249 })?;
250 let checksum = format!("{:x}", sha2::Sha256::digest(&data));
251
252 Ok(ValidatedExtension {
253 path: path.to_path_buf(),
254 checksum,
255 })
256}