use std::io::Write as _;
use std::path::{Path, PathBuf};
use anyhow::anyhow;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{Duration, timeout};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
Claude,
Gemini,
Codex,
}
impl Backend {
pub fn cli_name(self) -> &'static str {
match self {
Backend::Claude => "claude",
Backend::Gemini => "gemini",
Backend::Codex => "codex",
}
}
pub fn label(self) -> &'static str {
match self {
Backend::Claude => "Claude",
Backend::Gemini => "Gemini",
Backend::Codex => "Codex",
}
}
pub fn is_available(self) -> bool {
resolve_cli(self.cli_name()).is_some()
}
}
#[derive(Debug, Clone)]
pub struct ResolvedCli {
pub program: PathBuf,
pub prefix_args: Vec<String>,
}
pub fn resolve_cli(name: &str) -> Option<ResolvedCli> {
if let Ok(p) = which::which(name) {
return Some(wrap_if_powershell(p));
}
#[cfg(windows)]
{
if let Ok(p) = which::which(format!("{name}.ps1")) {
return Some(wrap_if_powershell(p));
}
}
None
}
fn wrap_if_powershell(path: PathBuf) -> ResolvedCli {
let is_ps1 = path
.extension()
.and_then(|e| e.to_str())
.map(|e| e.eq_ignore_ascii_case("ps1"))
.unwrap_or(false);
if !is_ps1 {
return ResolvedCli {
program: path,
prefix_args: Vec::new(),
};
}
let ps_exe = if which::which("pwsh").is_ok() {
"pwsh.exe"
} else {
"powershell.exe"
};
ResolvedCli {
program: PathBuf::from(ps_exe),
prefix_args: vec![
"-NoProfile".into(),
"-ExecutionPolicy".into(),
"Bypass".into(),
"-File".into(),
path.to_string_lossy().into_owned(),
],
}
}
pub fn ensure_cli_installed(backend: Backend) -> Result<ResolvedCli> {
if let Some(r) = resolve_cli(backend.cli_name()) {
return Ok(r);
}
let cli = backend.cli_name();
let hint = match backend {
Backend::Claude => "https://docs.claude.com/claude-code",
Backend::Gemini => "https://ai.google.dev/gemini-api/docs/cli",
Backend::Codex => "https://github.com/openai/codex",
};
Err(Error::Other(anyhow!(
"AI backend `{cli}` is not on PATH. Install it ({hint}) or pass a different `--ai` flag."
)))
}
const DEFAULT_TIMEOUT_SECS: u64 = 300;
const TIMEOUT_ENV: &str = "KATA_AI_TIMEOUT_SECS";
fn resolve_timeout() -> Duration {
let secs = std::env::var(TIMEOUT_ENV)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_TIMEOUT_SECS);
Duration::from_secs(secs)
}
pub async fn invoke_chat(backend: Backend, prompt_text: &str) -> Result<String> {
let resolved = ensure_cli_installed(backend)?;
tracing::debug!(
backend = backend.cli_name(),
bytes = prompt_text.len(),
lines = prompt_text.lines().count(),
"invoke_chat: piping prompt to agent stdin",
);
let mut cmd = Command::new(&resolved.program);
cmd.args(&resolved.prefix_args);
match backend {
Backend::Claude | Backend::Gemini => {
cmd.arg("-p").arg("-");
}
Backend::Codex => {
cmd.arg("exec").arg("-");
}
}
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut child = cmd.spawn().map_err(|e| {
Error::Other(anyhow::Error::from(e).context(format!(
"failed to spawn AI CLI `{}` (is it installed and on PATH?)",
backend.cli_name()
)))
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(prompt_text.as_bytes()).await.map_err(|e| {
Error::Other(anyhow::Error::from(e).context("writing prompt to AI CLI stdin"))
})?;
}
let to = resolve_timeout();
let waited: std::io::Result<std::process::Output> = match timeout(to, child.wait_with_output())
.await
{
Ok(r) => r,
Err(_) => {
return Err(Error::Other(anyhow!(
"AI CLI `{}` timed out after {}s. Set {TIMEOUT_ENV}=600 (or higher) if your network is slow.",
backend.cli_name(),
to.as_secs(),
)));
}
};
let output = waited.map_err(|e| {
Error::Other(anyhow::Error::from(e).context(format!(
"AI CLI `{}` failed to produce output",
backend.cli_name()
)))
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::Other(anyhow!(
"AI CLI `{}` exited with status {}: {}",
backend.cli_name(),
output.status,
stderr.trim()
)));
}
Ok(String::from_utf8_lossy(&output.stdout).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FirstMessageStrategy {
Positional,
InteractiveFlag,
}
pub(crate) fn first_message_strategy(backend: Backend) -> FirstMessageStrategy {
match backend {
Backend::Claude | Backend::Codex => FirstMessageStrategy::Positional,
Backend::Gemini => FirstMessageStrategy::InteractiveFlag,
}
}
pub async fn run_handoff(backend: Backend, prompt_text: &str, dst_hint: &Path) -> Result<()> {
let resolved = ensure_cli_installed(backend)?;
let mut tmp = tempfile::Builder::new()
.prefix("kata-ai-prompt-")
.suffix(".md")
.tempfile()
.map_err(|e| {
Error::Other(anyhow::Error::from(e).context("creating tmp prompt file for handoff"))
})?;
tmp.write_all(prompt_text.as_bytes()).map_err(|e| {
Error::Other(anyhow::Error::from(e).context("writing tmp prompt file for handoff"))
})?;
let tmp_path: PathBuf = tmp.into_temp_path().keep().map_err(|e| {
Error::Other(anyhow::Error::from(e).context("persisting tmp prompt file for handoff"))
})?;
let path_str = tmp_path.to_string_lossy().into_owned();
let dst_str = dst_hint.to_string_lossy().into_owned();
let first_message = format!(
"Read the file at {path_str} for kata context about {dst_str}. \
Summarize what it contains in 1-2 sentences. \
Do NOT apply, edit, or write any files yet. \
Wait for my next instruction before running Edit or Write tools."
);
eprintln!();
eprintln!("kata: handoff prompt saved to {path_str}");
let strategy = first_message_strategy(backend);
eprintln!(
"kata: starting `{}` interactively (kata will not re-import the result).",
backend.cli_name()
);
eprintln!();
let label = backend.cli_name().to_string();
let program = resolved.program.clone();
let prefix_args = resolved.prefix_args.clone();
tokio::task::spawn_blocking(move || -> Result<()> {
let mut cmd = std::process::Command::new(&program);
cmd.args(&prefix_args);
match strategy {
FirstMessageStrategy::Positional => {
cmd.arg(&first_message);
}
FirstMessageStrategy::InteractiveFlag => {
cmd.arg("-i").arg(&first_message);
}
}
let status = cmd
.stdin(std::process::Stdio::inherit())
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit())
.status()
.map_err(|e| {
Error::Other(
anyhow::Error::from(e).context(format!("failed to spawn AI CLI `{label}`")),
)
})?;
let _ = status;
Ok(())
})
.await
.map_err(|e| Error::Other(anyhow::Error::from(e).context("joining handoff task")))??;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cli_name_is_stable_per_backend() {
assert_eq!(Backend::Claude.cli_name(), "claude");
assert_eq!(Backend::Gemini.cli_name(), "gemini");
assert_eq!(Backend::Codex.cli_name(), "codex");
}
#[test]
fn label_is_stable_per_backend() {
assert_eq!(Backend::Claude.label(), "Claude");
assert_eq!(Backend::Gemini.label(), "Gemini");
assert_eq!(Backend::Codex.label(), "Codex");
}
#[test]
fn first_message_strategy_per_backend() {
assert_eq!(
first_message_strategy(Backend::Claude),
FirstMessageStrategy::Positional
);
assert_eq!(
first_message_strategy(Backend::Codex),
FirstMessageStrategy::Positional
);
assert_eq!(
first_message_strategy(Backend::Gemini),
FirstMessageStrategy::InteractiveFlag
);
}
#[test]
fn wrap_if_powershell_wraps_ps1_path() {
let p = PathBuf::from("C:/foo/gemini.ps1");
let r = wrap_if_powershell(p);
let prog = r.program.to_string_lossy().to_ascii_lowercase();
assert!(
prog == "pwsh.exe" || prog == "powershell.exe",
"expected pwsh.exe or powershell.exe, got {prog}",
);
assert!(r.prefix_args.iter().any(|a| a == "-NoProfile"));
assert!(r.prefix_args.iter().any(|a| a == "Bypass"));
assert!(r.prefix_args.iter().any(|a| a == "-File"));
assert!(r.prefix_args.iter().any(|a| a.contains("gemini.ps1")));
}
#[test]
fn wrap_if_powershell_passes_exe_through_unchanged() {
let p = PathBuf::from("C:/foo/claude.exe");
let r = wrap_if_powershell(p.clone());
assert_eq!(r.program, p);
assert!(r.prefix_args.is_empty());
}
#[test]
fn ensure_cli_installed_errors_when_missing() {
if Backend::Claude.is_available()
|| Backend::Codex.is_available()
|| Backend::Gemini.is_available()
{
return;
}
let err = match ensure_cli_installed(Backend::Claude) {
Err(e) => e,
Ok(_) => panic!("expected error when no AI CLI is installed"),
};
let msg = format!("{err}");
assert!(msg.contains("not on PATH"), "unexpected error: {msg}");
assert!(msg.contains("claude"), "missing backend name: {msg}");
}
#[test]
fn timeout_env_override_is_parsed() {
let parsed: Option<u64> = "120".parse().ok();
assert_eq!(parsed, Some(120));
let bad: Option<u64> = "not-a-number".parse().ok();
assert_eq!(bad, None);
}
}