Skip to main content

ati/core/
cli_executor.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4
5use thiserror::Error;
6
7use crate::core::auth_generator::{self, AuthCache, GenContext};
8use crate::core::keyring::Keyring;
9use crate::core::manifest::Provider;
10
11// ---------------------------------------------------------------------------
12// Errors
13// ---------------------------------------------------------------------------
14
15#[derive(Error, Debug)]
16pub enum CliError {
17    #[error("CLI config error: {0}")]
18    Config(String),
19    #[error("Missing keyring key: {0}")]
20    MissingKey(String),
21    #[error("Failed to spawn CLI process: {0}")]
22    Spawn(String),
23    #[error("CLI timed out after {0}s")]
24    Timeout(u64),
25    #[error("CLI exited with code {code}: {stderr}")]
26    NonZeroExit { code: i32, stderr: String },
27    #[error("IO error: {0}")]
28    Io(#[from] std::io::Error),
29    #[error("Credential file error: {0}")]
30    CredentialFile(String),
31    #[error("Captured output '{path}' exceeds ATI_CLI_MAX_OUTPUT_BYTES ({limit} bytes)")]
32    OutputTooLarge { path: String, limit: u64 },
33    #[error("Captured output '{path}' was not produced by the CLI")]
34    OutputMissing { path: String },
35}
36
37// ---------------------------------------------------------------------------
38// CredentialFile — wipe-on-drop temporary credential files
39// ---------------------------------------------------------------------------
40
41pub struct CredentialFile {
42    pub path: PathBuf,
43    wipe_on_drop: bool,
44}
45
46impl Drop for CredentialFile {
47    fn drop(&mut self) {
48        if self.wipe_on_drop {
49            // Best-effort overwrite with zeros then delete
50            if let Ok(meta) = std::fs::metadata(&self.path) {
51                let len = meta.len() as usize;
52                if len > 0 {
53                    if let Ok(file) = std::fs::OpenOptions::new().write(true).open(&self.path) {
54                        use std::io::Write;
55                        let zeros = vec![0u8; len];
56                        let _ = (&file).write_all(&zeros);
57                        let _ = file.sync_all();
58                    }
59                }
60            }
61            let _ = std::fs::remove_file(&self.path);
62        }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Credential file materialization
68// ---------------------------------------------------------------------------
69
70/// Materialize a keyring secret as a file on disk with 0600 permissions.
71///
72/// In dev mode (`wipe_on_drop = false`), uses a stable path so repeated runs
73/// reuse the same file. In prod mode (`wipe_on_drop = true`), appends a random
74/// suffix so concurrent invocations don't collide.
75pub fn materialize_credential_file(
76    key_name: &str,
77    content: &str,
78    wipe_on_drop: bool,
79    ati_dir: &Path,
80) -> Result<CredentialFile, CliError> {
81    use std::os::unix::fs::OpenOptionsExt;
82
83    let creds_dir = ati_dir.join(".creds");
84    std::fs::create_dir_all(&creds_dir).map_err(|e| {
85        CliError::CredentialFile(format!("failed to create {}: {e}", creds_dir.display()))
86    })?;
87
88    let path = if wipe_on_drop {
89        let suffix: u32 = rand::random();
90        creds_dir.join(format!("{key_name}_{suffix}"))
91    } else {
92        creds_dir.join(key_name)
93    };
94
95    let mut file = std::fs::OpenOptions::new()
96        .write(true)
97        .create(true)
98        .truncate(true)
99        .mode(0o600)
100        .open(&path)
101        .map_err(|e| {
102            CliError::CredentialFile(format!("failed to write {}: {e}", path.display()))
103        })?;
104
105    {
106        use std::io::Write;
107        file.write_all(content.as_bytes()).map_err(|e| {
108            CliError::CredentialFile(format!("failed to write {}: {e}", path.display()))
109        })?;
110        file.sync_all().map_err(|e| {
111            CliError::CredentialFile(format!("failed to sync {}: {e}", path.display()))
112        })?;
113    }
114
115    Ok(CredentialFile { path, wipe_on_drop })
116}
117
118// ---------------------------------------------------------------------------
119// Env resolution
120// ---------------------------------------------------------------------------
121
122/// Resolve `${key_ref}` placeholders in a string from the keyring.
123/// Same logic as `resolve_env_value` in `mcp_client.rs`.
124fn resolve_env_value(value: &str, keyring: &Keyring) -> Result<String, CliError> {
125    let mut result = value.to_string();
126    while let Some(start) = result.find("${") {
127        let rest = &result[start + 2..];
128        if let Some(end) = rest.find('}') {
129            let key_name = &rest[..end];
130            let replacement = keyring
131                .get(key_name)
132                .ok_or_else(|| CliError::MissingKey(key_name.to_string()))?;
133            result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
134        } else {
135            break; // No closing brace
136        }
137    }
138    Ok(result)
139}
140
141/// Resolve a provider's `cli_env` map against the keyring.
142///
143/// Three value forms:
144/// - `@{key_ref}`: materialize the keyring value as a credential file; env value = file path
145/// - `${key_ref}` (possibly inline): substitute from keyring
146/// - plain string: pass through unchanged
147///
148/// Returns the resolved env map and a vec of `CredentialFile`s whose lifetimes
149/// must span the subprocess execution (they are wiped on drop).
150pub fn resolve_cli_env(
151    env_map: &HashMap<String, String>,
152    keyring: &Keyring,
153    wipe_on_drop: bool,
154    ati_dir: &Path,
155) -> Result<(HashMap<String, String>, Vec<CredentialFile>), CliError> {
156    let mut resolved = HashMap::with_capacity(env_map.len());
157    let mut cred_files: Vec<CredentialFile> = Vec::new();
158
159    for (key, value) in env_map {
160        if let Some(key_ref) = value.strip_prefix("@{").and_then(|s| s.strip_suffix('}')) {
161            // File-materialized credential
162            let content = keyring
163                .get(key_ref)
164                .ok_or_else(|| CliError::MissingKey(key_ref.to_string()))?;
165            let cf = materialize_credential_file(key_ref, content, wipe_on_drop, ati_dir)?;
166            resolved.insert(key.clone(), cf.path.to_string_lossy().into_owned());
167            cred_files.push(cf);
168        } else if value.contains("${") {
169            // Inline keyring substitution
170            let val = resolve_env_value(value, keyring)?;
171            resolved.insert(key.clone(), val);
172        } else {
173            // Plain passthrough
174            resolved.insert(key.clone(), value.clone());
175        }
176    }
177
178    Ok((resolved, cred_files))
179}
180
181// ---------------------------------------------------------------------------
182// Output capture — rewrite agent-supplied output paths to proxy temp paths
183// ---------------------------------------------------------------------------
184
185/// Default per-file cap on captured CLI output size (500 MB).
186pub const DEFAULT_CLI_MAX_OUTPUT_BYTES: u64 = 500 * 1024 * 1024;
187
188fn cli_max_output_bytes() -> u64 {
189    std::env::var("ATI_CLI_MAX_OUTPUT_BYTES")
190        .ok()
191        .and_then(|s| s.parse::<u64>().ok())
192        .filter(|n| *n > 0)
193        .unwrap_or(DEFAULT_CLI_MAX_OUTPUT_BYTES)
194}
195
196/// One captured output: agent-supplied path + the proxy-side temp path the
197/// subprocess actually wrote to.
198#[derive(Debug, Clone)]
199pub struct CapturedOutput {
200    /// Path the agent passed to the CLI (sandbox-side).
201    pub original_path: String,
202    /// Temp path on the proxy that the rewritten arg pointed at.
203    pub temp_path: PathBuf,
204}
205
206/// Apply a provider's output-capture rules to a flat arg list, producing a
207/// rewritten arg list (with temp paths in place of caller paths) plus a
208/// list of captures the proxy must read back after the subprocess exits.
209///
210/// Rules applied in order:
211/// 1. Named flags from `cli_output_args`: any matching `--flag value` pair has
212///    its value rewritten to a temp path. Both `--flag value` and `--flag=value`
213///    forms are supported.
214/// 2. Positional captures from `cli_output_positional`: longest matching
215///    subcommand prefix (after stripping `cli_default_args`) wins; the
216///    configured positional index within the *remaining* args is rewritten.
217pub fn apply_output_captures(
218    provider: &Provider,
219    raw_args: &[String],
220) -> Result<(Vec<String>, Vec<CapturedOutput>), CliError> {
221    let mut rewritten: Vec<String> = raw_args.to_vec();
222    let mut captures: Vec<CapturedOutput> = Vec::new();
223    // Indices of `rewritten` that step 1 substituted with temp paths. Step 2
224    // must never touch these — otherwise an invocation that matches BOTH a
225    // named-flag rule AND a positional rule (e.g. `bb browse screenshot
226    // --output /tmp/x.png` against a manifest with both `cli_output_args`
227    // and `cli_output_positional`) would double-rewrite the same slot,
228    // capturing the temp path as a "second" output the subprocess never wrote.
229    let mut consumed: std::collections::HashSet<usize> = std::collections::HashSet::new();
230
231    // 1. Named flag rewriting
232    if !provider.cli_output_args.is_empty() {
233        let mut i = 0;
234        while i < rewritten.len() {
235            let arg = rewritten[i].clone();
236            // --flag=value form
237            if let Some(eq_idx) = arg.find('=') {
238                let (flag, value) = arg.split_at(eq_idx);
239                if provider
240                    .cli_output_args
241                    .iter()
242                    .any(|f| f.eq_ignore_ascii_case(flag))
243                {
244                    let original = value[1..].to_string();
245                    let temp = make_temp_for(&original)?;
246                    rewritten[i] = format!("{}={}", flag, temp.display());
247                    captures.push(CapturedOutput {
248                        original_path: original,
249                        temp_path: temp,
250                    });
251                    consumed.insert(i);
252                    i += 1;
253                    continue;
254                }
255            }
256            // --flag value form
257            if provider
258                .cli_output_args
259                .iter()
260                .any(|f| f.eq_ignore_ascii_case(&arg))
261                && i + 1 < rewritten.len()
262            {
263                let original = rewritten[i + 1].clone();
264                let temp = make_temp_for(&original)?;
265                rewritten[i + 1] = temp.to_string_lossy().into_owned();
266                captures.push(CapturedOutput {
267                    original_path: original,
268                    temp_path: temp,
269                });
270                consumed.insert(i);
271                consumed.insert(i + 1);
272                i += 2;
273                continue;
274            }
275            i += 1;
276        }
277    }
278
279    // 2. Positional rewriting — match the longest subcommand prefix.
280    // Skip any slot already rewritten by step 1 so we don't double-capture.
281    if !provider.cli_output_positional.is_empty() {
282        // Build the list of non-flag (positional) tokens with their indices,
283        // skipping flags, their inline values, and slots consumed by step 1.
284        let positionals: Vec<(usize, String)> = rewritten
285            .iter()
286            .enumerate()
287            .filter_map(|(idx, s)| {
288                if consumed.contains(&idx) || s.starts_with('-') {
289                    None
290                } else {
291                    Some((idx, s.clone()))
292                }
293            })
294            .collect();
295
296        // Find the longest configured prefix (by token count) that matches the
297        // start of `positionals`.
298        let mut best: Option<(usize, usize)> = None; // (prefix_token_count, output_index)
299        for (prefix, idx) in &provider.cli_output_positional {
300            let prefix_tokens: Vec<&str> = prefix.split_whitespace().collect();
301            if prefix_tokens.is_empty() {
302                continue;
303            }
304            if positionals.len() < prefix_tokens.len() + idx + 1 {
305                continue;
306            }
307            let prefix_matches = prefix_tokens
308                .iter()
309                .enumerate()
310                .all(|(i, tok)| positionals[i].1 == *tok);
311            if !prefix_matches {
312                continue;
313            }
314            let count = prefix_tokens.len();
315            if best.is_none_or(|(c, _)| count > c) {
316                best = Some((count, *idx));
317            }
318        }
319
320        if let Some((prefix_count, output_idx)) = best {
321            let target_positional_idx = prefix_count + output_idx;
322            if let Some((real_idx, original)) = positionals.get(target_positional_idx).cloned() {
323                let temp = make_temp_for(&original)?;
324                rewritten[real_idx] = temp.to_string_lossy().into_owned();
325                captures.push(CapturedOutput {
326                    original_path: original,
327                    temp_path: temp,
328                });
329            }
330        }
331    }
332
333    Ok((rewritten, captures))
334}
335
336/// Build a unique proxy-side temp path that preserves the file extension of
337/// `original_path`, so CLIs that key behavior off extension (e.g. `bb`'s
338/// `--type` defaulting via `.png`/`.jpeg`) still get the right hint.
339fn make_temp_for(original_path: &str) -> Result<PathBuf, CliError> {
340    let ext = Path::new(original_path)
341        .extension()
342        .and_then(|e| e.to_str())
343        .unwrap_or("");
344    let suffix: u64 = rand::random();
345    let pid = std::process::id();
346    let name = if ext.is_empty() {
347        format!(".ati-cli-out-{pid}-{suffix:016x}")
348    } else {
349        format!(".ati-cli-out-{pid}-{suffix:016x}.{ext}")
350    };
351    Ok(std::env::temp_dir().join(name))
352}
353
354/// Read each captured temp path, base64-encode, build the JSON map keyed by
355/// the agent's original paths. Always cleans up temp files (even on size cap
356/// violation), and never silently skips a missing file — agent supplied a
357/// path expecting a result, so missing = error.
358async fn collect_capture_results(
359    captures: &[CapturedOutput],
360) -> Result<HashMap<String, serde_json::Value>, CliError> {
361    use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
362    let max = cli_max_output_bytes();
363    let mut out = HashMap::with_capacity(captures.len());
364
365    for cap in captures {
366        let bytes_result = tokio::fs::read(&cap.temp_path).await;
367        // Cleanup happens regardless of read outcome.
368        let _ = tokio::fs::remove_file(&cap.temp_path).await;
369
370        let bytes = match bytes_result {
371            Ok(b) => b,
372            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
373                return Err(CliError::OutputMissing {
374                    path: cap.original_path.clone(),
375                });
376            }
377            Err(e) => return Err(CliError::Io(e)),
378        };
379
380        if (bytes.len() as u64) > max {
381            return Err(CliError::OutputTooLarge {
382                path: cap.original_path.clone(),
383                limit: max,
384            });
385        }
386
387        let entry = serde_json::json!({
388            "content_base64": B64.encode(&bytes),
389            "size_bytes": bytes.len(),
390            "content_type": guess_content_type(&cap.original_path),
391        });
392        out.insert(cap.original_path.clone(), entry);
393    }
394    Ok(out)
395}
396
397use crate::core::file_manager::guess_content_type;
398
399/// Best-effort cleanup — used when the subprocess errors before we get to the
400/// normal collection path.
401fn discard_captures(captures: &[CapturedOutput]) {
402    for cap in captures {
403        let _ = std::fs::remove_file(&cap.temp_path);
404    }
405}
406
407// ---------------------------------------------------------------------------
408// Execute CLI tool
409// ---------------------------------------------------------------------------
410
411/// Execute a CLI provider tool as a subprocess.
412///
413/// Builds a curated environment (only safe vars from the host + resolved
414/// provider env), spawns the CLI command with the provider's default args
415/// plus the caller's raw args, enforces a timeout, and returns stdout
416/// parsed as JSON (or as a plain string fallback).
417pub async fn execute(
418    provider: &Provider,
419    raw_args: &[String],
420    keyring: &Keyring,
421) -> Result<serde_json::Value, CliError> {
422    execute_with_gen(provider, raw_args, keyring, None, None).await
423}
424
425/// Execute a CLI provider tool, optionally using a dynamic auth generator.
426pub async fn execute_with_gen(
427    provider: &Provider,
428    raw_args: &[String],
429    keyring: &Keyring,
430    gen_ctx: Option<&GenContext>,
431    auth_cache: Option<&AuthCache>,
432) -> Result<serde_json::Value, CliError> {
433    let cli_command = provider
434        .cli_command
435        .as_deref()
436        .ok_or_else(|| CliError::Config("provider missing cli_command".into()))?;
437
438    let timeout_secs = provider.cli_timeout_secs.unwrap_or(120);
439
440    let ati_dir = std::env::var("ATI_DIR")
441        .map(PathBuf::from)
442        .unwrap_or_else(|_| {
443            std::env::var("HOME")
444                .map(PathBuf::from)
445                .unwrap_or_else(|_| PathBuf::from("/tmp"))
446                .join(".ati")
447        });
448
449    let wipe_on_drop = keyring.ephemeral;
450
451    // Resolve provider CLI env vars against keyring.
452    // cred_files must live until after the subprocess exits (Drop does cleanup).
453    let (resolved_env, cred_files) =
454        resolve_cli_env(&provider.cli_env, keyring, wipe_on_drop, &ati_dir)?;
455
456    // Build curated base env from host
457    let mut final_env: HashMap<String, String> = HashMap::new();
458    for var in &["PATH", "HOME", "TMPDIR", "LANG", "USER", "TERM"] {
459        if let Ok(val) = std::env::var(var) {
460            final_env.insert(var.to_string(), val);
461        }
462    }
463    // Layer provider-resolved env on top
464    final_env.extend(resolved_env);
465
466    // If auth_generator is configured, run it and inject into env
467    if let Some(gen) = &provider.auth_generator {
468        let default_ctx = GenContext::default();
469        let ctx = gen_ctx.unwrap_or(&default_ctx);
470        let default_cache = AuthCache::new();
471        let cache = auth_cache.unwrap_or(&default_cache);
472        match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
473            Ok(cred) => {
474                final_env.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
475                for (k, v) in &cred.extra_env {
476                    final_env.insert(k.clone(), v.clone());
477                }
478            }
479            Err(e) => {
480                return Err(CliError::Config(format!("auth_generator failed: {e}")));
481            }
482        }
483    }
484
485    // Apply output-capture rewriting BEFORE the subprocess runs. The agent's
486    // intended output paths are swapped for proxy-side temp paths; the originals
487    // are preserved on `captures` so we can map captured bytes back to them.
488    let (rewritten_args, captures) = apply_output_captures(provider, raw_args)?;
489
490    // Clone values for the blocking closure
491    let command = cli_command.to_string();
492    let default_args = provider.cli_default_args.clone();
493    let extra_args = rewritten_args;
494    let env_snapshot = final_env;
495    let timeout_dur = std::time::Duration::from_secs(timeout_secs);
496
497    // Spawn the subprocess via tokio::process so we get an async-aware child
498    // that we can kill on timeout (unlike spawn_blocking + std::process which
499    // would leave the subprocess running when the timeout fires).
500    let child = tokio::process::Command::new(&command)
501        .args(&default_args)
502        .args(&extra_args)
503        .env_clear()
504        .envs(&env_snapshot)
505        .stdout(Stdio::piped())
506        .stderr(Stdio::piped())
507        .kill_on_drop(true)
508        .spawn()
509        .map_err(|e| {
510            discard_captures(&captures);
511            CliError::Spawn(format!("{command}: {e}"))
512        })?;
513
514    // Apply timeout — kill_on_drop ensures the child is killed if we bail early
515    let output = match tokio::time::timeout(timeout_dur, child.wait_with_output()).await {
516        Ok(Ok(o)) => o,
517        Ok(Err(e)) => {
518            discard_captures(&captures);
519            return Err(CliError::Io(e));
520        }
521        Err(_) => {
522            discard_captures(&captures);
523            return Err(CliError::Timeout(timeout_secs));
524        }
525    };
526
527    // cred_files still alive here — drop explicitly after subprocess exits
528    drop(cred_files);
529
530    if !output.status.success() {
531        discard_captures(&captures);
532        let code = output.status.code().unwrap_or(-1);
533        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
534        return Err(CliError::NonZeroExit { code, stderr });
535    }
536
537    let stdout = String::from_utf8_lossy(&output.stdout);
538
539    // No captures configured → preserve the legacy response shape exactly:
540    // either parsed JSON or the trimmed stdout string.
541    if captures.is_empty() {
542        let value = match serde_json::from_str::<serde_json::Value>(stdout.trim()) {
543            Ok(v) => v,
544            Err(_) => serde_json::Value::String(stdout.trim().to_string()),
545        };
546        return Ok(value);
547    }
548
549    // Captures present → return a structured envelope so the sandbox CLI can
550    // distinguish "stdout text" from "files the agent should write to disk".
551    let outputs = collect_capture_results(&captures).await?;
552    Ok(serde_json::json!({
553        "stdout": stdout.trim().to_string(),
554        "outputs": outputs,
555    }))
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use std::fs;
562
563    #[test]
564    fn test_materialize_credential_file_dev_mode() {
565        let tmp = tempfile::tempdir().unwrap();
566        let cf = materialize_credential_file("test_key", "secret123", false, tmp.path()).unwrap();
567        assert_eq!(cf.path, tmp.path().join(".creds/test_key"));
568        let content = fs::read_to_string(&cf.path).unwrap();
569        assert_eq!(content, "secret123");
570
571        // Check permissions (unix)
572        #[cfg(unix)]
573        {
574            use std::os::unix::fs::PermissionsExt;
575            let mode = fs::metadata(&cf.path).unwrap().permissions().mode() & 0o777;
576            assert_eq!(mode, 0o600);
577        }
578    }
579
580    #[test]
581    fn test_materialize_credential_file_prod_mode_unique() {
582        let tmp = tempfile::tempdir().unwrap();
583        let cf1 = materialize_credential_file("key", "val1", true, tmp.path()).unwrap();
584        let cf2 = materialize_credential_file("key", "val2", true, tmp.path()).unwrap();
585        // Prod mode paths should differ (random suffix)
586        assert_ne!(cf1.path, cf2.path);
587    }
588
589    #[test]
590    fn test_credential_file_wipe_on_drop() {
591        let tmp = tempfile::tempdir().unwrap();
592        let path;
593        {
594            let cf = materialize_credential_file("wipe_me", "sensitive", true, tmp.path()).unwrap();
595            path = cf.path.clone();
596            assert!(path.exists());
597        }
598        // After drop, file should be deleted
599        assert!(!path.exists());
600    }
601}