1use std::collections::VecDeque;
8use std::fs;
9use std::io::{Read, Write as IoWrite};
10use std::path::{Path, PathBuf};
11use std::process::Command as ProcessCommand;
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15use std::time::Duration;
16use std::time::Instant;
17
18use anyhow::{Context, Result};
19use portable_pty::{Child, CommandBuilder, PtySize};
20
21use super::classifier::{self, AgentType, ScreenVerdict};
22use super::common::{self, QueuedMessage};
23use super::protocol::{Channel, Command, Event, ShimState};
24use super::pty_log::PtyLogWriter;
25use crate::prompt::strip_ansi;
26
27const DEFAULT_ROWS: u16 = 50;
32const DEFAULT_COLS: u16 = 220;
33const SCROLLBACK_LINES: usize = 5000;
34
35const POLL_INTERVAL_MS: u64 = 250;
37
38const WORKING_DWELL_MS: u64 = 300;
43
44const KIRO_IDLE_SETTLE_MS: u64 = 1200;
47
48const READY_TIMEOUT_SECS: u64 = 120;
50use common::MAX_QUEUE_DEPTH;
51use common::SESSION_STATS_INTERVAL_SECS;
52
53const PROCESS_EXIT_POLL_MS: u64 = 100;
54const PARENT_DEATH_POLL_SECS: u64 = 1;
55const GROUP_TERM_GRACE_SECS: u64 = 2;
56pub(crate) const HANDOFF_FILE_NAME: &str = "handoff.md";
57const AUTO_COMMIT_MESSAGE: &str = "wip: auto-save before restart [batty]";
58const AUTO_COMMIT_TIMEOUT_SECS: u64 = 5;
59
60pub(crate) fn preserve_handoff(worktree: &Path, recent_output: Option<&str>) -> Result<()> {
64 let diff_stat = git_capture(worktree, &["diff", "--stat"]).unwrap_or_default();
65 let recent_commits = git_capture(worktree, &["log", "--oneline", "-5"]).unwrap_or_default();
66 let tests_run = recent_output
67 .map(extract_test_commands)
68 .unwrap_or_default()
69 .join("\n");
70 let recent_activity = recent_output
71 .map(summarize_recent_activity)
72 .unwrap_or_default();
73
74 let handoff = format!(
75 "# Handoff\n## Modified Files\n{}\n\n## Tests Run\n{}\n\n## Recent Activity\n{}\n\n## Recent Commits\n{}\n",
76 empty_section_fallback(&diff_stat),
77 empty_section_fallback(&tests_run),
78 empty_section_fallback(&recent_activity),
79 empty_section_fallback(&recent_commits)
80 );
81 fs::write(worktree.join(HANDOFF_FILE_NAME), handoff)
82 .with_context(|| format!("failed to write handoff file in {}", worktree.display()))?;
83 Ok(())
84}
85
86fn git_capture(worktree: &Path, args: &[&str]) -> Result<String> {
87 let output = ProcessCommand::new("git")
88 .args(args)
89 .current_dir(worktree)
90 .env_remove("GIT_DIR")
91 .env_remove("GIT_WORK_TREE")
92 .output()
93 .with_context(|| {
94 format!(
95 "failed to run `git {}` in {}",
96 args.join(" "),
97 worktree.display()
98 )
99 })?;
100 if !output.status.success() {
101 anyhow::bail!(
102 "`git {}` failed in {}: {}",
103 args.join(" "),
104 worktree.display(),
105 String::from_utf8_lossy(&output.stderr).trim()
106 );
107 }
108 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
109}
110
111fn empty_section_fallback(content: &str) -> &str {
112 if content.trim().is_empty() {
113 "(none)"
114 } else {
115 content
116 }
117}
118
119fn summarize_recent_activity(output: &str) -> String {
120 let cleaned = strip_ansi(output);
121 let lines: Vec<&str> = cleaned
122 .lines()
123 .map(str::trim_end)
124 .filter(|line| !line.trim().is_empty())
125 .collect();
126 let start = lines.len().saturating_sub(40);
127 lines[start..].join("\n")
128}
129
130fn extract_test_commands(output: &str) -> Vec<String> {
131 let cleaned = strip_ansi(output);
132 let mut commands = Vec::new();
133
134 for line in cleaned.lines() {
135 let trimmed = line.trim();
136 if trimmed.is_empty() {
137 continue;
138 }
139 let lower = trimmed.to_ascii_lowercase();
140 if lower.contains("cargo test")
141 || lower.contains("cargo nextest")
142 || lower.contains("pytest")
143 || lower.contains("npm test")
144 || lower.contains("pnpm test")
145 || lower.contains("yarn test")
146 || lower.contains("go test")
147 || lower.contains("bundle exec rspec")
148 || lower.contains("mix test")
149 {
150 if !commands.iter().any(|existing| existing == trimmed) {
151 commands.push(trimmed.to_string());
152 }
153 }
154 }
155
156 commands
157}
158
159fn format_injected_message(sender: &str, body: &str) -> String {
160 common::format_injected_message(sender, body)
161}
162
163fn shell_single_quote(input: &str) -> String {
164 input.replace('\'', "'\\''")
165}
166
167fn build_supervised_agent_command(command: &str, shim_pid: u32) -> String {
168 let escaped_command = shell_single_quote(command);
169 format!(
170 "shim_pid={shim_pid}; \
171 agent_root_pid=$$; \
172 agent_pgid=$$; \
173 setsid sh -c ' \
174 shim_pid=\"$1\"; \
175 agent_pgid=\"$2\"; \
176 agent_root_pid=\"$3\"; \
177 collect_descendants() {{ \
178 parent_pid=\"$1\"; \
179 for child_pid in $(pgrep -P \"$parent_pid\" 2>/dev/null); do \
180 printf \"%s\\n\" \"$child_pid\"; \
181 collect_descendants \"$child_pid\"; \
182 done; \
183 }}; \
184 while kill -0 \"$shim_pid\" 2>/dev/null; do sleep {PARENT_DEATH_POLL_SECS}; done; \
185 descendant_pids=$(collect_descendants \"$agent_root_pid\"); \
186 kill -TERM -- -\"$agent_pgid\" >/dev/null 2>&1 || true; \
187 for descendant_pid in $descendant_pids; do kill -TERM \"$descendant_pid\" >/dev/null 2>&1 || true; done; \
188 sleep {GROUP_TERM_GRACE_SECS}; \
189 kill -KILL -- -\"$agent_pgid\" >/dev/null 2>&1 || true; \
190 for descendant_pid in $descendant_pids; do kill -KILL \"$descendant_pid\" >/dev/null 2>&1 || true; done \
191 ' _ \"$shim_pid\" \"$agent_pgid\" \"$agent_root_pid\" >/dev/null 2>&1 < /dev/null & \
192 exec bash -lc '{escaped_command}'"
193 )
194}
195
196#[cfg(unix)]
197fn signal_process_group(child: &dyn Child, signal: libc::c_int) -> std::io::Result<()> {
198 let pid = child
199 .process_id()
200 .ok_or_else(|| std::io::Error::other("child process id unavailable"))?;
201 let result = unsafe { libc::killpg(pid as libc::pid_t, signal) };
202 if result == 0 {
203 Ok(())
204 } else {
205 Err(std::io::Error::last_os_error())
206 }
207}
208
209fn terminate_agent_group(
210 child: &mut Box<dyn Child + Send + Sync>,
211 sigterm_grace: Duration,
212) -> std::io::Result<()> {
213 #[cfg(unix)]
214 {
215 signal_process_group(child.as_ref(), libc::SIGTERM)?;
216 let deadline = Instant::now() + sigterm_grace;
217 while Instant::now() <= deadline {
218 if let Ok(Some(_)) = child.try_wait() {
219 return Ok(());
220 }
221 thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS));
222 }
223
224 signal_process_group(child.as_ref(), libc::SIGKILL)?;
225 return Ok(());
226 }
227
228 #[allow(unreachable_code)]
229 child.kill()
230}
231
232fn graceful_shutdown_timeout() -> Duration {
233 let secs = std::env::var("BATTY_GRACEFUL_SHUTDOWN_TIMEOUT_SECS")
234 .ok()
235 .and_then(|value| value.parse::<u64>().ok())
236 .unwrap_or(AUTO_COMMIT_TIMEOUT_SECS);
237 Duration::from_secs(secs)
238}
239
240fn auto_commit_on_restart_enabled() -> bool {
241 std::env::var("BATTY_AUTO_COMMIT_ON_RESTART")
242 .map(|value| !matches!(value.as_str(), "0" | "false" | "FALSE"))
243 .unwrap_or(true)
244}
245
246fn preserve_work_before_kill_with<F>(
247 worktree_path: &Path,
248 timeout: Duration,
249 enabled: bool,
250 commit_fn: F,
251) -> Result<bool>
252where
253 F: FnOnce(PathBuf) -> Result<bool> + Send + 'static,
254{
255 if !enabled {
256 return Ok(false);
257 }
258
259 let (tx, rx) = mpsc::channel();
260 let path = worktree_path.to_path_buf();
261 thread::spawn(move || {
262 let _ = tx.send(commit_fn(path));
263 });
264
265 match rx.recv_timeout(timeout) {
266 Ok(result) => result,
267 Err(mpsc::RecvTimeoutError::Timeout) => Ok(false),
268 Err(mpsc::RecvTimeoutError::Disconnected) => Ok(false),
269 }
270}
271
272pub(crate) fn preserve_work_before_kill(worktree_path: &Path) -> Result<bool> {
273 let timeout = graceful_shutdown_timeout();
274 preserve_work_before_kill_with(
275 worktree_path,
276 timeout,
277 auto_commit_on_restart_enabled(),
278 move |path| {
279 crate::team::git_cmd::auto_commit_if_dirty(&path, AUTO_COMMIT_MESSAGE, timeout)
280 .map_err(anyhow::Error::from)
281 },
282 )
283}
284
285fn pty_write_paced(
289 pty_writer: &Arc<Mutex<Box<dyn std::io::Write + Send>>>,
290 agent_type: AgentType,
291 body: &[u8],
292 enter: &[u8],
293) -> std::io::Result<()> {
294 match agent_type {
302 AgentType::Generic => {
303 let mut writer = pty_writer.lock().unwrap();
305 writer.write_all(body)?;
306 writer.write_all(enter)?;
307 writer.flush()?;
308 }
309 _ => {
310 let mut writer = pty_writer.lock().unwrap();
312 writer.write_all(b"\x1b[200~")?;
313 writer.write_all(body)?;
314 writer.write_all(b"\x1b[201~")?;
315 writer.flush()?;
316 drop(writer);
317
318 std::thread::sleep(std::time::Duration::from_millis(200));
320
321 let mut writer = pty_writer.lock().unwrap();
322 writer.write_all(enter)?;
323 writer.flush()?;
324 }
325 }
326 Ok(())
327}
328
329fn enter_seq(agent_type: AgentType) -> &'static str {
333 match agent_type {
334 AgentType::Generic => "\n",
335 _ => "\r", }
337}
338
339#[derive(Debug, Clone)]
344pub struct ShimArgs {
345 pub id: String,
346 pub agent_type: AgentType,
347 pub cmd: String,
348 pub cwd: PathBuf,
349 pub rows: u16,
350 pub cols: u16,
351 pub pty_log_path: Option<PathBuf>,
354 pub graceful_shutdown_timeout_secs: u64,
355 pub auto_commit_on_restart: bool,
356}
357
358impl ShimArgs {
359 fn preserve_work_before_kill(&self, worktree_path: &Path) -> Result<bool> {
360 if !self.auto_commit_on_restart {
361 return Ok(false);
362 }
363
364 let status = ProcessCommand::new("git")
365 .arg("-C")
366 .arg(worktree_path)
367 .args(["status", "--porcelain"])
368 .output()
369 .with_context(|| {
370 format!(
371 "failed to inspect git status in {}",
372 worktree_path.display()
373 )
374 })?;
375 if !status.status.success() {
376 anyhow::bail!("git status failed in {}", worktree_path.display());
377 }
378
379 let dirty = String::from_utf8_lossy(&status.stdout)
380 .lines()
381 .any(|line| !line.starts_with("?? .batty/"));
382 if !dirty {
383 return Ok(false);
384 }
385
386 let timeout = Duration::from_secs(self.graceful_shutdown_timeout_secs);
387 run_git_preserve_with_timeout(worktree_path, &["add", "-A"], timeout)?;
388 run_git_preserve_with_timeout(
389 worktree_path,
390 &["commit", "-m", "wip: auto-save before restart [batty]"],
391 timeout,
392 )?;
393 Ok(true)
394 }
395}
396
397fn run_git_preserve_with_timeout(
398 worktree_path: &Path,
399 args: &[&str],
400 timeout: Duration,
401) -> Result<()> {
402 let mut child = ProcessCommand::new("git")
403 .arg("-C")
404 .arg(worktree_path)
405 .args(args)
406 .spawn()
407 .with_context(|| {
408 format!(
409 "failed to launch `git {}` in {}",
410 args.join(" "),
411 worktree_path.display()
412 )
413 })?;
414 let deadline = Instant::now() + timeout;
415 loop {
416 if let Some(status) = child.try_wait()? {
417 if status.success() {
418 return Ok(());
419 }
420 anyhow::bail!(
421 "`git {}` failed in {} with status {}",
422 args.join(" "),
423 worktree_path.display(),
424 status
425 );
426 }
427
428 if Instant::now() >= deadline {
429 let _ = child.kill();
430 let _ = child.wait();
431 anyhow::bail!(
432 "`git {}` timed out after {}s in {}",
433 args.join(" "),
434 timeout.as_secs(),
435 worktree_path.display()
436 );
437 }
438
439 thread::sleep(Duration::from_millis(50));
440 }
441}
442
443struct ShimInner {
450 parser: vt100::Parser,
451 state: ShimState,
452 state_changed_at: Instant,
453 last_screen_hash: u64,
454 last_pty_output_at: Instant,
455 started_at: Instant,
456 cumulative_output_bytes: u64,
457 pre_injection_content: String,
458 pending_message_id: Option<String>,
459 agent_type: AgentType,
460 message_queue: VecDeque<QueuedMessage>,
463 dialogs_dismissed: u8,
465 last_working_screen: String,
469}
470
471impl ShimInner {
472 fn screen_contents(&self) -> String {
473 self.parser.screen().contents()
474 }
475
476 fn last_n_lines(&self, n: usize) -> String {
477 let content = self.parser.screen().contents();
478 let lines: Vec<&str> = content.lines().collect();
479 let start = lines.len().saturating_sub(n);
480 lines[start..].join("\n")
481 }
482
483 fn cursor_position(&self) -> (u16, u16) {
484 self.parser.screen().cursor_position()
485 }
486}
487
488fn content_hash(s: &str) -> u64 {
493 let mut hash: u64 = 0xcbf29ce484222325;
494 for byte in s.bytes() {
495 hash ^= byte as u64;
496 hash = hash.wrapping_mul(0x100000001b3);
497 }
498 hash
499}
500
501pub fn run(args: ShimArgs, channel: Channel) -> Result<()> {
509 let rows = if args.rows > 0 {
510 args.rows
511 } else {
512 DEFAULT_ROWS
513 };
514 let cols = if args.cols > 0 {
515 args.cols
516 } else {
517 DEFAULT_COLS
518 };
519
520 let pty_system = portable_pty::native_pty_system();
522 let pty_pair = pty_system
523 .openpty(PtySize {
524 rows,
525 cols,
526 pixel_width: 0,
527 pixel_height: 0,
528 })
529 .context("failed to create PTY")?;
530
531 let shim_pid = std::process::id();
533 let supervised_cmd = build_supervised_agent_command(&args.cmd, shim_pid);
534
535 let mut cmd = CommandBuilder::new("bash");
536 cmd.args(["-lc", &supervised_cmd]);
537 cmd.cwd(&args.cwd);
538 cmd.env_remove("CLAUDECODE"); cmd.env("TERM", "xterm-256color");
540 cmd.env("COLORTERM", "truecolor");
541
542 let mut child = pty_pair
543 .slave
544 .spawn_command(cmd)
545 .context("failed to spawn agent CLI")?;
546
547 drop(pty_pair.slave);
549
550 let mut pty_reader = pty_pair
551 .master
552 .try_clone_reader()
553 .context("failed to clone PTY reader")?;
554
555 let pty_writer = pty_pair
556 .master
557 .take_writer()
558 .context("failed to take PTY writer")?;
559
560 let inner = Arc::new(Mutex::new(ShimInner {
562 parser: vt100::Parser::new(rows, cols, SCROLLBACK_LINES),
563 state: ShimState::Starting,
564 state_changed_at: Instant::now(),
565 last_screen_hash: 0,
566 last_pty_output_at: Instant::now(),
567 started_at: Instant::now(),
568 cumulative_output_bytes: 0,
569 pre_injection_content: String::new(),
570 pending_message_id: None,
571 agent_type: args.agent_type,
572 message_queue: VecDeque::new(),
573 dialogs_dismissed: 0,
574 last_working_screen: String::new(),
575 }));
576
577 let pty_log: Option<Mutex<PtyLogWriter>> = args
579 .pty_log_path
580 .as_deref()
581 .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
582 .transpose()?
583 .map(Mutex::new);
584 let pty_log = pty_log.map(Arc::new);
585
586 let pty_writer = Arc::new(Mutex::new(pty_writer));
588
589 let mut cmd_channel = channel;
591 let mut evt_channel = cmd_channel.try_clone().context("failed to clone channel")?;
592
593 let inner_pty = Arc::clone(&inner);
595 let log_handle = pty_log.clone();
596 let pty_writer_pty = Arc::clone(&pty_writer);
597 let pty_handle = std::thread::spawn(move || {
598 let mut buf = [0u8; 4096];
599 loop {
600 match pty_reader.read(&mut buf) {
601 Ok(0) => break, Ok(n) => {
603 if let Some(ref log) = log_handle {
605 let _ = log.lock().unwrap().write(&buf[..n]);
606 }
607
608 let mut inner = inner_pty.lock().unwrap();
609 inner.last_pty_output_at = Instant::now();
610 inner.cumulative_output_bytes =
611 inner.cumulative_output_bytes.saturating_add(n as u64);
612 inner.parser.process(&buf[..n]);
613
614 let content = inner.parser.screen().contents();
621 let hash = content_hash(&content);
622 if hash == inner.last_screen_hash {
623 continue; }
625 inner.last_screen_hash = hash;
626
627 let verdict = classifier::classify(inner.agent_type, inner.parser.screen());
628 let old_state = inner.state;
629
630 if old_state == ShimState::Working {
634 inner.last_working_screen = content.clone();
635 }
636
637 let working_too_short = old_state == ShimState::Working
641 && inner.state_changed_at.elapsed().as_millis() < WORKING_DWELL_MS as u128;
642 let new_state = match (old_state, verdict) {
643 (ShimState::Starting, ScreenVerdict::AgentIdle) => Some(ShimState::Idle),
644 (ShimState::Idle, ScreenVerdict::AgentIdle) => None,
645 (ShimState::Working, ScreenVerdict::AgentIdle) if working_too_short => None,
646 (ShimState::Working, ScreenVerdict::AgentIdle)
647 if inner.agent_type == AgentType::Kiro =>
648 {
649 None
650 }
651 (ShimState::Working, ScreenVerdict::AgentIdle) => Some(ShimState::Idle),
652 (ShimState::Working, ScreenVerdict::AgentWorking) => None,
653 (_, ScreenVerdict::ContextExhausted) => Some(ShimState::ContextExhausted),
654 (_, ScreenVerdict::Unknown) => None,
655 (ShimState::Idle, ScreenVerdict::AgentWorking) => Some(ShimState::Working),
656 (ShimState::Starting, ScreenVerdict::AgentWorking) => {
657 Some(ShimState::Working)
658 }
659 _ => None,
660 };
661
662 if let Some(new) = new_state {
663 let summary = inner.last_n_lines(5);
664 inner.state = new;
665 inner.state_changed_at = Instant::now();
666
667 let pre_content = inner.pre_injection_content.clone();
668 let current_content = inner.screen_contents();
669 let working_screen = inner.last_working_screen.clone();
670 let msg_id = inner.pending_message_id.take();
671
672 let drain_errors =
674 if new == ShimState::Dead || new == ShimState::ContextExhausted {
675 drain_queue_errors(&mut inner.message_queue, new)
676 } else {
677 Vec::new()
678 };
679
680 let queued_msg = if old_state == ShimState::Working
682 && new == ShimState::Idle
683 && !inner.message_queue.is_empty()
684 {
685 inner.message_queue.pop_front()
686 } else {
687 None
688 };
689
690 if let Some(ref msg) = queued_msg {
692 inner.pre_injection_content = inner.screen_contents();
693 inner.pending_message_id = msg.message_id.clone();
694 inner.state = ShimState::Working;
695 inner.state_changed_at = Instant::now();
696 }
697
698 let queue_depth = inner.message_queue.len();
699 let agent_type_for_enter = inner.agent_type;
700 let queued_injected = queued_msg
701 .as_ref()
702 .map(|msg| format_injected_message(&msg.from, &msg.body));
703
704 drop(inner); let events = build_transition_events(
707 old_state,
708 new,
709 &summary,
710 &pre_content,
711 ¤t_content,
712 &working_screen,
713 msg_id,
714 );
715
716 for event in events {
717 if evt_channel.send(&event).is_err() {
718 return; }
720 }
721
722 for event in drain_errors {
724 if evt_channel.send(&event).is_err() {
725 return;
726 }
727 }
728
729 if let Some(msg) = queued_msg {
731 let enter = enter_seq(agent_type_for_enter);
732 let injected = queued_injected.as_deref().unwrap_or(msg.body.as_str());
733 if let Err(e) = pty_write_paced(
734 &pty_writer_pty,
735 agent_type_for_enter,
736 injected.as_bytes(),
737 enter.as_bytes(),
738 ) {
739 let _ = evt_channel.send(&Event::Error {
740 command: "SendMessage".into(),
741 reason: format!("PTY write failed for queued message: {e}"),
742 });
743 }
744
745 let _ = evt_channel.send(&Event::StateChanged {
747 from: ShimState::Idle,
748 to: ShimState::Working,
749 summary: format!(
750 "delivering queued message ({} remaining)",
751 queue_depth
752 ),
753 });
754 }
755 }
756 }
757 Err(_) => break, }
759 }
760
761 let mut inner = inner_pty.lock().unwrap();
763 let last_lines = inner.last_n_lines(10);
764 let old = inner.state;
765 inner.state = ShimState::Dead;
766
767 let drain_errors = drain_queue_errors(&mut inner.message_queue, ShimState::Dead);
769 drop(inner);
770
771 let _ = evt_channel.send(&Event::StateChanged {
772 from: old,
773 to: ShimState::Dead,
774 summary: last_lines.clone(),
775 });
776
777 let _ = evt_channel.send(&Event::Died {
778 exit_code: None,
779 last_lines,
780 });
781
782 for event in drain_errors {
783 let _ = evt_channel.send(&event);
784 }
785 });
786
787 let inner_idle = Arc::clone(&inner);
791 let pty_writer_idle = Arc::clone(&pty_writer);
792 let mut idle_channel = cmd_channel.try_clone().context("failed to clone channel")?;
793 std::thread::spawn(move || {
794 loop {
795 std::thread::sleep(std::time::Duration::from_millis(POLL_INTERVAL_MS));
796
797 let mut inner = inner_idle.lock().unwrap();
798 if inner.agent_type != AgentType::Kiro || inner.state != ShimState::Working {
799 continue;
800 }
801 if inner.last_pty_output_at.elapsed().as_millis() < KIRO_IDLE_SETTLE_MS as u128 {
802 continue;
803 }
804 if classifier::classify(inner.agent_type, inner.parser.screen())
805 != ScreenVerdict::AgentIdle
806 {
807 continue;
808 }
809
810 let summary = inner.last_n_lines(5);
811 let pre_content = inner.pre_injection_content.clone();
812 let current_content = inner.screen_contents();
813 let working_screen = inner.last_working_screen.clone();
814 let msg_id = inner.pending_message_id.take();
815
816 inner.state = ShimState::Idle;
817 inner.state_changed_at = Instant::now();
818
819 let queued_msg = if !inner.message_queue.is_empty() {
820 inner.message_queue.pop_front()
821 } else {
822 None
823 };
824
825 if let Some(ref msg) = queued_msg {
826 inner.pre_injection_content = inner.screen_contents();
827 inner.pending_message_id = msg.message_id.clone();
828 inner.state = ShimState::Working;
829 inner.state_changed_at = Instant::now();
830 }
831
832 let queue_depth = inner.message_queue.len();
833 let agent_type_for_enter = inner.agent_type;
834 let queued_injected = queued_msg
835 .as_ref()
836 .map(|msg| format_injected_message(&msg.from, &msg.body));
837 drop(inner);
838
839 for event in build_transition_events(
840 ShimState::Working,
841 ShimState::Idle,
842 &summary,
843 &pre_content,
844 ¤t_content,
845 &working_screen,
846 msg_id,
847 ) {
848 if idle_channel.send(&event).is_err() {
849 return;
850 }
851 }
852
853 if let Some(msg) = queued_msg {
854 let enter = enter_seq(agent_type_for_enter);
855 let injected = queued_injected.as_deref().unwrap_or(msg.body.as_str());
856 if let Err(e) = pty_write_paced(
857 &pty_writer_idle,
858 agent_type_for_enter,
859 injected.as_bytes(),
860 enter.as_bytes(),
861 ) {
862 let _ = idle_channel.send(&Event::Error {
863 command: "SendMessage".into(),
864 reason: format!("PTY write failed for queued message: {e}"),
865 });
866 continue;
867 }
868
869 let _ = idle_channel.send(&Event::StateChanged {
870 from: ShimState::Idle,
871 to: ShimState::Working,
872 summary: format!("delivering queued message ({} remaining)", queue_depth),
873 });
874 }
875 }
876 });
877
878 let inner_poll = Arc::clone(&inner);
884 let mut poll_channel = cmd_channel
885 .try_clone()
886 .context("failed to clone channel for poll thread")?;
887 std::thread::spawn(move || {
888 loop {
889 std::thread::sleep(std::time::Duration::from_secs(5));
890 let mut inner = inner_poll.lock().unwrap();
891 if inner.state != ShimState::Working {
892 continue;
893 }
894 if inner.last_pty_output_at.elapsed().as_secs() < 2 {
896 continue;
897 }
898 let verdict = classifier::classify(inner.agent_type, inner.parser.screen());
899 if verdict == classifier::ScreenVerdict::AgentIdle {
900 let summary = inner.last_n_lines(5);
901 inner.state = ShimState::Idle;
902 inner.state_changed_at = Instant::now();
903 drop(inner);
904
905 let _ = poll_channel.send(&Event::StateChanged {
908 from: ShimState::Working,
909 to: ShimState::Idle,
910 summary,
911 });
912 }
913 }
914 });
915
916 let inner_stats = Arc::clone(&inner);
917 let mut stats_channel = cmd_channel
918 .try_clone()
919 .context("failed to clone channel for stats thread")?;
920 std::thread::spawn(move || {
921 loop {
922 std::thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
923 let inner = inner_stats.lock().unwrap();
924 if inner.state == ShimState::Dead {
925 return;
926 }
927 let output_bytes = inner.cumulative_output_bytes;
928 let uptime_secs = inner.started_at.elapsed().as_secs();
929 drop(inner);
930
931 if stats_channel
932 .send(&Event::SessionStats {
933 output_bytes,
934 uptime_secs,
935 })
936 .is_err()
937 {
938 return;
939 }
940 }
941 });
942
943 let inner_cmd = Arc::clone(&inner);
945
946 let start = Instant::now();
950 loop {
951 let mut inner = inner_cmd.lock().unwrap();
952 let state = inner.state;
953 match state {
954 ShimState::Starting => {
955 if inner.dialogs_dismissed < 10 {
957 let content = inner.screen_contents();
958 if classifier::detect_startup_dialog(&content) {
959 let attempt = inner.dialogs_dismissed + 1;
960 let enter = enter_seq(inner.agent_type);
961 inner.dialogs_dismissed = attempt;
962 drop(inner);
963 eprintln!(
964 "[shim {}] auto-dismissing startup dialog (attempt {attempt})",
965 args.id
966 );
967 let mut writer = pty_writer.lock().unwrap();
968 writer.write_all(enter.as_bytes()).ok();
969 writer.flush().ok();
970 std::thread::sleep(std::time::Duration::from_millis(POLL_INTERVAL_MS));
971 continue;
972 }
973 }
974 drop(inner);
975
976 if start.elapsed().as_secs() > READY_TIMEOUT_SECS {
977 let last = inner_cmd.lock().unwrap().last_n_lines(10);
978 cmd_channel.send(&Event::Error {
979 command: "startup".into(),
980 reason: format!(
981 "agent did not show prompt within {}s. Last lines:\n{}",
982 READY_TIMEOUT_SECS, last,
983 ),
984 })?;
985 terminate_agent_group(&mut child, Duration::from_secs(GROUP_TERM_GRACE_SECS))
986 .ok();
987 return Ok(());
988 }
989 thread::sleep(Duration::from_millis(POLL_INTERVAL_MS));
990 }
991 ShimState::Dead => {
992 drop(inner);
993 return Ok(());
994 }
995 ShimState::Idle => {
996 drop(inner);
997 cmd_channel.send(&Event::Ready)?;
998 break;
999 }
1000 _ => {
1001 drop(inner);
1004 if start.elapsed().as_secs() > READY_TIMEOUT_SECS {
1005 let last = inner_cmd.lock().unwrap().last_n_lines(10);
1006 cmd_channel.send(&Event::Error {
1007 command: "startup".into(),
1008 reason: format!(
1009 "agent did not reach idle within {}s (state: {}). Last lines:\n{}",
1010 READY_TIMEOUT_SECS, state, last,
1011 ),
1012 })?;
1013 terminate_agent_group(&mut child, Duration::from_secs(GROUP_TERM_GRACE_SECS))
1014 .ok();
1015 return Ok(());
1016 }
1017 thread::sleep(Duration::from_millis(POLL_INTERVAL_MS));
1018 }
1019 }
1020 }
1021
1022 loop {
1024 let cmd = match cmd_channel.recv::<Command>() {
1025 Ok(Some(c)) => c,
1026 Ok(None) => {
1027 eprintln!(
1028 "[shim {}] orchestrator disconnected, shutting down",
1029 args.id
1030 );
1031 terminate_agent_group(&mut child, Duration::from_secs(GROUP_TERM_GRACE_SECS)).ok();
1032 break;
1033 }
1034 Err(e) => {
1035 eprintln!("[shim {}] channel error: {e}", args.id);
1036 terminate_agent_group(&mut child, Duration::from_secs(GROUP_TERM_GRACE_SECS)).ok();
1037 break;
1038 }
1039 };
1040
1041 match cmd {
1042 Command::SendMessage {
1043 from,
1044 body,
1045 message_id,
1046 } => {
1047 let mut inner = inner_cmd.lock().unwrap();
1048 match inner.state {
1049 ShimState::Idle => {
1050 inner.pre_injection_content = inner.screen_contents();
1051 inner.pending_message_id = message_id;
1052 let agent_type = inner.agent_type;
1053 let enter = enter_seq(agent_type);
1054 let injected = format_injected_message(&from, &body);
1055 drop(inner);
1056 if let Err(e) = pty_write_paced(
1060 &pty_writer,
1061 agent_type,
1062 injected.as_bytes(),
1063 enter.as_bytes(),
1064 ) {
1065 cmd_channel.send(&Event::Error {
1066 command: "SendMessage".into(),
1067 reason: format!("PTY write failed: {e}"),
1068 })?;
1069 continue;
1071 }
1072 let mut inner = inner_cmd.lock().unwrap();
1073
1074 let old = inner.state;
1075 inner.state = ShimState::Working;
1076 inner.state_changed_at = Instant::now();
1077 let summary = inner.last_n_lines(3);
1078 drop(inner);
1079
1080 cmd_channel.send(&Event::StateChanged {
1081 from: old,
1082 to: ShimState::Working,
1083 summary,
1084 })?;
1085 }
1086 ShimState::Working => {
1087 if inner.message_queue.len() >= MAX_QUEUE_DEPTH {
1089 let dropped = inner.message_queue.pop_front();
1090 let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
1091 inner.message_queue.push_back(QueuedMessage {
1092 from,
1093 body,
1094 message_id,
1095 });
1096 let depth = inner.message_queue.len();
1097 drop(inner);
1098
1099 cmd_channel.send(&Event::Error {
1100 command: "SendMessage".into(),
1101 reason: format!(
1102 "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
1103 dropped_id
1104 .map(|id| format!(" (id: {id})"))
1105 .unwrap_or_default(),
1106 ),
1107 })?;
1108 cmd_channel.send(&Event::Warning {
1109 message: format!(
1110 "message queued while agent working (depth: {depth})"
1111 ),
1112 idle_secs: None,
1113 })?;
1114 } else {
1115 inner.message_queue.push_back(QueuedMessage {
1116 from,
1117 body,
1118 message_id,
1119 });
1120 let depth = inner.message_queue.len();
1121 drop(inner);
1122
1123 cmd_channel.send(&Event::Warning {
1124 message: format!(
1125 "message queued while agent working (depth: {depth})"
1126 ),
1127 idle_secs: None,
1128 })?;
1129 }
1130 }
1131 other => {
1132 cmd_channel.send(&Event::Error {
1133 command: "SendMessage".into(),
1134 reason: format!("agent in {other} state, cannot accept message"),
1135 })?;
1136 }
1137 }
1138 }
1139
1140 Command::CaptureScreen { last_n_lines } => {
1141 let inner = inner_cmd.lock().unwrap();
1142 let content = match last_n_lines {
1143 Some(n) => inner.last_n_lines(n),
1144 None => inner.screen_contents(),
1145 };
1146 let (row, col) = inner.cursor_position();
1147 drop(inner);
1148 cmd_channel.send(&Event::ScreenCapture {
1149 content,
1150 cursor_row: row,
1151 cursor_col: col,
1152 })?;
1153 }
1154
1155 Command::GetState => {
1156 let inner = inner_cmd.lock().unwrap();
1157 let since = inner.state_changed_at.elapsed().as_secs();
1158 let state = inner.state;
1159 drop(inner);
1160 cmd_channel.send(&Event::State {
1161 state,
1162 since_secs: since,
1163 })?;
1164 }
1165
1166 Command::Resize { rows, cols } => {
1167 pty_pair
1168 .master
1169 .resize(PtySize {
1170 rows,
1171 cols,
1172 pixel_width: 0,
1173 pixel_height: 0,
1174 })
1175 .ok();
1176 let mut inner = inner_cmd.lock().unwrap();
1177 inner.parser.set_size(rows, cols);
1178 }
1179
1180 Command::Ping => {
1181 cmd_channel.send(&Event::Pong)?;
1182 }
1183
1184 Command::Shutdown { timeout_secs } => {
1185 eprintln!(
1186 "[shim {}] shutdown requested (timeout: {}s)",
1187 args.id, timeout_secs
1188 );
1189 if let Err(error) = args.preserve_work_before_kill(&args.cwd) {
1190 eprintln!(
1191 "[shim {}] auto-save before shutdown failed: {}",
1192 args.id, error
1193 );
1194 }
1195 {
1196 let mut writer = pty_writer.lock().unwrap();
1197 writer.write_all(b"\x03").ok(); writer.flush().ok();
1199 }
1200 let deadline = Instant::now() + Duration::from_secs(timeout_secs as u64);
1201 loop {
1202 if Instant::now() > deadline {
1203 terminate_agent_group(
1204 &mut child,
1205 Duration::from_secs(GROUP_TERM_GRACE_SECS),
1206 )
1207 .ok();
1208 break;
1209 }
1210 if let Ok(Some(_)) = child.try_wait() {
1211 break;
1212 }
1213 thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS));
1214 }
1215 break;
1216 }
1217
1218 Command::Kill => {
1219 if let Err(error) = args.preserve_work_before_kill(&args.cwd) {
1220 eprintln!("[shim {}] auto-save before kill failed: {}", args.id, error);
1221 }
1222 terminate_agent_group(&mut child, Duration::from_secs(GROUP_TERM_GRACE_SECS)).ok();
1223 break;
1224 }
1225 }
1226 }
1227
1228 pty_handle.join().ok();
1229 Ok(())
1230}
1231
1232fn drain_queue_errors(
1233 queue: &mut VecDeque<QueuedMessage>,
1234 terminal_state: ShimState,
1235) -> Vec<Event> {
1236 common::drain_queue_errors(queue, terminal_state)
1237}
1238
1239fn build_transition_events(
1244 from: ShimState,
1245 to: ShimState,
1246 summary: &str,
1247 pre_injection_content: &str,
1248 current_content: &str,
1249 last_working_screen: &str,
1250 message_id: Option<String>,
1251) -> Vec<Event> {
1252 let summary = sanitize_summary(summary);
1253 let mut events = vec![Event::StateChanged {
1254 from,
1255 to,
1256 summary: summary.clone(),
1257 }];
1258
1259 if from == ShimState::Working && to == ShimState::Idle && !pre_injection_content.is_empty() {
1263 let mut response = extract_response(pre_injection_content, current_content);
1267 if response.is_empty() && !last_working_screen.is_empty() {
1268 response = extract_response(pre_injection_content, last_working_screen);
1269 }
1270 events.push(Event::Completion {
1271 message_id,
1272 response,
1273 last_lines: summary.clone(),
1274 });
1275 }
1276
1277 if to == ShimState::ContextExhausted {
1279 events.push(Event::ContextExhausted {
1280 message: "Agent reported context exhaustion".to_string(),
1281 last_lines: summary,
1282 });
1283 }
1284
1285 events
1286}
1287
1288fn sanitize_summary(summary: &str) -> String {
1289 let cleaned: Vec<String> = summary
1290 .lines()
1291 .filter_map(|line| {
1292 let trimmed = line.trim();
1293 if trimmed.is_empty() || is_tui_chrome(line) || is_prompt_line(trimmed) {
1294 return None;
1295 }
1296 Some(strip_claude_bullets(trimmed))
1297 })
1298 .collect();
1299
1300 if cleaned.is_empty() {
1301 String::new()
1302 } else {
1303 cleaned.join("\n")
1304 }
1305}
1306
1307fn extract_response(pre: &str, current: &str) -> String {
1311 let pre_lines: Vec<&str> = pre.lines().collect();
1312 let cur_lines: Vec<&str> = current.lines().collect();
1313
1314 let overlap = pre_lines.len().min(cur_lines.len());
1315 let mut diverge_at = 0;
1316 for i in 0..overlap {
1317 if pre_lines[i] != cur_lines[i] {
1318 break;
1319 }
1320 diverge_at = i + 1;
1321 }
1322
1323 let response_lines = &cur_lines[diverge_at..];
1324 if response_lines.is_empty() {
1325 return String::new();
1326 }
1327
1328 let filtered: Vec<&str> = response_lines
1330 .iter()
1331 .filter(|line| !is_tui_chrome(line))
1332 .copied()
1333 .collect();
1334
1335 if filtered.is_empty() {
1336 return String::new();
1337 }
1338
1339 let mut end = filtered.len();
1341 while end > 0 && filtered[end - 1].trim().is_empty() {
1342 end -= 1;
1343 }
1344 while end > 0 && is_prompt_line(filtered[end - 1].trim()) {
1345 end -= 1;
1346 }
1347 while end > 0 && filtered[end - 1].trim().is_empty() {
1348 end -= 1;
1349 }
1350
1351 let mut start = 0;
1353 while start < end {
1354 let trimmed = filtered[start].trim();
1355 if trimmed.is_empty() {
1356 start += 1;
1357 } else if trimmed.starts_with('\u{276F}')
1358 && !trimmed['\u{276F}'.len_utf8()..].trim().is_empty()
1359 {
1360 start += 1;
1362 } else {
1363 break;
1364 }
1365 }
1366
1367 let cleaned: Vec<String> = filtered[start..end]
1369 .iter()
1370 .map(|line| strip_claude_bullets(line))
1371 .collect();
1372
1373 cleaned.join("\n")
1374}
1375
1376fn strip_claude_bullets(line: &str) -> String {
1378 let trimmed = line.trim_start();
1379 if trimmed.starts_with('\u{23FA}') {
1380 let after = &trimmed['\u{23FA}'.len_utf8()..];
1381 let leading = line.len() - line.trim_start().len();
1383 format!("{}{}", &" ".repeat(leading), after.trim_start())
1384 } else {
1385 line.to_string()
1386 }
1387}
1388
1389fn is_tui_chrome(line: &str) -> bool {
1393 let trimmed = line.trim();
1394 if trimmed.is_empty() {
1395 return false; }
1397
1398 if trimmed.chars().all(|c| {
1400 matches!(
1401 c,
1402 '─' | '━'
1403 | '═'
1404 | '╌'
1405 | '╍'
1406 | '┄'
1407 | '┅'
1408 | '╶'
1409 | '╴'
1410 | '╸'
1411 | '╺'
1412 | '│'
1413 | '┃'
1414 | '╎'
1415 | '╏'
1416 | '┊'
1417 | '┋'
1418 )
1419 }) {
1420 return true;
1421 }
1422
1423 if trimmed.contains("\u{23F5}\u{23F5}") || trimmed.contains("bypass permissions") {
1425 return true;
1426 }
1427 if trimmed.contains("shift+tab") && trimmed.len() < 80 {
1428 return true;
1429 }
1430
1431 if trimmed.starts_with('$') && trimmed.contains("token") {
1433 return true;
1434 }
1435
1436 let braille_count = trimmed
1438 .chars()
1439 .filter(|c| ('\u{2800}'..='\u{28FF}').contains(c))
1440 .count();
1441 if braille_count > 5 {
1442 return true;
1443 }
1444
1445 let lower = trimmed.to_lowercase();
1447 if lower.contains("welcome to the new kiro") || lower.contains("/feedback command") {
1448 return true;
1449 }
1450
1451 if lower.starts_with("kiro") && lower.contains('\u{25D4}') {
1453 return true;
1455 }
1456
1457 if trimmed.starts_with('╭') || trimmed.starts_with('╰') || trimmed.starts_with('│') {
1459 return true;
1460 }
1461
1462 if lower.starts_with("tip:") || (trimmed.starts_with('⚠') && lower.contains("limit")) {
1464 return true;
1465 }
1466
1467 if lower.contains("ask a question") || lower.contains("describe a task") {
1469 return true;
1470 }
1471
1472 false
1473}
1474
1475fn is_prompt_line(line: &str) -> bool {
1476 line == "\u{276F}"
1477 || line.starts_with("\u{276F} ")
1478 || line == "\u{203A}"
1479 || line.starts_with("\u{203A} ")
1480 || line.ends_with("$ ")
1481 || line.ends_with('$')
1482 || line.ends_with("% ")
1483 || line.ends_with('%')
1484 || line == ">"
1485 || line.starts_with("Kiro>")
1486}
1487
1488#[cfg(test)]
1493mod tests {
1494 use super::*;
1495
1496 #[test]
1497 fn extract_response_basic() {
1498 let pre = "line1\nline2\n$ ";
1499 let cur = "line1\nline2\nhello world\n$ ";
1500 assert_eq!(extract_response(pre, cur), "hello world");
1501 }
1502
1503 #[test]
1504 fn extract_response_multiline() {
1505 let pre = "$ ";
1506 let cur = "$ echo hi\nhi\n$ ";
1507 let resp = extract_response(pre, cur);
1508 assert!(resp.contains("echo hi"));
1509 assert!(resp.contains("hi"));
1510 }
1511
1512 #[test]
1513 fn extract_response_empty() {
1514 let pre = "$ ";
1515 let cur = "$ ";
1516 assert_eq!(extract_response(pre, cur), "");
1517 }
1518
1519 #[test]
1520 fn content_hash_deterministic() {
1521 assert_eq!(content_hash("hello"), content_hash("hello"));
1522 assert_ne!(content_hash("hello"), content_hash("world"));
1523 }
1524
1525 #[test]
1526 fn shell_single_quote_escapes_embedded_quote() {
1527 assert_eq!(shell_single_quote("fix user's bug"), "fix user'\\''s bug");
1528 }
1529
1530 #[test]
1531 fn supervised_command_contains_watchdog_and_exec() {
1532 let command = build_supervised_agent_command("kiro-cli chat 'hello'", 4242);
1533 assert!(command.contains("shim_pid=4242"));
1534 assert!(command.contains("agent_root_pid=$$"));
1535 assert!(command.contains("agent_pgid=$$"));
1536 assert!(command.contains("setsid sh -c"));
1537 assert!(command.contains("shim_pid=\"$1\""));
1538 assert!(command.contains("agent_pgid=\"$2\""));
1539 assert!(command.contains("agent_root_pid=\"$3\""));
1540 assert!(command.contains("collect_descendants()"));
1541 assert!(command.contains("pgrep -P \"$parent_pid\""));
1542 assert!(command.contains("descendant_pids=$(collect_descendants \"$agent_root_pid\")"));
1543 assert!(command.contains("kill -TERM -- -\"$agent_pgid\""));
1544 assert!(command.contains("kill -TERM \"$descendant_pid\""));
1545 assert!(command.contains("kill -KILL -- -\"$agent_pgid\""));
1546 assert!(command.contains("kill -KILL \"$descendant_pid\""));
1547 assert!(command.contains("' _ \"$shim_pid\" \"$agent_pgid\" \"$agent_root_pid\""));
1548 assert!(command.contains("exec bash -lc 'kiro-cli chat '\\''hello'\\'''"));
1549 }
1550
1551 #[test]
1552 fn is_prompt_line_shell_dollar() {
1553 assert!(is_prompt_line("user@host:~$ "));
1554 assert!(is_prompt_line("$"));
1555 }
1556
1557 #[test]
1558 fn is_prompt_line_claude() {
1559 assert!(is_prompt_line("\u{276F}"));
1560 assert!(is_prompt_line("\u{276F} "));
1561 }
1562
1563 #[test]
1564 fn is_prompt_line_codex() {
1565 assert!(is_prompt_line("\u{203A}"));
1566 assert!(is_prompt_line("\u{203A} "));
1567 }
1568
1569 #[test]
1570 fn is_prompt_line_kiro() {
1571 assert!(is_prompt_line("Kiro>"));
1572 assert!(is_prompt_line(">"));
1573 }
1574
1575 #[test]
1576 fn is_prompt_line_not_prompt() {
1577 assert!(!is_prompt_line("hello world"));
1578 assert!(!is_prompt_line("some output here"));
1579 }
1580
1581 #[test]
1582 fn build_transition_events_working_to_idle() {
1583 let events = build_transition_events(
1584 ShimState::Working,
1585 ShimState::Idle,
1586 "summary",
1587 "pre\n$ ",
1588 "pre\nhello\n$ ",
1589 "",
1590 Some("msg-1".into()),
1591 );
1592 assert_eq!(events.len(), 2);
1593 assert!(matches!(&events[0], Event::StateChanged { .. }));
1594 assert!(matches!(&events[1], Event::Completion { .. }));
1595 }
1596
1597 #[test]
1598 fn build_transition_events_to_context_exhausted() {
1599 let events = build_transition_events(
1600 ShimState::Working,
1601 ShimState::ContextExhausted,
1602 "summary",
1603 "",
1604 "",
1605 "",
1606 None,
1607 );
1608 assert_eq!(events.len(), 2);
1610 assert!(matches!(&events[1], Event::ContextExhausted { .. }));
1611 }
1612
1613 #[test]
1614 fn build_transition_events_starting_to_idle() {
1615 let events = build_transition_events(
1616 ShimState::Starting,
1617 ShimState::Idle,
1618 "summary",
1619 "",
1620 "",
1621 "",
1622 None,
1623 );
1624 assert_eq!(events.len(), 1);
1625 assert!(matches!(&events[0], Event::StateChanged { .. }));
1626 }
1627
1628 fn make_queued_msg(id: &str, body: &str) -> QueuedMessage {
1633 QueuedMessage {
1634 from: "user".into(),
1635 body: body.into(),
1636 message_id: Some(id.into()),
1637 }
1638 }
1639
1640 #[test]
1641 fn queue_enqueue_basic() {
1642 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1643 queue.push_back(make_queued_msg("m1", "hello"));
1644 queue.push_back(make_queued_msg("m2", "world"));
1645 assert_eq!(queue.len(), 2);
1646 }
1647
1648 #[test]
1649 fn queue_fifo_order() {
1650 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1651 queue.push_back(make_queued_msg("m1", "first"));
1652 queue.push_back(make_queued_msg("m2", "second"));
1653 queue.push_back(make_queued_msg("m3", "third"));
1654
1655 let msg = queue.pop_front().unwrap();
1656 assert_eq!(msg.message_id.as_deref(), Some("m1"));
1657 assert_eq!(msg.body, "first");
1658
1659 let msg = queue.pop_front().unwrap();
1660 assert_eq!(msg.message_id.as_deref(), Some("m2"));
1661 assert_eq!(msg.body, "second");
1662
1663 let msg = queue.pop_front().unwrap();
1664 assert_eq!(msg.message_id.as_deref(), Some("m3"));
1665 assert_eq!(msg.body, "third");
1666
1667 assert!(queue.is_empty());
1668 }
1669
1670 #[test]
1671 fn queue_overflow_drops_oldest() {
1672 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1673
1674 for i in 0..MAX_QUEUE_DEPTH {
1676 queue.push_back(make_queued_msg(&format!("m{i}"), &format!("msg {i}")));
1677 }
1678 assert_eq!(queue.len(), MAX_QUEUE_DEPTH);
1679
1680 assert!(queue.len() >= MAX_QUEUE_DEPTH);
1682 let dropped = queue.pop_front().unwrap();
1683 assert_eq!(dropped.message_id.as_deref(), Some("m0")); queue.push_back(make_queued_msg("m_new", "new message"));
1685 assert_eq!(queue.len(), MAX_QUEUE_DEPTH);
1686
1687 let first = queue.pop_front().unwrap();
1689 assert_eq!(first.message_id.as_deref(), Some("m1"));
1690 }
1691
1692 #[test]
1693 fn drain_queue_errors_empty() {
1694 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1695 let events = drain_queue_errors(&mut queue, ShimState::Dead);
1696 assert!(events.is_empty());
1697 }
1698
1699 #[test]
1700 fn drain_queue_errors_with_messages() {
1701 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1702 queue.push_back(make_queued_msg("m1", "hello"));
1703 queue.push_back(make_queued_msg("m2", "world"));
1704 queue.push_back(QueuedMessage {
1705 from: "user".into(),
1706 body: "no id".into(),
1707 message_id: None,
1708 });
1709
1710 let events = drain_queue_errors(&mut queue, ShimState::Dead);
1711 assert_eq!(events.len(), 3);
1712 assert!(queue.is_empty());
1713
1714 for event in &events {
1716 assert!(matches!(event, Event::Error { .. }));
1717 }
1718
1719 if let Event::Error { reason, .. } = &events[0] {
1721 assert!(reason.contains("dead"));
1722 assert!(reason.contains("m1"));
1723 }
1724
1725 if let Event::Error { reason, .. } = &events[2] {
1727 assert!(!reason.contains("(id:"));
1728 }
1729 }
1730
1731 #[test]
1732 fn drain_queue_errors_context_exhausted() {
1733 let mut queue: VecDeque<QueuedMessage> = VecDeque::new();
1734 queue.push_back(make_queued_msg("m1", "hello"));
1735
1736 let events = drain_queue_errors(&mut queue, ShimState::ContextExhausted);
1737 assert_eq!(events.len(), 1);
1738 if let Event::Error { reason, .. } = &events[0] {
1739 assert!(reason.contains("context_exhausted"));
1740 }
1741 }
1742
1743 #[test]
1744 fn queued_message_preserves_fields() {
1745 let msg = QueuedMessage {
1746 from: "manager".into(),
1747 body: "do this task".into(),
1748 message_id: Some("msg-42".into()),
1749 };
1750 assert_eq!(msg.from, "manager");
1751 assert_eq!(msg.body, "do this task");
1752 assert_eq!(msg.message_id.as_deref(), Some("msg-42"));
1753 }
1754
1755 #[test]
1756 fn queued_message_none_id() {
1757 let msg = QueuedMessage {
1758 from: "user".into(),
1759 body: "anonymous".into(),
1760 message_id: None,
1761 };
1762 assert!(msg.message_id.is_none());
1763 }
1764
1765 #[test]
1766 fn max_queue_depth_is_16() {
1767 assert_eq!(MAX_QUEUE_DEPTH, 16);
1768 }
1769
1770 #[test]
1771 fn format_injected_message_includes_sender_and_reply_target() {
1772 let formatted = format_injected_message("human", "what is 2+2?");
1773 assert!(formatted.contains("--- Message from human ---"));
1774 assert!(formatted.contains("Reply-To: human"));
1775 assert!(formatted.contains("batty send human"));
1776 assert!(formatted.ends_with("what is 2+2?"));
1777 }
1778
1779 #[test]
1780 fn format_injected_message_uses_sender_as_reply_target() {
1781 let formatted = format_injected_message("manager", "status?");
1782 assert!(formatted.contains("Reply-To: manager"));
1783 assert!(formatted.contains("batty send manager"));
1784 }
1785
1786 #[test]
1787 fn sanitize_summary_strips_tui_chrome_and_prompt_lines() {
1788 let summary = "────────────────────\n❯ \n ⏵⏵ bypass permissions on\nThe answer is 4\n";
1789 assert_eq!(sanitize_summary(summary), "The answer is 4");
1790 }
1791
1792 #[test]
1793 fn sanitize_summary_keeps_multiline_meaningful_content() {
1794 let summary = " Root cause: stale resume id\n\n Fix: retry with fresh start\n";
1795 assert_eq!(
1796 sanitize_summary(summary),
1797 "Root cause: stale resume id\nFix: retry with fresh start"
1798 );
1799 }
1800
1801 #[test]
1806 fn is_tui_chrome_horizontal_rule() {
1807 assert!(is_tui_chrome("────────────────────────────────────"));
1808 assert!(is_tui_chrome(" ───────── "));
1809 assert!(is_tui_chrome("━━━━━━━━━━━━━━━━━━━━"));
1810 }
1811
1812 #[test]
1813 fn is_tui_chrome_status_bar() {
1814 assert!(is_tui_chrome(
1815 " \u{23F5}\u{23F5} bypass permissions on (shift+tab to toggle)"
1816 ));
1817 assert!(is_tui_chrome(" bypass permissions on"));
1818 assert!(is_tui_chrome(" shift+tab"));
1819 }
1820
1821 #[test]
1822 fn is_tui_chrome_cost_line() {
1823 assert!(is_tui_chrome("$0.01 · 2.3k tokens"));
1824 }
1825
1826 #[test]
1827 fn is_tui_chrome_not_content() {
1828 assert!(!is_tui_chrome("Hello, world!"));
1829 assert!(!is_tui_chrome("The answer is 4"));
1830 assert!(!is_tui_chrome("")); assert!(!is_tui_chrome(" some output "));
1832 }
1833
1834 #[test]
1835 fn extract_response_strips_chrome() {
1836 let pre = "idle screen\n\u{276F} ";
1837 let cur = "\u{276F} Hello\n\nThe answer is 42\n\n\
1838 ────────────────────\n\
1839 \u{23F5}\u{23F5} bypass permissions on\n\
1840 \u{276F} ";
1841 let resp = extract_response(pre, cur);
1842 assert!(resp.contains("42"), "should contain the answer: {resp}");
1843 assert!(
1844 !resp.contains("────"),
1845 "should strip horizontal rule: {resp}"
1846 );
1847 assert!(!resp.contains("bypass"), "should strip status bar: {resp}");
1848 }
1849
1850 #[test]
1851 fn extract_response_strips_echoed_input() {
1852 let pre = "\u{276F} ";
1853 let cur = "\u{276F} What is 2+2?\n\n4\n\n\u{276F} ";
1854 let resp = extract_response(pre, cur);
1855 assert!(resp.contains('4'), "should contain answer: {resp}");
1856 assert!(
1857 !resp.contains("What is 2+2"),
1858 "should strip echoed input: {resp}"
1859 );
1860 }
1861
1862 #[test]
1863 fn extract_response_tui_full_rewrite() {
1864 let pre = "Welcome to Claude\n\n\u{276F} ";
1866 let cur = "\u{276F} Hello\n\nHello! How can I help?\n\n\
1867 ────────────────────\n\
1868 \u{276F} ";
1869 let resp = extract_response(pre, cur);
1870 assert!(
1871 resp.contains("Hello! How can I help?"),
1872 "should extract response from TUI rewrite: {resp}"
1873 );
1874 }
1875
1876 #[test]
1877 fn strip_claude_bullets_removes_marker() {
1878 assert_eq!(strip_claude_bullets("\u{23FA} 4"), "4");
1879 assert_eq!(
1880 strip_claude_bullets(" \u{23FA} hello world"),
1881 " hello world"
1882 );
1883 assert_eq!(strip_claude_bullets("no bullet here"), "no bullet here");
1884 assert_eq!(strip_claude_bullets(""), "");
1885 }
1886
1887 #[test]
1888 fn extract_response_strips_claude_bullets() {
1889 let pre = "\u{276F} ";
1890 let cur = "\u{276F} question\n\n\u{23FA} 42\n\n\u{276F} ";
1891 let resp = extract_response(pre, cur);
1892 assert!(resp.contains("42"), "should contain answer: {resp}");
1893 assert!(
1894 !resp.contains('\u{23FA}'),
1895 "should strip bullet marker: {resp}"
1896 );
1897 }
1898
1899 #[test]
1900 fn preserve_handoff_writes_diff_and_commit_summary() {
1901 let repo = tempfile::tempdir().unwrap();
1902 init_test_git_repo(repo.path());
1903
1904 std::fs::write(repo.path().join("tracked.txt"), "one\n").unwrap();
1905 run_test_git(repo.path(), &["add", "tracked.txt"]);
1906 run_test_git(repo.path(), &["commit", "-m", "initial commit"]);
1907 std::fs::write(repo.path().join("tracked.txt"), "one\ntwo\n").unwrap();
1908
1909 let recent_output = "\
1910running cargo test --lib\n\
1911test result: ok\n\
1912editing src/lib.rs\n";
1913 preserve_handoff(repo.path(), Some(recent_output)).unwrap();
1914
1915 let handoff = std::fs::read_to_string(repo.path().join(HANDOFF_FILE_NAME)).unwrap();
1916 assert!(handoff.contains("# Handoff"));
1917 assert!(handoff.contains("## Modified Files"));
1918 assert!(handoff.contains("tracked.txt"));
1919 assert!(handoff.contains("## Tests Run"));
1920 assert!(handoff.contains("cargo test --lib"));
1921 assert!(handoff.contains("## Recent Activity"));
1922 assert!(handoff.contains("editing src/lib.rs"));
1923 assert!(handoff.contains("## Recent Commits"));
1924 assert!(handoff.contains("initial commit"));
1925 }
1926
1927 #[test]
1928 fn preserve_handoff_uses_none_when_repo_has_no_changes_or_commits() {
1929 let repo = tempfile::tempdir().unwrap();
1930 init_test_git_repo(repo.path());
1931
1932 preserve_handoff(repo.path(), None).unwrap();
1933
1934 let handoff = std::fs::read_to_string(repo.path().join(HANDOFF_FILE_NAME)).unwrap();
1935 assert!(handoff.contains("## Modified Files\n(none)"));
1936 assert!(handoff.contains("## Tests Run\n(none)"));
1937 assert!(handoff.contains("## Recent Activity\n(none)"));
1938 assert!(handoff.contains("## Recent Commits\n(none)"));
1939 }
1940
1941 #[test]
1942 fn extract_test_commands_deduplicates_known_test_invocations() {
1943 let output = "\
1944\u{1b}[31mcargo test --lib\u{1b}[0m\n\
1945pytest tests/test_api.py\n\
1946cargo test --lib\n\
1947plain output\n";
1948 let tests = extract_test_commands(output);
1949 assert_eq!(
1950 tests,
1951 vec![
1952 "cargo test --lib".to_string(),
1953 "pytest tests/test_api.py".to_string()
1954 ]
1955 );
1956 }
1957
1958 #[test]
1959 fn preserve_work_before_kill_respects_config_toggle() {
1960 let tmp = tempfile::tempdir().unwrap();
1961 let preserved =
1962 preserve_work_before_kill_with(tmp.path(), Duration::from_millis(10), false, |_path| {
1963 panic!("commit should not run when disabled")
1964 })
1965 .unwrap();
1966
1967 assert!(!preserved);
1968 }
1969
1970 #[test]
1971 fn preserve_work_before_kill_times_out() {
1972 let tmp = tempfile::tempdir().unwrap();
1973 let preserved =
1974 preserve_work_before_kill_with(tmp.path(), Duration::from_millis(10), true, |_path| {
1975 std::thread::sleep(Duration::from_millis(50));
1976 Ok(true)
1977 })
1978 .unwrap();
1979
1980 assert!(!preserved);
1981 }
1982
1983 fn init_test_git_repo(path: &Path) {
1984 run_test_git(path, &["init"]);
1985 run_test_git(path, &["config", "user.name", "Batty Tests"]);
1986 run_test_git(path, &["config", "user.email", "batty-tests@example.com"]);
1987 }
1988
1989 fn run_test_git(path: &Path, args: &[&str]) {
1990 use std::process::Command;
1991 let output = Command::new("git")
1992 .args(args)
1993 .current_dir(path)
1994 .output()
1995 .unwrap();
1996 assert!(
1997 output.status.success(),
1998 "git {} failed: {}",
1999 args.join(" "),
2000 String::from_utf8_lossy(&output.stderr)
2001 );
2002 }
2003}