Skip to main content

fude/
acp.rs

1//! ACP (Agent Client Protocol) client.
2//!
3//! Registered by [`crate::App::with_acp`]. Consumers pass a list of
4//! [`AcpAdapterConfig`] entries describing which ACP agents this app
5//! supports and how to locate their binaries; fude handles the
6//! JSON-RPC framing, the pending-request map, permission prompts, and
7//! sandboxed `fs/read_text_file` / `fs/write_text_file` responses so the
8//! agent cannot escape the user's allow-list.
9
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14
15use serde_json::Value;
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::process::{Child, ChildStdin};
18use tokio::sync::{oneshot, Mutex};
19
20use crate::events::EventEmitter;
21use crate::sandbox::{validate_path, SharedList};
22
23type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>>;
24
25#[derive(Clone, Debug)]
26pub struct AcpAdapterConfig {
27    pub name: String,
28    pub candidate_bin_names: Vec<String>,
29}
30
31pub struct AcpProcess {
32    stdin: Arc<Mutex<ChildStdin>>,
33    pending: PendingMap,
34    next_id: AtomicU64,
35    child: Arc<Mutex<Option<Child>>>,
36}
37
38fn pick_permission_option(params: &Value, allow: bool) -> String {
39    let want_prefixes: &[&str] = if allow {
40        &[
41            "allow_once",
42            "allow-once",
43            "allow_always",
44            "allow-always",
45            "allow",
46        ]
47    } else {
48        &[
49            "reject_once",
50            "reject-once",
51            "reject_always",
52            "reject-always",
53            "reject",
54            "deny",
55        ]
56    };
57    if let Some(options) = params.get("options").and_then(|v| v.as_array()) {
58        for want in want_prefixes {
59            for opt in options {
60                if let Some(k) = opt.get("kind").and_then(|v| v.as_str()) {
61                    if k.eq_ignore_ascii_case(want) {
62                        if let Some(id) = opt.get("optionId").and_then(|v| v.as_str()) {
63                            return id.to_string();
64                        }
65                    }
66                }
67            }
68        }
69        for want in want_prefixes {
70            for opt in options {
71                if let Some(id) = opt.get("optionId").and_then(|v| v.as_str()) {
72                    if id
73                        .to_ascii_lowercase()
74                        .starts_with(&want.to_ascii_lowercase())
75                    {
76                        return id.to_string();
77                    }
78                }
79            }
80        }
81    }
82    if allow {
83        "allow-once".to_string()
84    } else {
85        "reject-once".to_string()
86    }
87}
88
89fn check_path_access(
90    path: &str,
91    allowed_paths: &SharedList,
92    allowed_dirs: &SharedList,
93) -> Result<std::path::PathBuf, String> {
94    validate_path(path)?;
95    let canonical = std::fs::canonicalize(path).map_err(|_| "Invalid file path".to_string())?;
96    let canonical_str = canonical.to_string_lossy().to_string();
97
98    {
99        let guard = allowed_paths.lock().unwrap_or_else(|e| e.into_inner());
100        if guard.contains(&canonical_str) {
101            return Ok(canonical);
102        }
103    }
104    {
105        let guard = allowed_dirs.lock().unwrap_or_else(|e| e.into_inner());
106        if guard.iter().any(|d| canonical.starts_with(Path::new(d))) {
107            return Ok(canonical);
108        }
109    }
110    Err("Access denied: path not whitelisted".to_string())
111}
112
113fn resolve_write_target(
114    path: &str,
115    allowed_paths: &SharedList,
116    allowed_dirs: &SharedList,
117) -> Result<(std::path::PathBuf, std::path::PathBuf), String> {
118    validate_path(path)?;
119    if let Ok(canonical) = check_path_access(path, allowed_paths, allowed_dirs) {
120        let parent = canonical
121            .parent()
122            .ok_or_else(|| "Invalid file path".to_string())?
123            .to_path_buf();
124        return Ok((canonical, parent));
125    }
126    let p = Path::new(path);
127    let parent = p.parent().ok_or_else(|| "Invalid file path".to_string())?;
128    let filename = p
129        .file_name()
130        .ok_or_else(|| "Invalid file name".to_string())?;
131    let canonical_parent =
132        std::fs::canonicalize(parent).map_err(|_| "Parent directory does not exist".to_string())?;
133    let dir_ok = {
134        let guard = allowed_dirs.lock().unwrap_or_else(|e| e.into_inner());
135        guard
136            .iter()
137            .any(|d| canonical_parent.starts_with(Path::new(d)))
138    };
139    if !dir_ok {
140        return Err("Access denied: parent directory not whitelisted".to_string());
141    }
142    Ok((canonical_parent.join(filename), canonical_parent))
143}
144
145fn acp_write_file(
146    path: &str,
147    content: &str,
148    allowed_paths: &SharedList,
149    allowed_dirs: &SharedList,
150) -> Result<(), String> {
151    let (target, canonical_parent) = resolve_write_target(path, allowed_paths, allowed_dirs)?;
152    std::fs::write(&target, content).map_err(|e| format!("Failed to write file: {}", e))?;
153    let final_canonical = match std::fs::canonicalize(&target) {
154        Ok(p) => p,
155        Err(e) => {
156            let _ = std::fs::remove_file(&target);
157            return Err(format!("Cannot resolve written file: {}", e));
158        }
159    };
160    if validate_path(&final_canonical.to_string_lossy()).is_err()
161        || !final_canonical.starts_with(&canonical_parent)
162    {
163        let _ = std::fs::remove_file(&target);
164        return Err("Write rejected: symlink escape detected".to_string());
165    }
166    Ok(())
167}
168
169impl AcpProcess {
170    pub async fn spawn(
171        bin: &str,
172        emitter: EventEmitter,
173        allowed_paths: SharedList,
174        allowed_dirs: SharedList,
175    ) -> Result<Arc<Self>, String> {
176        use tokio::process::Command;
177
178        let mut cmd = Command::new(bin);
179        cmd.stdin(std::process::Stdio::piped());
180        cmd.stdout(std::process::Stdio::piped());
181        cmd.stderr(std::process::Stdio::inherit());
182
183        if let Ok(path) = std::env::var("PATH") {
184            cmd.env("PATH", path);
185        }
186        if let Ok(home) = std::env::var("HOME") {
187            cmd.env("HOME", home);
188        }
189
190        let mut child = cmd
191            .spawn()
192            .map_err(|e| format!("Failed to spawn ACP process: {}", e))?;
193
194        let stdin = child.stdin.take().ok_or("Failed to capture ACP stdin")?;
195        let stdout = child.stdout.take().ok_or("Failed to capture ACP stdout")?;
196
197        let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
198        let acp = Arc::new(Self {
199            stdin: Arc::new(Mutex::new(stdin)),
200            pending: pending.clone(),
201            next_id: AtomicU64::new(1),
202            child: Arc::new(Mutex::new(Some(child))),
203        });
204
205        let pending_for_reader = pending.clone();
206        let stdin_for_reader = acp.stdin.clone();
207        let paths_for_reader = allowed_paths.clone();
208        let dirs_for_reader = allowed_dirs.clone();
209        tokio::spawn(async move {
210            let mut reader = BufReader::new(stdout).lines();
211            while let Ok(Some(line)) = reader.next_line().await {
212                if line.is_empty() {
213                    continue;
214                }
215                let msg: Value = match serde_json::from_str(&line) {
216                    Ok(v) => v,
217                    Err(_) => continue,
218                };
219
220                if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
221                    if msg.get("result").is_some() || msg.get("error").is_some() {
222                        let mut map = pending_for_reader.lock().await;
223                        if let Some(tx) = map.remove(&id) {
224                            if let Some(err) = msg.get("error") {
225                                let _ = tx.send(Err(err.clone()));
226                            } else {
227                                let _ =
228                                    tx.send(Ok(msg.get("result").cloned().unwrap_or(Value::Null)));
229                            }
230                        }
231                        continue;
232                    }
233                }
234
235                if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
236                    let params = msg.get("params").cloned().unwrap_or(Value::Null);
237                    match method {
238                        "session/update" => {
239                            emitter.emit("acp:session-update", params);
240                        }
241                        "session/request_permission" => {
242                            if let Some(id) = msg.get("id") {
243                                emitter.emit(
244                                    "acp:permission-request",
245                                    serde_json::json!({ "requestId": id, "params": params }),
246                                );
247                                let kind = params
248                                    .pointer("/toolCall/kind")
249                                    .and_then(|v| v.as_str())
250                                    .unwrap_or("");
251                                let safe_kind =
252                                    matches!(kind, "read" | "edit" | "think" | "search");
253                                let option_id = pick_permission_option(&params, safe_kind);
254                                let response = serde_json::json!({
255                                    "jsonrpc": "2.0",
256                                    "id": id,
257                                    "result": {
258                                        "outcome": { "outcome": "selected", "optionId": option_id }
259                                    }
260                                });
261                                let mut data = serde_json::to_string(&response).unwrap();
262                                data.push('\n');
263                                let mut w = stdin_for_reader.lock().await;
264                                let _ = w.write_all(data.as_bytes()).await;
265                                let _ = w.flush().await;
266                            }
267                        }
268                        "fs/read_text_file" => {
269                            if let Some(id) = msg.get("id") {
270                                let path =
271                                    params.get("path").and_then(|p| p.as_str()).unwrap_or("");
272                                let response = match check_path_access(
273                                    path,
274                                    &paths_for_reader,
275                                    &dirs_for_reader,
276                                ) {
277                                    Ok(canonical) => match std::fs::read_to_string(&canonical) {
278                                        Ok(content) => serde_json::json!({
279                                            "jsonrpc": "2.0", "id": id,
280                                            "result": { "content": content }
281                                        }),
282                                        Err(e) => serde_json::json!({
283                                            "jsonrpc": "2.0", "id": id,
284                                            "error": { "code": -32603, "message": format!("Cannot read file: {}", e) }
285                                        }),
286                                    },
287                                    Err(m) => serde_json::json!({
288                                        "jsonrpc": "2.0", "id": id,
289                                        "error": { "code": -32602, "message": m }
290                                    }),
291                                };
292                                let mut data = serde_json::to_string(&response).unwrap();
293                                data.push('\n');
294                                let mut w = stdin_for_reader.lock().await;
295                                let _ = w.write_all(data.as_bytes()).await;
296                                let _ = w.flush().await;
297                            }
298                        }
299                        "fs/write_text_file" => {
300                            if let Some(id) = msg.get("id") {
301                                let path =
302                                    params.get("path").and_then(|p| p.as_str()).unwrap_or("");
303                                let content =
304                                    params.get("content").and_then(|c| c.as_str()).unwrap_or("");
305                                let response = match acp_write_file(
306                                    path,
307                                    content,
308                                    &paths_for_reader,
309                                    &dirs_for_reader,
310                                ) {
311                                    Ok(()) => {
312                                        serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} })
313                                    }
314                                    Err(m) => serde_json::json!({
315                                        "jsonrpc": "2.0", "id": id,
316                                        "error": { "code": -32602, "message": m }
317                                    }),
318                                };
319                                let mut data = serde_json::to_string(&response).unwrap();
320                                data.push('\n');
321                                let mut w = stdin_for_reader.lock().await;
322                                let _ = w.write_all(data.as_bytes()).await;
323                                let _ = w.flush().await;
324                            }
325                        }
326                        _ => {
327                            if let Some(id) = msg.get("id") {
328                                let response = serde_json::json!({
329                                    "jsonrpc": "2.0", "id": id,
330                                    "error": { "code": -32601, "message": format!("Method not found: {}", method) }
331                                });
332                                let mut data = serde_json::to_string(&response).unwrap();
333                                data.push('\n');
334                                let mut w = stdin_for_reader.lock().await;
335                                let _ = w.write_all(data.as_bytes()).await;
336                                let _ = w.flush().await;
337                            }
338                        }
339                    }
340                }
341            }
342        });
343
344        Ok(acp)
345    }
346
347    pub async fn request(&self, method: &str, params: Value) -> Result<Value, String> {
348        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
349        let msg =
350            serde_json::json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params });
351        let (tx, rx) = oneshot::channel();
352        {
353            let mut map = self.pending.lock().await;
354            map.insert(id, tx);
355        }
356        let mut data = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
357        data.push('\n');
358        {
359            let mut w = self.stdin.lock().await;
360            w.write_all(data.as_bytes())
361                .await
362                .map_err(|e| format!("ACP write error: {}", e))?;
363            w.flush()
364                .await
365                .map_err(|e| format!("ACP flush error: {}", e))?;
366        }
367        match rx.await {
368            Ok(Ok(result)) => Ok(result),
369            Ok(Err(err)) => Err(format!(
370                "ACP error: {}",
371                err.get("message")
372                    .and_then(|m| m.as_str())
373                    .unwrap_or("unknown")
374            )),
375            Err(_) => Err("ACP response channel closed".to_string()),
376        }
377    }
378
379    pub async fn notify(&self, method: &str, params: Value) -> Result<(), String> {
380        let msg = serde_json::json!({ "jsonrpc": "2.0", "method": method, "params": params });
381        let mut data = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
382        data.push('\n');
383        let mut w = self.stdin.lock().await;
384        w.write_all(data.as_bytes())
385            .await
386            .map_err(|e| format!("ACP write error: {}", e))?;
387        w.flush()
388            .await
389            .map_err(|e| format!("ACP flush error: {}", e))?;
390        Ok(())
391    }
392
393    pub async fn kill(&self) {
394        let mut guard = self.child.lock().await;
395        if let Some(ref mut child) = *guard {
396            let _ = child.kill().await;
397        }
398        *guard = None;
399    }
400}
401
402pub struct AcpState {
403    pub process: Mutex<Option<Arc<AcpProcess>>>,
404    pub adapter: Mutex<String>,
405    pub adapters: Vec<AcpAdapterConfig>,
406    pub client_name: String,
407    pub client_version: String,
408}
409
410impl AcpState {
411    pub fn new(
412        adapters: Vec<AcpAdapterConfig>,
413        client_name: String,
414        client_version: String,
415    ) -> Self {
416        let default = adapters.first().map(|a| a.name.clone()).unwrap_or_default();
417        Self {
418            process: Mutex::new(None),
419            adapter: Mutex::new(default),
420            adapters,
421            client_name,
422            client_version,
423        }
424    }
425
426    pub fn find_adapter(&self, name: &str) -> Option<&AcpAdapterConfig> {
427        self.adapters.iter().find(|a| a.name == name)
428    }
429}
430
431pub fn resolve_acp_bin(adapter: &AcpAdapterConfig) -> Result<String, String> {
432    let bin_names: Vec<&str> = adapter
433        .candidate_bin_names
434        .iter()
435        .map(|s| s.as_str())
436        .collect();
437    let mut candidates: Vec<std::path::PathBuf> = Vec::new();
438
439    if let Ok(cwd) = std::env::current_dir() {
440        for b in &bin_names {
441            candidates.push(cwd.join("node_modules/.bin").join(b));
442        }
443        if let Some(parent) = cwd.parent() {
444            for b in &bin_names {
445                candidates.push(parent.join("node_modules/.bin").join(b));
446            }
447        }
448    }
449    if let Ok(exe) = std::env::current_exe() {
450        if let Some(dir) = exe.parent() {
451            for b in &bin_names {
452                candidates.push(dir.join("../Resources/node_modules/.bin").join(b));
453                candidates.push(dir.join("node_modules/.bin").join(b));
454            }
455            // Walk up looking for node_modules/.bin so `cargo run` from any
456            // sub-crate finds the workspace-level install.
457            let mut cur: Option<&Path> = Some(dir);
458            while let Some(d) = cur {
459                for b in &bin_names {
460                    candidates.push(d.join("node_modules/.bin").join(b));
461                }
462                cur = d.parent();
463            }
464        }
465    }
466
467    for c in &candidates {
468        if c.exists() {
469            return Ok(c.to_string_lossy().to_string());
470        }
471    }
472    Err(format!("ACP binary not found for {}", adapter.name))
473}
474pub async fn ensure_acp(
475    state: &AcpState,
476    emitter: EventEmitter,
477    allowed_paths: SharedList,
478    allowed_dirs: SharedList,
479) -> Result<Arc<AcpProcess>, String> {
480    let name = state.adapter.lock().await.clone();
481    let mut guard = state.process.lock().await;
482    if let Some(ref acp) = *guard {
483        return Ok(Arc::clone(acp));
484    }
485    let adapter = state
486        .find_adapter(&name)
487        .ok_or_else(|| format!("Unknown ACP adapter: {}", name))?;
488    let bin = resolve_acp_bin(adapter)?;
489    let acp = AcpProcess::spawn(&bin, emitter, allowed_paths, allowed_dirs).await?;
490    *guard = Some(Arc::clone(&acp));
491    Ok(acp)
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use serde_json::json;
498
499    // --- pick_permission_option -------------------------------------------
500
501    #[test]
502    fn picks_allow_kind_when_present() {
503        let params = json!({
504            "options": [
505                { "kind": "reject_once", "optionId": "r1" },
506                { "kind": "allow_once",  "optionId": "a1" }
507            ]
508        });
509        assert_eq!(pick_permission_option(&params, true), "a1");
510    }
511
512    #[test]
513    fn picks_reject_kind_when_denied() {
514        let params = json!({
515            "options": [
516                { "kind": "allow_once",  "optionId": "a1" },
517                { "kind": "reject_once", "optionId": "r1" }
518            ]
519        });
520        assert_eq!(pick_permission_option(&params, false), "r1");
521    }
522
523    #[test]
524    fn falls_back_to_prefix_match_on_option_id() {
525        // No `kind` field — use optionId prefix as the fallback signal.
526        let params = json!({
527            "options": [
528                { "optionId": "allow-always" },
529                { "optionId": "deny-always" }
530            ]
531        });
532        assert_eq!(pick_permission_option(&params, true), "allow-always");
533        assert_eq!(pick_permission_option(&params, false), "deny-always");
534    }
535
536    #[test]
537    fn default_when_no_options() {
538        let params = json!({});
539        assert_eq!(pick_permission_option(&params, true), "allow-once");
540        assert_eq!(pick_permission_option(&params, false), "reject-once");
541    }
542
543    #[test]
544    fn prefers_hyphen_or_underscore_variant_equally() {
545        let params = json!({
546            "options": [
547                { "kind": "allow-once", "optionId": "a-hyphen" }
548            ]
549        });
550        assert_eq!(pick_permission_option(&params, true), "a-hyphen");
551    }
552
553    // --- AcpState / AcpAdapterConfig --------------------------------------
554
555    fn adapters() -> Vec<AcpAdapterConfig> {
556        vec![
557            AcpAdapterConfig {
558                name: "claude-code".into(),
559                candidate_bin_names: vec!["claude-code-acp".into()],
560            },
561            AcpAdapterConfig {
562                name: "codex".into(),
563                candidate_bin_names: vec!["codex-acp".into()],
564            },
565        ]
566    }
567
568    #[test]
569    fn acp_state_default_adapter_is_first_entry() {
570        let state = AcpState::new(adapters(), "test".into(), "0.0.0".into());
571        let name = tokio_block_on(async { state.adapter.lock().await.clone() });
572        assert_eq!(name, "claude-code");
573    }
574
575    #[test]
576    fn acp_state_default_is_empty_when_no_adapters() {
577        let state = AcpState::new(vec![], "test".into(), "0.0.0".into());
578        let name = tokio_block_on(async { state.adapter.lock().await.clone() });
579        assert_eq!(name, "");
580    }
581
582    #[test]
583    fn find_adapter_matches_by_name() {
584        let state = AcpState::new(adapters(), "c".into(), "0.0.0".into());
585        let found = state.find_adapter("codex").unwrap();
586        assert_eq!(found.name, "codex");
587    }
588
589    #[test]
590    fn find_adapter_returns_none_for_unknown() {
591        let state = AcpState::new(adapters(), "c".into(), "0.0.0".into());
592        assert!(state.find_adapter("does-not-exist").is_none());
593    }
594
595    // --- resolve_acp_bin --------------------------------------------------
596
597    #[test]
598    fn resolve_acp_bin_errors_when_nothing_found() {
599        let adapter = AcpAdapterConfig {
600            name: "nope".into(),
601            candidate_bin_names: vec!["definitely-not-installed-anywhere-xyz".into()],
602        };
603        let err = resolve_acp_bin(&adapter).unwrap_err();
604        assert!(err.contains("not found"), "got: {err}");
605    }
606
607    fn tokio_block_on<F: std::future::Future<Output = T>, T>(f: F) -> T {
608        tokio::runtime::Builder::new_current_thread()
609            .enable_all()
610            .build()
611            .unwrap()
612            .block_on(f)
613    }
614}