Skip to main content

batty_cli/shim/
chat.rs

1//! Chat frontend: spawns a shim subprocess, sends user messages, displays responses.
2//!
3//! This is a simple TTY application that demonstrates the shim protocol.
4//! Under the hood it forks a shim subprocess, communicates via socketpair,
5//! and presents a readline-style prompt.
6
7use std::io::{self, BufRead, Write};
8use std::os::unix::io::AsRawFd;
9use std::os::unix::process::CommandExt;
10use std::path::Path;
11use std::process::{Command, Stdio};
12use std::time::Instant;
13
14use anyhow::{Context, Result};
15
16use super::classifier::AgentType;
17use super::protocol::{self, Channel, Event};
18
19// ---------------------------------------------------------------------------
20// Default command per agent type
21// ---------------------------------------------------------------------------
22
23/// Returns the default shell command used to launch each agent type.
24pub fn default_cmd(agent_type: AgentType) -> &'static str {
25    match agent_type {
26        AgentType::Claude => "claude --dangerously-skip-permissions",
27        AgentType::Codex => "codex --dangerously-bypass-approvals-and-sandbox",
28        AgentType::Kiro => "kiro-cli",
29        AgentType::Generic => "bash",
30    }
31}
32
33// ---------------------------------------------------------------------------
34// Special command parsing
35// ---------------------------------------------------------------------------
36
37/// Recognized special commands typed at the `you> ` prompt.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum SpecialCommand {
40    Quit,
41    Screen,
42    State,
43    Ping,
44}
45
46/// Try to parse a line of user input as a special command.
47/// Returns `None` if the input is a regular message.
48pub fn parse_special(input: &str) -> Option<SpecialCommand> {
49    match input {
50        ":quit" | ":q" => Some(SpecialCommand::Quit),
51        ":screen" => Some(SpecialCommand::Screen),
52        ":state" => Some(SpecialCommand::State),
53        ":ping" => Some(SpecialCommand::Ping),
54        _ => None,
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Chat entry point
60// ---------------------------------------------------------------------------
61
62pub fn run(agent_type: AgentType, cmd: &str, cwd: &Path, sdk_mode: bool) -> Result<()> {
63    // -- Create socketpair --
64    let (parent_sock, child_sock) = protocol::socketpair().context("socketpair failed")?;
65
66    // -- Find our own binary path (for spawning shim subprocess) --
67    let self_exe = std::env::current_exe().context("cannot determine own executable path")?;
68
69    // -- Spawn shim as child process, passing child_sock as fd 3 --
70    let child_fd = child_sock.as_raw_fd();
71    let child_fd_val = child_fd; // copy the raw fd value
72    let agent_type_str = agent_type.to_string();
73    let cmd_owned = cmd.to_string();
74    let cwd_str = cwd.display().to_string();
75
76    let mut args = vec![
77        "shim".to_string(),
78        "--id".to_string(),
79        "chat-agent".to_string(),
80        "--agent-type".to_string(),
81        agent_type_str.clone(),
82        "--cmd".to_string(),
83        cmd_owned.clone(),
84        "--cwd".to_string(),
85        cwd_str.clone(),
86    ];
87    if sdk_mode {
88        args.push("--sdk-mode".to_string());
89    }
90    let args_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
91
92    let mut child = unsafe {
93        Command::new(&self_exe)
94            .args(&args_refs)
95            .stdin(Stdio::null())
96            .stderr(Stdio::inherit())
97            .pre_exec(move || {
98                // Dup the socketpair fd to fd 3
99                if child_fd_val != 3 {
100                    let ret = libc::dup2(child_fd_val, 3);
101                    if ret < 0 {
102                        return Err(io::Error::last_os_error());
103                    }
104                }
105                Ok(())
106            })
107            .spawn()
108            .context("failed to spawn shim process")?
109    };
110
111    // Close child's end in parent
112    drop(child_sock);
113
114    // -- Set up channel --
115    let mut send_ch = Channel::new(parent_sock);
116    let mut recv_ch = send_ch.try_clone().context("failed to clone channel")?;
117
118    let t_start = Instant::now();
119    eprintln!(
120        "[chat] shim spawned (pid {}), waiting for agent to become ready...",
121        child.id()
122    );
123
124    // -- Wait for Ready event --
125    loop {
126        match recv_ch.recv::<Event>()? {
127            Some(Event::Ready) => {
128                eprintln!(
129                    "[chat] agent is ready ({:.1}s). Type a message and press Enter.",
130                    t_start.elapsed().as_secs_f64()
131                );
132                eprintln!(
133                    "[chat] Type :quit to exit, :screen to capture screen, :state to query state.\n"
134                );
135                break;
136            }
137            Some(Event::StateChanged { from, to, .. }) => {
138                eprintln!(
139                    "[chat] state: {} \u{2192} {} (+{:.1}s)",
140                    from,
141                    to,
142                    t_start.elapsed().as_secs_f64()
143                );
144            }
145            Some(Event::Error { reason, .. }) => {
146                eprintln!("[chat] error during startup: {reason}");
147                child.kill().ok();
148                return Ok(());
149            }
150            Some(Event::Died {
151                exit_code,
152                last_lines,
153            }) => {
154                eprintln!(
155                    "[chat] agent died before becoming ready (exit={:?})\n{}",
156                    exit_code, last_lines
157                );
158                return Ok(());
159            }
160            Some(other) => {
161                eprintln!("[chat] unexpected event during startup: {:?}", other);
162            }
163            None => {
164                eprintln!("[chat] shim disconnected before ready");
165                return Ok(());
166            }
167        }
168    }
169
170    // -- Main chat loop --
171    let (event_tx, event_rx) = std::sync::mpsc::channel::<Event>();
172
173    // Background thread: read events from shim
174    let recv_handle = std::thread::spawn(move || {
175        loop {
176            match recv_ch.recv::<Event>() {
177                Ok(Some(evt)) => {
178                    if event_tx.send(evt).is_err() {
179                        break; // main thread dropped receiver
180                    }
181                }
182                Ok(None) => break, // shim closed
183                Err(_) => break,
184            }
185        }
186    });
187
188    let stdin = io::stdin();
189    let mut stdout = io::stdout();
190
191    loop {
192        print!("you> ");
193        stdout.flush()?;
194
195        let mut line = String::new();
196        let n = stdin.lock().read_line(&mut line)?;
197        if n == 0 {
198            eprintln!("\n[chat] EOF, shutting down...");
199            send_ch.send(&protocol::Command::Shutdown { timeout_secs: 5 })?;
200            break;
201        }
202
203        let input = line.trim();
204        if input.is_empty() {
205            continue;
206        }
207
208        // -- Special commands --
209        match parse_special(input) {
210            Some(SpecialCommand::Quit) => {
211                send_ch.send(&protocol::Command::Shutdown { timeout_secs: 5 })?;
212                break;
213            }
214            Some(SpecialCommand::Screen) => {
215                send_ch.send(&protocol::Command::CaptureScreen {
216                    last_n_lines: Some(30),
217                })?;
218                if let Ok(Event::ScreenCapture {
219                    content,
220                    cursor_row,
221                    cursor_col,
222                }) = event_rx.recv()
223                {
224                    println!(
225                        "--- screen capture (cursor at {},{}) ---",
226                        cursor_row, cursor_col
227                    );
228                    println!("{content}");
229                    println!("--- end screen capture ---");
230                }
231                continue;
232            }
233            Some(SpecialCommand::State) => {
234                send_ch.send(&protocol::Command::GetState)?;
235                if let Ok(Event::State { state, since_secs }) = event_rx.recv() {
236                    println!("[state: {state}, since: {since_secs}s ago]");
237                }
238                continue;
239            }
240            Some(SpecialCommand::Ping) => {
241                send_ch.send(&protocol::Command::Ping)?;
242                if let Ok(Event::Pong) = event_rx.recv() {
243                    println!("[pong]");
244                }
245                continue;
246            }
247            None => {}
248        }
249
250        // -- Send message to agent --
251        let t_send = Instant::now();
252        send_ch.send(&protocol::Command::SendMessage {
253            from: "user".into(),
254            body: input.to_string(),
255            message_id: None,
256        })?;
257
258        // Wait for completion (or other terminal events)
259        let mut got_completion = false;
260        while !got_completion {
261            match event_rx.recv() {
262                Ok(Event::Completion {
263                    response,
264                    last_lines,
265                    ..
266                }) => {
267                    let elapsed = t_send.elapsed();
268                    let text = if !response.is_empty() {
269                        response.trim_end().to_string()
270                    } else if !last_lines.is_empty() {
271                        last_lines.trim_end().to_string()
272                    } else {
273                        String::new()
274                    };
275                    if text.is_empty() {
276                        eprintln!(
277                            "\n[completed in {:.1}s with no visible output]",
278                            elapsed.as_secs_f64()
279                        );
280                    } else {
281                        println!("\n{text}");
282                        eprintln!("[{:.1}s]", elapsed.as_secs_f64());
283                    }
284                    got_completion = true;
285                }
286                Ok(Event::StateChanged { from, to, .. }) => {
287                    let elapsed = t_send.elapsed();
288                    eprint!("[{from} \u{2192} {to} +{:.1}s] ", elapsed.as_secs_f64());
289                    io::stderr().flush().ok();
290                }
291                Ok(Event::Died {
292                    exit_code,
293                    last_lines,
294                }) => {
295                    eprintln!("\n[chat] agent died (exit={exit_code:?})");
296                    if !last_lines.is_empty() {
297                        println!("{last_lines}");
298                    }
299                    return Ok(());
300                }
301                Ok(Event::ContextExhausted { message, .. }) => {
302                    eprintln!("\n[chat] context exhausted: {message}");
303                    return Ok(());
304                }
305                Ok(Event::Error { command, reason }) => {
306                    eprintln!("\n[chat] error ({command}): {reason}");
307                    got_completion = true; // don't hang
308                }
309                Ok(other) => {
310                    eprintln!("[chat] event: {other:?}");
311                }
312                Err(_) => {
313                    eprintln!("\n[chat] channel closed");
314                    return Ok(());
315                }
316            }
317        }
318    }
319
320    // Cleanup
321    child.wait().ok();
322    recv_handle.join().ok();
323    eprintln!("[chat] done.");
324    Ok(())
325}
326
327// ---------------------------------------------------------------------------
328// Tests
329// ---------------------------------------------------------------------------
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn default_cmd_claude() {
337        assert_eq!(
338            default_cmd(AgentType::Claude),
339            "claude --dangerously-skip-permissions"
340        );
341    }
342
343    #[test]
344    fn default_cmd_codex() {
345        assert_eq!(
346            default_cmd(AgentType::Codex),
347            "codex --dangerously-bypass-approvals-and-sandbox"
348        );
349    }
350
351    #[test]
352    fn default_cmd_kiro() {
353        assert_eq!(default_cmd(AgentType::Kiro), "kiro-cli");
354    }
355
356    #[test]
357    fn default_cmd_generic() {
358        assert_eq!(default_cmd(AgentType::Generic), "bash");
359    }
360
361    #[test]
362    fn parse_special_quit() {
363        assert_eq!(parse_special(":quit"), Some(SpecialCommand::Quit));
364        assert_eq!(parse_special(":q"), Some(SpecialCommand::Quit));
365    }
366
367    #[test]
368    fn parse_special_screen() {
369        assert_eq!(parse_special(":screen"), Some(SpecialCommand::Screen));
370    }
371
372    #[test]
373    fn parse_special_state() {
374        assert_eq!(parse_special(":state"), Some(SpecialCommand::State));
375    }
376
377    #[test]
378    fn parse_special_ping() {
379        assert_eq!(parse_special(":ping"), Some(SpecialCommand::Ping));
380    }
381
382    #[test]
383    fn parse_special_none() {
384        assert_eq!(parse_special("hello world"), None);
385        assert_eq!(parse_special(""), None);
386        assert_eq!(parse_special(":unknown"), None);
387    }
388}