Skip to main content

oxi/extensions/
wasm.rs

1//! WASM-based extension system powered by Extism.
2//!
3//! Loads `.wasm` extension files from `~/.oxi/extensions/` and project-local
4//! `.oxi/extensions/`. Each extension exports well-known functions (`init`,
5//! `register_tools`, `execute_tool`) called via Extism's JSON-in/JSON-out
6//! protocol. Extensions run inside a WASM sandbox (WASI enabled). Host access
7//! is provided through injected host functions; arbitrary command execution
8//! (`oxi_exec`) is **disabled by default** and must be opted in via the
9//! `OXI_EXTENSION_EXEC=1` environment variable.
10
11use anyhow::{Context, Result};
12use extism::{CurrentPlugin, Function, PTR, UserData, Val};
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::io::Read as _;
19use std::path::{Path, PathBuf};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23// ── Types ────────────────────────────────────────────────────────────
24
25/// Metadata returned by an extension's `init()` function.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ExtensionInfo {
28    pub name: String,
29    pub version: String,
30    #[serde(default)]
31    pub description: String,
32    /// Permissions requested by the extension.
33    /// Supported: "fs_read", "fs_write", "exec", "env", "network"
34    #[serde(default)]
35    pub permissions: Vec<String>,
36}
37
38/// A tool definition returned by `register_tools()`.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct WasmToolDef {
41    pub name: String,
42    pub description: String,
43    pub schema: Value,
44}
45
46/// A command definition returned by `register_commands()`.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct WasmCommandDef {
49    pub name: String,
50    pub description: String,
51}
52
53/// Result of loading a WASM extension.
54#[derive(Debug)]
55pub struct LoadedWasmExtension {
56    pub info: ExtensionInfo,
57    pub tools: Vec<WasmToolDef>,
58    pub commands: Vec<WasmCommandDef>,
59    pub source_path: PathBuf,
60}
61
62// ── Host Functions ────────────────────────────────────────────────────
63
64/// Host function: `oxi_http_request(request_json) -> response_json`
65///
66/// Request JSON: `{"url": "...", "method": "GET", "headers": {...}, "body": "..."}`
67/// Response JSON: `{"status": 200, "headers": {...}, "body": "..."}`
68fn host_oxi_http_request(
69    plugin: &mut CurrentPlugin,
70    inputs: &[Val],
71    outputs: &mut [Val],
72    user_data: UserData<Arc<reqwest::blocking::Client>>,
73) -> Result<(), extism::Error> {
74    // We use anyhow internally, then convert at the boundary
75    let result: anyhow::Result<()> = (|| {
76        let input_json: String = plugin.memory_get_val(&inputs[0])?;
77
78        #[derive(Deserialize)]
79        struct HttpReq {
80            url: String,
81            #[serde(default)]
82            method: String,
83            #[serde(default)]
84            headers: HashMap<String, String>,
85            #[serde(default)]
86            body: Option<String>,
87        }
88
89        let req: HttpReq =
90            serde_json::from_str(&input_json).context("oxi_http_request: invalid request JSON")?;
91
92        let method = if req.method.is_empty() {
93            "GET"
94        } else {
95            &req.method
96        };
97
98        // SSRF protection: block internal/private network addresses
99        if let Err(e) = validate_url(&req.url) {
100            anyhow::bail!("oxi_http_request: {}", e);
101        }
102
103        // UserData::get() returns Arc<Mutex<T>>
104        let client_arc = user_data.get()?;
105        let client = client_arc.lock().expect("wasm client lock poisoned");
106
107        let method = match method.to_uppercase().as_str() {
108            "GET" => reqwest::Method::GET,
109            "POST" => reqwest::Method::POST,
110            "PUT" => reqwest::Method::PUT,
111            "DELETE" => reqwest::Method::DELETE,
112            "PATCH" => reqwest::Method::PATCH,
113            "HEAD" => reqwest::Method::HEAD,
114            other => anyhow::bail!("oxi_http_request: unsupported method '{}'", other),
115        };
116
117        let mut rb = client.request(method, &req.url);
118        for (k, v) in &req.headers {
119            rb = rb.header(k, v);
120        }
121        if let Some(body) = &req.body {
122            rb = rb.body(body.clone());
123        }
124
125        // Execute HTTP request (blocking — called from spawn_blocking in wasm_tool.rs)
126        let resp = rb
127            .send()
128            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
129        let status = resp.status().as_u16();
130        let resp_headers: HashMap<String, String> = resp
131            .headers()
132            .iter()
133            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
134            .collect();
135        // Limit response body to 1MB to prevent memory exhaustion
136        let resp_body = {
137            let max_body = 1024 * 1024; // 1MB
138            let body_bytes = resp
139                .bytes()
140                .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
141            if body_bytes.len() > max_body {
142                tracing::warn!(
143                    "HTTP response truncated: {} bytes > {} limit",
144                    body_bytes.len(),
145                    max_body
146                );
147                String::from_utf8_lossy(&body_bytes[..max_body]).to_string()
148            } else {
149                String::from_utf8_lossy(&body_bytes).to_string()
150            }
151        };
152
153        let response = serde_json::json!({
154            "status": status,
155            "headers": resp_headers,
156            "body": resp_body,
157        });
158
159        let output = serde_json::to_string(&response)?;
160        let handle = plugin.memory_new(&output)?;
161        if !outputs.is_empty() {
162            outputs[0] = plugin.memory_to_val(handle);
163        }
164        Ok(())
165    })();
166
167    result
168}
169
170/// Host function: `oxi_log(message)` — logs a debug message from WASM.
171fn host_oxi_log(
172    plugin: &mut CurrentPlugin,
173    inputs: &[Val],
174    _outputs: &mut [Val],
175    _user_data: UserData<()>,
176) -> Result<(), extism::Error> {
177    let message: String = plugin.memory_get_val(&inputs[0])?;
178    tracing::debug!("[WASM] {}", message);
179    Ok(())
180}
181
182// ── File I/O Host Functions ─────────────────────────────────────────
183
184/// Host function: `oxi_read_file(path_json) → result_json`
185///
186/// Input: `{"path": "/path/to/file", "offset": 0, "limit": 2000}`
187/// Output: `{"success": true, "content": "...", "truncated": false, "bytes": 1234}` or
188///         `{"success": false, "error": "..."}`
189fn host_oxi_read_file(
190    plugin: &mut CurrentPlugin,
191    inputs: &[Val],
192    outputs: &mut [Val],
193    _user_data: UserData<()>,
194) -> Result<(), extism::Error> {
195    let result: anyhow::Result<()> = (|| {
196        let input_json: String = plugin.memory_get_val(&inputs[0])?;
197
198        #[derive(Deserialize)]
199        struct ReadReq {
200            path: String,
201            #[serde(default)]
202            offset: Option<usize>,
203            #[serde(default = "default_limit")]
204            limit: usize,
205        }
206        fn default_limit() -> usize {
207            2000
208        }
209
210        let req: ReadReq =
211            serde_json::from_str(&input_json).context("oxi_read_file: invalid request JSON")?;
212
213        // Validate path is not outside cwd
214        validate_path_allowed(&req.path)?;
215
216        let metadata = std::fs::metadata(&req.path);
217        match metadata {
218            Ok(m) => {
219                let max_bytes = 50 * 1024; // 50KB limit
220                let file_size = m.len() as usize;
221
222                let content = std::fs::read_to_string(&req.path)
223                    .map_err(|e| anyhow::anyhow!("Failed to read file: {}", e))?;
224
225                let lines: Vec<&str> = content.lines().collect();
226                let total_lines = lines.len();
227
228                let offset = req.offset.unwrap_or(0).min(total_lines);
229                let end = (offset + req.limit).min(total_lines);
230                let selected: Vec<&str> = lines[offset..end].to_vec();
231                let mut result = selected.join("\n");
232
233                let truncated = result.len() > max_bytes;
234                if truncated {
235                    result = result.chars().take(max_bytes).collect();
236                }
237
238                let response = serde_json::json!({
239                    "success": true,
240                    "content": result,
241                    "truncated": truncated || end < total_lines,
242                    "bytes": file_size,
243                    "total_lines": total_lines,
244                    "shown_lines": end - offset,
245                });
246                let output = serde_json::to_string(&response)?;
247                let handle = plugin.memory_new(&output)?;
248                if !outputs.is_empty() {
249                    outputs[0] = plugin.memory_to_val(handle);
250                }
251            }
252            Err(e) => {
253                let response = serde_json::json!({
254                    "success": false,
255                    "error": format!("File not found: {}", e),
256                });
257                let output = serde_json::to_string(&response)?;
258                let handle = plugin.memory_new(&output)?;
259                if !outputs.is_empty() {
260                    outputs[0] = plugin.memory_to_val(handle);
261                }
262            }
263        }
264        Ok(())
265    })();
266    result
267}
268
269/// Host function: `oxi_write_file(path_json) → result_json`
270///
271/// Input: `{"path": "/path/to/file", "content": "...", "create_dirs": true}`
272/// Output: `{"success": true, "bytes_written": 1234}` or `{"success": false, "error": "..."}`
273fn host_oxi_write_file(
274    plugin: &mut CurrentPlugin,
275    inputs: &[Val],
276    outputs: &mut [Val],
277    _user_data: UserData<()>,
278) -> Result<(), extism::Error> {
279    let result: anyhow::Result<()> = (|| {
280        let input_json: String = plugin.memory_get_val(&inputs[0])?;
281
282        #[derive(Deserialize)]
283        struct WriteReq {
284            path: String,
285            content: String,
286            #[serde(default = "default_true")]
287            create_dirs: bool,
288        }
289        fn default_true() -> bool {
290            true
291        }
292
293        let req: WriteReq =
294            serde_json::from_str(&input_json).context("oxi_write_file: invalid request JSON")?;
295
296        validate_path_allowed(&req.path)?;
297
298        if req.create_dirs
299            && let Some(parent) = std::path::Path::new(&req.path).parent()
300        {
301            std::fs::create_dir_all(parent)
302                .map_err(|e| anyhow::anyhow!("Failed to create directories: {}", e))?;
303        }
304
305        let bytes = req.content.len();
306        std::fs::write(&req.path, &req.content)
307            .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
308
309        let response = serde_json::json!({
310            "success": true,
311            "bytes_written": bytes,
312        });
313        let output = serde_json::to_string(&response)?;
314        let handle = plugin.memory_new(&output)?;
315        if !outputs.is_empty() {
316            outputs[0] = plugin.memory_to_val(handle);
317        }
318        Ok(())
319    })();
320    result
321}
322
323/// Host function: `oxi_exec(exec_json) → result_json`
324///
325/// Input: `{"command": "git status", "args": [], "cwd": ".", "timeout": 30}`
326/// Output: `{"success": true, "stdout": "...", "stderr": "...", "exit_code": 0}` or error.
327fn host_oxi_exec(
328    plugin: &mut CurrentPlugin,
329    inputs: &[Val],
330    outputs: &mut [Val],
331    _user_data: UserData<()>,
332) -> Result<(), extism::Error> {
333    let result: anyhow::Result<()> = (|| {
334        // Security: arbitrary command execution is OFF by default. The host
335        // must opt in explicitly via `OXI_EXTENSION_EXEC=1`. This restores a
336        // least-privilege default — extensions cannot run commands unless the
337        // user enables it.
338        if std::env::var("OXI_EXTENSION_EXEC").ok().as_deref() != Some("1") {
339            let response = serde_json::json!({
340                "success": false,
341                "error": "oxi_exec is disabled; set OXI_EXTENSION_EXEC=1 to allow extensions to run commands",
342                "exit_code": -1,
343            });
344            let output = serde_json::to_string(&response)?;
345            let handle = plugin.memory_new(&output)?;
346            if !outputs.is_empty() {
347                outputs[0] = plugin.memory_to_val(handle);
348            }
349            return Ok(());
350        }
351        let input_json: String = plugin.memory_get_val(&inputs[0])?;
352
353        #[derive(Deserialize)]
354        struct ExecReq {
355            command: String,
356            #[serde(default)]
357            args: Vec<String>,
358            #[serde(default)]
359            cwd: Option<String>,
360            #[serde(default = "default_timeout")]
361            timeout: u64,
362        }
363        fn default_timeout() -> u64 {
364            30
365        }
366
367        let req: ExecReq =
368            serde_json::from_str(&input_json).context("oxi_exec: invalid request JSON")?;
369
370        let cwd = req.cwd.as_deref().unwrap_or(".");
371
372        // Build full command string for deny-list checking (command + args)
373        let full_cmd = if req.args.is_empty() {
374            req.command.clone()
375        } else {
376            format!("{} {}", req.command, req.args.join(" "))
377        };
378
379        // Block dangerous commands — deny-list (checks combined command+args)
380        let blocked_patterns = [
381            "rm -rf /",
382            "rm -rf /*",
383            "mkfs",
384            "dd if=",
385            "format ",
386            ":(){ :|:& };:",
387            "chmod 777 /",
388            "chown root",
389            "> /etc/",
390            "> /boot/",
391            "> /dev/",
392            "dd of=/dev/",
393            "mv / /",
394        ];
395        for blocked in &blocked_patterns {
396            if full_cmd.contains(blocked) {
397                anyhow::bail!("oxi_exec: blocked dangerous command pattern");
398            }
399        }
400
401        // Block obvious privilege escalation
402        let cmd_lower = req.command.to_lowercase();
403        if cmd_lower == "sudo"
404            || cmd_lower == "su"
405            || cmd_lower == "doas"
406            || cmd_lower.starts_with("sudo ")
407            || cmd_lower.starts_with("su ")
408            || cmd_lower.starts_with("doas ")
409        {
410            anyhow::bail!("oxi_exec: privilege escalation commands are blocked");
411        }
412
413        // Clamp timeout to 1-30 seconds to prevent abuse
414        let timeout_ms = req.timeout.clamp(1000, 30000);
415        let timeout_dur = Duration::from_millis(timeout_ms);
416
417        // Spawn child process with piped stdout/stderr
418        let mut child = match std::process::Command::new(&req.command)
419            .args(&req.args)
420            .current_dir(cwd)
421            .stdout(std::process::Stdio::piped())
422            .stderr(std::process::Stdio::piped())
423            .spawn()
424        {
425            Ok(c) => c,
426            Err(e) => {
427                let response = serde_json::json!({
428                    "success": false,
429                    "error": format!("Failed to execute: {}", e),
430                    "exit_code": -1,
431                });
432                let out = serde_json::to_string(&response)?;
433                let handle = plugin.memory_new(&out)?;
434                if !outputs.is_empty() {
435                    outputs[0] = plugin.memory_to_val(handle);
436                }
437                return Ok(());
438            }
439        };
440
441        // Poll for completion with timeout enforcement
442        let start = Instant::now();
443        let mut timed_out = false;
444        let mut exit_status: Option<std::process::ExitStatus> = None;
445
446        loop {
447            match child.try_wait() {
448                Ok(Some(status)) => {
449                    exit_status = Some(status);
450                    break;
451                }
452                Ok(None) => {
453                    if start.elapsed() >= timeout_dur {
454                        // Timeout reached — kill the process and clean up
455                        tracing::warn!(
456                            "oxi_exec: command '{}' timed out after {}ms",
457                            req.command,
458                            timeout_ms
459                        );
460                        let _ = child.kill();
461                        let _ = child.wait(); // reap zombie
462                        timed_out = true;
463                        break;
464                    }
465                    std::thread::sleep(Duration::from_millis(50));
466                }
467                Err(_) => {
468                    // try_wait failed — fall back to blocking wait
469                    match child.wait() {
470                        Ok(status) => {
471                            exit_status = Some(status);
472                        }
473                        Err(_) => {
474                            timed_out = true;
475                        }
476                    }
477                    break;
478                }
479            }
480        }
481
482        // Collect stdout/stderr from the child pipes (read what was buffered)
483        let mut stdout_buf = Vec::new();
484        let mut stderr_buf = Vec::new();
485        if let Some(mut out) = child.stdout.take() {
486            let _ = out.read_to_end(&mut stdout_buf);
487        }
488        if let Some(mut err) = child.stderr.take() {
489            let _ = err.read_to_end(&mut stderr_buf);
490        }
491
492        let stdout = String::from_utf8_lossy(&stdout_buf);
493        let stderr = String::from_utf8_lossy(&stderr_buf);
494        let max_output = 50 * 1024; // 50KB
495        let stdout_truncated = stdout.len() > max_output;
496        let stderr_truncated = stderr.len() > max_output;
497        let stdout_str: String = if stdout_truncated {
498            stdout.chars().take(max_output).collect()
499        } else {
500            stdout.to_string()
501        };
502        let stderr_str: String = if stderr_truncated {
503            stderr.chars().take(max_output).collect()
504        } else {
505            stderr.to_string()
506        };
507
508        let response = serde_json::json!({
509            "success": !timed_out && exit_status.map(|s| s.success()).unwrap_or(false),
510            "stdout": stdout_str,
511            "stderr": stderr_str,
512            "exit_code": if timed_out { -2 } else { exit_status.and_then(|s| s.code()).unwrap_or(-1) },
513            "stdout_truncated": stdout_truncated,
514            "stderr_truncated": stderr_truncated,
515            "timed_out": timed_out,
516        });
517        let out = serde_json::to_string(&response)?;
518        let handle = plugin.memory_new(&out)?;
519        if !outputs.is_empty() {
520            outputs[0] = plugin.memory_to_val(handle);
521        }
522        Ok(())
523    })();
524    result
525}
526
527/// Host function: `oxi_get_env(key_json) → result_json`
528///
529/// Input: `{"key": "HOME"}`
530/// Output: `{"success": true, "value": "/home/user"}` or `{"success": false, "error": "not found"}`
531fn host_oxi_get_env(
532    plugin: &mut CurrentPlugin,
533    inputs: &[Val],
534    outputs: &mut [Val],
535    _user_data: UserData<()>,
536) -> Result<(), extism::Error> {
537    let result: anyhow::Result<()> = (|| {
538        let input_json: String = plugin.memory_get_val(&inputs[0])?;
539
540        #[derive(Deserialize)]
541        struct EnvReq {
542            key: String,
543        }
544
545        let req: EnvReq =
546            serde_json::from_str(&input_json).context("oxi_get_env: invalid request JSON")?;
547
548        // Block sensitive env vars
549        let blocked_keys = ["AWS_SECRET", "PRIVATE_KEY", "PASSWORD", "TOKEN", "SECRET"];
550        let key_upper = req.key.to_uppercase();
551        for blocked in &blocked_keys {
552            if key_upper.contains(blocked) {
553                anyhow::bail!("oxi_get_env: access to '{}' is blocked", req.key);
554            }
555        }
556
557        let value = std::env::var(&req.key).ok();
558        let response = serde_json::json!({
559            "success": value.is_some(),
560            "value": value.unwrap_or_default(),
561        });
562        let output = serde_json::to_string(&response)?;
563        let handle = plugin.memory_new(&output)?;
564        if !outputs.is_empty() {
565            outputs[0] = plugin.memory_to_val(handle);
566        }
567        Ok(())
568    })();
569    result
570}
571
572// ── KV Store Host Functions ─────────────────────────────────────────
573
574/// Host function: `oxi_kv_get(key_json) → result_json`
575///
576/// Persistent key-value store for extension state.
577/// Keys are namespaced per extension.
578/// Input: `{"key": "my_state"}`
579/// Output: `{"success": true, "value": "..."}` or `{"success": false}`
580fn host_oxi_kv_get(
581    plugin: &mut CurrentPlugin,
582    inputs: &[Val],
583    outputs: &mut [Val],
584    _user_data: UserData<()>,
585) -> Result<(), extism::Error> {
586    let result: anyhow::Result<()> = (|| {
587        let input_json: String = plugin.memory_get_val(&inputs[0])?;
588
589        #[derive(Deserialize)]
590        struct KvReq {
591            key: String,
592        }
593
594        let req: KvReq =
595            serde_json::from_str(&input_json).context("oxi_kv_get: invalid request JSON")?;
596
597        // Namespace the key with the current extension identity
598        let ext_name = current_extension_name();
599        let value = kv_namespaced_get(&ext_name, &req.key);
600        let response = serde_json::json!({
601            "success": value.is_some(),
602            "value": value.unwrap_or_default(),
603        });
604        let output = serde_json::to_string(&response)?;
605        let handle = plugin.memory_new(&output)?;
606        if !outputs.is_empty() {
607            outputs[0] = plugin.memory_to_val(handle);
608        }
609        Ok(())
610    })();
611    result
612}
613
614/// Host function: `oxi_kv_set(set_json)`
615///
616/// Input: `{"key": "my_state", "value": "saved_data"}`
617fn host_oxi_kv_set(
618    plugin: &mut CurrentPlugin,
619    inputs: &[Val],
620    _outputs: &mut [Val],
621    _user_data: UserData<()>,
622) -> Result<(), extism::Error> {
623    let result: anyhow::Result<()> = (|| {
624        let input_json: String = plugin.memory_get_val(&inputs[0])?;
625
626        #[derive(Deserialize)]
627        struct KvSetReq {
628            key: String,
629            value: String,
630        }
631
632        let req: KvSetReq =
633            serde_json::from_str(&input_json).context("oxi_kv_set: invalid request JSON")?;
634
635        // Namespace the key with the current extension identity
636        let ext_name = current_extension_name();
637        kv_namespaced_set(&ext_name, &req.key, &req.value);
638        Ok(())
639    })();
640    result
641}
642
643// ── KV Store Implementation ─────────────────────────────────────────
644
645use std::sync::LazyLock;
646
647static KV_STORE: LazyLock<parking_lot::RwLock<HashMap<String, String>>> =
648    LazyLock::new(|| parking_lot::RwLock::new(HashMap::new()));
649
650// Thread-local tracking the currently executing extension name.
651// Set by `execute_tool` / `execute_command` / `load` before invoking plugin
652// calls, read by KV host functions to namespace keys.
653thread_local! {
654    static CURRENT_EXTENSION: RefCell<Option<String>> = const { RefCell::new(None) };
655}
656
657/// Run a closure with the current extension name set in thread-local storage.
658#[allow(dead_code)]
659fn with_extension_context<F, R>(ext_name: &str, f: F) -> R
660where
661    F: FnOnce() -> R,
662{
663    CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.to_string()));
664    let result = f();
665    CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
666    result
667}
668
669/// Get the current extension name from thread-local storage.
670/// Returns `"__unknown__"` if not set (e.g., during `load()` before `init()`).
671fn current_extension_name() -> String {
672    CURRENT_EXTENSION.with(|cell| {
673        cell.borrow()
674            .clone()
675            .unwrap_or_else(|| "__unknown__".to_string())
676    })
677}
678
679fn kv_store_get(key: &str) -> Option<String> {
680    KV_STORE.read().get(key).cloned()
681}
682
683fn kv_store_set(key: &str, value: &str) {
684    KV_STORE.write().insert(key.to_string(), value.to_string());
685}
686
687/// Namespaced KV access — prefixes key with `extension:` to prevent
688/// cross-extension key collision.
689fn kv_namespaced_get(extension: &str, key: &str) -> Option<String> {
690    let namespaced = format!("{}:{}", extension, key);
691    kv_store_get(&namespaced)
692}
693
694fn kv_namespaced_set(extension: &str, key: &str, value: &str) {
695    let namespaced = format!("{}:{}", extension, key);
696    kv_store_set(&namespaced, value);
697}
698
699// ── Path Validation ─────────────────────────────────────────────────
700
701/// Validate that a path is within the allowed directories.
702/// Extensions can only read/write outside protected system paths.
703fn validate_path_allowed(path: &str) -> Result<()> {
704    let p = std::path::Path::new(path);
705
706    // Resolve to absolute, then canonicalize to eliminate ../ and symlinks
707    let abs = if p.is_absolute() {
708        p.to_path_buf()
709    } else {
710        std::env::current_dir().unwrap_or_default().join(p)
711    };
712
713    // Try canonicalize for existing paths; fall back to resolved absolute
714    let resolved = if abs.exists() {
715        abs.canonicalize().unwrap_or(abs)
716    } else {
717        // For new files, resolve parent if it exists
718        if let Some(parent) = abs.parent() {
719            if parent.exists() {
720                let canon_parent = parent
721                    .canonicalize()
722                    .unwrap_or_else(|_| parent.to_path_buf());
723                canon_parent.join(abs.file_name().unwrap_or_default())
724            } else {
725                abs
726            }
727        } else {
728            abs
729        }
730    };
731
732    let abs_str = resolved.to_string_lossy();
733
734    // Block sensitive system paths
735    let blocked_prefixes = [
736        "/etc",
737        "/sys",
738        "/proc",
739        "/dev",
740        "/boot",
741        "/root",
742        "/System",
743        "/Library/System",
744        "/usr/bin",
745        "/usr/sbin",
746        "/bin",
747        "/sbin",
748    ];
749    for prefix in &blocked_prefixes {
750        if abs_str.starts_with(prefix) {
751            anyhow::bail!("Path '{}' is in a protected system directory", path);
752        }
753    }
754
755    // Block hidden sensitive paths in home directory
756    if let Some(home) = dirs::home_dir() {
757        let home_str = home.to_string_lossy();
758        if abs_str.starts_with(&*home_str) {
759            let blocked_home_suffixes = [
760                "/.ssh/",
761                "/.gnupg/",
762                "/.aws/",
763                "/.config/gcloud/",
764                "/.kube/",
765                "/.docker/",
766                "/.npmrc",
767                "/.netrc",
768            ];
769            for suffix in &blocked_home_suffixes {
770                if abs_str.contains(suffix) {
771                    anyhow::bail!("Path '{}' is in a protected directory", path);
772                }
773            }
774        }
775    }
776
777    Ok(())
778}
779
780// ── SSRF Protection ─────────────────────────────────────────────────
781
782/// Validate that a URL does not target internal/private network addresses.
783fn validate_url(url: &str) -> Result<(), String> {
784    let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
785    let host = parsed.host_str().unwrap_or("").to_lowercase();
786
787    // Block private IPs, localhost, link-local, and internal hostnames
788    let blocked = [
789        "localhost",
790        "127.0.0.1",
791        "0.0.0.0",
792        "::1",
793        "[::1]",
794        "169.254.169.254", // cloud metadata
795        "metadata.google.internal",
796    ];
797    for &b in &blocked {
798        if host == b || host.starts_with(b) {
799            return Err(format!("Blocked internal address: {}", host));
800        }
801    }
802
803    // Block private IP ranges (10.x, 172.16-31.x, 192.168.x)
804    if host.starts_with("10.") || host.starts_with("192.168.") || is_172_private(&host) {
805        return Err(format!("Blocked private address: {}", host));
806    }
807
808    Ok(())
809}
810
811/// Check if host is in 172.16.0.0/12 range.
812fn is_172_private(host: &str) -> bool {
813    if !host.starts_with("172.") {
814        return false;
815    }
816    let parts: Vec<&str> = host.split('.').collect();
817    if parts.len() < 2 {
818        return false;
819    }
820    if let Ok(second) = parts[1].parse::<u8>() {
821        (16..=31).contains(&second)
822    } else {
823        false
824    }
825}
826
827// ── Manager ──────────────────────────────────────────────────────────
828
829/// Manages WASM extensions: discovery, loading, and tool execution.
830///
831/// # Thread Safety
832/// `plugins` is wrapped in `Arc<Mutex<>>` for exclusive thread-safe access.
833/// All host functions run inside `spawn_blocking`, so WASM execution
834/// never blocks the async runtime.
835pub struct WasmExtensionManager {
836    extensions: HashMap<String, LoadedWasmExtension>,
837    /// Raw Extism plugin references — needed for execute_tool calls.
838    /// Wrapped in Mutex for exclusive thread-safe access (required because
839    /// extism::Plugin's Send+Sync bounds are not publicly documented).
840    /// All access goes through `plugins.lock()` which ensures mutual exclusion.
841    pub(crate) plugins: Arc<parking_lot::Mutex<HashMap<String, extism::Plugin>>>,
842    /// Maps tool name → extension name.
843    tool_to_ext: HashMap<String, String>,
844    /// HTTP client shared by all extensions for oxi_http_request.
845    http_client: Arc<reqwest::blocking::Client>,
846    /// Permissions granted to each extension (ext_name → set of permission strings).
847    #[allow(dead_code, unused)]
848    permissions: HashMap<String, std::collections::HashSet<String>>,
849}
850
851impl Default for WasmExtensionManager {
852    fn default() -> Self {
853        Self::new()
854    }
855}
856
857impl WasmExtensionManager {
858    /// Create an empty manager.
859    pub fn new() -> Self {
860        Self {
861            extensions: HashMap::new(),
862            plugins: Arc::new(Mutex::new(HashMap::new())),
863            tool_to_ext: HashMap::new(),
864            http_client: Arc::new(
865                reqwest::blocking::Client::builder()
866                    .timeout(std::time::Duration::from_secs(30))
867                    .connect_timeout(std::time::Duration::from_secs(10))
868                    .no_proxy() // Prevent proxy-based SSRF
869                    .build()
870                    .expect("Failed to build HTTP client"),
871            ),
872            permissions: HashMap::new(),
873        }
874    }
875
876    /// Create with a custom HTTP client (useful for testing with a mock client).
877    pub fn with_http_client(client: reqwest::blocking::Client) -> Self {
878        Self {
879            extensions: HashMap::new(),
880            plugins: Arc::new(Mutex::new(HashMap::new())),
881            tool_to_ext: HashMap::new(),
882            http_client: Arc::new(client),
883            permissions: HashMap::new(),
884        }
885    }
886
887    // ── Discovery ──────────────────────────────────────────────────
888
889    /// Discover `.wasm` files in standard extension directories.
890    pub fn discover(cwd: &Path) -> Vec<PathBuf> {
891        let mut paths = Vec::new();
892
893        // ~/.oxi/extensions/
894        if let Some(home) = dirs::home_dir() {
895            let dir = home.join(".oxi").join("extensions");
896            if dir.is_dir() {
897                Self::discover_in_dir(&dir, &mut paths);
898            }
899        }
900
901        // .oxi/extensions/ (project-local)
902        let local_dir = cwd.join(".oxi").join("extensions");
903        if local_dir.is_dir() {
904            Self::discover_in_dir(&local_dir, &mut paths);
905        }
906
907        paths.sort();
908        paths.dedup();
909        paths
910    }
911
912    fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
913        let Ok(entries) = std::fs::read_dir(dir) else {
914            return;
915        };
916        for entry in entries.flatten() {
917            let path = entry.path();
918            if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("wasm") {
919                out.push(path);
920            }
921        }
922    }
923
924    // ── Loading ────────────────────────────────────────────────────
925
926    /// Build the host functions available to all WASM extensions.
927    fn host_functions(http_client: &Arc<reqwest::blocking::Client>) -> Vec<Function> {
928        let http_fn = Function::new(
929            "oxi_http_request",
930            [PTR],
931            [PTR],
932            UserData::new(http_client.clone()),
933            host_oxi_http_request,
934        );
935
936        let log_fn = Function::new("oxi_log", [PTR], [], UserData::new(()), host_oxi_log);
937
938        let read_fn = Function::new(
939            "oxi_read_file",
940            [PTR],
941            [PTR],
942            UserData::new(()),
943            host_oxi_read_file,
944        );
945
946        let write_fn = Function::new(
947            "oxi_write_file",
948            [PTR],
949            [PTR],
950            UserData::new(()),
951            host_oxi_write_file,
952        );
953
954        let exec_fn = Function::new("oxi_exec", [PTR], [PTR], UserData::new(()), host_oxi_exec);
955
956        let get_env_fn = Function::new(
957            "oxi_get_env",
958            [PTR],
959            [PTR],
960            UserData::new(()),
961            host_oxi_get_env,
962        );
963
964        let kv_get_fn = Function::new(
965            "oxi_kv_get",
966            [PTR],
967            [PTR],
968            UserData::new(()),
969            host_oxi_kv_get,
970        );
971
972        let kv_set_fn = Function::new("oxi_kv_set", [PTR], [], UserData::new(()), host_oxi_kv_set);
973
974        vec![
975            http_fn, log_fn, read_fn, write_fn, exec_fn, get_env_fn, kv_get_fn, kv_set_fn,
976        ]
977    }
978
979    /// Load a single `.wasm` extension.
980    pub fn load(&mut self, path: &Path) -> Result<ExtensionInfo> {
981        let path_display = path.display().to_string();
982        tracing::info!("Loading WASM extension: {}", path_display);
983
984        let wasm_bytes = std::fs::read(path)
985            .with_context(|| format!("Failed to read extension: {}", path_display))?;
986
987        let wasm = extism::Wasm::data(wasm_bytes);
988        // Limit WASM memory to 64 pages (4MB) to prevent unbounded allocation
989        let manifest = extism::Manifest::new([wasm]).with_memory_max(64);
990        let mut plugin = extism::PluginBuilder::new(manifest)
991            .with_wasi(true)
992            .with_functions(Self::host_functions(&self.http_client))
993            .build()
994            .with_context(|| format!("Failed to create Extism plugin from {}", path_display))?;
995
996        // Call init()
997        let info: ExtensionInfo = match plugin.call::<&str, &str>("init", "{}") {
998            Ok(output) => serde_json::from_str(output)
999                .with_context(|| format!("init() returned invalid JSON: {}", output))?,
1000            Err(_) => {
1001                // No init function — derive name from filename
1002                let name = path
1003                    .file_stem()
1004                    .and_then(|s| s.to_str())
1005                    .unwrap_or("unknown")
1006                    .to_string();
1007                ExtensionInfo {
1008                    name,
1009                    version: "0.0.0".to_string(),
1010                    description: String::new(),
1011                    permissions: vec![],
1012                }
1013            }
1014        };
1015        // Surface requested permissions and warn on privileged asks. `exec`
1016        // stays call-time gated by OXI_EXTENSION_EXEC (see host_oxi_exec).
1017        if !info.permissions.is_empty() {
1018            tracing::info!(name = %info.name, perms = ?info.permissions, "extension permissions requested");
1019        }
1020        if info.permissions.iter().any(|p| p == "exec") {
1021            tracing::warn!(
1022                name = %info.name,
1023                "extension requested 'exec' permission — oxi_exec stays disabled unless OXI_EXTENSION_EXEC=1 is set"
1024            );
1025        }
1026
1027        // Set extension context so KV operations during register_tools/register_commands
1028        // are properly namespaced
1029        let ext_name_for_ctx = info.name.clone();
1030        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name_for_ctx));
1031        let tools: Vec<WasmToolDef> = match plugin.call::<&str, &str>("register_tools", "{}") {
1032            Ok(output) => {
1033                let resp: Value = serde_json::from_str(output)
1034                    .with_context(|| format!("register_tools() invalid JSON: {}", output))?;
1035                resp.get("tools")
1036                    .cloned()
1037                    .unwrap_or(Value::Array(vec![]))
1038                    .as_array()
1039                    .map(|arr| {
1040                        arr.iter()
1041                            .filter_map(|v| serde_json::from_value(v.clone()).ok())
1042                            .collect()
1043                    })
1044                    .unwrap_or_default()
1045            }
1046            Err(_) => vec![], // No tools — event-only extension
1047        };
1048
1049        // Call register_commands() — optional
1050        let commands: Vec<WasmCommandDef> =
1051            match plugin.call::<&str, &str>("register_commands", "{}") {
1052                Ok(output) => {
1053                    let resp: Value = serde_json::from_str(output)
1054                        .with_context(|| format!("register_commands() invalid JSON: {}", output))?;
1055                    resp.get("commands")
1056                        .cloned()
1057                        .unwrap_or(Value::Array(vec![]))
1058                        .as_array()
1059                        .map(|arr| {
1060                            arr.iter()
1061                                .filter_map(|v| serde_json::from_value(v.clone()).ok())
1062                                .collect()
1063                        })
1064                        .unwrap_or_default()
1065                }
1066                Err(_) => vec![], // No commands
1067            };
1068
1069        // Clear extension context set during register_tools/register_commands
1070        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1071
1072        let ext_name = info.name.clone();
1073
1074        // Warn on name collision — clean up old extension fully
1075        if self.extensions.contains_key(&ext_name) {
1076            tracing::warn!(
1077                "Extension '{}' already loaded, replacing with '{}'",
1078                ext_name,
1079                path_display
1080            );
1081            // Remove old tool mappings
1082            self.tool_to_ext.retain(|_, v| v != &ext_name);
1083            // Remove old plugin instance
1084            self.plugins.lock().remove(&ext_name);
1085        }
1086
1087        for tool in &tools {
1088            self.tool_to_ext.insert(tool.name.clone(), ext_name.clone());
1089        }
1090
1091        let loaded = LoadedWasmExtension {
1092            info: info.clone(),
1093            tools,
1094            commands,
1095            source_path: path.to_path_buf(),
1096        };
1097
1098        self.extensions.insert(ext_name.clone(), loaded);
1099        self.plugins.lock().insert(ext_name, plugin);
1100
1101        tracing::info!(
1102            name = %info.name,
1103            version = %info.version,
1104            tools = self.tool_to_ext.len(),
1105            "WASM extension loaded"
1106        );
1107
1108        Ok(info)
1109    }
1110
1111    /// Load all discovered extensions, collecting errors.
1112    pub fn load_all(&mut self, paths: &[PathBuf]) -> (Vec<ExtensionInfo>, Vec<anyhow::Error>) {
1113        let mut loaded = Vec::new();
1114        let mut errors = Vec::new();
1115
1116        for path in paths {
1117            match self.load(path) {
1118                Ok(info) => loaded.push(info),
1119                Err(e) => {
1120                    tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
1121                    errors.push(e);
1122                }
1123            }
1124        }
1125
1126        (loaded, errors)
1127    }
1128
1129    // ── Execution ──────────────────────────────────────────────────
1130
1131    /// Execute a tool via the WASM extension.
1132    pub fn execute_tool(&self, tool_name: &str, params: Value) -> Result<Value> {
1133        let ext_name = self
1134            .tool_to_ext
1135            .get(tool_name)
1136            .with_context(|| format!("No extension registered for tool: {}", tool_name))?
1137            .clone();
1138
1139        let mut plugins = self.plugins.lock();
1140        let plugin = plugins
1141            .get_mut(&ext_name)
1142            .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1143
1144        let input = serde_json::json!({
1145            "tool": tool_name,
1146            "params": params,
1147        });
1148        let input_str = serde_json::to_string(&input)?;
1149
1150        // Set thread-local extension context so KV host functions can namespace keys
1151        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1152        let call_result = plugin.call("execute_tool", &input_str);
1153        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1154
1155        let output: &str = call_result
1156            .with_context(|| format!("execute_tool('{}') failed in '{}'", tool_name, ext_name))?;
1157
1158        let result: Value = serde_json::from_str(output)
1159            .with_context(|| format!("execute_tool() returned invalid JSON: {}", output))?;
1160
1161        Ok(result)
1162    }
1163
1164    // ── Accessors ──────────────────────────────────────────────────
1165
1166    /// Get all tool definitions from all loaded extensions.
1167    pub fn all_tool_defs(&self) -> Vec<&WasmToolDef> {
1168        self.extensions
1169            .values()
1170            .flat_map(|e| e.tools.iter())
1171            .collect()
1172    }
1173
1174    /// Check if a tool name belongs to a WASM extension.
1175    pub fn is_wasm_tool(&self, tool_name: &str) -> bool {
1176        self.tool_to_ext.contains_key(tool_name)
1177    }
1178
1179    /// List loaded extension names.
1180    pub fn extension_names(&self) -> impl Iterator<Item = &str> {
1181        self.extensions.keys().map(|s| s.as_str())
1182    }
1183
1184    /// Get extension info by name.
1185    pub fn get_info(&self, name: &str) -> Option<&ExtensionInfo> {
1186        self.extensions.get(name).map(|e| &e.info)
1187    }
1188
1189    /// Number of loaded extensions.
1190    pub fn len(&self) -> usize {
1191        self.extensions.len()
1192    }
1193
1194    /// Whether any extensions are loaded.
1195    pub fn is_empty(&self) -> bool {
1196        self.extensions.is_empty()
1197    }
1198
1199    // ── Commands ──────────────────────────────────────────────────
1200
1201    /// Get all command definitions from all loaded extensions.
1202    pub fn all_command_defs(&self) -> Vec<(&str, &WasmCommandDef)> {
1203        let mut cmds = Vec::new();
1204        for ext in self.extensions.values() {
1205            for cmd in &ext.commands {
1206                cmds.push((ext.info.name.as_str(), cmd));
1207            }
1208        }
1209        cmds
1210    }
1211
1212    /// Execute a command via the WASM extension.
1213    /// Returns the output text to display to the user.
1214    pub fn execute_command(&self, command_name: &str, args: &str) -> Result<String> {
1215        // Find which extension owns this command
1216        let ext_name = self
1217            .extensions
1218            .iter()
1219            .find(|(_, ext)| ext.commands.iter().any(|c| c.name == command_name))
1220            .map(|(name, _)| name.clone())
1221            .with_context(|| format!("No extension registered for command: /{}", command_name))?;
1222
1223        let mut plugins = self.plugins.lock();
1224        let plugin = plugins
1225            .get_mut(&ext_name)
1226            .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1227
1228        let input = serde_json::json!({
1229            "command": command_name,
1230            "args": args,
1231        });
1232        let input_str = serde_json::to_string(&input)?;
1233
1234        let output: &str = {
1235            CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1236            let result = plugin.call("execute_command", &input_str);
1237            CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1238            result
1239        }
1240        .with_context(|| {
1241            format!(
1242                "execute_command('/{}') failed in '{}'",
1243                command_name, ext_name
1244            )
1245        })?;
1246
1247        // Parse response — extension returns {"output": "..."} or plain string
1248        let result: Value =
1249            serde_json::from_str(output).unwrap_or_else(|_| serde_json::json!({"output": output}));
1250
1251        Ok(result
1252            .get("output")
1253            .and_then(|v| v.as_str())
1254            .unwrap_or(output)
1255            .to_string())
1256    }
1257}
1258
1259// ── Tests ────────────────────────────────────────────────────────────
1260
1261#[cfg(test)]
1262mod tests {
1263    use super::*;
1264
1265    #[test]
1266    fn test_discover_empty_dir() {
1267        let dir = tempfile::tempdir().unwrap();
1268        let paths = WasmExtensionManager::discover(dir.path());
1269        assert!(paths.is_empty());
1270    }
1271
1272    #[test]
1273    fn test_discover_finds_wasm_files() {
1274        let dir = tempfile::tempdir().unwrap();
1275        let wasm_path = dir.path().join("test_ext.wasm");
1276        std::fs::write(&wasm_path, b"\x00asm").unwrap();
1277        // Create a non-wasm file that should be ignored
1278        std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
1279
1280        let mut paths = Vec::new();
1281        WasmExtensionManager::discover_in_dir(dir.path(), &mut paths);
1282        assert_eq!(paths.len(), 1);
1283        assert!(paths[0].ends_with("test_ext.wasm"));
1284    }
1285
1286    #[test]
1287    fn test_extension_info_parse() {
1288        let json = r#"{"name":"my_ext","version":"1.0.0","description":"Test"}"#;
1289        let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1290        assert_eq!(info.name, "my_ext");
1291        assert_eq!(info.version, "1.0.0");
1292    }
1293
1294    #[test]
1295    fn test_tool_def_parse() {
1296        let json = r#"{"name":"search","description":"Search","schema":{"type":"object"}}"#;
1297        let tool: WasmToolDef = serde_json::from_str(json).unwrap();
1298        assert_eq!(tool.name, "search");
1299    }
1300
1301    #[test]
1302    fn test_manager_new_is_empty() {
1303        let mgr = WasmExtensionManager::new();
1304        assert!(mgr.is_empty());
1305        assert_eq!(mgr.len(), 0);
1306    }
1307
1308    #[test]
1309    fn test_is_wasm_tool_false() {
1310        let mgr = WasmExtensionManager::new();
1311        assert!(!mgr.is_wasm_tool("anything"));
1312    }
1313
1314    #[test]
1315    fn test_extension_info_default_description() {
1316        let json = r#"{"name":"test","version":"0.1"}"#;
1317        let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1318        assert_eq!(info.description, "");
1319    }
1320
1321    #[test]
1322    fn test_ssrf_blocks_localhost() {
1323        assert!(validate_url("http://localhost/admin").is_err());
1324        assert!(validate_url("http://127.0.0.1/secret").is_err());
1325        assert!(validate_url("http://10.0.0.1/internal").is_err());
1326        assert!(validate_url("http://192.168.1.1/router").is_err());
1327        assert!(validate_url("http://172.16.0.1/corp").is_err());
1328        assert!(validate_url("http://169.254.169.254/metadata").is_err());
1329        assert!(validate_url("http://[::1]/ipv6").is_err());
1330        // Also test without brackets (parsed hostname)
1331        assert!(validate_url("http://0.0.0.0/admin").is_err());
1332    }
1333
1334    #[test]
1335    fn test_ssrf_allows_public() {
1336        assert!(validate_url("https://api.github.com/repos/test").is_ok());
1337        assert!(validate_url("https://example.com/api").is_ok());
1338        assert!(validate_url("https://search.brave.com/api/search?q=test").is_ok());
1339    }
1340
1341    #[test]
1342    fn test_ssrf_172_range() {
1343        assert!(validate_url("http://172.16.0.1/test").is_err());
1344        assert!(validate_url("http://172.31.255.255/test").is_err());
1345        assert!(validate_url("http://172.15.0.1/test").is_ok());
1346        assert!(validate_url("http://172.32.0.1/test").is_ok());
1347    }
1348}