Skip to main content

tam_daemon/
agent.rs

1use std::os::fd::{AsRawFd, OwnedFd};
2use std::os::unix::process::CommandExt;
3use std::path::{Path, PathBuf};
4use std::process::{Child, Stdio};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::thread::JoinHandle;
8use std::time::Instant;
9
10use anyhow::{Context, Result};
11use nix::pty::openpty;
12use nix::sys::signal::{kill, Signal};
13use nix::unistd::Pid;
14use tam_proto::{AgentInfo, AgentState};
15use tokio::sync::broadcast;
16use tracing::error;
17
18use crate::provider::Provider;
19use crate::scrollback::ScrollbackBuffer;
20
21/// Lightweight metadata for context refresh — collected under lock, IO done outside.
22pub struct ContextRefreshJob {
23    pub id: String,
24    pub pid: u32,
25    pub dir: PathBuf,
26    pub provider: String,
27}
28
29pub struct Agent {
30    provider: Arc<dyn Provider>,
31    dir: PathBuf,
32    /// Stored state — used as fallback when provider doesn't do output detection (e.g. Claude with hooks).
33    state: AgentState,
34    child: Child,
35    pty_master: Arc<OwnedFd>,
36    scrollback: Arc<Mutex<ScrollbackBuffer>>,
37    output_tx: broadcast::Sender<Vec<u8>>,
38    viewers: Arc<AtomicUsize>,
39    started_at: Instant,
40    /// Updated by the PTY reader thread on every output chunk.
41    last_output_at: Arc<Mutex<Instant>>,
42    /// Last state reported via events. Used to detect transitions.
43    reported_state: AgentState,
44    /// Cached context usage percentage, updated periodically.
45    context_percent: Option<u8>,
46    _reader_handle: JoinHandle<()>,
47}
48
49impl Agent {
50    /// Spawn a new agent process attached to a PTY.
51    /// `env_vars` are set on the child process (e.g. TAM_AGENT_ID, TAM_SOCKET).
52    pub fn spawn(
53        provider: Arc<dyn Provider>,
54        dir: &Path,
55        args: &[String],
56        resume_session: Option<&str>,
57        prompt: Option<&str>,
58        env_vars: &[(&str, &str)],
59    ) -> Result<Self> {
60        // Verify directory exists
61        anyhow::ensure!(dir.is_dir(), "directory does not exist: {}", dir.display());
62
63        let mut cmd = provider.build_command(dir, args, resume_session, prompt);
64        for (key, val) in env_vars {
65            cmd.env(key, val);
66        }
67
68        // Create PTY pair
69        let pty = openpty(None, None).context("failed to create PTY")?;
70        let master = pty.master;
71        let slave = pty.slave;
72
73        // Grab raw fd before slave is consumed (valid in child after fork)
74        let slave_raw_fd = slave.as_raw_fd();
75
76        // Create stdio from slave PTY
77        let stdin_fd = slave.try_clone().context("failed to clone slave fd")?;
78        let stdout_fd = slave.try_clone().context("failed to clone slave fd")?;
79        let stderr_fd = slave; // consumes original
80
81        let child = unsafe {
82            cmd.stdin(Stdio::from(stdin_fd))
83                .stdout(Stdio::from(stdout_fd))
84                .stderr(Stdio::from(stderr_fd))
85                .pre_exec(move || {
86                    // Create new session so the agent is detached from our terminal
87                    nix::unistd::setsid()
88                        .map_err(|e| std::io::Error::from_raw_os_error(e as i32))?;
89                    // Set the slave PTY as the controlling terminal
90                    if libc::ioctl(slave_raw_fd, libc::TIOCSCTTY as libc::c_ulong, 0) < 0 {
91                        return Err(std::io::Error::last_os_error());
92                    }
93                    Ok(())
94                })
95                .spawn()
96                .with_context(|| {
97                    format!("failed to spawn '{}' in {}", provider.name(), dir.display())
98                })?
99        };
100
101        let master = Arc::new(master);
102        let scrollback = Arc::new(Mutex::new(ScrollbackBuffer::default()));
103        let (output_tx, _) = broadcast::channel(64);
104        let last_output_at = Arc::new(Mutex::new(Instant::now()));
105
106        // Spawn a reader thread that drains PTY output into scrollback + broadcast
107        let reader_handle = {
108            let master = master.clone();
109            let scrollback = scrollback.clone();
110            let output_tx = output_tx.clone();
111            let last_output_at = last_output_at.clone();
112            std::thread::Builder::new()
113                .name("pty-reader".to_string())
114                .spawn(move || {
115                    pty_reader_loop(master.as_raw_fd(), scrollback, output_tx, last_output_at);
116                })
117                .context("failed to spawn PTY reader thread")?
118        };
119
120        Ok(Self {
121            provider,
122            dir: dir.to_path_buf(),
123            state: AgentState::Working,
124            child,
125            pty_master: master,
126            scrollback,
127            output_tx,
128            viewers: Arc::new(AtomicUsize::new(0)),
129            started_at: Instant::now(),
130            last_output_at,
131            reported_state: AgentState::Working,
132            context_percent: None,
133            _reader_handle: reader_handle,
134        })
135    }
136
137    /// Check if the child process has exited. Returns Some(exit_code) if so.
138    /// Agents that exit are cleaned up by the daemon — exit is an event, not a state.
139    pub fn check_exited(&mut self) -> Option<i32> {
140        match self.child.try_wait() {
141            Ok(Some(status)) => Some(status.code().unwrap_or(-1)),
142            Ok(None) => None,
143            Err(_) => Some(-1),
144        }
145    }
146
147    /// Send SIGTERM, wait briefly, then SIGKILL if needed.
148    pub fn kill(&mut self) -> Result<()> {
149        let pid = Pid::from_raw(self.child.id() as i32);
150
151        // Try graceful shutdown first
152        let _ = kill(pid, Signal::SIGTERM);
153
154        // Wait up to 200ms for graceful exit
155        for _ in 0..20 {
156            std::thread::sleep(std::time::Duration::from_millis(10));
157            if matches!(self.child.try_wait(), Ok(Some(_))) {
158                return Ok(());
159            }
160        }
161
162        // Force kill
163        let _ = kill(pid, Signal::SIGKILL);
164
165        // Wait up to 2s for forced exit (handles slow process cleanup)
166        for _ in 0..200 {
167            std::thread::sleep(std::time::Duration::from_millis(10));
168            if matches!(self.child.try_wait(), Ok(Some(_))) {
169                return Ok(());
170            }
171        }
172
173        error!(pid = %self.child.id(), "process did not exit after SIGKILL, abandoning");
174        Ok(())
175    }
176
177    /// Kill the agent and drop it. Intended for use in background tasks
178    /// where ownership is transferred (e.g. `spawn_blocking`).
179    pub fn kill_and_drop(mut self) {
180        let _ = self.kill();
181    }
182
183    /// Set the stored state directly (used by hook-based providers).
184    pub fn set_state(&mut self, state: AgentState) {
185        self.state = state;
186    }
187
188    /// Map a hook event to a state via the provider, and update stored state.
189    /// Returns the new state if the event was recognized, None otherwise.
190    pub fn handle_hook_event(&mut self, event: &str) -> Option<AgentState> {
191        let new_state = self.provider.map_hook_event(event)?;
192        self.set_state(new_state);
193        Some(new_state)
194    }
195
196    /// Check for a state transition. Returns Some((old, new)) if state changed.
197    /// Updates reported_state so the same transition isn't reported twice.
198    pub fn check_state_change(&mut self) -> Option<(AgentState, AgentState)> {
199        let current = self.current_state();
200        if current != self.reported_state {
201            let old = self.reported_state;
202            self.reported_state = current;
203            Some((old, current))
204        } else {
205            None
206        }
207    }
208
209    /// Compute current state: ask the provider first (output heuristic),
210    /// fall back to the stored state (set by hooks or default).
211    pub fn current_state(&self) -> AgentState {
212        let idle_duration = self.last_output_at.lock().unwrap().elapsed();
213        self.provider
214            .detect_state_from_output(&[], idle_duration)
215            .unwrap_or(self.state)
216    }
217
218    /// Build an AgentInfo snapshot for reporting to clients.
219    pub fn info(&self, id: &str) -> AgentInfo {
220        AgentInfo {
221            id: id.to_string(),
222            provider: self.provider.name().to_string(),
223            dir: self.dir.clone(),
224            state: self.current_state(),
225            pid: Some(self.child.id()),
226            uptime_secs: self.started_at.elapsed().as_secs(),
227            viewers: self.viewers.load(Ordering::Relaxed),
228            context_percent: self.context_percent,
229            task: Some(id.to_string()),
230        }
231    }
232
233    /// Collect lightweight metadata for two-phase context refresh.
234    /// This is cheap (no IO) and can be called under the lock.
235    pub fn context_refresh_job(&self, id: &str) -> ContextRefreshJob {
236        ContextRefreshJob {
237            id: id.to_string(),
238            pid: self.child.id(),
239            dir: self.dir.clone(),
240            provider: self.provider.name().to_string(),
241        }
242    }
243
244    /// Set context percent. Returns true if the value changed.
245    pub fn set_context_percent(&mut self, pct: Option<u8>) -> bool {
246        let changed = self.context_percent != pct;
247        self.context_percent = pct;
248        changed
249    }
250
251    pub fn context_percent(&self) -> Option<u8> {
252        self.context_percent
253    }
254
255    /// Get the viewer count handle for increment/decrement by attach sessions.
256    pub fn viewers(&self) -> Arc<AtomicUsize> {
257        self.viewers.clone()
258    }
259
260    /// Subscribe to live PTY output. Returns a broadcast receiver.
261    pub fn subscribe(&self) -> broadcast::Receiver<Vec<u8>> {
262        self.output_tx.subscribe()
263    }
264
265    /// Get a copy of the current scrollback buffer contents.
266    pub fn scrollback_contents(&self) -> Vec<u8> {
267        self.scrollback.lock().unwrap().to_vec()
268    }
269
270    /// Get a clone of the PTY master fd (kept alive by Arc).
271    pub fn pty_master(&self) -> Arc<OwnedFd> {
272        self.pty_master.clone()
273    }
274
275    /// Resize the agent's PTY and notify the agent process.
276    pub fn resize(&self, cols: u16, rows: u16) {
277        let ws = libc::winsize {
278            ws_col: cols,
279            ws_row: rows,
280            ws_xpixel: 0,
281            ws_ypixel: 0,
282        };
283        unsafe {
284            libc::ioctl(
285                self.pty_master.as_raw_fd(),
286                libc::TIOCSWINSZ as libc::c_ulong,
287                &ws,
288            );
289        }
290        // Notify the agent process of the resize
291        let _ = kill(Pid::from_raw(self.child.id() as i32), Signal::SIGWINCH);
292    }
293}
294
295/// Blocking loop that reads PTY master output, stores it in the scrollback buffer,
296/// broadcasts it to any attached clients, and tracks when output last arrived.
297/// Exits when the PTY slave side is closed (agent exits).
298fn pty_reader_loop(
299    master_fd: i32,
300    scrollback: Arc<Mutex<ScrollbackBuffer>>,
301    output_tx: broadcast::Sender<Vec<u8>>,
302    last_output_at: Arc<Mutex<Instant>>,
303) {
304    let mut buf = [0u8; 4096];
305    loop {
306        match nix::unistd::read(master_fd, &mut buf) {
307            Ok(0) => break,
308            Ok(n) => {
309                let data = buf[..n].to_vec();
310                if let Ok(mut sb) = scrollback.lock() {
311                    sb.write(&data);
312                }
313                if let Ok(mut ts) = last_output_at.lock() {
314                    *ts = Instant::now();
315                }
316                // Ignore send errors (no receivers is fine)
317                let _ = output_tx.send(data);
318            }
319            Err(nix::errno::Errno::EIO) => break, // PTY closed
320            Err(nix::errno::Errno::EINTR) => continue,
321            Err(e) => {
322                error!("PTY read error: {}", e);
323                break;
324            }
325        }
326    }
327}