use std::path::PathBuf;
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::state::TokenUsage;
use super::{
subprocess::{self, SubprocessOutcome},
Agent, AgentEvent, AgentOutcome, AgentRequest, StopReason,
};
const DEFAULT_BINARY: &str = "gemini";
const ERROR_TAIL_LINES: usize = 8;
#[derive(Debug, Clone)]
pub struct GeminiAgent {
binary: PathBuf,
extra_args: Vec<String>,
model_override: Option<String>,
}
impl GeminiAgent {
pub fn new() -> Self {
Self {
binary: PathBuf::from(DEFAULT_BINARY),
extra_args: Vec::new(),
model_override: None,
}
}
pub fn with_binary(binary: impl Into<PathBuf>) -> Self {
Self {
binary: binary.into(),
extra_args: Vec::new(),
model_override: None,
}
}
pub fn with_extra_args(mut self, args: Vec<String>) -> Self {
self.extra_args = args;
self
}
pub fn with_model_override(mut self, model: impl Into<String>) -> Self {
self.model_override = Some(model.into());
self
}
pub fn binary(&self) -> &PathBuf {
&self.binary
}
}
impl Default for GeminiAgent {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for GeminiAgent {
fn name(&self) -> &str {
"gemini"
}
async fn run(
&self,
req: AgentRequest,
events: mpsc::Sender<AgentEvent>,
cancel: CancellationToken,
) -> Result<AgentOutcome> {
let log_path = req.log_path.clone();
let cmd = self.build_command(&req);
let (raw_tx, mut raw_rx) = mpsc::channel::<AgentEvent>(64);
let outbound = events.clone();
let forwarder = tokio::spawn(async move {
let mut stdout_buf = String::new();
let mut stderr_tail: Vec<String> = Vec::new();
while let Some(ev) = raw_rx.recv().await {
match ev {
AgentEvent::Stdout(line) => {
if !stdout_buf.is_empty() {
stdout_buf.push('\n');
}
stdout_buf.push_str(&line);
}
AgentEvent::Stderr(line) => {
push_tail(&mut stderr_tail, line.clone(), ERROR_TAIL_LINES);
let _ = outbound.send(AgentEvent::Stderr(line)).await;
}
other => {
let _ = outbound.send(other).await;
}
}
}
let parsed = parse_gemini_output(&stdout_buf);
let mut tokens = TokenUsage::default();
let mut error_message: Option<String> = None;
match parsed {
ParsedOutput::Success {
response,
tools,
token_usage,
} => {
if let Some(text) = response {
if !text.is_empty() {
let _ = outbound.send(AgentEvent::Stdout(text)).await;
}
}
for tool in tools {
let _ = outbound.send(AgentEvent::ToolUse(tool)).await;
}
tokens = token_usage;
}
ParsedOutput::Error { message } => {
error_message = Some(message);
}
ParsedOutput::Unparseable => {
if !stdout_buf.is_empty() {
let _ = outbound.send(AgentEvent::Stdout(stdout_buf.clone())).await;
}
}
}
if tokens.input > 0 || tokens.output > 0 {
let _ = outbound.send(AgentEvent::TokenDelta(tokens.clone())).await;
}
ForwarderResult {
tokens,
error_message,
stderr_tail,
}
});
let sub_outcome: SubprocessOutcome =
subprocess::run_logged(cmd, &log_path, raw_tx, cancel, req.timeout).await?;
let ForwarderResult {
mut tokens,
error_message,
stderr_tail,
} = forwarder.await.unwrap_or(ForwarderResult {
tokens: TokenUsage::default(),
error_message: None,
stderr_tail: Vec::new(),
});
if tokens.input > 0 || tokens.output > 0 {
tokens
.by_role
.entry(req.role.as_str().to_string())
.or_default();
let entry = tokens
.by_role
.get_mut(req.role.as_str())
.expect("just inserted");
entry.input = tokens.input;
entry.output = tokens.output;
}
let stop_reason = match sub_outcome.stop_reason {
StopReason::Completed => {
if sub_outcome.exit_code == 0 && error_message.is_none() {
StopReason::Completed
} else {
StopReason::Error(format_error_message(
sub_outcome.exit_code,
error_message.as_deref(),
&stderr_tail,
))
}
}
other => other,
};
Ok(AgentOutcome {
exit_code: sub_outcome.exit_code,
stop_reason,
tokens,
log_path,
})
}
}
impl GeminiAgent {
fn build_command(&self, req: &AgentRequest) -> Command {
let mut cmd = Command::new(&self.binary);
cmd.current_dir(&req.workdir);
if !req.env.is_empty() {
cmd.envs(req.env.iter());
}
cmd.args(["--yolo", "--output-format", "json"]);
let model = self.model_override.as_deref().unwrap_or(&req.model);
cmd.args(["--model", model]);
for arg in &self.extra_args {
cmd.arg(arg);
}
cmd.arg("--prompt").arg(build_prompt_payload(req));
cmd
}
}
fn build_prompt_payload(req: &AgentRequest) -> String {
let mut out = String::new();
if !req.system_prompt.is_empty() {
out.push_str(&req.system_prompt);
out.push_str("\n\n");
}
out.push_str(&req.user_prompt);
out
}
struct ForwarderResult {
tokens: TokenUsage,
error_message: Option<String>,
stderr_tail: Vec<String>,
}
enum ParsedOutput {
Success {
response: Option<String>,
tools: Vec<String>,
token_usage: TokenUsage,
},
Error {
message: String,
},
Unparseable,
}
fn parse_gemini_output(buf: &str) -> ParsedOutput {
let trimmed = buf.trim();
if trimmed.is_empty() {
return ParsedOutput::Unparseable;
}
let value: Value = match serde_json::from_str(trimmed) {
Ok(v) => v,
Err(_) => return ParsedOutput::Unparseable,
};
if let Some(err_obj) = value.get("error") {
let message = err_obj
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.unwrap_or_else(|| "gemini reported an error".to_string());
return ParsedOutput::Error { message };
}
let response = value
.get("response")
.and_then(Value::as_str)
.map(str::to_string);
let tools = extract_tool_calls(&value);
let token_usage = extract_token_usage(&value);
ParsedOutput::Success {
response,
tools,
token_usage,
}
}
fn extract_tool_calls(value: &Value) -> Vec<String> {
let by_name = match value
.get("stats")
.and_then(|s| s.get("tools"))
.and_then(|t| t.get("byName"))
.and_then(Value::as_object)
{
Some(m) => m,
None => return Vec::new(),
};
let mut out = Vec::with_capacity(by_name.len());
for (name, entry) in by_name {
let count = entry
.get("count")
.and_then(Value::as_u64)
.unwrap_or(1)
.max(1);
for _ in 0..count {
out.push(name.clone());
}
}
out
}
fn extract_token_usage(value: &Value) -> TokenUsage {
let mut usage = TokenUsage::default();
let models = match value
.get("stats")
.and_then(|s| s.get("models"))
.and_then(Value::as_object)
{
Some(m) => m,
None => return usage,
};
for (_, entry) in models {
let tokens = match entry.get("tokens") {
Some(t) => t,
None => continue,
};
usage.input += read_u64(tokens, "prompt") + read_u64(tokens, "cached");
usage.output += read_u64(tokens, "candidates") + read_u64(tokens, "thoughts");
}
usage
}
fn read_u64(v: &Value, key: &str) -> u64 {
v.get(key).and_then(Value::as_u64).unwrap_or(0)
}
fn push_tail(buf: &mut Vec<String>, line: String, max: usize) {
if buf.len() == max {
buf.remove(0);
}
buf.push(line);
}
fn format_error_message(exit_code: i32, parsed: Option<&str>, stderr_tail: &[String]) -> String {
let label = exit_code_label(exit_code);
let mut out = match (parsed, label) {
(Some(m), Some(l)) if !m.is_empty() => {
format!("gemini: {} ({}, exit {})", m, l, exit_code)
}
(Some(m), None) if !m.is_empty() => format!("gemini: {} (exit {})", m, exit_code),
(_, Some(l)) => format!("gemini exited with code {} ({})", exit_code, l),
(_, None) => format!("gemini exited with code {}", exit_code),
};
if !stderr_tail.is_empty() {
out.push_str("\nstderr tail:\n");
for line in stderr_tail {
out.push_str(line);
out.push('\n');
}
}
out
}
fn exit_code_label(code: i32) -> Option<&'static str> {
match code {
41 => Some("usage error"),
42 => Some("authentication error"),
43 => Some("quota exceeded"),
44 => Some("network error"),
53 => Some("tool error"),
_ => None,
}
}
#[cfg(all(test, unix))]
mod tests {
use super::*;
use crate::agent::Role;
use std::path::PathBuf;
use std::time::Duration;
fn fixture_path(name: &str) -> PathBuf {
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest
.join("tests")
.join("fixtures")
.join("gemini")
.join(name)
}
fn req_with_log(log_path: PathBuf, timeout: Duration) -> AgentRequest {
AgentRequest {
role: Role::Implementer,
model: "gemini-2.5-pro".into(),
system_prompt: "be brief".into(),
user_prompt: "say hi".into(),
workdir: std::env::temp_dir(),
log_path,
timeout,
env: std::collections::HashMap::new(),
}
}
async fn drain<T>(mut rx: mpsc::Receiver<T>) -> Vec<T> {
let mut out = Vec::new();
while let Some(v) = rx.recv().await {
out.push(v);
}
out
}
#[tokio::test]
async fn parses_response_tool_calls_and_token_stats() {
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::with_binary(fixture_path("fake-gemini-success.sh"));
let (tx, rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let outcome = agent
.run(
req_with_log(log.clone(), Duration::from_secs(5)),
tx,
cancel,
)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Completed);
assert_eq!(outcome.exit_code, 0);
let evs = drain(rx).await;
let stdouts: Vec<&str> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::Stdout(s) => Some(s.as_str()),
_ => None,
})
.collect();
let tool_uses: Vec<&str> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::ToolUse(s) => Some(s.as_str()),
_ => None,
})
.collect();
let token_deltas: Vec<&TokenUsage> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::TokenDelta(t) => Some(t),
_ => None,
})
.collect();
assert!(
stdouts.iter().any(|s| s.contains("Hello from Gemini")),
"missing assistant text: {stdouts:?}"
);
assert_eq!(tool_uses.len(), 3);
assert!(tool_uses.contains(&"list_directory"));
assert_eq!(
tool_uses.iter().filter(|t| **t == "edit_file").count(),
2,
"expected two edit_file tool-use events, got {tool_uses:?}"
);
assert_eq!(token_deltas.len(), 1);
let total = token_deltas[0];
assert_eq!(total.input, 1200);
assert_eq!(total.output, 800);
assert_eq!(outcome.tokens.input, 1200);
assert_eq!(outcome.tokens.output, 800);
let role_usage = outcome
.tokens
.by_role
.get("implementer")
.expect("implementer role usage");
assert_eq!(role_usage.input, 1200);
assert_eq!(role_usage.output, 800);
let log_text = std::fs::read_to_string(&log).unwrap();
assert!(log_text.contains("\"response\""), "{log_text}");
assert!(log_text.contains("edit_file"), "{log_text}");
}
#[tokio::test]
async fn partial_output_with_no_stats_still_completes() {
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::with_binary(fixture_path("fake-gemini-partial.sh"));
let (tx, rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let outcome = agent
.run(req_with_log(log, Duration::from_secs(5)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Completed);
let evs = drain(rx).await;
let stdouts: Vec<&str> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::Stdout(s) => Some(s.as_str()),
_ => None,
})
.collect();
let tool_uses: Vec<&str> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::ToolUse(s) => Some(s.as_str()),
_ => None,
})
.collect();
let token_deltas: Vec<&TokenUsage> = evs
.iter()
.filter_map(|e| match e {
AgentEvent::TokenDelta(t) => Some(t),
_ => None,
})
.collect();
assert!(
stdouts.iter().any(|s| s.contains("Nothing to change")),
"expected response text, got {stdouts:?}"
);
assert!(
tool_uses.is_empty(),
"partial run should produce no tool-use events, got {tool_uses:?}"
);
assert!(
token_deltas.is_empty(),
"partial run should produce no token deltas, got {token_deltas:?}"
);
assert_eq!(outcome.tokens.input, 0);
assert_eq!(outcome.tokens.output, 0);
}
#[tokio::test]
async fn error_event_maps_to_error_stop_reason_with_exit_label() {
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::with_binary(fixture_path("fake-gemini-error.sh"));
let (tx, _rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let outcome = agent
.run(req_with_log(log, Duration::from_secs(5)), tx, cancel)
.await
.unwrap();
match outcome.stop_reason {
StopReason::Error(msg) => {
assert!(
msg.contains("GEMINI_API_KEY"),
"expected embedded message, got: {msg}"
);
assert!(
msg.contains("authentication error"),
"expected exit-code label, got: {msg}"
);
}
other => panic!("expected Error, got {other:?}"),
}
assert_eq!(outcome.exit_code, 42);
}
#[tokio::test]
async fn nonzero_exit_without_json_falls_back_to_stderr_tail() {
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::with_binary(fixture_path("fake-gemini-crash.sh"));
let (tx, _rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let outcome = agent
.run(req_with_log(log, Duration::from_secs(5)), tx, cancel)
.await
.unwrap();
match outcome.stop_reason {
StopReason::Error(msg) => {
assert!(msg.contains("exit"), "{msg}");
assert!(
msg.contains("settings file"),
"expected stderr tail in error message, got: {msg}"
);
}
other => panic!("expected Error, got {other:?}"),
}
assert_eq!(outcome.exit_code, 1);
}
#[tokio::test]
async fn cancellation_propagates_to_child_process() {
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::with_binary(fixture_path("fake-gemini-hang.sh"));
let (tx, _rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let canceler = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(80)).await;
canceler.cancel();
});
let outcome = agent
.run(req_with_log(log, Duration::from_secs(30)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Cancelled);
assert_eq!(outcome.exit_code, -1);
}
#[tokio::test]
async fn build_command_includes_required_flags_and_workdir() {
let agent = GeminiAgent::with_binary("/usr/local/bin/gemini")
.with_extra_args(vec!["--include-directories".into(), "src".into()])
.with_model_override("gemini-2.5-flash");
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let req = AgentRequest {
role: Role::Auditor,
model: "ignored-because-override".into(),
system_prompt: "system body".into(),
user_prompt: "user body".into(),
workdir: dir.path().to_path_buf(),
log_path: log,
timeout: Duration::from_secs(1),
env: std::collections::HashMap::new(),
};
let cmd = agent.build_command(&req);
let std_cmd = cmd.as_std();
let args: Vec<String> = std_cmd
.get_args()
.map(|a| a.to_string_lossy().into_owned())
.collect();
assert!(args.iter().any(|a| a == "--yolo"));
assert!(args
.windows(2)
.any(|w| w[0] == "--output-format" && w[1] == "json"));
assert!(args
.windows(2)
.any(|w| w[0] == "--model" && w[1] == "gemini-2.5-flash"));
assert!(!args.iter().any(|a| a == "ignored-because-override"));
assert!(args
.windows(2)
.any(|w| w[0] == "--include-directories" && w[1] == "src"));
let prompt_idx = args
.iter()
.position(|a| a == "--prompt")
.expect("--prompt flag must be present");
let body = &args[prompt_idx + 1];
assert!(body.starts_with("system body\n\n"));
assert!(body.ends_with("user body"));
assert_eq!(std_cmd.get_program(), "/usr/local/bin/gemini");
assert_eq!(std_cmd.get_current_dir(), Some(dir.path()));
}
#[tokio::test]
async fn build_command_uses_request_model_when_no_override() {
let agent = GeminiAgent::with_binary("gemini");
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let req = AgentRequest {
role: Role::Implementer,
model: "gemini-2.5-pro".into(),
system_prompt: String::new(),
user_prompt: "u".into(),
workdir: dir.path().to_path_buf(),
log_path: log,
timeout: Duration::from_secs(1),
env: std::collections::HashMap::new(),
};
let cmd = agent.build_command(&req);
let args: Vec<String> = cmd
.as_std()
.get_args()
.map(|a| a.to_string_lossy().into_owned())
.collect();
assert!(args
.windows(2)
.any(|w| w[0] == "--model" && w[1] == "gemini-2.5-pro"));
}
#[test]
fn build_prompt_payload_concatenates_system_and_user_with_blank_line() {
let req = AgentRequest {
role: Role::Implementer,
model: "x".into(),
system_prompt: "you are a careful engineer".into(),
user_prompt: "implement phase 04".into(),
workdir: std::env::temp_dir(),
log_path: std::env::temp_dir().join("never.log"),
timeout: Duration::from_secs(1),
env: std::collections::HashMap::new(),
};
let payload = build_prompt_payload(&req);
assert!(payload.starts_with("you are a careful engineer\n\n"));
assert!(payload.contains("implement phase 04"));
}
#[test]
fn build_prompt_payload_omits_system_when_empty() {
let req = AgentRequest {
role: Role::Implementer,
model: "x".into(),
system_prompt: String::new(),
user_prompt: "just the user body".into(),
workdir: std::env::temp_dir(),
log_path: std::env::temp_dir().join("never.log"),
timeout: Duration::from_secs(1),
env: std::collections::HashMap::new(),
};
let payload = build_prompt_payload(&req);
assert_eq!(payload, "just the user body");
}
#[test]
fn parse_gemini_output_handles_success_shape() {
let buf = r#"{"response":"hi","stats":{"models":{"gemini-2.5-pro":{"tokens":{"prompt":10,"candidates":20,"cached":5,"thoughts":3}}},"tools":{"byName":{"a":{"count":1},"b":{"count":2}}}}}"#;
match parse_gemini_output(buf) {
ParsedOutput::Success {
response,
tools,
token_usage,
} => {
assert_eq!(response.as_deref(), Some("hi"));
assert_eq!(token_usage.input, 15);
assert_eq!(token_usage.output, 23);
assert_eq!(tools.len(), 3);
assert!(tools.contains(&"a".to_string()));
assert_eq!(tools.iter().filter(|t| t.as_str() == "b").count(), 2);
}
other => panic!("expected Success, got {:?}", std::mem::discriminant(&other)),
}
}
#[test]
fn parse_gemini_output_handles_error_shape() {
let buf = r#"{"error":{"type":"AuthError","message":"missing key"}}"#;
match parse_gemini_output(buf) {
ParsedOutput::Error { message } => {
assert_eq!(message, "missing key");
}
_ => panic!("expected Error variant"),
}
}
#[test]
fn parse_gemini_output_treats_non_json_as_unparseable() {
match parse_gemini_output("not json at all") {
ParsedOutput::Unparseable => {}
_ => panic!("expected Unparseable variant for non-JSON input"),
}
match parse_gemini_output("") {
ParsedOutput::Unparseable => {}
_ => panic!("expected Unparseable variant for empty input"),
}
}
#[test]
fn exit_code_label_covers_known_buckets() {
assert_eq!(exit_code_label(42), Some("authentication error"));
assert_eq!(exit_code_label(43), Some("quota exceeded"));
assert_eq!(exit_code_label(44), Some("network error"));
assert_eq!(exit_code_label(53), Some("tool error"));
assert_eq!(exit_code_label(1), None);
assert_eq!(exit_code_label(99), None);
}
#[tokio::test]
async fn real_gemini_smoke_test() {
if std::env::var("PITBOSS_REAL_AGENT_TESTS").ok().as_deref() != Some("1") {
eprintln!("skipping real_gemini_smoke_test (set PITBOSS_REAL_AGENT_TESTS=1 to run)");
return;
}
let dir = tempfile::tempdir().unwrap();
let log = dir.path().join("run.log");
let agent = GeminiAgent::new();
let (tx, _rx) = mpsc::channel(64);
let cancel = CancellationToken::new();
let req = AgentRequest {
role: Role::Implementer,
model: "gemini-2.5-pro".into(),
system_prompt: String::new(),
user_prompt: "respond with the single word OK".into(),
workdir: dir.path().to_path_buf(),
log_path: log,
timeout: Duration::from_secs(120),
env: std::collections::HashMap::new(),
};
let outcome = agent.run(req, tx, cancel).await.unwrap();
assert!(
matches!(outcome.stop_reason, StopReason::Completed),
"real gemini run did not complete: {:?}",
outcome.stop_reason
);
assert_eq!(outcome.exit_code, 0);
}
}