1use anyhow::{Context, Result};
10use extism::{CurrentPlugin, Function, UserData, Val, PTR};
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::cell::RefCell;
15use std::collections::HashMap;
16use std::io::Read as _;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ExtensionInfo {
26 pub name: String,
27 pub version: String,
28 #[serde(default)]
29 pub description: String,
30 #[serde(default)]
33 #[allow(dead_code)]
34 pub permissions: Vec<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct WasmToolDef {
40 pub name: String,
41 pub description: String,
42 pub schema: Value,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WasmCommandDef {
48 pub name: String,
49 pub description: String,
50}
51
52#[derive(Debug)]
54pub struct LoadedWasmExtension {
55 pub info: ExtensionInfo,
56 pub tools: Vec<WasmToolDef>,
57 pub commands: Vec<WasmCommandDef>,
58 pub source_path: PathBuf,
59}
60
61fn host_oxi_http_request(
68 plugin: &mut CurrentPlugin,
69 inputs: &[Val],
70 outputs: &mut [Val],
71 user_data: UserData<Arc<reqwest::blocking::Client>>,
72) -> Result<(), extism::Error> {
73 let result: anyhow::Result<()> = (|| {
75 let input_json: String = plugin.memory_get_val(&inputs[0])?;
76
77 #[derive(Deserialize)]
78 struct HttpReq {
79 url: String,
80 #[serde(default)]
81 method: String,
82 #[serde(default)]
83 headers: HashMap<String, String>,
84 #[serde(default)]
85 body: Option<String>,
86 }
87
88 let req: HttpReq =
89 serde_json::from_str(&input_json).context("oxi_http_request: invalid request JSON")?;
90
91 let method = if req.method.is_empty() {
92 "GET"
93 } else {
94 &req.method
95 };
96
97 if let Err(e) = validate_url(&req.url) {
99 anyhow::bail!("oxi_http_request: {}", e);
100 }
101
102 let client_arc = user_data.get()?;
104 let client = client_arc.lock().expect("wasm client lock poisoned");
105
106 let method = match method.to_uppercase().as_str() {
107 "GET" => reqwest::Method::GET,
108 "POST" => reqwest::Method::POST,
109 "PUT" => reqwest::Method::PUT,
110 "DELETE" => reqwest::Method::DELETE,
111 "PATCH" => reqwest::Method::PATCH,
112 "HEAD" => reqwest::Method::HEAD,
113 other => anyhow::bail!("oxi_http_request: unsupported method '{}'", other),
114 };
115
116 let mut rb = client.request(method, &req.url);
117 for (k, v) in &req.headers {
118 rb = rb.header(k, v);
119 }
120 if let Some(body) = &req.body {
121 rb = rb.body(body.clone());
122 }
123
124 let resp = rb
126 .send()
127 .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
128 let status = resp.status().as_u16();
129 let resp_headers: HashMap<String, String> = resp
130 .headers()
131 .iter()
132 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
133 .collect();
134 let resp_body = {
136 let max_body = 1024 * 1024; let body_bytes = resp
138 .bytes()
139 .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
140 if body_bytes.len() > max_body {
141 tracing::warn!(
142 "HTTP response truncated: {} bytes > {} limit",
143 body_bytes.len(),
144 max_body
145 );
146 String::from_utf8_lossy(&body_bytes[..max_body]).to_string()
147 } else {
148 String::from_utf8_lossy(&body_bytes).to_string()
149 }
150 };
151
152 let response = serde_json::json!({
153 "status": status,
154 "headers": resp_headers,
155 "body": resp_body,
156 });
157
158 let output = serde_json::to_string(&response)?;
159 let handle = plugin.memory_new(&output)?;
160 if !outputs.is_empty() {
161 outputs[0] = plugin.memory_to_val(handle);
162 }
163 Ok(())
164 })();
165
166 result
167}
168
169fn host_oxi_log(
171 plugin: &mut CurrentPlugin,
172 inputs: &[Val],
173 _outputs: &mut [Val],
174 _user_data: UserData<()>,
175) -> Result<(), extism::Error> {
176 let message: String = plugin.memory_get_val(&inputs[0])?;
177 tracing::debug!("[WASM] {}", message);
178 Ok(())
179}
180
181fn host_oxi_read_file(
189 plugin: &mut CurrentPlugin,
190 inputs: &[Val],
191 outputs: &mut [Val],
192 _user_data: UserData<()>,
193) -> Result<(), extism::Error> {
194 let result: anyhow::Result<()> = (|| {
195 let input_json: String = plugin.memory_get_val(&inputs[0])?;
196
197 #[derive(Deserialize)]
198 struct ReadReq {
199 path: String,
200 #[serde(default)]
201 offset: Option<usize>,
202 #[serde(default = "default_limit")]
203 limit: usize,
204 }
205 fn default_limit() -> usize {
206 2000
207 }
208
209 let req: ReadReq =
210 serde_json::from_str(&input_json).context("oxi_read_file: invalid request JSON")?;
211
212 validate_path_allowed(&req.path)?;
214
215 let metadata = std::fs::metadata(&req.path);
216 match metadata {
217 Ok(m) => {
218 let max_bytes = 50 * 1024; let file_size = m.len() as usize;
220
221 let content = std::fs::read_to_string(&req.path)
222 .map_err(|e| anyhow::anyhow!("Failed to read file: {}", e))?;
223
224 let lines: Vec<&str> = content.lines().collect();
225 let total_lines = lines.len();
226
227 let offset = req.offset.unwrap_or(0).min(total_lines);
228 let end = (offset + req.limit).min(total_lines);
229 let selected: Vec<&str> = lines[offset..end].to_vec();
230 let mut result = selected.join("\n");
231
232 let truncated = result.len() > max_bytes;
233 if truncated {
234 result = result.chars().take(max_bytes).collect();
235 }
236
237 let response = serde_json::json!({
238 "success": true,
239 "content": result,
240 "truncated": truncated || end < total_lines,
241 "bytes": file_size,
242 "total_lines": total_lines,
243 "shown_lines": end - offset,
244 });
245 let output = serde_json::to_string(&response)?;
246 let handle = plugin.memory_new(&output)?;
247 if !outputs.is_empty() {
248 outputs[0] = plugin.memory_to_val(handle);
249 }
250 }
251 Err(e) => {
252 let response = serde_json::json!({
253 "success": false,
254 "error": format!("File not found: {}", e),
255 });
256 let output = serde_json::to_string(&response)?;
257 let handle = plugin.memory_new(&output)?;
258 if !outputs.is_empty() {
259 outputs[0] = plugin.memory_to_val(handle);
260 }
261 }
262 }
263 Ok(())
264 })();
265 result
266}
267
268fn host_oxi_write_file(
273 plugin: &mut CurrentPlugin,
274 inputs: &[Val],
275 outputs: &mut [Val],
276 _user_data: UserData<()>,
277) -> Result<(), extism::Error> {
278 let result: anyhow::Result<()> = (|| {
279 let input_json: String = plugin.memory_get_val(&inputs[0])?;
280
281 #[derive(Deserialize)]
282 struct WriteReq {
283 path: String,
284 content: String,
285 #[serde(default = "default_true")]
286 create_dirs: bool,
287 }
288 fn default_true() -> bool {
289 true
290 }
291
292 let req: WriteReq =
293 serde_json::from_str(&input_json).context("oxi_write_file: invalid request JSON")?;
294
295 validate_path_allowed(&req.path)?;
296
297 if req.create_dirs {
298 if let Some(parent) = std::path::Path::new(&req.path).parent() {
299 std::fs::create_dir_all(parent)
300 .map_err(|e| anyhow::anyhow!("Failed to create directories: {}", e))?;
301 }
302 }
303
304 let bytes = req.content.len();
305 std::fs::write(&req.path, &req.content)
306 .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
307
308 let response = serde_json::json!({
309 "success": true,
310 "bytes_written": bytes,
311 });
312 let output = serde_json::to_string(&response)?;
313 let handle = plugin.memory_new(&output)?;
314 if !outputs.is_empty() {
315 outputs[0] = plugin.memory_to_val(handle);
316 }
317 Ok(())
318 })();
319 result
320}
321
322fn host_oxi_exec(
327 plugin: &mut CurrentPlugin,
328 inputs: &[Val],
329 outputs: &mut [Val],
330 _user_data: UserData<()>,
331) -> Result<(), extism::Error> {
332 let result: anyhow::Result<()> = (|| {
333 let input_json: String = plugin.memory_get_val(&inputs[0])?;
334
335 #[derive(Deserialize)]
336 struct ExecReq {
337 command: String,
338 #[serde(default)]
339 args: Vec<String>,
340 #[serde(default)]
341 cwd: Option<String>,
342 #[serde(default = "default_timeout")]
343 timeout: u64,
344 }
345 fn default_timeout() -> u64 {
346 30
347 }
348
349 let req: ExecReq =
350 serde_json::from_str(&input_json).context("oxi_exec: invalid request JSON")?;
351
352 let cwd = req.cwd.as_deref().unwrap_or(".");
353
354 let full_cmd = if req.args.is_empty() {
356 req.command.clone()
357 } else {
358 format!("{} {}", req.command, req.args.join(" "))
359 };
360
361 let blocked_patterns = [
363 "rm -rf /",
364 "rm -rf /*",
365 "mkfs",
366 "dd if=",
367 "format ",
368 ":(){ :|:& };:",
369 "chmod 777 /",
370 "chown root",
371 "> /etc/",
372 "> /boot/",
373 "> /dev/",
374 "dd of=/dev/",
375 "mv / /",
376 ];
377 for blocked in &blocked_patterns {
378 if full_cmd.contains(blocked) {
379 anyhow::bail!("oxi_exec: blocked dangerous command pattern");
380 }
381 }
382
383 let cmd_lower = req.command.to_lowercase();
385 if cmd_lower == "sudo"
386 || cmd_lower == "su"
387 || cmd_lower == "doas"
388 || cmd_lower.starts_with("sudo ")
389 || cmd_lower.starts_with("su ")
390 || cmd_lower.starts_with("doas ")
391 {
392 anyhow::bail!("oxi_exec: privilege escalation commands are blocked");
393 }
394
395 let timeout_ms = req.timeout.clamp(1000, 30000);
397 let timeout_dur = Duration::from_millis(timeout_ms);
398
399 let mut child = match std::process::Command::new(&req.command)
401 .args(&req.args)
402 .current_dir(cwd)
403 .stdout(std::process::Stdio::piped())
404 .stderr(std::process::Stdio::piped())
405 .spawn()
406 {
407 Ok(c) => c,
408 Err(e) => {
409 let response = serde_json::json!({
410 "success": false,
411 "error": format!("Failed to execute: {}", e),
412 "exit_code": -1,
413 });
414 let out = serde_json::to_string(&response)?;
415 let handle = plugin.memory_new(&out)?;
416 if !outputs.is_empty() {
417 outputs[0] = plugin.memory_to_val(handle);
418 }
419 return Ok(());
420 }
421 };
422
423 let start = Instant::now();
425 let mut timed_out = false;
426 let mut exit_status: Option<std::process::ExitStatus> = None;
427
428 loop {
429 match child.try_wait() {
430 Ok(Some(status)) => {
431 exit_status = Some(status);
432 break;
433 }
434 Ok(None) => {
435 if start.elapsed() >= timeout_dur {
436 tracing::warn!(
438 "oxi_exec: command '{}' timed out after {}ms",
439 req.command,
440 timeout_ms
441 );
442 let _ = child.kill();
443 let _ = child.wait(); timed_out = true;
445 break;
446 }
447 std::thread::sleep(Duration::from_millis(50));
448 }
449 Err(_) => {
450 match child.wait() {
452 Ok(status) => {
453 exit_status = Some(status);
454 }
455 Err(_) => {
456 timed_out = true;
457 }
458 }
459 break;
460 }
461 }
462 }
463
464 let mut stdout_buf = Vec::new();
466 let mut stderr_buf = Vec::new();
467 if let Some(mut out) = child.stdout.take() {
468 let _ = out.read_to_end(&mut stdout_buf);
469 }
470 if let Some(mut err) = child.stderr.take() {
471 let _ = err.read_to_end(&mut stderr_buf);
472 }
473
474 let stdout = String::from_utf8_lossy(&stdout_buf);
475 let stderr = String::from_utf8_lossy(&stderr_buf);
476 let max_output = 50 * 1024; let stdout_truncated = stdout.len() > max_output;
478 let stderr_truncated = stderr.len() > max_output;
479 let stdout_str: String = if stdout_truncated {
480 stdout.chars().take(max_output).collect()
481 } else {
482 stdout.to_string()
483 };
484 let stderr_str: String = if stderr_truncated {
485 stderr.chars().take(max_output).collect()
486 } else {
487 stderr.to_string()
488 };
489
490 let response = serde_json::json!({
491 "success": !timed_out && exit_status.map(|s| s.success()).unwrap_or(false),
492 "stdout": stdout_str,
493 "stderr": stderr_str,
494 "exit_code": if timed_out { -2 } else { exit_status.and_then(|s| s.code()).unwrap_or(-1) },
495 "stdout_truncated": stdout_truncated,
496 "stderr_truncated": stderr_truncated,
497 "timed_out": timed_out,
498 });
499 let out = serde_json::to_string(&response)?;
500 let handle = plugin.memory_new(&out)?;
501 if !outputs.is_empty() {
502 outputs[0] = plugin.memory_to_val(handle);
503 }
504 Ok(())
505 })();
506 result
507}
508
509fn host_oxi_get_env(
514 plugin: &mut CurrentPlugin,
515 inputs: &[Val],
516 outputs: &mut [Val],
517 _user_data: UserData<()>,
518) -> Result<(), extism::Error> {
519 let result: anyhow::Result<()> = (|| {
520 let input_json: String = plugin.memory_get_val(&inputs[0])?;
521
522 #[derive(Deserialize)]
523 struct EnvReq {
524 key: String,
525 }
526
527 let req: EnvReq =
528 serde_json::from_str(&input_json).context("oxi_get_env: invalid request JSON")?;
529
530 let blocked_keys = ["AWS_SECRET", "PRIVATE_KEY", "PASSWORD", "TOKEN", "SECRET"];
532 let key_upper = req.key.to_uppercase();
533 for blocked in &blocked_keys {
534 if key_upper.contains(blocked) {
535 anyhow::bail!("oxi_get_env: access to '{}' is blocked", req.key);
536 }
537 }
538
539 let value = std::env::var(&req.key).ok();
540 let response = serde_json::json!({
541 "success": value.is_some(),
542 "value": value.unwrap_or_default(),
543 });
544 let output = serde_json::to_string(&response)?;
545 let handle = plugin.memory_new(&output)?;
546 if !outputs.is_empty() {
547 outputs[0] = plugin.memory_to_val(handle);
548 }
549 Ok(())
550 })();
551 result
552}
553
554fn host_oxi_kv_get(
563 plugin: &mut CurrentPlugin,
564 inputs: &[Val],
565 outputs: &mut [Val],
566 _user_data: UserData<()>,
567) -> Result<(), extism::Error> {
568 let result: anyhow::Result<()> = (|| {
569 let input_json: String = plugin.memory_get_val(&inputs[0])?;
570
571 #[derive(Deserialize)]
572 struct KvReq {
573 key: String,
574 }
575
576 let req: KvReq =
577 serde_json::from_str(&input_json).context("oxi_kv_get: invalid request JSON")?;
578
579 let ext_name = current_extension_name();
581 let value = kv_namespaced_get(&ext_name, &req.key);
582 let response = serde_json::json!({
583 "success": value.is_some(),
584 "value": value.unwrap_or_default(),
585 });
586 let output = serde_json::to_string(&response)?;
587 let handle = plugin.memory_new(&output)?;
588 if !outputs.is_empty() {
589 outputs[0] = plugin.memory_to_val(handle);
590 }
591 Ok(())
592 })();
593 result
594}
595
596fn host_oxi_kv_set(
600 plugin: &mut CurrentPlugin,
601 inputs: &[Val],
602 _outputs: &mut [Val],
603 _user_data: UserData<()>,
604) -> Result<(), extism::Error> {
605 let result: anyhow::Result<()> = (|| {
606 let input_json: String = plugin.memory_get_val(&inputs[0])?;
607
608 #[derive(Deserialize)]
609 struct KvSetReq {
610 key: String,
611 value: String,
612 }
613
614 let req: KvSetReq =
615 serde_json::from_str(&input_json).context("oxi_kv_set: invalid request JSON")?;
616
617 let ext_name = current_extension_name();
619 kv_namespaced_set(&ext_name, &req.key, &req.value);
620 Ok(())
621 })();
622 result
623}
624
625use std::sync::LazyLock;
628
629static KV_STORE: LazyLock<parking_lot::RwLock<HashMap<String, String>>> =
630 LazyLock::new(|| parking_lot::RwLock::new(HashMap::new()));
631
632thread_local! {
636 static CURRENT_EXTENSION: RefCell<Option<String>> = const { RefCell::new(None) };
637}
638
639#[allow(dead_code)]
641fn with_extension_context<F, R>(ext_name: &str, f: F) -> R
642where
643 F: FnOnce() -> R,
644{
645 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.to_string()));
646 let result = f();
647 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
648 result
649}
650
651fn current_extension_name() -> String {
654 CURRENT_EXTENSION.with(|cell| {
655 cell.borrow()
656 .clone()
657 .unwrap_or_else(|| "__unknown__".to_string())
658 })
659}
660
661fn kv_store_get(key: &str) -> Option<String> {
662 KV_STORE.read().get(key).cloned()
663}
664
665fn kv_store_set(key: &str, value: &str) {
666 KV_STORE.write().insert(key.to_string(), value.to_string());
667}
668
669fn kv_namespaced_get(extension: &str, key: &str) -> Option<String> {
672 let namespaced = format!("{}:{}", extension, key);
673 kv_store_get(&namespaced)
674}
675
676fn kv_namespaced_set(extension: &str, key: &str, value: &str) {
677 let namespaced = format!("{}:{}", extension, key);
678 kv_store_set(&namespaced, value);
679}
680
681fn validate_path_allowed(path: &str) -> Result<()> {
686 let p = std::path::Path::new(path);
687
688 let abs = if p.is_absolute() {
690 p.to_path_buf()
691 } else {
692 std::env::current_dir().unwrap_or_default().join(p)
693 };
694
695 let resolved = if abs.exists() {
697 abs.canonicalize().unwrap_or(abs)
698 } else {
699 if let Some(parent) = abs.parent() {
701 if parent.exists() {
702 let canon_parent = parent
703 .canonicalize()
704 .unwrap_or_else(|_| parent.to_path_buf());
705 canon_parent.join(abs.file_name().unwrap_or_default())
706 } else {
707 abs
708 }
709 } else {
710 abs
711 }
712 };
713
714 let abs_str = resolved.to_string_lossy();
715
716 let blocked_prefixes = [
718 "/etc",
719 "/sys",
720 "/proc",
721 "/dev",
722 "/boot",
723 "/root",
724 "/System",
725 "/Library/System",
726 "/usr/bin",
727 "/usr/sbin",
728 "/bin",
729 "/sbin",
730 ];
731 for prefix in &blocked_prefixes {
732 if abs_str.starts_with(prefix) {
733 anyhow::bail!("Path '{}' is in a protected system directory", path);
734 }
735 }
736
737 if let Some(home) = dirs::home_dir() {
739 let home_str = home.to_string_lossy();
740 if abs_str.starts_with(&*home_str) {
741 let blocked_home_suffixes = [
742 "/.ssh/",
743 "/.gnupg/",
744 "/.aws/",
745 "/.config/gcloud/",
746 "/.kube/",
747 "/.docker/",
748 "/.npmrc",
749 "/.netrc",
750 ];
751 for suffix in &blocked_home_suffixes {
752 if abs_str.contains(suffix) {
753 anyhow::bail!("Path '{}' is in a protected directory", path);
754 }
755 }
756 }
757 }
758
759 Ok(())
760}
761
762fn validate_url(url: &str) -> Result<(), String> {
766 let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
767 let host = parsed.host_str().unwrap_or("").to_lowercase();
768
769 let blocked = [
771 "localhost",
772 "127.0.0.1",
773 "0.0.0.0",
774 "::1",
775 "[::1]",
776 "169.254.169.254", "metadata.google.internal",
778 ];
779 for &b in &blocked {
780 if host == b || host.starts_with(b) {
781 return Err(format!("Blocked internal address: {}", host));
782 }
783 }
784
785 if host.starts_with("10.") || host.starts_with("192.168.") || is_172_private(&host) {
787 return Err(format!("Blocked private address: {}", host));
788 }
789
790 Ok(())
791}
792
793fn is_172_private(host: &str) -> bool {
795 if !host.starts_with("172.") {
796 return false;
797 }
798 let parts: Vec<&str> = host.split('.').collect();
799 if parts.len() < 2 {
800 return false;
801 }
802 if let Ok(second) = parts[1].parse::<u8>() {
803 (16..=31).contains(&second)
804 } else {
805 false
806 }
807}
808
809pub struct WasmExtensionManager {
818 extensions: HashMap<String, LoadedWasmExtension>,
819 pub(crate) plugins: Arc<parking_lot::Mutex<HashMap<String, extism::Plugin>>>,
824 tool_to_ext: HashMap<String, String>,
826 http_client: Arc<reqwest::blocking::Client>,
828 #[allow(dead_code, unused)]
830 permissions: HashMap<String, std::collections::HashSet<String>>,
831}
832
833impl Default for WasmExtensionManager {
834 fn default() -> Self {
835 Self::new()
836 }
837}
838
839impl WasmExtensionManager {
840 pub fn new() -> Self {
842 Self {
843 extensions: HashMap::new(),
844 plugins: Arc::new(Mutex::new(HashMap::new())),
845 tool_to_ext: HashMap::new(),
846 http_client: Arc::new(
847 reqwest::blocking::Client::builder()
848 .timeout(std::time::Duration::from_secs(30))
849 .connect_timeout(std::time::Duration::from_secs(10))
850 .no_proxy() .build()
852 .expect("Failed to build HTTP client"),
853 ),
854 permissions: HashMap::new(),
855 }
856 }
857
858 pub fn with_http_client(client: reqwest::blocking::Client) -> Self {
860 Self {
861 extensions: HashMap::new(),
862 plugins: Arc::new(Mutex::new(HashMap::new())),
863 tool_to_ext: HashMap::new(),
864 http_client: Arc::new(client),
865 permissions: HashMap::new(),
866 }
867 }
868
869 pub fn discover(cwd: &Path) -> Vec<PathBuf> {
873 let mut paths = Vec::new();
874
875 if let Some(home) = dirs::home_dir() {
877 let dir = home.join(".oxi").join("extensions");
878 if dir.is_dir() {
879 Self::discover_in_dir(&dir, &mut paths);
880 }
881 }
882
883 let local_dir = cwd.join(".oxi").join("extensions");
885 if local_dir.is_dir() {
886 Self::discover_in_dir(&local_dir, &mut paths);
887 }
888
889 paths.sort();
890 paths.dedup();
891 paths
892 }
893
894 fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
895 let Ok(entries) = std::fs::read_dir(dir) else {
896 return;
897 };
898 for entry in entries.flatten() {
899 let path = entry.path();
900 if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("wasm") {
901 out.push(path);
902 }
903 }
904 }
905
906 fn host_functions(http_client: &Arc<reqwest::blocking::Client>) -> Vec<Function> {
910 let http_fn = Function::new(
911 "oxi_http_request",
912 [PTR],
913 [PTR],
914 UserData::new(http_client.clone()),
915 host_oxi_http_request,
916 );
917
918 let log_fn = Function::new("oxi_log", [PTR], [], UserData::new(()), host_oxi_log);
919
920 let read_fn = Function::new(
921 "oxi_read_file",
922 [PTR],
923 [PTR],
924 UserData::new(()),
925 host_oxi_read_file,
926 );
927
928 let write_fn = Function::new(
929 "oxi_write_file",
930 [PTR],
931 [PTR],
932 UserData::new(()),
933 host_oxi_write_file,
934 );
935
936 let exec_fn = Function::new("oxi_exec", [PTR], [PTR], UserData::new(()), host_oxi_exec);
937
938 let get_env_fn = Function::new(
939 "oxi_get_env",
940 [PTR],
941 [PTR],
942 UserData::new(()),
943 host_oxi_get_env,
944 );
945
946 let kv_get_fn = Function::new(
947 "oxi_kv_get",
948 [PTR],
949 [PTR],
950 UserData::new(()),
951 host_oxi_kv_get,
952 );
953
954 let kv_set_fn = Function::new("oxi_kv_set", [PTR], [], UserData::new(()), host_oxi_kv_set);
955
956 vec![
957 http_fn, log_fn, read_fn, write_fn, exec_fn, get_env_fn, kv_get_fn, kv_set_fn,
958 ]
959 }
960
961 pub fn load(&mut self, path: &Path) -> Result<ExtensionInfo> {
963 let path_display = path.display().to_string();
964 tracing::info!("Loading WASM extension: {}", path_display);
965
966 let wasm_bytes = std::fs::read(path)
967 .with_context(|| format!("Failed to read extension: {}", path_display))?;
968
969 let wasm = extism::Wasm::data(wasm_bytes);
970 let manifest = extism::Manifest::new([wasm]).with_memory_max(64);
972 let mut plugin = extism::PluginBuilder::new(manifest)
973 .with_wasi(true)
974 .with_functions(Self::host_functions(&self.http_client))
975 .build()
976 .with_context(|| format!("Failed to create Extism plugin from {}", path_display))?;
977
978 let info: ExtensionInfo = match plugin.call::<&str, &str>("init", "{}") {
980 Ok(output) => serde_json::from_str(output)
981 .with_context(|| format!("init() returned invalid JSON: {}", output))?,
982 Err(_) => {
983 let name = path
985 .file_stem()
986 .and_then(|s| s.to_str())
987 .unwrap_or("unknown")
988 .to_string();
989 ExtensionInfo {
990 name,
991 version: "0.0.0".to_string(),
992 description: String::new(),
993 permissions: vec![],
994 }
995 }
996 };
997
998 let ext_name_for_ctx = info.name.clone();
1001 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name_for_ctx));
1002 let tools: Vec<WasmToolDef> = match plugin.call::<&str, &str>("register_tools", "{}") {
1003 Ok(output) => {
1004 let resp: Value = serde_json::from_str(output)
1005 .with_context(|| format!("register_tools() invalid JSON: {}", output))?;
1006 resp.get("tools")
1007 .cloned()
1008 .unwrap_or(Value::Array(vec![]))
1009 .as_array()
1010 .map(|arr| {
1011 arr.iter()
1012 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1013 .collect()
1014 })
1015 .unwrap_or_default()
1016 }
1017 Err(_) => vec![], };
1019
1020 let commands: Vec<WasmCommandDef> =
1022 match plugin.call::<&str, &str>("register_commands", "{}") {
1023 Ok(output) => {
1024 let resp: Value = serde_json::from_str(output)
1025 .with_context(|| format!("register_commands() invalid JSON: {}", output))?;
1026 resp.get("commands")
1027 .cloned()
1028 .unwrap_or(Value::Array(vec![]))
1029 .as_array()
1030 .map(|arr| {
1031 arr.iter()
1032 .filter_map(|v| serde_json::from_value(v.clone()).ok())
1033 .collect()
1034 })
1035 .unwrap_or_default()
1036 }
1037 Err(_) => vec![], };
1039
1040 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1042
1043 let ext_name = info.name.clone();
1044
1045 if self.extensions.contains_key(&ext_name) {
1047 tracing::warn!(
1048 "Extension '{}' already loaded, replacing with '{}'",
1049 ext_name,
1050 path_display
1051 );
1052 self.tool_to_ext.retain(|_, v| v != &ext_name);
1054 self.plugins.lock().remove(&ext_name);
1056 }
1057
1058 for tool in &tools {
1059 self.tool_to_ext.insert(tool.name.clone(), ext_name.clone());
1060 }
1061
1062 let loaded = LoadedWasmExtension {
1063 info: info.clone(),
1064 tools,
1065 commands,
1066 source_path: path.to_path_buf(),
1067 };
1068
1069 self.extensions.insert(ext_name.clone(), loaded);
1070 self.plugins.lock().insert(ext_name, plugin);
1071
1072 tracing::info!(
1073 name = %info.name,
1074 version = %info.version,
1075 tools = self.tool_to_ext.len(),
1076 "WASM extension loaded"
1077 );
1078
1079 Ok(info)
1080 }
1081
1082 pub fn load_all(&mut self, paths: &[PathBuf]) -> (Vec<ExtensionInfo>, Vec<anyhow::Error>) {
1084 let mut loaded = Vec::new();
1085 let mut errors = Vec::new();
1086
1087 for path in paths {
1088 match self.load(path) {
1089 Ok(info) => loaded.push(info),
1090 Err(e) => {
1091 tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
1092 errors.push(e);
1093 }
1094 }
1095 }
1096
1097 (loaded, errors)
1098 }
1099
1100 pub fn execute_tool(&self, tool_name: &str, params: Value) -> Result<Value> {
1104 let ext_name = self
1105 .tool_to_ext
1106 .get(tool_name)
1107 .with_context(|| format!("No extension registered for tool: {}", tool_name))?
1108 .clone();
1109
1110 let mut plugins = self.plugins.lock();
1111 let plugin = plugins
1112 .get_mut(&ext_name)
1113 .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1114
1115 let input = serde_json::json!({
1116 "tool": tool_name,
1117 "params": params,
1118 });
1119 let input_str = serde_json::to_string(&input)?;
1120
1121 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1123 let call_result = plugin.call("execute_tool", &input_str);
1124 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1125
1126 let output: &str = call_result
1127 .with_context(|| format!("execute_tool('{}') failed in '{}'", tool_name, ext_name))?;
1128
1129 let result: Value = serde_json::from_str(output)
1130 .with_context(|| format!("execute_tool() returned invalid JSON: {}", output))?;
1131
1132 Ok(result)
1133 }
1134
1135 pub fn all_tool_defs(&self) -> Vec<&WasmToolDef> {
1139 self.extensions
1140 .values()
1141 .flat_map(|e| e.tools.iter())
1142 .collect()
1143 }
1144
1145 pub fn is_wasm_tool(&self, tool_name: &str) -> bool {
1147 self.tool_to_ext.contains_key(tool_name)
1148 }
1149
1150 pub fn extension_names(&self) -> impl Iterator<Item = &str> {
1152 self.extensions.keys().map(|s| s.as_str())
1153 }
1154
1155 pub fn get_info(&self, name: &str) -> Option<&ExtensionInfo> {
1157 self.extensions.get(name).map(|e| &e.info)
1158 }
1159
1160 pub fn len(&self) -> usize {
1162 self.extensions.len()
1163 }
1164
1165 pub fn is_empty(&self) -> bool {
1167 self.extensions.is_empty()
1168 }
1169
1170 pub fn all_command_defs(&self) -> Vec<(&str, &WasmCommandDef)> {
1174 let mut cmds = Vec::new();
1175 for ext in self.extensions.values() {
1176 for cmd in &ext.commands {
1177 cmds.push((ext.info.name.as_str(), cmd));
1178 }
1179 }
1180 cmds
1181 }
1182
1183 pub fn execute_command(&self, command_name: &str, args: &str) -> Result<String> {
1186 let ext_name = self
1188 .extensions
1189 .iter()
1190 .find(|(_, ext)| ext.commands.iter().any(|c| c.name == command_name))
1191 .map(|(name, _)| name.clone())
1192 .with_context(|| format!("No extension registered for command: /{}", command_name))?;
1193
1194 let mut plugins = self.plugins.lock();
1195 let plugin = plugins
1196 .get_mut(&ext_name)
1197 .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1198
1199 let input = serde_json::json!({
1200 "command": command_name,
1201 "args": args,
1202 });
1203 let input_str = serde_json::to_string(&input)?;
1204
1205 let output: &str = {
1206 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1207 let result = plugin.call("execute_command", &input_str);
1208 CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1209 result
1210 }
1211 .with_context(|| {
1212 format!(
1213 "execute_command('/{}') failed in '{}'",
1214 command_name, ext_name
1215 )
1216 })?;
1217
1218 let result: Value =
1220 serde_json::from_str(output).unwrap_or_else(|_| serde_json::json!({"output": output}));
1221
1222 Ok(result
1223 .get("output")
1224 .and_then(|v| v.as_str())
1225 .unwrap_or(output)
1226 .to_string())
1227 }
1228}
1229
1230#[cfg(test)]
1233mod tests {
1234 use super::*;
1235
1236 #[test]
1237 fn test_discover_empty_dir() {
1238 let dir = tempfile::tempdir().unwrap();
1239 let paths = WasmExtensionManager::discover(dir.path());
1240 assert!(paths.is_empty());
1241 }
1242
1243 #[test]
1244 fn test_discover_finds_wasm_files() {
1245 let dir = tempfile::tempdir().unwrap();
1246 let wasm_path = dir.path().join("test_ext.wasm");
1247 std::fs::write(&wasm_path, b"\x00asm").unwrap();
1248 std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
1250
1251 let mut paths = Vec::new();
1252 WasmExtensionManager::discover_in_dir(dir.path(), &mut paths);
1253 assert_eq!(paths.len(), 1);
1254 assert!(paths[0].ends_with("test_ext.wasm"));
1255 }
1256
1257 #[test]
1258 fn test_extension_info_parse() {
1259 let json = r#"{"name":"my_ext","version":"1.0.0","description":"Test"}"#;
1260 let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1261 assert_eq!(info.name, "my_ext");
1262 assert_eq!(info.version, "1.0.0");
1263 }
1264
1265 #[test]
1266 fn test_tool_def_parse() {
1267 let json = r#"{"name":"search","description":"Search","schema":{"type":"object"}}"#;
1268 let tool: WasmToolDef = serde_json::from_str(json).unwrap();
1269 assert_eq!(tool.name, "search");
1270 }
1271
1272 #[test]
1273 fn test_manager_new_is_empty() {
1274 let mgr = WasmExtensionManager::new();
1275 assert!(mgr.is_empty());
1276 assert_eq!(mgr.len(), 0);
1277 }
1278
1279 #[test]
1280 fn test_is_wasm_tool_false() {
1281 let mgr = WasmExtensionManager::new();
1282 assert!(!mgr.is_wasm_tool("anything"));
1283 }
1284
1285 #[test]
1286 fn test_extension_info_default_description() {
1287 let json = r#"{"name":"test","version":"0.1"}"#;
1288 let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1289 assert_eq!(info.description, "");
1290 }
1291
1292 #[test]
1293 fn test_ssrf_blocks_localhost() {
1294 assert!(validate_url("http://localhost/admin").is_err());
1295 assert!(validate_url("http://127.0.0.1/secret").is_err());
1296 assert!(validate_url("http://10.0.0.1/internal").is_err());
1297 assert!(validate_url("http://192.168.1.1/router").is_err());
1298 assert!(validate_url("http://172.16.0.1/corp").is_err());
1299 assert!(validate_url("http://169.254.169.254/metadata").is_err());
1300 assert!(validate_url("http://[::1]/ipv6").is_err());
1301 assert!(validate_url("http://0.0.0.0/admin").is_err());
1303 }
1304
1305 #[test]
1306 fn test_ssrf_allows_public() {
1307 assert!(validate_url("https://api.github.com/repos/test").is_ok());
1308 assert!(validate_url("https://example.com/api").is_ok());
1309 assert!(validate_url("https://search.brave.com/api/search?q=test").is_ok());
1310 }
1311
1312 #[test]
1313 fn test_ssrf_172_range() {
1314 assert!(validate_url("http://172.16.0.1/test").is_err());
1315 assert!(validate_url("http://172.31.255.255/test").is_err());
1316 assert!(validate_url("http://172.15.0.1/test").is_ok());
1317 assert!(validate_url("http://172.32.0.1/test").is_ok());
1318 }
1319}