1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4
5use thiserror::Error;
6
7use crate::core::auth_generator::{self, AuthCache, GenContext};
8use crate::core::keyring::Keyring;
9use crate::core::manifest::Provider;
10
11#[derive(Error, Debug)]
16pub enum CliError {
17 #[error("CLI config error: {0}")]
18 Config(String),
19 #[error("Missing keyring key: {0}")]
20 MissingKey(String),
21 #[error("Failed to spawn CLI process: {0}")]
22 Spawn(String),
23 #[error("CLI timed out after {0}s")]
24 Timeout(u64),
25 #[error("CLI exited with code {code}: {stderr}")]
26 NonZeroExit { code: i32, stderr: String },
27 #[error("IO error: {0}")]
28 Io(#[from] std::io::Error),
29 #[error("Credential file error: {0}")]
30 CredentialFile(String),
31 #[error("Captured output '{path}' exceeds ATI_CLI_MAX_OUTPUT_BYTES ({limit} bytes)")]
32 OutputTooLarge { path: String, limit: u64 },
33 #[error("Captured output '{path}' was not produced by the CLI")]
34 OutputMissing { path: String },
35}
36
37pub struct CredentialFile {
42 pub path: PathBuf,
43 wipe_on_drop: bool,
44}
45
46impl Drop for CredentialFile {
47 fn drop(&mut self) {
48 if self.wipe_on_drop {
49 if let Ok(meta) = std::fs::metadata(&self.path) {
51 let len = meta.len() as usize;
52 if len > 0 {
53 if let Ok(file) = std::fs::OpenOptions::new().write(true).open(&self.path) {
54 use std::io::Write;
55 let zeros = vec![0u8; len];
56 let _ = (&file).write_all(&zeros);
57 let _ = file.sync_all();
58 }
59 }
60 }
61 let _ = std::fs::remove_file(&self.path);
62 }
63 }
64}
65
66pub fn materialize_credential_file(
76 key_name: &str,
77 content: &str,
78 wipe_on_drop: bool,
79 ati_dir: &Path,
80) -> Result<CredentialFile, CliError> {
81 use std::os::unix::fs::OpenOptionsExt;
82
83 let creds_dir = ati_dir.join(".creds");
84 std::fs::create_dir_all(&creds_dir).map_err(|e| {
85 CliError::CredentialFile(format!("failed to create {}: {e}", creds_dir.display()))
86 })?;
87
88 let path = if wipe_on_drop {
89 let suffix: u32 = rand::random();
90 creds_dir.join(format!("{key_name}_{suffix}"))
91 } else {
92 creds_dir.join(key_name)
93 };
94
95 let mut file = std::fs::OpenOptions::new()
96 .write(true)
97 .create(true)
98 .truncate(true)
99 .mode(0o600)
100 .open(&path)
101 .map_err(|e| {
102 CliError::CredentialFile(format!("failed to write {}: {e}", path.display()))
103 })?;
104
105 {
106 use std::io::Write;
107 file.write_all(content.as_bytes()).map_err(|e| {
108 CliError::CredentialFile(format!("failed to write {}: {e}", path.display()))
109 })?;
110 file.sync_all().map_err(|e| {
111 CliError::CredentialFile(format!("failed to sync {}: {e}", path.display()))
112 })?;
113 }
114
115 Ok(CredentialFile { path, wipe_on_drop })
116}
117
118fn resolve_env_value(value: &str, keyring: &Keyring) -> Result<String, CliError> {
125 let mut result = value.to_string();
126 while let Some(start) = result.find("${") {
127 let rest = &result[start + 2..];
128 if let Some(end) = rest.find('}') {
129 let key_name = &rest[..end];
130 let replacement = keyring
131 .get(key_name)
132 .ok_or_else(|| CliError::MissingKey(key_name.to_string()))?;
133 result = format!("{}{}{}", &result[..start], replacement, &rest[end + 1..]);
134 } else {
135 break; }
137 }
138 Ok(result)
139}
140
141pub fn resolve_cli_env(
151 env_map: &HashMap<String, String>,
152 keyring: &Keyring,
153 wipe_on_drop: bool,
154 ati_dir: &Path,
155) -> Result<(HashMap<String, String>, Vec<CredentialFile>), CliError> {
156 let mut resolved = HashMap::with_capacity(env_map.len());
157 let mut cred_files: Vec<CredentialFile> = Vec::new();
158
159 for (key, value) in env_map {
160 if let Some(key_ref) = value.strip_prefix("@{").and_then(|s| s.strip_suffix('}')) {
161 let content = keyring
163 .get(key_ref)
164 .ok_or_else(|| CliError::MissingKey(key_ref.to_string()))?;
165 let cf = materialize_credential_file(key_ref, content, wipe_on_drop, ati_dir)?;
166 resolved.insert(key.clone(), cf.path.to_string_lossy().into_owned());
167 cred_files.push(cf);
168 } else if value.contains("${") {
169 let val = resolve_env_value(value, keyring)?;
171 resolved.insert(key.clone(), val);
172 } else {
173 resolved.insert(key.clone(), value.clone());
175 }
176 }
177
178 Ok((resolved, cred_files))
179}
180
181pub const DEFAULT_CLI_MAX_OUTPUT_BYTES: u64 = 500 * 1024 * 1024;
187
188fn cli_max_output_bytes() -> u64 {
189 std::env::var("ATI_CLI_MAX_OUTPUT_BYTES")
190 .ok()
191 .and_then(|s| s.parse::<u64>().ok())
192 .filter(|n| *n > 0)
193 .unwrap_or(DEFAULT_CLI_MAX_OUTPUT_BYTES)
194}
195
196#[derive(Debug, Clone)]
199pub struct CapturedOutput {
200 pub original_path: String,
202 pub temp_path: PathBuf,
204}
205
206pub fn apply_output_captures(
218 provider: &Provider,
219 raw_args: &[String],
220) -> Result<(Vec<String>, Vec<CapturedOutput>), CliError> {
221 let mut rewritten: Vec<String> = raw_args.to_vec();
222 let mut captures: Vec<CapturedOutput> = Vec::new();
223 let mut consumed: std::collections::HashSet<usize> = std::collections::HashSet::new();
230
231 if !provider.cli_output_args.is_empty() {
233 let mut i = 0;
234 while i < rewritten.len() {
235 let arg = rewritten[i].clone();
236 if let Some(eq_idx) = arg.find('=') {
238 let (flag, value) = arg.split_at(eq_idx);
239 if provider
240 .cli_output_args
241 .iter()
242 .any(|f| f.eq_ignore_ascii_case(flag))
243 {
244 let original = value[1..].to_string();
245 let temp = make_temp_for(&original)?;
246 rewritten[i] = format!("{}={}", flag, temp.display());
247 captures.push(CapturedOutput {
248 original_path: original,
249 temp_path: temp,
250 });
251 consumed.insert(i);
252 i += 1;
253 continue;
254 }
255 }
256 if provider
258 .cli_output_args
259 .iter()
260 .any(|f| f.eq_ignore_ascii_case(&arg))
261 && i + 1 < rewritten.len()
262 {
263 let original = rewritten[i + 1].clone();
264 let temp = make_temp_for(&original)?;
265 rewritten[i + 1] = temp.to_string_lossy().into_owned();
266 captures.push(CapturedOutput {
267 original_path: original,
268 temp_path: temp,
269 });
270 consumed.insert(i);
271 consumed.insert(i + 1);
272 i += 2;
273 continue;
274 }
275 i += 1;
276 }
277 }
278
279 if !provider.cli_output_positional.is_empty() {
282 let positionals: Vec<(usize, String)> = rewritten
285 .iter()
286 .enumerate()
287 .filter_map(|(idx, s)| {
288 if consumed.contains(&idx) || s.starts_with('-') {
289 None
290 } else {
291 Some((idx, s.clone()))
292 }
293 })
294 .collect();
295
296 let mut best: Option<(usize, usize)> = None; for (prefix, idx) in &provider.cli_output_positional {
300 let prefix_tokens: Vec<&str> = prefix.split_whitespace().collect();
301 if prefix_tokens.is_empty() {
302 continue;
303 }
304 if positionals.len() < prefix_tokens.len() + idx + 1 {
305 continue;
306 }
307 let prefix_matches = prefix_tokens
308 .iter()
309 .enumerate()
310 .all(|(i, tok)| positionals[i].1 == *tok);
311 if !prefix_matches {
312 continue;
313 }
314 let count = prefix_tokens.len();
315 if best.is_none_or(|(c, _)| count > c) {
316 best = Some((count, *idx));
317 }
318 }
319
320 if let Some((prefix_count, output_idx)) = best {
321 let target_positional_idx = prefix_count + output_idx;
322 if let Some((real_idx, original)) = positionals.get(target_positional_idx).cloned() {
323 let temp = make_temp_for(&original)?;
324 rewritten[real_idx] = temp.to_string_lossy().into_owned();
325 captures.push(CapturedOutput {
326 original_path: original,
327 temp_path: temp,
328 });
329 }
330 }
331 }
332
333 Ok((rewritten, captures))
334}
335
336fn make_temp_for(original_path: &str) -> Result<PathBuf, CliError> {
340 let ext = Path::new(original_path)
341 .extension()
342 .and_then(|e| e.to_str())
343 .unwrap_or("");
344 let suffix: u64 = rand::random();
345 let pid = std::process::id();
346 let name = if ext.is_empty() {
347 format!(".ati-cli-out-{pid}-{suffix:016x}")
348 } else {
349 format!(".ati-cli-out-{pid}-{suffix:016x}.{ext}")
350 };
351 Ok(std::env::temp_dir().join(name))
352}
353
354async fn collect_capture_results(
359 captures: &[CapturedOutput],
360) -> Result<HashMap<String, serde_json::Value>, CliError> {
361 use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
362 let max = cli_max_output_bytes();
363 let mut out = HashMap::with_capacity(captures.len());
364
365 for cap in captures {
366 let bytes_result = tokio::fs::read(&cap.temp_path).await;
367 let _ = tokio::fs::remove_file(&cap.temp_path).await;
369
370 let bytes = match bytes_result {
371 Ok(b) => b,
372 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
373 return Err(CliError::OutputMissing {
374 path: cap.original_path.clone(),
375 });
376 }
377 Err(e) => return Err(CliError::Io(e)),
378 };
379
380 if (bytes.len() as u64) > max {
381 return Err(CliError::OutputTooLarge {
382 path: cap.original_path.clone(),
383 limit: max,
384 });
385 }
386
387 let entry = serde_json::json!({
388 "content_base64": B64.encode(&bytes),
389 "size_bytes": bytes.len(),
390 "content_type": guess_content_type(&cap.original_path),
391 });
392 out.insert(cap.original_path.clone(), entry);
393 }
394 Ok(out)
395}
396
397use crate::core::file_manager::guess_content_type;
398
399fn discard_captures(captures: &[CapturedOutput]) {
402 for cap in captures {
403 let _ = std::fs::remove_file(&cap.temp_path);
404 }
405}
406
407pub async fn execute(
418 provider: &Provider,
419 raw_args: &[String],
420 keyring: &Keyring,
421) -> Result<serde_json::Value, CliError> {
422 execute_with_gen(provider, raw_args, keyring, None, None).await
423}
424
425pub async fn execute_with_gen(
427 provider: &Provider,
428 raw_args: &[String],
429 keyring: &Keyring,
430 gen_ctx: Option<&GenContext>,
431 auth_cache: Option<&AuthCache>,
432) -> Result<serde_json::Value, CliError> {
433 let cli_command = provider
434 .cli_command
435 .as_deref()
436 .ok_or_else(|| CliError::Config("provider missing cli_command".into()))?;
437
438 let timeout_secs = provider.cli_timeout_secs.unwrap_or(120);
439
440 let ati_dir = std::env::var("ATI_DIR")
441 .map(PathBuf::from)
442 .unwrap_or_else(|_| {
443 std::env::var("HOME")
444 .map(PathBuf::from)
445 .unwrap_or_else(|_| PathBuf::from("/tmp"))
446 .join(".ati")
447 });
448
449 let wipe_on_drop = keyring.ephemeral;
450
451 let (resolved_env, cred_files) =
454 resolve_cli_env(&provider.cli_env, keyring, wipe_on_drop, &ati_dir)?;
455
456 let mut final_env: HashMap<String, String> = HashMap::new();
458 for var in &["PATH", "HOME", "TMPDIR", "LANG", "USER", "TERM"] {
459 if let Ok(val) = std::env::var(var) {
460 final_env.insert(var.to_string(), val);
461 }
462 }
463 final_env.extend(resolved_env);
465
466 if let Some(gen) = &provider.auth_generator {
468 let default_ctx = GenContext::default();
469 let ctx = gen_ctx.unwrap_or(&default_ctx);
470 let default_cache = AuthCache::new();
471 let cache = auth_cache.unwrap_or(&default_cache);
472 match auth_generator::generate(provider, gen, ctx, keyring, cache).await {
473 Ok(cred) => {
474 final_env.insert("ATI_AUTH_TOKEN".to_string(), cred.value);
475 for (k, v) in &cred.extra_env {
476 final_env.insert(k.clone(), v.clone());
477 }
478 }
479 Err(e) => {
480 return Err(CliError::Config(format!("auth_generator failed: {e}")));
481 }
482 }
483 }
484
485 let (rewritten_args, captures) = apply_output_captures(provider, raw_args)?;
489
490 let command = cli_command.to_string();
492 let default_args = provider.cli_default_args.clone();
493 let extra_args = rewritten_args;
494 let env_snapshot = final_env;
495 let timeout_dur = std::time::Duration::from_secs(timeout_secs);
496
497 let child = tokio::process::Command::new(&command)
501 .args(&default_args)
502 .args(&extra_args)
503 .env_clear()
504 .envs(&env_snapshot)
505 .stdout(Stdio::piped())
506 .stderr(Stdio::piped())
507 .kill_on_drop(true)
508 .spawn()
509 .map_err(|e| {
510 discard_captures(&captures);
511 CliError::Spawn(format!("{command}: {e}"))
512 })?;
513
514 let output = match tokio::time::timeout(timeout_dur, child.wait_with_output()).await {
516 Ok(Ok(o)) => o,
517 Ok(Err(e)) => {
518 discard_captures(&captures);
519 return Err(CliError::Io(e));
520 }
521 Err(_) => {
522 discard_captures(&captures);
523 return Err(CliError::Timeout(timeout_secs));
524 }
525 };
526
527 drop(cred_files);
529
530 if !output.status.success() {
531 discard_captures(&captures);
532 let code = output.status.code().unwrap_or(-1);
533 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
534 return Err(CliError::NonZeroExit { code, stderr });
535 }
536
537 let stdout = String::from_utf8_lossy(&output.stdout);
538
539 if captures.is_empty() {
542 let value = match serde_json::from_str::<serde_json::Value>(stdout.trim()) {
543 Ok(v) => v,
544 Err(_) => serde_json::Value::String(stdout.trim().to_string()),
545 };
546 return Ok(value);
547 }
548
549 let outputs = collect_capture_results(&captures).await?;
552 Ok(serde_json::json!({
553 "stdout": stdout.trim().to_string(),
554 "outputs": outputs,
555 }))
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use std::fs;
562
563 #[test]
564 fn test_materialize_credential_file_dev_mode() {
565 let tmp = tempfile::tempdir().unwrap();
566 let cf = materialize_credential_file("test_key", "secret123", false, tmp.path()).unwrap();
567 assert_eq!(cf.path, tmp.path().join(".creds/test_key"));
568 let content = fs::read_to_string(&cf.path).unwrap();
569 assert_eq!(content, "secret123");
570
571 #[cfg(unix)]
573 {
574 use std::os::unix::fs::PermissionsExt;
575 let mode = fs::metadata(&cf.path).unwrap().permissions().mode() & 0o777;
576 assert_eq!(mode, 0o600);
577 }
578 }
579
580 #[test]
581 fn test_materialize_credential_file_prod_mode_unique() {
582 let tmp = tempfile::tempdir().unwrap();
583 let cf1 = materialize_credential_file("key", "val1", true, tmp.path()).unwrap();
584 let cf2 = materialize_credential_file("key", "val2", true, tmp.path()).unwrap();
585 assert_ne!(cf1.path, cf2.path);
587 }
588
589 #[test]
590 fn test_credential_file_wipe_on_drop() {
591 let tmp = tempfile::tempdir().unwrap();
592 let path;
593 {
594 let cf = materialize_credential_file("wipe_me", "sensitive", true, tmp.path()).unwrap();
595 path = cf.path.clone();
596 assert!(path.exists());
597 }
598 assert!(!path.exists());
600 }
601}