1use anyhow::{Context, Result};
12use extism::{CurrentPlugin, Function, PTR, UserData, Val};
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::io::Read as _;
19use std::path::{Path, PathBuf};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ExtensionInfo {
28 pub name: String,
29 pub version: String,
30 #[serde(default)]
31 pub description: String,
32 #[serde(default)]
35 pub permissions: Vec<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct WasmToolDef {
41 pub name: String,
42 pub description: String,
43 pub schema: Value,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct WasmCommandDef {
49 pub name: String,
50 pub description: String,
51}
52
53#[derive(Debug)]
55pub struct LoadedWasmExtension {
56 pub info: ExtensionInfo,
57 pub tools: Vec<WasmToolDef>,
58 pub commands: Vec<WasmCommandDef>,
59 pub source_path: PathBuf,
60}
61
62fn host_oxi_http_request(
69 plugin: &mut CurrentPlugin,
70 inputs: &[Val],
71 outputs: &mut [Val],
72 user_data: UserData<Arc<reqwest::blocking::Client>>,
73) -> Result<(), extism::Error> {
74 let result: anyhow::Result<()> = (|| {
76 let input_json: String = plugin.memory_get_val(&inputs[0])?;
77
78 #[derive(Deserialize)]
79 struct HttpReq {
80 url: String,
81 #[serde(default)]
82 method: String,
83 #[serde(default)]
84 headers: HashMap<String, String>,
85 #[serde(default)]
86 body: Option<String>,
87 }
88
89 let req: HttpReq =
90 serde_json::from_str(&input_json).context("oxi_http_request: invalid request JSON")?;
91
92 let method = if req.method.is_empty() {
93 "GET"
94 } else {
95 &req.method
96 };
97
98 if let Err(e) = validate_url(&req.url) {
100 anyhow::bail!("oxi_http_request: {}", e);
101 }
102
103 let client_arc = user_data.get()?;
105 let client = client_arc.lock().expect("wasm client lock poisoned");
106
107 let method = match method.to_uppercase().as_str() {
108 "GET" => reqwest::Method::GET,
109 "POST" => reqwest::Method::POST,
110 "PUT" => reqwest::Method::PUT,
111 "DELETE" => reqwest::Method::DELETE,
112 "PATCH" => reqwest::Method::PATCH,
113 "HEAD" => reqwest::Method::HEAD,
114 other => anyhow::bail!("oxi_http_request: unsupported method '{}'", other),
115 };
116
117 let mut rb = client.request(method, &req.url);
118 for (k, v) in &req.headers {
119 rb = rb.header(k, v);
120 }
121 if let Some(body) = &req.body {
122 rb = rb.body(body.clone());
123 }
124
125 let resp = rb
127 .send()
128 .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
129 let status = resp.status().as_u16();
130 let resp_headers: HashMap<String, String> = resp
131 .headers()
132 .iter()
133 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
134 .collect();
135 let resp_body = {
137 let max_body = 1024 * 1024; let body_bytes = resp
139 .bytes()
140 .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
141 if body_bytes.len() > max_body {
142 tracing::warn!(
143 "HTTP response truncated: {} bytes > {} limit",
144 body_bytes.len(),
145 max_body
146 );
147 String::from_utf8_lossy(&body_bytes[..max_body]).to_string()
148 } else {
149 String::from_utf8_lossy(&body_bytes).to_string()
150 }
151 };
152
153 let response = serde_json::json!({
154 "status": status,
155 "headers": resp_headers,
156 "body": resp_body,
157 });
158
159 let output = serde_json::to_string(&response)?;
160 let handle = plugin.memory_new(&output)?;
161 if !outputs.is_empty() {
162 outputs[0] = plugin.memory_to_val(handle);
163 }
164 Ok(())
165 })();
166
167 result
168}
169
170fn host_oxi_log(
172 plugin: &mut CurrentPlugin,
173 inputs: &[Val],
174 _outputs: &mut [Val],
175 _user_data: UserData<()>,
176) -> Result<(), extism::Error> {
177 let message: String = plugin.memory_get_val(&inputs[0])?;
178 tracing::debug!("[WASM] {}", message);
179 Ok(())
180}
181
182fn host_oxi_read_file(
190 plugin: &mut CurrentPlugin,
191 inputs: &[Val],
192 outputs: &mut [Val],
193 _user_data: UserData<()>,
194) -> Result<(), extism::Error> {
195 let result: anyhow::Result<()> = (|| {
196 let input_json: String = plugin.memory_get_val(&inputs[0])?;
197
198 #[derive(Deserialize)]
199 struct ReadReq {
200 path: String,
201 #[serde(default)]
202 offset: Option<usize>,
203 #[serde(default = "default_limit")]
204 limit: usize,
205 }
206 fn default_limit() -> usize {
207 2000
208 }
209
210 let req: ReadReq =
211 serde_json::from_str(&input_json).context("oxi_read_file: invalid request JSON")?;
212
213 validate_path_allowed(&req.path)?;
215
216 let metadata = std::fs::metadata(&req.path);
217 match metadata {
218 Ok(m) => {
219 let max_bytes = 50 * 1024; let file_size = m.len() as usize;
221
222 let content = std::fs::read_to_string(&req.path)
223 .map_err(|e| anyhow::anyhow!("Failed to read file: {}", e))?;
224
225 let lines: Vec<&str> = content.lines().collect();
226 let total_lines = lines.len();
227
228 let offset = req.offset.unwrap_or(0).min(total_lines);
229 let end = (offset + req.limit).min(total_lines);
230 let selected: Vec<&str> = lines[offset..end].to_vec();
231 let mut result = selected.join("\n");
232
233 let truncated = result.len() > max_bytes;
234 if truncated {
235 result = result.chars().take(max_bytes).collect();
236 }
237
238 let response = serde_json::json!({
239 "success": true,
240 "content": result,
241 "truncated": truncated || end < total_lines,
242 "bytes": file_size,
243 "total_lines": total_lines,
244 "shown_lines": end - offset,
245 });
246 let output = serde_json::to_string(&response)?;
247 let handle = plugin.memory_new(&output)?;
248 if !outputs.is_empty() {
249 outputs[0] = plugin.memory_to_val(handle);
250 }
251 }
252 Err(e) => {
253 let response = serde_json::json!({
254 "success": false,
255 "error": format!("File not found: {}", e),
256 });
257 let output = serde_json::to_string(&response)?;
258 let handle = plugin.memory_new(&output)?;
259 if !outputs.is_empty() {
260 outputs[0] = plugin.memory_to_val(handle);
261 }
262 }
263 }
264 Ok(())
265 })();
266 result
267}
268
269fn host_oxi_write_file(
274 plugin: &mut CurrentPlugin,
275 inputs: &[Val],
276 outputs: &mut [Val],
277 _user_data: UserData<()>,
278) -> Result<(), extism::Error> {
279 let result: anyhow::Result<()> = (|| {
280 let input_json: String = plugin.memory_get_val(&inputs[0])?;
281
282 #[derive(Deserialize)]
283 struct WriteReq {
284 path: String,
285 content: String,
286 #[serde(default = "default_true")]
287 create_dirs: bool,
288 }
289 fn default_true() -> bool {
290 true
291 }
292
293 let req: WriteReq =
294 serde_json::from_str(&input_json).context("oxi_write_file: invalid request JSON")?;
295
296 validate_path_allowed(&req.path)?;
297
298 if req.create_dirs
299 && let Some(parent) = std::path::Path::new(&req.path).parent()
300 {
301 std::fs::create_dir_all(parent)
302 .map_err(|e| anyhow::anyhow!("Failed to create directories: {}", e))?;
303 }
304
305 let bytes = req.content.len();
306 std::fs::write(&req.path, &req.content)
307 .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
308
309 let response = serde_json::json!({
310 "success": true,
311 "bytes_written": bytes,
312 });
313 let output = serde_json::to_string(&response)?;
314 let handle = plugin.memory_new(&output)?;
315 if !outputs.is_empty() {
316 outputs[0] = plugin.memory_to_val(handle);
317 }
318 Ok(())
319 })();
320 result
321}
322
323fn host_oxi_exec(
328 plugin: &mut CurrentPlugin,
329 inputs: &[Val],
330 outputs: &mut [Val],
331 _user_data: UserData<()>,
332) -> Result<(), extism::Error> {
333 let result: anyhow::Result<()> = (|| {
334 if std::env::var("OXI_EXTENSION_EXEC").ok().as_deref() != Some("1") {
339 let response = serde_json::json!({
340 "success": false,
341 "error": "oxi_exec is disabled; set OXI_EXTENSION_EXEC=1 to allow extensions to run commands",
342 "exit_code": -1,
343 });
344 let output = serde_json::to_string(&response)?;
345 let handle = plugin.memory_new(&output)?;
346 if !outputs.is_empty() {
347 outputs[0] = plugin.memory_to_val(handle);
348 }
349 return Ok(());
350 }
351 let input_json: String = plugin.memory_get_val(&inputs[0])?;
352
353 #[derive(Deserialize)]
354 struct ExecReq {
355 command: String,
356 #[serde(default)]
357 args: Vec<String>,
358 #[serde(default)]
359 cwd: Option<String>,
360 #[serde(default = "default_timeout")]
361 timeout: u64,
362 }
363 fn default_timeout() -> u64 {
364 30
365 }
366
367 let req: ExecReq =
368 serde_json::from_str(&input_json).context("oxi_exec: invalid request JSON")?;
369
370 let cwd = req.cwd.as_deref().unwrap_or(".");
371
372 let full_cmd = if req.args.is_empty() {
374 req.command.clone()
375 } else {
376 format!("{} {}", req.command, req.args.join(" "))
377 };
378
379 let blocked_patterns = [
381 "rm -rf /",
382 "rm -rf /*",
383 "mkfs",
384 "dd if=",
385 "format ",
386 ":(){ :|:& };:",
387 "chmod 777 /",
388 "chown root",
389 "> /etc/",
390 "> /boot/",
391 "> /dev/",
392 "dd of=/dev/",
393 "mv / /",
394 ];
395 for blocked in &blocked_patterns {
396 if full_cmd.contains(blocked) {
397 anyhow::bail!("oxi_exec: blocked dangerous command pattern");
398 }
399 }
400
401 let cmd_lower = req.command.to_lowercase();
403 if cmd_lower == "sudo"
404 || cmd_lower == "su"
405 || cmd_lower == "doas"
406 || cmd_lower.starts_with("sudo ")
407 || cmd_lower.starts_with("su ")
408 || cmd_lower.starts_with("doas ")
409 {
410 anyhow::bail!("oxi_exec: privilege escalation commands are blocked");
411 }
412
413 let timeout_ms = req.timeout.clamp(1000, 30000);
415 let timeout_dur = Duration::from_millis(timeout_ms);
416
417 let mut child = match std::process::Command::new(&req.command)
419 .args(&req.args)
420 .current_dir(cwd)
421 .stdout(std::process::Stdio::piped())
422 .stderr(std::process::Stdio::piped())
423 .spawn()
424 {
425 Ok(c) => c,
426 Err(e) => {
427 let response = serde_json::json!({
428 "success": false,
429 "error": format!("Failed to execute: {}", e),
430 "exit_code": -1,
431 });
432 let out = serde_json::to_string(&response)?;
433 let handle = plugin.memory_new(&out)?;
434 if !outputs.is_empty() {
435 outputs[0] = plugin.memory_to_val(handle);
436 }
437 return Ok(());
438 }
439 };
440
441 let start = Instant::now();
443 let mut timed_out = false;
444 let mut exit_status: Option<std::process::ExitStatus> = None;
445
446 loop {
447 match child.try_wait() {
448 Ok(Some(status)) => {
449 exit_status = Some(status);
450 break;
451 }
452 Ok(None) => {
453 if start.elapsed() >= timeout_dur {
454 tracing::warn!(
456 "oxi_exec: command '{}' timed out after {}ms",
457 req.command,
458 timeout_ms
459 );
460 let _ = child.kill();
461 let _ = child.wait(); timed_out = true;
463 break;
464 }
465 std::thread::sleep(Duration::from_millis(50));
466 }
467 Err(_) => {
468 match child.wait() {
470 Ok(status) => {
471 exit_status = Some(status);
472 }
473 Err(_) => {
474 timed_out = true;
475 }
476 }
477 break;
478 }
479 }
480 }
481
482 let mut stdout_buf = Vec::new();
484 let mut stderr_buf = Vec::new();
485 if let Some(mut out) = child.stdout.take() {
486 let _ = out.read_to_end(&mut stdout_buf);
487 }
488 if let Some(mut err) = child.stderr.take() {
489 let _ = err.read_to_end(&mut stderr_buf);
490 }
491
492 let stdout = String::from_utf8_lossy(&stdout_buf);
493 let stderr = String::from_utf8_lossy(&stderr_buf);
494 let max_output = 50 * 1024; let stdout_truncated = stdout.len() > max_output;
496 let stderr_truncated = stderr.len() > max_output;
497 let stdout_str: String = if stdout_truncated {
498 stdout.chars().take(max_output).collect()
499 } else {
500 stdout.to_string()
501 };
502 let stderr_str: String = if stderr_truncated {
503 stderr.chars().take(max_output).collect()
504 } else {
505 stderr.to_string()
506 };
507
508 let response = serde_json::json!({
509 "success": !timed_out && exit_status.map(|s| s.success()).unwrap_or(false),
510 "stdout": stdout_str,
511 "stderr": stderr_str,
512 "exit_code": if timed_out { -2 } else { exit_status.and_then(|s| s.code()).unwrap_or(-1) },
513 "stdout_truncated": stdout_truncated,
514 "stderr_truncated": stderr_truncated,
515 "timed_out": timed_out,
516 });
517 let out = serde_json::to_string(&response)?;
518 let handle = plugin.memory_new(&out)?;
519 if !outputs.is_empty() {
520 outputs[0] = plugin.memory_to_val(handle);
521 }
522 Ok(())
523 })();
524 result
525}
526
527fn host_oxi_get_env(
532 plugin: &mut CurrentPlugin,
533 inputs: &[Val],
534 outputs: &mut [Val],
535 _user_data: UserData<()>,
536) -> Result<(), extism::Error> {
537 let result: anyhow::Result<()> = (|| {
538 let input_json: String = plugin.memory_get_val(&inputs[0])?;
539
540 #[derive(Deserialize)]
541 struct EnvReq {
542 key: String,
543 }
544
545 let req: EnvReq =
546 serde_json::from_str(&input_json).context("oxi_get_env: invalid request JSON")?;
547
548 let blocked_keys = ["AWS_SECRET", "PRIVATE_KEY", "PASSWORD", "TOKEN", "SECRET"];
550 let key_upper = req.key.to_uppercase();
551 for blocked in &blocked_keys {
552 if key_upper.contains(blocked) {
553 anyhow::bail!("oxi_get_env: access to '{}' is blocked", req.key);
554 }
555 }
556
557 let value = std::env::var(&req.key).ok();
558 let response = serde_json::json!({
559 "success": value.is_some(),
560 "value": value.unwrap_or_default(),
561 });
562 let output = serde_json::to_string(&response)?;
563 let handle = plugin.memory_new(&output)?;
564 if !outputs.is_empty() {
565 outputs[0] = plugin.memory_to_val(handle);
566 }
567 Ok(())
568 })();
569 result
570}
571
572fn host_oxi_kv_get(
581 plugin: &mut CurrentPlugin,
582 inputs: &[Val],
583 outputs: &mut [Val],
584 _user_data: UserData<()>,
585) -> Result<(), extism::Error> {
586 let result: anyhow::Result<()> = (|| {
587 let input_json: String = plugin.memory_get_val(&inputs[0])?;
588
589 #[derive(Deserialize)]
590 struct KvReq {
591 key: String,
592 }
593
594 let req: KvReq =
595 serde_json::from_str(&input_json).context("oxi_kv_get: invalid request JSON")?;
596
597 let ext_name = current_extension_name();
599 let value = kv_namespaced_get(&ext_name, &req.key);
600 let response = serde_json::json!({
601 "success": value.is_some(),
602 "value": value.unwrap_or_default(),
603 });
604 let output = serde_json::to_string(&response)?;
605 let handle = plugin.memory_new(&output)?;
606 if !outputs.is_empty() {
607 outputs[0] = plugin.memory_to_val(handle);
608 }
609 Ok(())
610 })();
611 result
612}
613
614fn host_oxi_kv_set(
618 plugin: &mut CurrentPlugin,
619 inputs: &[Val],
620 _outputs: &mut [Val],
621 _user_data: UserData<()>,
622) -> Result<(), extism::Error> {
623 let result: anyhow::Result<()> = (|| {
624 let input_json: String = plugin.memory_get_val(&inputs[0])?;
625
626 #[derive(Deserialize)]
627 struct KvSetReq {
628 key: String,
629 value: String,
630 }
631
632 let req: KvSetReq =
633 serde_json::from_str(&input_json).context("oxi_kv_set: invalid request JSON")?;
634
635 let ext_name = current_extension_name();
637 kv_namespaced_set(&ext_name, &req.key, &req.value);
638 Ok(())
639 })();
640 result
641}
642
643use std::sync::LazyLock;
646
647static KV_STORE: LazyLock<parking_lot::RwLock<HashMap<String, String>>> =
648 LazyLock::new(|| parking_lot::RwLock::new(HashMap::new()));
649
650thread_local! {
654 static CURRENT_EXTENSION: RefCell<Option<String>> = const { RefCell::new(None) };
655}
656
657#[allow(dead_code)]
659fn with_extension_context<F, R>(ext_name: &str, f: F) -> R
660where
661 F: FnOnce() -> R,
662{
663 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.to_string()));
664 let result = f();
665 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
666 result
667}
668
669fn current_extension_name() -> String {
672 CURRENT_EXTENSION.with(|cell| {
673 cell.borrow()
674 .clone()
675 .unwrap_or_else(|| "__unknown__".to_string())
676 })
677}
678
679fn kv_store_get(key: &str) -> Option<String> {
680 KV_STORE.read().get(key).cloned()
681}
682
683fn kv_store_set(key: &str, value: &str) {
684 KV_STORE.write().insert(key.to_string(), value.to_string());
685}
686
687fn kv_namespaced_get(extension: &str, key: &str) -> Option<String> {
690 let namespaced = format!("{}:{}", extension, key);
691 kv_store_get(&namespaced)
692}
693
694fn kv_namespaced_set(extension: &str, key: &str, value: &str) {
695 let namespaced = format!("{}:{}", extension, key);
696 kv_store_set(&namespaced, value);
697}
698
699fn validate_path_allowed(path: &str) -> Result<()> {
704 let p = std::path::Path::new(path);
705
706 let abs = if p.is_absolute() {
708 p.to_path_buf()
709 } else {
710 std::env::current_dir().unwrap_or_default().join(p)
711 };
712
713 let resolved = if abs.exists() {
715 abs.canonicalize().unwrap_or(abs)
716 } else {
717 if let Some(parent) = abs.parent() {
719 if parent.exists() {
720 let canon_parent = parent
721 .canonicalize()
722 .unwrap_or_else(|_| parent.to_path_buf());
723 canon_parent.join(abs.file_name().unwrap_or_default())
724 } else {
725 abs
726 }
727 } else {
728 abs
729 }
730 };
731
732 let abs_str = resolved.to_string_lossy();
733
734 let blocked_prefixes = [
736 "/etc",
737 "/sys",
738 "/proc",
739 "/dev",
740 "/boot",
741 "/root",
742 "/System",
743 "/Library/System",
744 "/usr/bin",
745 "/usr/sbin",
746 "/bin",
747 "/sbin",
748 ];
749 for prefix in &blocked_prefixes {
750 if abs_str.starts_with(prefix) {
751 anyhow::bail!("Path '{}' is in a protected system directory", path);
752 }
753 }
754
755 if let Some(home) = dirs::home_dir() {
757 let home_str = home.to_string_lossy();
758 if abs_str.starts_with(&*home_str) {
759 let blocked_home_suffixes = [
760 "/.ssh/",
761 "/.gnupg/",
762 "/.aws/",
763 "/.config/gcloud/",
764 "/.kube/",
765 "/.docker/",
766 "/.npmrc",
767 "/.netrc",
768 ];
769 for suffix in &blocked_home_suffixes {
770 if abs_str.contains(suffix) {
771 anyhow::bail!("Path '{}' is in a protected directory", path);
772 }
773 }
774 }
775 }
776
777 Ok(())
778}
779
780fn validate_url(url: &str) -> Result<(), String> {
784 let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
785 let host = parsed.host_str().unwrap_or("").to_lowercase();
786
787 let blocked = [
789 "localhost",
790 "127.0.0.1",
791 "0.0.0.0",
792 "::1",
793 "[::1]",
794 "169.254.169.254", "metadata.google.internal",
796 ];
797 for &b in &blocked {
798 if host == b || host.starts_with(b) {
799 return Err(format!("Blocked internal address: {}", host));
800 }
801 }
802
803 if host.starts_with("10.") || host.starts_with("192.168.") || is_172_private(&host) {
805 return Err(format!("Blocked private address: {}", host));
806 }
807
808 Ok(())
809}
810
811fn is_172_private(host: &str) -> bool {
813 if !host.starts_with("172.") {
814 return false;
815 }
816 let parts: Vec<&str> = host.split('.').collect();
817 if parts.len() < 2 {
818 return false;
819 }
820 if let Ok(second) = parts[1].parse::<u8>() {
821 (16..=31).contains(&second)
822 } else {
823 false
824 }
825}
826
827pub struct WasmExtensionManager {
836 extensions: HashMap<String, LoadedWasmExtension>,
837 pub(crate) plugins: Arc<parking_lot::Mutex<HashMap<String, extism::Plugin>>>,
842 tool_to_ext: HashMap<String, String>,
844 http_client: Arc<reqwest::blocking::Client>,
846 #[allow(dead_code, unused)]
848 permissions: HashMap<String, std::collections::HashSet<String>>,
849}
850
851impl Default for WasmExtensionManager {
852 fn default() -> Self {
853 Self::new()
854 }
855}
856
857impl WasmExtensionManager {
858 pub fn new() -> Self {
860 Self {
861 extensions: HashMap::new(),
862 plugins: Arc::new(Mutex::new(HashMap::new())),
863 tool_to_ext: HashMap::new(),
864 http_client: Arc::new(
865 reqwest::blocking::Client::builder()
866 .timeout(std::time::Duration::from_secs(30))
867 .connect_timeout(std::time::Duration::from_secs(10))
868 .no_proxy() .build()
870 .expect("Failed to build HTTP client"),
871 ),
872 permissions: HashMap::new(),
873 }
874 }
875
876 pub fn with_http_client(client: reqwest::blocking::Client) -> Self {
878 Self {
879 extensions: HashMap::new(),
880 plugins: Arc::new(Mutex::new(HashMap::new())),
881 tool_to_ext: HashMap::new(),
882 http_client: Arc::new(client),
883 permissions: HashMap::new(),
884 }
885 }
886
887 pub fn discover(cwd: &Path) -> Vec<PathBuf> {
891 let mut paths = Vec::new();
892
893 if let Some(home) = dirs::home_dir() {
895 let dir = home.join(".oxi").join("extensions");
896 if dir.is_dir() {
897 Self::discover_in_dir(&dir, &mut paths);
898 }
899 }
900
901 let local_dir = cwd.join(".oxi").join("extensions");
903 if local_dir.is_dir() {
904 Self::discover_in_dir(&local_dir, &mut paths);
905 }
906
907 paths.sort();
908 paths.dedup();
909 paths
910 }
911
912 fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
913 let Ok(entries) = std::fs::read_dir(dir) else {
914 return;
915 };
916 for entry in entries.flatten() {
917 let path = entry.path();
918 if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("wasm") {
919 out.push(path);
920 }
921 }
922 }
923
924 fn host_functions(http_client: &Arc<reqwest::blocking::Client>) -> Vec<Function> {
928 let http_fn = Function::new(
929 "oxi_http_request",
930 [PTR],
931 [PTR],
932 UserData::new(http_client.clone()),
933 host_oxi_http_request,
934 );
935
936 let log_fn = Function::new("oxi_log", [PTR], [], UserData::new(()), host_oxi_log);
937
938 let read_fn = Function::new(
939 "oxi_read_file",
940 [PTR],
941 [PTR],
942 UserData::new(()),
943 host_oxi_read_file,
944 );
945
946 let write_fn = Function::new(
947 "oxi_write_file",
948 [PTR],
949 [PTR],
950 UserData::new(()),
951 host_oxi_write_file,
952 );
953
954 let exec_fn = Function::new("oxi_exec", [PTR], [PTR], UserData::new(()), host_oxi_exec);
955
956 let get_env_fn = Function::new(
957 "oxi_get_env",
958 [PTR],
959 [PTR],
960 UserData::new(()),
961 host_oxi_get_env,
962 );
963
964 let kv_get_fn = Function::new(
965 "oxi_kv_get",
966 [PTR],
967 [PTR],
968 UserData::new(()),
969 host_oxi_kv_get,
970 );
971
972 let kv_set_fn = Function::new("oxi_kv_set", [PTR], [], UserData::new(()), host_oxi_kv_set);
973
974 vec![
975 http_fn, log_fn, read_fn, write_fn, exec_fn, get_env_fn, kv_get_fn, kv_set_fn,
976 ]
977 }
978
979 pub fn load(&mut self, path: &Path) -> Result<ExtensionInfo> {
981 let path_display = path.display().to_string();
982 tracing::info!("Loading WASM extension: {}", path_display);
983
984 let wasm_bytes = std::fs::read(path)
985 .with_context(|| format!("Failed to read extension: {}", path_display))?;
986
987 let wasm = extism::Wasm::data(wasm_bytes);
988 let manifest = extism::Manifest::new([wasm]).with_memory_max(64);
990 let mut plugin = extism::PluginBuilder::new(manifest)
991 .with_wasi(true)
992 .with_functions(Self::host_functions(&self.http_client))
993 .build()
994 .with_context(|| format!("Failed to create Extism plugin from {}", path_display))?;
995
996 let info: ExtensionInfo = match plugin.call::<&str, &str>("init", "{}") {
998 Ok(output) => serde_json::from_str(output)
999 .with_context(|| format!("init() returned invalid JSON: {}", output))?,
1000 Err(_) => {
1001 let name = path
1003 .file_stem()
1004 .and_then(|s| s.to_str())
1005 .unwrap_or("unknown")
1006 .to_string();
1007 ExtensionInfo {
1008 name,
1009 version: "0.0.0".to_string(),
1010 description: String::new(),
1011 permissions: vec![],
1012 }
1013 }
1014 };
1015 if !info.permissions.is_empty() {
1018 tracing::info!(name = %info.name, perms = ?info.permissions, "extension permissions requested");
1019 }
1020 if info.permissions.iter().any(|p| p == "exec") {
1021 tracing::warn!(
1022 name = %info.name,
1023 "extension requested 'exec' permission — oxi_exec stays disabled unless OXI_EXTENSION_EXEC=1 is set"
1024 );
1025 }
1026
1027 let ext_name_for_ctx = info.name.clone();
1030 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name_for_ctx));
1031 let tools: Vec<WasmToolDef> = match plugin.call::<&str, &str>("register_tools", "{}") {
1032 Ok(output) => {
1033 let resp: Value = serde_json::from_str(output)
1034 .with_context(|| format!("register_tools() invalid JSON: {}", output))?;
1035 resp.get("tools")
1036 .cloned()
1037 .unwrap_or(Value::Array(vec![]))
1038 .as_array()
1039 .map(|arr| {
1040 arr.iter()
1041 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1042 .collect()
1043 })
1044 .unwrap_or_default()
1045 }
1046 Err(_) => vec![], };
1048
1049 let commands: Vec<WasmCommandDef> =
1051 match plugin.call::<&str, &str>("register_commands", "{}") {
1052 Ok(output) => {
1053 let resp: Value = serde_json::from_str(output)
1054 .with_context(|| format!("register_commands() invalid JSON: {}", output))?;
1055 resp.get("commands")
1056 .cloned()
1057 .unwrap_or(Value::Array(vec![]))
1058 .as_array()
1059 .map(|arr| {
1060 arr.iter()
1061 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1062 .collect()
1063 })
1064 .unwrap_or_default()
1065 }
1066 Err(_) => vec![], };
1068
1069 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1071
1072 let ext_name = info.name.clone();
1073
1074 if self.extensions.contains_key(&ext_name) {
1076 tracing::warn!(
1077 "Extension '{}' already loaded, replacing with '{}'",
1078 ext_name,
1079 path_display
1080 );
1081 self.tool_to_ext.retain(|_, v| v != &ext_name);
1083 self.plugins.lock().remove(&ext_name);
1085 }
1086
1087 for tool in &tools {
1088 self.tool_to_ext.insert(tool.name.clone(), ext_name.clone());
1089 }
1090
1091 let loaded = LoadedWasmExtension {
1092 info: info.clone(),
1093 tools,
1094 commands,
1095 source_path: path.to_path_buf(),
1096 };
1097
1098 self.extensions.insert(ext_name.clone(), loaded);
1099 self.plugins.lock().insert(ext_name, plugin);
1100
1101 tracing::info!(
1102 name = %info.name,
1103 version = %info.version,
1104 tools = self.tool_to_ext.len(),
1105 "WASM extension loaded"
1106 );
1107
1108 Ok(info)
1109 }
1110
1111 pub fn load_all(&mut self, paths: &[PathBuf]) -> (Vec<ExtensionInfo>, Vec<anyhow::Error>) {
1113 let mut loaded = Vec::new();
1114 let mut errors = Vec::new();
1115
1116 for path in paths {
1117 match self.load(path) {
1118 Ok(info) => loaded.push(info),
1119 Err(e) => {
1120 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
1121 errors.push(e);
1122 }
1123 }
1124 }
1125
1126 (loaded, errors)
1127 }
1128
1129 pub fn execute_tool(&self, tool_name: &str, params: Value) -> Result<Value> {
1133 let ext_name = self
1134 .tool_to_ext
1135 .get(tool_name)
1136 .with_context(|| format!("No extension registered for tool: {}", tool_name))?
1137 .clone();
1138
1139 let mut plugins = self.plugins.lock();
1140 let plugin = plugins
1141 .get_mut(&ext_name)
1142 .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1143
1144 let input = serde_json::json!({
1145 "tool": tool_name,
1146 "params": params,
1147 });
1148 let input_str = serde_json::to_string(&input)?;
1149
1150 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1152 let call_result = plugin.call("execute_tool", &input_str);
1153 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1154
1155 let output: &str = call_result
1156 .with_context(|| format!("execute_tool('{}') failed in '{}'", tool_name, ext_name))?;
1157
1158 let result: Value = serde_json::from_str(output)
1159 .with_context(|| format!("execute_tool() returned invalid JSON: {}", output))?;
1160
1161 Ok(result)
1162 }
1163
1164 pub fn all_tool_defs(&self) -> Vec<&WasmToolDef> {
1168 self.extensions
1169 .values()
1170 .flat_map(|e| e.tools.iter())
1171 .collect()
1172 }
1173
1174 pub fn is_wasm_tool(&self, tool_name: &str) -> bool {
1176 self.tool_to_ext.contains_key(tool_name)
1177 }
1178
1179 pub fn extension_names(&self) -> impl Iterator<Item = &str> {
1181 self.extensions.keys().map(|s| s.as_str())
1182 }
1183
1184 pub fn get_info(&self, name: &str) -> Option<&ExtensionInfo> {
1186 self.extensions.get(name).map(|e| &e.info)
1187 }
1188
1189 pub fn len(&self) -> usize {
1191 self.extensions.len()
1192 }
1193
1194 pub fn is_empty(&self) -> bool {
1196 self.extensions.is_empty()
1197 }
1198
1199 pub fn all_command_defs(&self) -> Vec<(&str, &WasmCommandDef)> {
1203 let mut cmds = Vec::new();
1204 for ext in self.extensions.values() {
1205 for cmd in &ext.commands {
1206 cmds.push((ext.info.name.as_str(), cmd));
1207 }
1208 }
1209 cmds
1210 }
1211
1212 pub fn execute_command(&self, command_name: &str, args: &str) -> Result<String> {
1215 let ext_name = self
1217 .extensions
1218 .iter()
1219 .find(|(_, ext)| ext.commands.iter().any(|c| c.name == command_name))
1220 .map(|(name, _)| name.clone())
1221 .with_context(|| format!("No extension registered for command: /{}", command_name))?;
1222
1223 let mut plugins = self.plugins.lock();
1224 let plugin = plugins
1225 .get_mut(&ext_name)
1226 .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1227
1228 let input = serde_json::json!({
1229 "command": command_name,
1230 "args": args,
1231 });
1232 let input_str = serde_json::to_string(&input)?;
1233
1234 let output: &str = {
1235 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1236 let result = plugin.call("execute_command", &input_str);
1237 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1238 result
1239 }
1240 .with_context(|| {
1241 format!(
1242 "execute_command('/{}') failed in '{}'",
1243 command_name, ext_name
1244 )
1245 })?;
1246
1247 let result: Value =
1249 serde_json::from_str(output).unwrap_or_else(|_| serde_json::json!({"output": output}));
1250
1251 Ok(result
1252 .get("output")
1253 .and_then(|v| v.as_str())
1254 .unwrap_or(output)
1255 .to_string())
1256 }
1257}
1258
1259#[cfg(test)]
1262mod tests {
1263 use super::*;
1264
1265 #[test]
1266 fn test_discover_empty_dir() {
1267 let dir = tempfile::tempdir().unwrap();
1268 let paths = WasmExtensionManager::discover(dir.path());
1269 assert!(paths.is_empty());
1270 }
1271
1272 #[test]
1273 fn test_discover_finds_wasm_files() {
1274 let dir = tempfile::tempdir().unwrap();
1275 let wasm_path = dir.path().join("test_ext.wasm");
1276 std::fs::write(&wasm_path, b"\x00asm").unwrap();
1277 std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
1279
1280 let mut paths = Vec::new();
1281 WasmExtensionManager::discover_in_dir(dir.path(), &mut paths);
1282 assert_eq!(paths.len(), 1);
1283 assert!(paths[0].ends_with("test_ext.wasm"));
1284 }
1285
1286 #[test]
1287 fn test_extension_info_parse() {
1288 let json = r#"{"name":"my_ext","version":"1.0.0","description":"Test"}"#;
1289 let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1290 assert_eq!(info.name, "my_ext");
1291 assert_eq!(info.version, "1.0.0");
1292 }
1293
1294 #[test]
1295 fn test_tool_def_parse() {
1296 let json = r#"{"name":"search","description":"Search","schema":{"type":"object"}}"#;
1297 let tool: WasmToolDef = serde_json::from_str(json).unwrap();
1298 assert_eq!(tool.name, "search");
1299 }
1300
1301 #[test]
1302 fn test_manager_new_is_empty() {
1303 let mgr = WasmExtensionManager::new();
1304 assert!(mgr.is_empty());
1305 assert_eq!(mgr.len(), 0);
1306 }
1307
1308 #[test]
1309 fn test_is_wasm_tool_false() {
1310 let mgr = WasmExtensionManager::new();
1311 assert!(!mgr.is_wasm_tool("anything"));
1312 }
1313
1314 #[test]
1315 fn test_extension_info_default_description() {
1316 let json = r#"{"name":"test","version":"0.1"}"#;
1317 let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1318 assert_eq!(info.description, "");
1319 }
1320
1321 #[test]
1322 fn test_ssrf_blocks_localhost() {
1323 assert!(validate_url("http://localhost/admin").is_err());
1324 assert!(validate_url("http://127.0.0.1/secret").is_err());
1325 assert!(validate_url("http://10.0.0.1/internal").is_err());
1326 assert!(validate_url("http://192.168.1.1/router").is_err());
1327 assert!(validate_url("http://172.16.0.1/corp").is_err());
1328 assert!(validate_url("http://169.254.169.254/metadata").is_err());
1329 assert!(validate_url("http://[::1]/ipv6").is_err());
1330 assert!(validate_url("http://0.0.0.0/admin").is_err());
1332 }
1333
1334 #[test]
1335 fn test_ssrf_allows_public() {
1336 assert!(validate_url("https://api.github.com/repos/test").is_ok());
1337 assert!(validate_url("https://example.com/api").is_ok());
1338 assert!(validate_url("https://search.brave.com/api/search?q=test").is_ok());
1339 }
1340
1341 #[test]
1342 fn test_ssrf_172_range() {
1343 assert!(validate_url("http://172.16.0.1/test").is_err());
1344 assert!(validate_url("http://172.31.255.255/test").is_err());
1345 assert!(validate_url("http://172.15.0.1/test").is_ok());
1346 assert!(validate_url("http://172.32.0.1/test").is_ok());
1347 }
1348}