oxi/extensions/
loading.rs1use std::path::{Path, PathBuf};
29use std::sync::Arc;
30
31use libloading::Library;
32use sha2::Digest;
33
34use crate::extensions::types::ExtensionError;
35use crate::extensions::Extension;
36
37const ENTRY_SYMBOL: &[u8] = b"oxi_extension_create\0";
39
40type CreateFn = unsafe fn() -> *mut dyn Extension;
42
43pub const SHARED_LIB_EXTENSION: &str = if cfg!(target_os = "macos") {
45 "dylib"
46} else if cfg!(target_os = "windows") {
47 "dll"
48} else {
49 "so"
50};
51
52fn is_shared_library(path: &Path) -> bool {
54 path.extension()
55 .and_then(|e| e.to_str())
56 .map(|e| e == SHARED_LIB_EXTENSION)
57 .unwrap_or(false)
58}
59
60pub fn discover_extensions(cwd: &Path, extra_paths: &[PathBuf]) -> Vec<PathBuf> {
62 let mut paths = Vec::new();
63
64 if let Some(home) = dirs::home_dir() {
66 let ext_dir = home.join(".oxi").join("extensions");
67 if ext_dir.is_dir() {
68 discover_in_dir(&ext_dir, &mut paths);
69 }
70 }
71
72 let project_ext_dir = cwd.join(".oxi").join("extensions");
74 if project_ext_dir.is_dir() {
75 discover_in_dir(&project_ext_dir, &mut paths);
76 }
77
78 for extra in extra_paths {
80 if extra.is_dir() {
81 discover_in_dir(extra, &mut paths);
82 } else if is_shared_library(extra) && extra.exists() {
83 paths.push(extra.clone());
84 }
85 }
86
87 paths.sort();
88 paths.dedup();
89 paths
90}
91
92pub fn discover_extensions_in_dir(dir: &Path) -> Vec<PathBuf> {
94 let mut paths = Vec::new();
95 discover_in_dir(dir, &mut paths);
96 paths
97}
98
99fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
100 let Ok(entries) = std::fs::read_dir(dir) else {
101 return;
102 };
103 for entry in entries.flatten() {
104 let path = entry.path();
105 if path.is_file() && is_shared_library(&path) {
106 out.push(path);
107 }
108 }
109}
110
111pub fn load_extension(path: &Path) -> anyhow::Result<Arc<dyn Extension>> {
119 let path_display = path.display().to_string();
120
121 if !path.exists() {
122 anyhow::bail!("Extension file not found: {}", path_display);
123 }
124
125 if !is_shared_library(path) {
126 anyhow::bail!(
127 "Not a shared library (expected .{}): {}",
128 SHARED_LIB_EXTENSION,
129 path_display
130 );
131 }
132
133 let library = unsafe { Library::new(path) }
137 .map_err(|e| anyhow::anyhow!("Failed to load library '{}': {}", path_display, e))?;
138
139 let create: libloading::Symbol<CreateFn> =
142 unsafe { library.get(ENTRY_SYMBOL) }.map_err(|e| {
143 anyhow::anyhow!(
144 "Symbol 'oxi_extension_create' not found in '{}': {}",
145 path_display,
146 e
147 )
148 })?;
149
150 let raw_ptr = unsafe { create() };
154 if raw_ptr.is_null() {
155 anyhow::bail!("oxi_extension_create returned null in '{}'", path_display);
156 }
157
158 let extension: Arc<dyn Extension> = unsafe {
162 let boxed: Box<dyn Extension> = Box::from_raw(raw_ptr);
163 Arc::from(boxed)
164 };
165
166 tracing::info!(
167 name = %extension.name(),
168 path = %path_display,
169 "Extension loaded"
170 );
171
172 std::mem::forget(library);
177
178 Ok(extension)
179}
180
181pub fn load_extensions(paths: &[&Path]) -> (Vec<Arc<dyn Extension>>, Vec<anyhow::Error>) {
186 let mut loaded = Vec::new();
187 let mut errors = Vec::new();
188
189 for path in paths {
190 match load_extension(path) {
191 Ok(ext) => loaded.push(ext),
192 Err(e) => {
193 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
194 errors.push(e);
195 }
196 }
197 }
198
199 (loaded, errors)
200}
201
202pub struct ValidatedExtension {
204 pub path: PathBuf,
206 pub checksum: String,
208}
209
210pub fn validate_extension(path: &Path) -> Result<ValidatedExtension, ExtensionError> {
214 if !path.exists() {
215 return Err(ExtensionError::LoadFailed {
216 name: path.display().to_string(),
217 reason: "File not found".into(),
218 });
219 }
220
221 let metadata = std::fs::metadata(path).map_err(|e| ExtensionError::LoadFailed {
222 name: path.display().to_string(),
223 reason: format!("Cannot read file metadata: {e}"),
224 })?;
225
226 if metadata.len() == 0 {
227 return Err(ExtensionError::LoadFailed {
228 name: path.display().to_string(),
229 reason: "Empty file".into(),
230 });
231 }
232 if metadata.len() > 100 * 1024 * 1024 {
233 return Err(ExtensionError::LoadFailed {
234 name: path.display().to_string(),
235 reason: "File too large (>100MB)".into(),
236 });
237 }
238
239 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
240 let valid_ext = match std::env::consts::OS {
241 "linux" => ext == "so",
242 "macos" => ext == "dylib",
243 "windows" => ext == "dll",
244 _ => true,
245 };
246 if !valid_ext {
247 return Err(ExtensionError::LoadFailed {
248 name: path.display().to_string(),
249 reason: format!("Invalid extension: .{ext}"),
250 });
251 }
252
253 let data = std::fs::read(path).map_err(|e| ExtensionError::LoadFailed {
254 name: path.display().to_string(),
255 reason: format!("Cannot read file: {e}"),
256 })?;
257 let checksum = format!("{:x}", sha2::Sha256::digest(&data));
258
259 Ok(ValidatedExtension {
260 path: path.to_path_buf(),
261 checksum,
262 })
263}