use std::path::{Path, PathBuf};
use std::process::Command;
use std::thread;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use crate::compose::AgentHandle;
#[derive(Debug, Clone)]
pub struct AgentSpec {
pub project: String,
pub agent: String,
pub tmux_session: String,
pub wrapper: PathBuf,
pub cwd: PathBuf,
pub env_file: PathBuf,
}
impl AgentSpec {
pub fn from_handle(h: AgentHandle<'_>, root: &Path, tmux_prefix: &str) -> Self {
Self {
project: h.project.into(),
agent: h.agent.into(),
tmux_session: format!("{tmux_prefix}{}-{}", h.project, h.agent),
wrapper: root.join("bin/agent-wrapper.sh"),
cwd: root.to_path_buf(),
env_file: crate::render::env_path(root, h.project, h.agent),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgentState {
Running,
Stopped,
Unknown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DrainOutcome {
Graceful,
TimedOutKilled,
}
pub trait Supervisor {
fn up(&self, spec: &AgentSpec) -> Result<()>;
fn down(&self, spec: &AgentSpec) -> Result<()>;
fn state(&self, spec: &AgentSpec) -> Result<AgentState>;
fn drain(&self, spec: &AgentSpec, _timeout: Duration) -> Result<DrainOutcome> {
self.down(spec)?;
Ok(DrainOutcome::TimedOutKilled)
}
fn drain_poll_interval(&self) -> Duration {
Duration::from_millis(250)
}
}
pub fn orchestrate_drain<S, F>(
supervisor: &S,
spec: &AgentSpec,
timeout: Duration,
signal_fn: F,
) -> Result<DrainOutcome>
where
S: Supervisor + ?Sized,
F: FnOnce(),
{
signal_fn();
let outcome = poll_for_stopped(timeout, supervisor.drain_poll_interval(), || {
supervisor.state(spec).unwrap_or(AgentState::Unknown)
});
if outcome == DrainOutcome::TimedOutKilled {
supervisor.down(spec)?;
}
Ok(outcome)
}
pub struct TmuxSupervisor;
impl Supervisor for TmuxSupervisor {
fn up(&self, spec: &AgentSpec) -> Result<()> {
if matches!(self.state(spec)?, AgentState::Running) {
return Ok(());
}
let cmd = format!(
"env $(cat {env}) {wrapper} {project}:{agent}",
env = shlex::try_quote(&spec.env_file.display().to_string())?,
wrapper = shlex::try_quote(&spec.wrapper.display().to_string())?,
project = spec.project,
agent = spec.agent,
);
let status = Command::new("tmux")
.args([
"new-session",
"-d",
"-s",
&spec.tmux_session,
"-c",
&spec.cwd.display().to_string(),
"sh",
"-c",
&cmd,
])
.status()
.context("spawn tmux new-session")?;
anyhow::ensure!(status.success(), "tmux new-session exited {status}");
Ok(())
}
fn down(&self, spec: &AgentSpec) -> Result<()> {
let _ = Command::new("tmux")
.args(["kill-session", "-t", &spec.tmux_session])
.status();
Ok(())
}
fn state(&self, spec: &AgentSpec) -> Result<AgentState> {
let out = Command::new("tmux")
.args(["has-session", "-t", &spec.tmux_session])
.output();
Ok(match out {
Ok(o) if o.status.success() => AgentState::Running,
Ok(_) => AgentState::Stopped,
Err(_) => AgentState::Unknown,
})
}
fn drain(&self, spec: &AgentSpec, timeout: Duration) -> Result<DrainOutcome> {
orchestrate_drain(self, spec, timeout, || {
let _ = Command::new("tmux")
.args(["send-keys", "-t", &spec.tmux_session, "C-c"])
.status();
})
}
}
fn poll_for_stopped<F: FnMut() -> AgentState>(
timeout: Duration,
interval: Duration,
mut observe_state: F,
) -> DrainOutcome {
let deadline = Instant::now() + timeout;
loop {
if observe_state() == AgentState::Stopped {
return DrainOutcome::Graceful;
}
if Instant::now() >= deadline {
return DrainOutcome::TimedOutKilled;
}
thread::sleep(interval);
}
}
#[cfg(test)]
mod drain_tests {
use super::*;
use std::cell::RefCell;
#[test]
fn poll_returns_graceful_when_stopped_observed_in_time() {
let calls = RefCell::new(0u32);
let outcome = poll_for_stopped(Duration::from_millis(50), Duration::from_millis(1), || {
let mut n = calls.borrow_mut();
*n += 1;
if *n >= 2 {
AgentState::Stopped
} else {
AgentState::Running
}
});
assert_eq!(outcome, DrainOutcome::Graceful);
}
#[test]
fn poll_falls_through_to_kill_when_agent_never_stops() {
let outcome = poll_for_stopped(Duration::from_millis(8), Duration::from_millis(2), || {
AgentState::Running
});
assert_eq!(outcome, DrainOutcome::TimedOutKilled);
}
#[test]
fn poll_zero_timeout_only_checks_once_then_kills() {
let mut calls: u32 = 0;
let outcome = poll_for_stopped(Duration::from_millis(0), Duration::from_millis(1), || {
calls += 1;
AgentState::Running
});
assert_eq!(outcome, DrainOutcome::TimedOutKilled);
assert_eq!(calls, 1, "single state observation before timeout");
}
#[derive(Default)]
struct MockSupervisor {
calls: RefCell<Vec<&'static str>>,
stop_after: u32,
state_calls: RefCell<u32>,
poll_interval: Duration,
}
impl MockSupervisor {
fn record(&self, op: &'static str) {
self.calls.borrow_mut().push(op);
}
}
impl Supervisor for MockSupervisor {
fn up(&self, _spec: &AgentSpec) -> Result<()> {
self.record("up");
Ok(())
}
fn down(&self, _spec: &AgentSpec) -> Result<()> {
self.record("down");
Ok(())
}
fn state(&self, _spec: &AgentSpec) -> Result<AgentState> {
self.record("state");
let mut n = self.state_calls.borrow_mut();
*n += 1;
if self.stop_after > 0 && *n >= self.stop_after {
Ok(AgentState::Stopped)
} else {
Ok(AgentState::Running)
}
}
fn drain_poll_interval(&self) -> Duration {
self.poll_interval
}
}
fn fake_spec() -> AgentSpec {
AgentSpec {
project: "p".into(),
agent: "a".into(),
tmux_session: "p-a".into(),
wrapper: PathBuf::from("/dev/null"),
cwd: PathBuf::from("/tmp"),
env_file: PathBuf::from("/dev/null"),
}
}
#[test]
fn drain_with_zero_timeout_returns_timed_out_killed_and_calls_down() {
let mock = MockSupervisor {
poll_interval: Duration::from_millis(1),
..Default::default()
};
let spec = fake_spec();
let signaled = RefCell::new(false);
let outcome = orchestrate_drain(&mock, &spec, Duration::ZERO, || {
*signaled.borrow_mut() = true;
})
.unwrap();
assert_eq!(outcome, DrainOutcome::TimedOutKilled);
assert!(*signaled.borrow(), "signal_fn must run before the poll");
assert_eq!(
mock.calls.borrow().as_slice(),
&["state", "down"],
"zero-timeout: one state observation then kill"
);
}
#[test]
fn drain_with_graceful_stop_does_not_call_down() {
let mock = MockSupervisor {
poll_interval: Duration::from_millis(1),
stop_after: 2, ..Default::default()
};
let spec = fake_spec();
let outcome = orchestrate_drain(&mock, &spec, Duration::from_millis(100), || {}).unwrap();
assert_eq!(outcome, DrainOutcome::Graceful);
assert!(
!mock.calls.borrow().contains(&"down"),
"graceful drain must not call down(); calls: {:?}",
mock.calls.borrow()
);
}
#[test]
fn drain_poll_interval_default_is_250ms() {
struct Default250;
impl Supervisor for Default250 {
fn up(&self, _: &AgentSpec) -> Result<()> {
Ok(())
}
fn down(&self, _: &AgentSpec) -> Result<()> {
Ok(())
}
fn state(&self, _: &AgentSpec) -> Result<AgentState> {
Ok(AgentState::Stopped)
}
}
assert_eq!(Default250.drain_poll_interval(), Duration::from_millis(250));
}
#[test]
fn drain_poll_interval_override_is_used_by_orchestrator() {
let mock = MockSupervisor {
poll_interval: Duration::from_millis(2),
stop_after: 0,
..Default::default()
};
let spec = fake_spec();
let start = Instant::now();
let _ = orchestrate_drain(&mock, &spec, Duration::from_millis(8), || {});
let elapsed = start.elapsed();
let states = mock
.calls
.borrow()
.iter()
.filter(|c| **c == "state")
.count();
assert!(
states >= 2,
"expected several state observations at 2ms cadence, got {states}"
);
assert!(
elapsed < Duration::from_millis(60),
"drain with 2ms interval finished too slowly ({elapsed:?})"
);
}
}
mod shlex {
pub fn try_quote(s: &str) -> anyhow::Result<String> {
anyhow::ensure!(!s.contains('\0'), "null byte in shell arg");
let escaped = s.replace('\'', r"'\''");
Ok(format!("'{escaped}'"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quotes_plain_path() {
assert_eq!(try_quote("/a/b.sh").unwrap(), "'/a/b.sh'");
}
#[test]
fn escapes_embedded_single_quote() {
assert_eq!(try_quote("x'y").unwrap(), r"'x'\''y'");
}
}
}