Skip to main content

oxi/extensions/
wasm.rs

1//! WASM-based extension system powered by Extism.
2//!
3//! Loads `.wasm` extension files from `~/.oxi/extensions/` and project-local
4//! `.oxi/extensions/`. Each extension exports well-known functions (`init`,
5//! `register_tools`, `execute_tool`) called via Extism's JSON-in/JSON-out
6//! protocol. Extensions run inside a WASM sandbox with zero host access by
7//! default — HTTP access is granted via the `oxi_http_request` host function.
8
9use anyhow::{Context, Result};
10use extism::{CurrentPlugin, Function, UserData, Val, PTR};
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::cell::RefCell;
15use std::collections::HashMap;
16use std::io::Read as _;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20
21// ── Types ────────────────────────────────────────────────────────────
22
23/// Metadata returned by an extension's `init()` function.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ExtensionInfo {
26    pub name: String,
27    pub version: String,
28    #[serde(default)]
29    pub description: String,
30    /// Permissions requested by the extension.
31    /// Supported: "fs_read", "fs_write", "exec", "env", "network"
32    #[serde(default)]
33    #[allow(dead_code)]
34    pub permissions: Vec<String>,
35}
36
37/// A tool definition returned by `register_tools()`.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct WasmToolDef {
40    pub name: String,
41    pub description: String,
42    pub schema: Value,
43}
44
45/// A command definition returned by `register_commands()`.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WasmCommandDef {
48    pub name: String,
49    pub description: String,
50}
51
52/// Result of loading a WASM extension.
53#[derive(Debug)]
54pub struct LoadedWasmExtension {
55    pub info: ExtensionInfo,
56    pub tools: Vec<WasmToolDef>,
57    pub commands: Vec<WasmCommandDef>,
58    pub source_path: PathBuf,
59}
60
61// ── Host Functions ────────────────────────────────────────────────────
62
63/// Host function: `oxi_http_request(request_json) -> response_json`
64///
65/// Request JSON: `{"url": "...", "method": "GET", "headers": {...}, "body": "..."}`
66/// Response JSON: `{"status": 200, "headers": {...}, "body": "..."}`
67fn host_oxi_http_request(
68    plugin: &mut CurrentPlugin,
69    inputs: &[Val],
70    outputs: &mut [Val],
71    user_data: UserData<Arc<reqwest::blocking::Client>>,
72) -> Result<(), extism::Error> {
73    // We use anyhow internally, then convert at the boundary
74    let result: anyhow::Result<()> = (|| {
75        let input_json: String = plugin.memory_get_val(&inputs[0])?;
76
77        #[derive(Deserialize)]
78        struct HttpReq {
79            url: String,
80            #[serde(default)]
81            method: String,
82            #[serde(default)]
83            headers: HashMap<String, String>,
84            #[serde(default)]
85            body: Option<String>,
86        }
87
88        let req: HttpReq =
89            serde_json::from_str(&input_json).context("oxi_http_request: invalid request JSON")?;
90
91        let method = if req.method.is_empty() {
92            "GET"
93        } else {
94            &req.method
95        };
96
97        // SSRF protection: block internal/private network addresses
98        if let Err(e) = validate_url(&req.url) {
99            anyhow::bail!("oxi_http_request: {}", e);
100        }
101
102        // UserData::get() returns Arc<Mutex<T>>
103        let client_arc = user_data.get()?;
104        let client = client_arc.lock().expect("wasm client lock poisoned");
105
106        let method = match method.to_uppercase().as_str() {
107            "GET" => reqwest::Method::GET,
108            "POST" => reqwest::Method::POST,
109            "PUT" => reqwest::Method::PUT,
110            "DELETE" => reqwest::Method::DELETE,
111            "PATCH" => reqwest::Method::PATCH,
112            "HEAD" => reqwest::Method::HEAD,
113            other => anyhow::bail!("oxi_http_request: unsupported method '{}'", other),
114        };
115
116        let mut rb = client.request(method, &req.url);
117        for (k, v) in &req.headers {
118            rb = rb.header(k, v);
119        }
120        if let Some(body) = &req.body {
121            rb = rb.body(body.clone());
122        }
123
124        // Execute HTTP request (blocking — called from spawn_blocking in wasm_tool.rs)
125        let resp = rb
126            .send()
127            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
128        let status = resp.status().as_u16();
129        let resp_headers: HashMap<String, String> = resp
130            .headers()
131            .iter()
132            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
133            .collect();
134        // Limit response body to 1MB to prevent memory exhaustion
135        let resp_body = {
136            let max_body = 1024 * 1024; // 1MB
137            let body_bytes = resp
138                .bytes()
139                .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
140            if body_bytes.len() > max_body {
141                tracing::warn!(
142                    "HTTP response truncated: {} bytes > {} limit",
143                    body_bytes.len(),
144                    max_body
145                );
146                String::from_utf8_lossy(&body_bytes[..max_body]).to_string()
147            } else {
148                String::from_utf8_lossy(&body_bytes).to_string()
149            }
150        };
151
152        let response = serde_json::json!({
153            "status": status,
154            "headers": resp_headers,
155            "body": resp_body,
156        });
157
158        let output = serde_json::to_string(&response)?;
159        let handle = plugin.memory_new(&output)?;
160        if !outputs.is_empty() {
161            outputs[0] = plugin.memory_to_val(handle);
162        }
163        Ok(())
164    })();
165
166    result
167}
168
169/// Host function: `oxi_log(message)` — logs a debug message from WASM.
170fn host_oxi_log(
171    plugin: &mut CurrentPlugin,
172    inputs: &[Val],
173    _outputs: &mut [Val],
174    _user_data: UserData<()>,
175) -> Result<(), extism::Error> {
176    let message: String = plugin.memory_get_val(&inputs[0])?;
177    tracing::debug!("[WASM] {}", message);
178    Ok(())
179}
180
181// ── File I/O Host Functions ─────────────────────────────────────────
182
183/// Host function: `oxi_read_file(path_json) → result_json`
184///
185/// Input: `{"path": "/path/to/file", "offset": 0, "limit": 2000}`
186/// Output: `{"success": true, "content": "...", "truncated": false, "bytes": 1234}` or
187///         `{"success": false, "error": "..."}`
188fn host_oxi_read_file(
189    plugin: &mut CurrentPlugin,
190    inputs: &[Val],
191    outputs: &mut [Val],
192    _user_data: UserData<()>,
193) -> Result<(), extism::Error> {
194    let result: anyhow::Result<()> = (|| {
195        let input_json: String = plugin.memory_get_val(&inputs[0])?;
196
197        #[derive(Deserialize)]
198        struct ReadReq {
199            path: String,
200            #[serde(default)]
201            offset: Option<usize>,
202            #[serde(default = "default_limit")]
203            limit: usize,
204        }
205        fn default_limit() -> usize {
206            2000
207        }
208
209        let req: ReadReq =
210            serde_json::from_str(&input_json).context("oxi_read_file: invalid request JSON")?;
211
212        // Validate path is not outside cwd
213        validate_path_allowed(&req.path)?;
214
215        let metadata = std::fs::metadata(&req.path);
216        match metadata {
217            Ok(m) => {
218                let max_bytes = 50 * 1024; // 50KB limit
219                let file_size = m.len() as usize;
220
221                let content = std::fs::read_to_string(&req.path)
222                    .map_err(|e| anyhow::anyhow!("Failed to read file: {}", e))?;
223
224                let lines: Vec<&str> = content.lines().collect();
225                let total_lines = lines.len();
226
227                let offset = req.offset.unwrap_or(0).min(total_lines);
228                let end = (offset + req.limit).min(total_lines);
229                let selected: Vec<&str> = lines[offset..end].to_vec();
230                let mut result = selected.join("\n");
231
232                let truncated = result.len() > max_bytes;
233                if truncated {
234                    result = result.chars().take(max_bytes).collect();
235                }
236
237                let response = serde_json::json!({
238                    "success": true,
239                    "content": result,
240                    "truncated": truncated || end < total_lines,
241                    "bytes": file_size,
242                    "total_lines": total_lines,
243                    "shown_lines": end - offset,
244                });
245                let output = serde_json::to_string(&response)?;
246                let handle = plugin.memory_new(&output)?;
247                if !outputs.is_empty() {
248                    outputs[0] = plugin.memory_to_val(handle);
249                }
250            }
251            Err(e) => {
252                let response = serde_json::json!({
253                    "success": false,
254                    "error": format!("File not found: {}", e),
255                });
256                let output = serde_json::to_string(&response)?;
257                let handle = plugin.memory_new(&output)?;
258                if !outputs.is_empty() {
259                    outputs[0] = plugin.memory_to_val(handle);
260                }
261            }
262        }
263        Ok(())
264    })();
265    result
266}
267
268/// Host function: `oxi_write_file(path_json) → result_json`
269///
270/// Input: `{"path": "/path/to/file", "content": "...", "create_dirs": true}`
271/// Output: `{"success": true, "bytes_written": 1234}` or `{"success": false, "error": "..."}`
272fn host_oxi_write_file(
273    plugin: &mut CurrentPlugin,
274    inputs: &[Val],
275    outputs: &mut [Val],
276    _user_data: UserData<()>,
277) -> Result<(), extism::Error> {
278    let result: anyhow::Result<()> = (|| {
279        let input_json: String = plugin.memory_get_val(&inputs[0])?;
280
281        #[derive(Deserialize)]
282        struct WriteReq {
283            path: String,
284            content: String,
285            #[serde(default = "default_true")]
286            create_dirs: bool,
287        }
288        fn default_true() -> bool {
289            true
290        }
291
292        let req: WriteReq =
293            serde_json::from_str(&input_json).context("oxi_write_file: invalid request JSON")?;
294
295        validate_path_allowed(&req.path)?;
296
297        if req.create_dirs {
298            if let Some(parent) = std::path::Path::new(&req.path).parent() {
299                std::fs::create_dir_all(parent)
300                    .map_err(|e| anyhow::anyhow!("Failed to create directories: {}", e))?;
301            }
302        }
303
304        let bytes = req.content.len();
305        std::fs::write(&req.path, &req.content)
306            .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
307
308        let response = serde_json::json!({
309            "success": true,
310            "bytes_written": bytes,
311        });
312        let output = serde_json::to_string(&response)?;
313        let handle = plugin.memory_new(&output)?;
314        if !outputs.is_empty() {
315            outputs[0] = plugin.memory_to_val(handle);
316        }
317        Ok(())
318    })();
319    result
320}
321
322/// Host function: `oxi_exec(exec_json) → result_json`
323///
324/// Input: `{"command": "git status", "args": [], "cwd": ".", "timeout": 30}`
325/// Output: `{"success": true, "stdout": "...", "stderr": "...", "exit_code": 0}` or error.
326fn host_oxi_exec(
327    plugin: &mut CurrentPlugin,
328    inputs: &[Val],
329    outputs: &mut [Val],
330    _user_data: UserData<()>,
331) -> Result<(), extism::Error> {
332    let result: anyhow::Result<()> = (|| {
333        let input_json: String = plugin.memory_get_val(&inputs[0])?;
334
335        #[derive(Deserialize)]
336        struct ExecReq {
337            command: String,
338            #[serde(default)]
339            args: Vec<String>,
340            #[serde(default)]
341            cwd: Option<String>,
342            #[serde(default = "default_timeout")]
343            timeout: u64,
344        }
345        fn default_timeout() -> u64 {
346            30
347        }
348
349        let req: ExecReq =
350            serde_json::from_str(&input_json).context("oxi_exec: invalid request JSON")?;
351
352        let cwd = req.cwd.as_deref().unwrap_or(".");
353
354        // Build full command string for deny-list checking (command + args)
355        let full_cmd = if req.args.is_empty() {
356            req.command.clone()
357        } else {
358            format!("{} {}", req.command, req.args.join(" "))
359        };
360
361        // Block dangerous commands — deny-list (checks combined command+args)
362        let blocked_patterns = [
363            "rm -rf /",
364            "rm -rf /*",
365            "mkfs",
366            "dd if=",
367            "format ",
368            ":(){ :|:& };:",
369            "chmod 777 /",
370            "chown root",
371            "> /etc/",
372            "> /boot/",
373            "> /dev/",
374            "dd of=/dev/",
375            "mv / /",
376        ];
377        for blocked in &blocked_patterns {
378            if full_cmd.contains(blocked) {
379                anyhow::bail!("oxi_exec: blocked dangerous command pattern");
380            }
381        }
382
383        // Block obvious privilege escalation
384        let cmd_lower = req.command.to_lowercase();
385        if cmd_lower == "sudo"
386            || cmd_lower == "su"
387            || cmd_lower == "doas"
388            || cmd_lower.starts_with("sudo ")
389            || cmd_lower.starts_with("su ")
390            || cmd_lower.starts_with("doas ")
391        {
392            anyhow::bail!("oxi_exec: privilege escalation commands are blocked");
393        }
394
395        // Clamp timeout to 1-30 seconds to prevent abuse
396        let timeout_ms = req.timeout.clamp(1000, 30000);
397        let timeout_dur = Duration::from_millis(timeout_ms);
398
399        // Spawn child process with piped stdout/stderr
400        let mut child = match std::process::Command::new(&req.command)
401            .args(&req.args)
402            .current_dir(cwd)
403            .stdout(std::process::Stdio::piped())
404            .stderr(std::process::Stdio::piped())
405            .spawn()
406        {
407            Ok(c) => c,
408            Err(e) => {
409                let response = serde_json::json!({
410                    "success": false,
411                    "error": format!("Failed to execute: {}", e),
412                    "exit_code": -1,
413                });
414                let out = serde_json::to_string(&response)?;
415                let handle = plugin.memory_new(&out)?;
416                if !outputs.is_empty() {
417                    outputs[0] = plugin.memory_to_val(handle);
418                }
419                return Ok(());
420            }
421        };
422
423        // Poll for completion with timeout enforcement
424        let start = Instant::now();
425        let mut timed_out = false;
426        let mut exit_status: Option<std::process::ExitStatus> = None;
427
428        loop {
429            match child.try_wait() {
430                Ok(Some(status)) => {
431                    exit_status = Some(status);
432                    break;
433                }
434                Ok(None) => {
435                    if start.elapsed() >= timeout_dur {
436                        // Timeout reached — kill the process and clean up
437                        tracing::warn!(
438                            "oxi_exec: command '{}' timed out after {}ms",
439                            req.command,
440                            timeout_ms
441                        );
442                        let _ = child.kill();
443                        let _ = child.wait(); // reap zombie
444                        timed_out = true;
445                        break;
446                    }
447                    std::thread::sleep(Duration::from_millis(50));
448                }
449                Err(_) => {
450                    // try_wait failed — fall back to blocking wait
451                    match child.wait() {
452                        Ok(status) => {
453                            exit_status = Some(status);
454                        }
455                        Err(_) => {
456                            timed_out = true;
457                        }
458                    }
459                    break;
460                }
461            }
462        }
463
464        // Collect stdout/stderr from the child pipes (read what was buffered)
465        let mut stdout_buf = Vec::new();
466        let mut stderr_buf = Vec::new();
467        if let Some(mut out) = child.stdout.take() {
468            let _ = out.read_to_end(&mut stdout_buf);
469        }
470        if let Some(mut err) = child.stderr.take() {
471            let _ = err.read_to_end(&mut stderr_buf);
472        }
473
474        let stdout = String::from_utf8_lossy(&stdout_buf);
475        let stderr = String::from_utf8_lossy(&stderr_buf);
476        let max_output = 50 * 1024; // 50KB
477        let stdout_truncated = stdout.len() > max_output;
478        let stderr_truncated = stderr.len() > max_output;
479        let stdout_str: String = if stdout_truncated {
480            stdout.chars().take(max_output).collect()
481        } else {
482            stdout.to_string()
483        };
484        let stderr_str: String = if stderr_truncated {
485            stderr.chars().take(max_output).collect()
486        } else {
487            stderr.to_string()
488        };
489
490        let response = serde_json::json!({
491            "success": !timed_out && exit_status.map(|s| s.success()).unwrap_or(false),
492            "stdout": stdout_str,
493            "stderr": stderr_str,
494            "exit_code": if timed_out { -2 } else { exit_status.and_then(|s| s.code()).unwrap_or(-1) },
495            "stdout_truncated": stdout_truncated,
496            "stderr_truncated": stderr_truncated,
497            "timed_out": timed_out,
498        });
499        let out = serde_json::to_string(&response)?;
500        let handle = plugin.memory_new(&out)?;
501        if !outputs.is_empty() {
502            outputs[0] = plugin.memory_to_val(handle);
503        }
504        Ok(())
505    })();
506    result
507}
508
509/// Host function: `oxi_get_env(key_json) → result_json`
510///
511/// Input: `{"key": "HOME"}`
512/// Output: `{"success": true, "value": "/home/user"}` or `{"success": false, "error": "not found"}`
513fn host_oxi_get_env(
514    plugin: &mut CurrentPlugin,
515    inputs: &[Val],
516    outputs: &mut [Val],
517    _user_data: UserData<()>,
518) -> Result<(), extism::Error> {
519    let result: anyhow::Result<()> = (|| {
520        let input_json: String = plugin.memory_get_val(&inputs[0])?;
521
522        #[derive(Deserialize)]
523        struct EnvReq {
524            key: String,
525        }
526
527        let req: EnvReq =
528            serde_json::from_str(&input_json).context("oxi_get_env: invalid request JSON")?;
529
530        // Block sensitive env vars
531        let blocked_keys = ["AWS_SECRET", "PRIVATE_KEY", "PASSWORD", "TOKEN", "SECRET"];
532        let key_upper = req.key.to_uppercase();
533        for blocked in &blocked_keys {
534            if key_upper.contains(blocked) {
535                anyhow::bail!("oxi_get_env: access to '{}' is blocked", req.key);
536            }
537        }
538
539        let value = std::env::var(&req.key).ok();
540        let response = serde_json::json!({
541            "success": value.is_some(),
542            "value": value.unwrap_or_default(),
543        });
544        let output = serde_json::to_string(&response)?;
545        let handle = plugin.memory_new(&output)?;
546        if !outputs.is_empty() {
547            outputs[0] = plugin.memory_to_val(handle);
548        }
549        Ok(())
550    })();
551    result
552}
553
554// ── KV Store Host Functions ─────────────────────────────────────────
555
556/// Host function: `oxi_kv_get(key_json) → result_json`
557///
558/// Persistent key-value store for extension state.
559/// Keys are namespaced per extension.
560/// Input: `{"key": "my_state"}`
561/// Output: `{"success": true, "value": "..."}` or `{"success": false}`
562fn host_oxi_kv_get(
563    plugin: &mut CurrentPlugin,
564    inputs: &[Val],
565    outputs: &mut [Val],
566    _user_data: UserData<()>,
567) -> Result<(), extism::Error> {
568    let result: anyhow::Result<()> = (|| {
569        let input_json: String = plugin.memory_get_val(&inputs[0])?;
570
571        #[derive(Deserialize)]
572        struct KvReq {
573            key: String,
574        }
575
576        let req: KvReq =
577            serde_json::from_str(&input_json).context("oxi_kv_get: invalid request JSON")?;
578
579        // Namespace the key with the current extension identity
580        let ext_name = current_extension_name();
581        let value = kv_namespaced_get(&ext_name, &req.key);
582        let response = serde_json::json!({
583            "success": value.is_some(),
584            "value": value.unwrap_or_default(),
585        });
586        let output = serde_json::to_string(&response)?;
587        let handle = plugin.memory_new(&output)?;
588        if !outputs.is_empty() {
589            outputs[0] = plugin.memory_to_val(handle);
590        }
591        Ok(())
592    })();
593    result
594}
595
596/// Host function: `oxi_kv_set(set_json)`
597///
598/// Input: `{"key": "my_state", "value": "saved_data"}`
599fn host_oxi_kv_set(
600    plugin: &mut CurrentPlugin,
601    inputs: &[Val],
602    _outputs: &mut [Val],
603    _user_data: UserData<()>,
604) -> Result<(), extism::Error> {
605    let result: anyhow::Result<()> = (|| {
606        let input_json: String = plugin.memory_get_val(&inputs[0])?;
607
608        #[derive(Deserialize)]
609        struct KvSetReq {
610            key: String,
611            value: String,
612        }
613
614        let req: KvSetReq =
615            serde_json::from_str(&input_json).context("oxi_kv_set: invalid request JSON")?;
616
617        // Namespace the key with the current extension identity
618        let ext_name = current_extension_name();
619        kv_namespaced_set(&ext_name, &req.key, &req.value);
620        Ok(())
621    })();
622    result
623}
624
625// ── KV Store Implementation ─────────────────────────────────────────
626
627use std::sync::LazyLock;
628
629static KV_STORE: LazyLock<parking_lot::RwLock<HashMap<String, String>>> =
630    LazyLock::new(|| parking_lot::RwLock::new(HashMap::new()));
631
632// Thread-local tracking the currently executing extension name.
633// Set by `execute_tool` / `execute_command` / `load` before invoking plugin
634// calls, read by KV host functions to namespace keys.
635thread_local! {
636    static CURRENT_EXTENSION: RefCell<Option<String>> = const { RefCell::new(None) };
637}
638
639/// Run a closure with the current extension name set in thread-local storage.
640#[allow(dead_code)]
641fn with_extension_context<F, R>(ext_name: &str, f: F) -> R
642where
643    F: FnOnce() -> R,
644{
645    CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.to_string()));
646    let result = f();
647    CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
648    result
649}
650
651/// Get the current extension name from thread-local storage.
652/// Returns `"__unknown__"` if not set (e.g., during `load()` before `init()`).
653fn current_extension_name() -> String {
654    CURRENT_EXTENSION.with(|cell| {
655        cell.borrow()
656            .clone()
657            .unwrap_or_else(|| "__unknown__".to_string())
658    })
659}
660
661fn kv_store_get(key: &str) -> Option<String> {
662    KV_STORE.read().get(key).cloned()
663}
664
665fn kv_store_set(key: &str, value: &str) {
666    KV_STORE.write().insert(key.to_string(), value.to_string());
667}
668
669/// Namespaced KV access — prefixes key with `extension:` to prevent
670/// cross-extension key collision.
671fn kv_namespaced_get(extension: &str, key: &str) -> Option<String> {
672    let namespaced = format!("{}:{}", extension, key);
673    kv_store_get(&namespaced)
674}
675
676fn kv_namespaced_set(extension: &str, key: &str, value: &str) {
677    let namespaced = format!("{}:{}", extension, key);
678    kv_store_set(&namespaced, value);
679}
680
681// ── Path Validation ─────────────────────────────────────────────────
682
683/// Validate that a path is within the allowed directories.
684/// Extensions can only read/write outside protected system paths.
685fn validate_path_allowed(path: &str) -> Result<()> {
686    let p = std::path::Path::new(path);
687
688    // Resolve to absolute, then canonicalize to eliminate ../ and symlinks
689    let abs = if p.is_absolute() {
690        p.to_path_buf()
691    } else {
692        std::env::current_dir().unwrap_or_default().join(p)
693    };
694
695    // Try canonicalize for existing paths; fall back to resolved absolute
696    let resolved = if abs.exists() {
697        abs.canonicalize().unwrap_or(abs)
698    } else {
699        // For new files, resolve parent if it exists
700        if let Some(parent) = abs.parent() {
701            if parent.exists() {
702                let canon_parent = parent
703                    .canonicalize()
704                    .unwrap_or_else(|_| parent.to_path_buf());
705                canon_parent.join(abs.file_name().unwrap_or_default())
706            } else {
707                abs
708            }
709        } else {
710            abs
711        }
712    };
713
714    let abs_str = resolved.to_string_lossy();
715
716    // Block sensitive system paths
717    let blocked_prefixes = [
718        "/etc",
719        "/sys",
720        "/proc",
721        "/dev",
722        "/boot",
723        "/root",
724        "/System",
725        "/Library/System",
726        "/usr/bin",
727        "/usr/sbin",
728        "/bin",
729        "/sbin",
730    ];
731    for prefix in &blocked_prefixes {
732        if abs_str.starts_with(prefix) {
733            anyhow::bail!("Path '{}' is in a protected system directory", path);
734        }
735    }
736
737    // Block hidden sensitive paths in home directory
738    if let Some(home) = dirs::home_dir() {
739        let home_str = home.to_string_lossy();
740        if abs_str.starts_with(&*home_str) {
741            let blocked_home_suffixes = [
742                "/.ssh/",
743                "/.gnupg/",
744                "/.aws/",
745                "/.config/gcloud/",
746                "/.kube/",
747                "/.docker/",
748                "/.npmrc",
749                "/.netrc",
750            ];
751            for suffix in &blocked_home_suffixes {
752                if abs_str.contains(suffix) {
753                    anyhow::bail!("Path '{}' is in a protected directory", path);
754                }
755            }
756        }
757    }
758
759    Ok(())
760}
761
762// ── SSRF Protection ─────────────────────────────────────────────────
763
764/// Validate that a URL does not target internal/private network addresses.
765fn validate_url(url: &str) -> Result<(), String> {
766    let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
767    let host = parsed.host_str().unwrap_or("").to_lowercase();
768
769    // Block private IPs, localhost, link-local, and internal hostnames
770    let blocked = [
771        "localhost",
772        "127.0.0.1",
773        "0.0.0.0",
774        "::1",
775        "[::1]",
776        "169.254.169.254", // cloud metadata
777        "metadata.google.internal",
778    ];
779    for &b in &blocked {
780        if host == b || host.starts_with(b) {
781            return Err(format!("Blocked internal address: {}", host));
782        }
783    }
784
785    // Block private IP ranges (10.x, 172.16-31.x, 192.168.x)
786    if host.starts_with("10.") || host.starts_with("192.168.") || is_172_private(&host) {
787        return Err(format!("Blocked private address: {}", host));
788    }
789
790    Ok(())
791}
792
793/// Check if host is in 172.16.0.0/12 range.
794fn is_172_private(host: &str) -> bool {
795    if !host.starts_with("172.") {
796        return false;
797    }
798    let parts: Vec<&str> = host.split('.').collect();
799    if parts.len() < 2 {
800        return false;
801    }
802    if let Ok(second) = parts[1].parse::<u8>() {
803        (16..=31).contains(&second)
804    } else {
805        false
806    }
807}
808
809// ── Manager ──────────────────────────────────────────────────────────
810
811/// Manages WASM extensions: discovery, loading, and tool execution.
812///
813/// # Thread Safety
814/// `plugins` is wrapped in `Arc<Mutex<>>` for exclusive thread-safe access.
815/// All host functions run inside `spawn_blocking`, so WASM execution
816/// never blocks the async runtime.
817pub struct WasmExtensionManager {
818    extensions: HashMap<String, LoadedWasmExtension>,
819    /// Raw Extism plugin references — needed for execute_tool calls.
820    /// Wrapped in Mutex for exclusive thread-safe access (required because
821    /// extism::Plugin's Send+Sync bounds are not publicly documented).
822    /// All access goes through `plugins.lock()` which ensures mutual exclusion.
823    pub(crate) plugins: Arc<parking_lot::Mutex<HashMap<String, extism::Plugin>>>,
824    /// Maps tool name → extension name.
825    tool_to_ext: HashMap<String, String>,
826    /// HTTP client shared by all extensions for oxi_http_request.
827    http_client: Arc<reqwest::blocking::Client>,
828    /// Permissions granted to each extension (ext_name → set of permission strings).
829    #[allow(dead_code, unused)]
830    permissions: HashMap<String, std::collections::HashSet<String>>,
831}
832
833impl Default for WasmExtensionManager {
834    fn default() -> Self {
835        Self::new()
836    }
837}
838
839impl WasmExtensionManager {
840    /// Create an empty manager.
841    pub fn new() -> Self {
842        Self {
843            extensions: HashMap::new(),
844            plugins: Arc::new(Mutex::new(HashMap::new())),
845            tool_to_ext: HashMap::new(),
846            http_client: Arc::new(
847                reqwest::blocking::Client::builder()
848                    .timeout(std::time::Duration::from_secs(30))
849                    .connect_timeout(std::time::Duration::from_secs(10))
850                    .no_proxy() // Prevent proxy-based SSRF
851                    .build()
852                    .expect("Failed to build HTTP client"),
853            ),
854            permissions: HashMap::new(),
855        }
856    }
857
858    /// Create with a custom HTTP client (useful for testing with a mock client).
859    pub fn with_http_client(client: reqwest::blocking::Client) -> Self {
860        Self {
861            extensions: HashMap::new(),
862            plugins: Arc::new(Mutex::new(HashMap::new())),
863            tool_to_ext: HashMap::new(),
864            http_client: Arc::new(client),
865            permissions: HashMap::new(),
866        }
867    }
868
869    // ── Discovery ──────────────────────────────────────────────────
870
871    /// Discover `.wasm` files in standard extension directories.
872    pub fn discover(cwd: &Path) -> Vec<PathBuf> {
873        let mut paths = Vec::new();
874
875        // ~/.oxi/extensions/
876        if let Some(home) = dirs::home_dir() {
877            let dir = home.join(".oxi").join("extensions");
878            if dir.is_dir() {
879                Self::discover_in_dir(&dir, &mut paths);
880            }
881        }
882
883        // .oxi/extensions/ (project-local)
884        let local_dir = cwd.join(".oxi").join("extensions");
885        if local_dir.is_dir() {
886            Self::discover_in_dir(&local_dir, &mut paths);
887        }
888
889        paths.sort();
890        paths.dedup();
891        paths
892    }
893
894    fn discover_in_dir(dir: &Path, out: &mut Vec<PathBuf>) {
895        let Ok(entries) = std::fs::read_dir(dir) else {
896            return;
897        };
898        for entry in entries.flatten() {
899            let path = entry.path();
900            if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("wasm") {
901                out.push(path);
902            }
903        }
904    }
905
906    // ── Loading ────────────────────────────────────────────────────
907
908    /// Build the host functions available to all WASM extensions.
909    fn host_functions(http_client: &Arc<reqwest::blocking::Client>) -> Vec<Function> {
910        let http_fn = Function::new(
911            "oxi_http_request",
912            [PTR],
913            [PTR],
914            UserData::new(http_client.clone()),
915            host_oxi_http_request,
916        );
917
918        let log_fn = Function::new("oxi_log", [PTR], [], UserData::new(()), host_oxi_log);
919
920        let read_fn = Function::new(
921            "oxi_read_file",
922            [PTR],
923            [PTR],
924            UserData::new(()),
925            host_oxi_read_file,
926        );
927
928        let write_fn = Function::new(
929            "oxi_write_file",
930            [PTR],
931            [PTR],
932            UserData::new(()),
933            host_oxi_write_file,
934        );
935
936        let exec_fn = Function::new("oxi_exec", [PTR], [PTR], UserData::new(()), host_oxi_exec);
937
938        let get_env_fn = Function::new(
939            "oxi_get_env",
940            [PTR],
941            [PTR],
942            UserData::new(()),
943            host_oxi_get_env,
944        );
945
946        let kv_get_fn = Function::new(
947            "oxi_kv_get",
948            [PTR],
949            [PTR],
950            UserData::new(()),
951            host_oxi_kv_get,
952        );
953
954        let kv_set_fn = Function::new("oxi_kv_set", [PTR], [], UserData::new(()), host_oxi_kv_set);
955
956        vec![
957            http_fn, log_fn, read_fn, write_fn, exec_fn, get_env_fn, kv_get_fn, kv_set_fn,
958        ]
959    }
960
961    /// Load a single `.wasm` extension.
962    pub fn load(&mut self, path: &Path) -> Result<ExtensionInfo> {
963        let path_display = path.display().to_string();
964        tracing::info!("Loading WASM extension: {}", path_display);
965
966        let wasm_bytes = std::fs::read(path)
967            .with_context(|| format!("Failed to read extension: {}", path_display))?;
968
969        let wasm = extism::Wasm::data(wasm_bytes);
970        // Limit WASM memory to 64 pages (4MB) to prevent unbounded allocation
971        let manifest = extism::Manifest::new([wasm]).with_memory_max(64);
972        let mut plugin = extism::PluginBuilder::new(manifest)
973            .with_wasi(true)
974            .with_functions(Self::host_functions(&self.http_client))
975            .build()
976            .with_context(|| format!("Failed to create Extism plugin from {}", path_display))?;
977
978        // Call init()
979        let info: ExtensionInfo = match plugin.call::<&str, &str>("init", "{}") {
980            Ok(output) => serde_json::from_str(output)
981                .with_context(|| format!("init() returned invalid JSON: {}", output))?,
982            Err(_) => {
983                // No init function — derive name from filename
984                let name = path
985                    .file_stem()
986                    .and_then(|s| s.to_str())
987                    .unwrap_or("unknown")
988                    .to_string();
989                ExtensionInfo {
990                    name,
991                    version: "0.0.0".to_string(),
992                    description: String::new(),
993                    permissions: vec![],
994                }
995            }
996        };
997
998        // Set extension context so KV operations during register_tools/register_commands
999        // are properly namespaced
1000        let ext_name_for_ctx = info.name.clone();
1001        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name_for_ctx));
1002        let tools: Vec<WasmToolDef> = match plugin.call::<&str, &str>("register_tools", "{}") {
1003            Ok(output) => {
1004                let resp: Value = serde_json::from_str(output)
1005                    .with_context(|| format!("register_tools() invalid JSON: {}", output))?;
1006                resp.get("tools")
1007                    .cloned()
1008                    .unwrap_or(Value::Array(vec![]))
1009                    .as_array()
1010                    .map(|arr| {
1011                        arr.iter()
1012                            .filter_map(|v| serde_json::from_value(v.clone()).ok())
1013                            .collect()
1014                    })
1015                    .unwrap_or_default()
1016            }
1017            Err(_) => vec![], // No tools — event-only extension
1018        };
1019
1020        // Call register_commands() — optional
1021        let commands: Vec<WasmCommandDef> =
1022            match plugin.call::<&str, &str>("register_commands", "{}") {
1023                Ok(output) => {
1024                    let resp: Value = serde_json::from_str(output)
1025                        .with_context(|| format!("register_commands() invalid JSON: {}", output))?;
1026                    resp.get("commands")
1027                        .cloned()
1028                        .unwrap_or(Value::Array(vec![]))
1029                        .as_array()
1030                        .map(|arr| {
1031                            arr.iter()
1032                                .filter_map(|v| serde_json::from_value(v.clone()).ok())
1033                                .collect()
1034                        })
1035                        .unwrap_or_default()
1036                }
1037                Err(_) => vec![], // No commands
1038            };
1039
1040        // Clear extension context set during register_tools/register_commands
1041        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1042
1043        let ext_name = info.name.clone();
1044
1045        // Warn on name collision — clean up old extension fully
1046        if self.extensions.contains_key(&ext_name) {
1047            tracing::warn!(
1048                "Extension '{}' already loaded, replacing with '{}'",
1049                ext_name,
1050                path_display
1051            );
1052            // Remove old tool mappings
1053            self.tool_to_ext.retain(|_, v| v != &ext_name);
1054            // Remove old plugin instance
1055            self.plugins.lock().remove(&ext_name);
1056        }
1057
1058        for tool in &tools {
1059            self.tool_to_ext.insert(tool.name.clone(), ext_name.clone());
1060        }
1061
1062        let loaded = LoadedWasmExtension {
1063            info: info.clone(),
1064            tools,
1065            commands,
1066            source_path: path.to_path_buf(),
1067        };
1068
1069        self.extensions.insert(ext_name.clone(), loaded);
1070        self.plugins.lock().insert(ext_name, plugin);
1071
1072        tracing::info!(
1073            name = %info.name,
1074            version = %info.version,
1075            tools = self.tool_to_ext.len(),
1076            "WASM extension loaded"
1077        );
1078
1079        Ok(info)
1080    }
1081
1082    /// Load all discovered extensions, collecting errors.
1083    pub fn load_all(&mut self, paths: &[PathBuf]) -> (Vec<ExtensionInfo>, Vec<anyhow::Error>) {
1084        let mut loaded = Vec::new();
1085        let mut errors = Vec::new();
1086
1087        for path in paths {
1088            match self.load(path) {
1089                Ok(info) => loaded.push(info),
1090                Err(e) => {
1091                    tracing::warn!("Failed to load extension '{}': {}", path.display(), e);
1092                    errors.push(e);
1093                }
1094            }
1095        }
1096
1097        (loaded, errors)
1098    }
1099
1100    // ── Execution ──────────────────────────────────────────────────
1101
1102    /// Execute a tool via the WASM extension.
1103    pub fn execute_tool(&self, tool_name: &str, params: Value) -> Result<Value> {
1104        let ext_name = self
1105            .tool_to_ext
1106            .get(tool_name)
1107            .with_context(|| format!("No extension registered for tool: {}", tool_name))?
1108            .clone();
1109
1110        let mut plugins = self.plugins.lock();
1111        let plugin = plugins
1112            .get_mut(&ext_name)
1113            .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1114
1115        let input = serde_json::json!({
1116            "tool": tool_name,
1117            "params": params,
1118        });
1119        let input_str = serde_json::to_string(&input)?;
1120
1121        // Set thread-local extension context so KV host functions can namespace keys
1122        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1123        let call_result = plugin.call("execute_tool", &input_str);
1124        CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1125
1126        let output: &str = call_result
1127            .with_context(|| format!("execute_tool('{}') failed in '{}'", tool_name, ext_name))?;
1128
1129        let result: Value = serde_json::from_str(output)
1130            .with_context(|| format!("execute_tool() returned invalid JSON: {}", output))?;
1131
1132        Ok(result)
1133    }
1134
1135    // ── Accessors ──────────────────────────────────────────────────
1136
1137    /// Get all tool definitions from all loaded extensions.
1138    pub fn all_tool_defs(&self) -> Vec<&WasmToolDef> {
1139        self.extensions
1140            .values()
1141            .flat_map(|e| e.tools.iter())
1142            .collect()
1143    }
1144
1145    /// Check if a tool name belongs to a WASM extension.
1146    pub fn is_wasm_tool(&self, tool_name: &str) -> bool {
1147        self.tool_to_ext.contains_key(tool_name)
1148    }
1149
1150    /// List loaded extension names.
1151    pub fn extension_names(&self) -> impl Iterator<Item = &str> {
1152        self.extensions.keys().map(|s| s.as_str())
1153    }
1154
1155    /// Get extension info by name.
1156    pub fn get_info(&self, name: &str) -> Option<&ExtensionInfo> {
1157        self.extensions.get(name).map(|e| &e.info)
1158    }
1159
1160    /// Number of loaded extensions.
1161    pub fn len(&self) -> usize {
1162        self.extensions.len()
1163    }
1164
1165    /// Whether any extensions are loaded.
1166    pub fn is_empty(&self) -> bool {
1167        self.extensions.is_empty()
1168    }
1169
1170    // ── Commands ──────────────────────────────────────────────────
1171
1172    /// Get all command definitions from all loaded extensions.
1173    pub fn all_command_defs(&self) -> Vec<(&str, &WasmCommandDef)> {
1174        let mut cmds = Vec::new();
1175        for ext in self.extensions.values() {
1176            for cmd in &ext.commands {
1177                cmds.push((ext.info.name.as_str(), cmd));
1178            }
1179        }
1180        cmds
1181    }
1182
1183    /// Execute a command via the WASM extension.
1184    /// Returns the output text to display to the user.
1185    pub fn execute_command(&self, command_name: &str, args: &str) -> Result<String> {
1186        // Find which extension owns this command
1187        let ext_name = self
1188            .extensions
1189            .iter()
1190            .find(|(_, ext)| ext.commands.iter().any(|c| c.name == command_name))
1191            .map(|(name, _)| name.clone())
1192            .with_context(|| format!("No extension registered for command: /{}", command_name))?;
1193
1194        let mut plugins = self.plugins.lock();
1195        let plugin = plugins
1196            .get_mut(&ext_name)
1197            .with_context(|| format!("Extension '{}' not loaded", ext_name))?;
1198
1199        let input = serde_json::json!({
1200            "command": command_name,
1201            "args": args,
1202        });
1203        let input_str = serde_json::to_string(&input)?;
1204
1205        let output: &str = {
1206            CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = Some(ext_name.clone()));
1207            let result = plugin.call("execute_command", &input_str);
1208            CURRENT_EXTENSION.with(|cell| *cell.borrow_mut() = None);
1209            result
1210        }
1211        .with_context(|| {
1212            format!(
1213                "execute_command('/{}') failed in '{}'",
1214                command_name, ext_name
1215            )
1216        })?;
1217
1218        // Parse response — extension returns {"output": "..."} or plain string
1219        let result: Value =
1220            serde_json::from_str(output).unwrap_or_else(|_| serde_json::json!({"output": output}));
1221
1222        Ok(result
1223            .get("output")
1224            .and_then(|v| v.as_str())
1225            .unwrap_or(output)
1226            .to_string())
1227    }
1228}
1229
1230// ── Tests ────────────────────────────────────────────────────────────
1231
1232#[cfg(test)]
1233mod tests {
1234    use super::*;
1235
1236    #[test]
1237    fn test_discover_empty_dir() {
1238        let dir = tempfile::tempdir().unwrap();
1239        let paths = WasmExtensionManager::discover(dir.path());
1240        assert!(paths.is_empty());
1241    }
1242
1243    #[test]
1244    fn test_discover_finds_wasm_files() {
1245        let dir = tempfile::tempdir().unwrap();
1246        let wasm_path = dir.path().join("test_ext.wasm");
1247        std::fs::write(&wasm_path, b"\x00asm").unwrap();
1248        // Create a non-wasm file that should be ignored
1249        std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
1250
1251        let mut paths = Vec::new();
1252        WasmExtensionManager::discover_in_dir(dir.path(), &mut paths);
1253        assert_eq!(paths.len(), 1);
1254        assert!(paths[0].ends_with("test_ext.wasm"));
1255    }
1256
1257    #[test]
1258    fn test_extension_info_parse() {
1259        let json = r#"{"name":"my_ext","version":"1.0.0","description":"Test"}"#;
1260        let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1261        assert_eq!(info.name, "my_ext");
1262        assert_eq!(info.version, "1.0.0");
1263    }
1264
1265    #[test]
1266    fn test_tool_def_parse() {
1267        let json = r#"{"name":"search","description":"Search","schema":{"type":"object"}}"#;
1268        let tool: WasmToolDef = serde_json::from_str(json).unwrap();
1269        assert_eq!(tool.name, "search");
1270    }
1271
1272    #[test]
1273    fn test_manager_new_is_empty() {
1274        let mgr = WasmExtensionManager::new();
1275        assert!(mgr.is_empty());
1276        assert_eq!(mgr.len(), 0);
1277    }
1278
1279    #[test]
1280    fn test_is_wasm_tool_false() {
1281        let mgr = WasmExtensionManager::new();
1282        assert!(!mgr.is_wasm_tool("anything"));
1283    }
1284
1285    #[test]
1286    fn test_extension_info_default_description() {
1287        let json = r#"{"name":"test","version":"0.1"}"#;
1288        let info: ExtensionInfo = serde_json::from_str(json).unwrap();
1289        assert_eq!(info.description, "");
1290    }
1291
1292    #[test]
1293    fn test_ssrf_blocks_localhost() {
1294        assert!(validate_url("http://localhost/admin").is_err());
1295        assert!(validate_url("http://127.0.0.1/secret").is_err());
1296        assert!(validate_url("http://10.0.0.1/internal").is_err());
1297        assert!(validate_url("http://192.168.1.1/router").is_err());
1298        assert!(validate_url("http://172.16.0.1/corp").is_err());
1299        assert!(validate_url("http://169.254.169.254/metadata").is_err());
1300        assert!(validate_url("http://[::1]/ipv6").is_err());
1301        // Also test without brackets (parsed hostname)
1302        assert!(validate_url("http://0.0.0.0/admin").is_err());
1303    }
1304
1305    #[test]
1306    fn test_ssrf_allows_public() {
1307        assert!(validate_url("https://api.github.com/repos/test").is_ok());
1308        assert!(validate_url("https://example.com/api").is_ok());
1309        assert!(validate_url("https://search.brave.com/api/search?q=test").is_ok());
1310    }
1311
1312    #[test]
1313    fn test_ssrf_172_range() {
1314        assert!(validate_url("http://172.16.0.1/test").is_err());
1315        assert!(validate_url("http://172.31.255.255/test").is_err());
1316        assert!(validate_url("http://172.15.0.1/test").is_ok());
1317        assert!(validate_url("http://172.32.0.1/test").is_ok());
1318    }
1319}