Skip to main content

batty_cli/shim/
chat.rs

1//! Chat frontend: spawns a shim subprocess, sends user messages, displays responses.
2//!
3//! This is a simple TTY application that demonstrates the shim protocol.
4//! Under the hood it forks a shim subprocess, communicates via socketpair,
5//! and presents a readline-style prompt.
6
7use std::io::{self, BufRead, Write};
8use std::os::unix::io::AsRawFd;
9use std::os::unix::process::CommandExt;
10use std::path::Path;
11use std::process::{Command, Stdio};
12use std::time::Instant;
13
14use anyhow::{Context, Result};
15
16use super::classifier::AgentType;
17use super::protocol::{self, Channel, Event};
18
19// ---------------------------------------------------------------------------
20// Default command per agent type
21// ---------------------------------------------------------------------------
22
23/// Returns the default shell command used to launch each agent type.
24pub fn default_cmd(agent_type: AgentType) -> &'static str {
25    match agent_type {
26        AgentType::Claude => "claude --dangerously-skip-permissions",
27        AgentType::Codex => "codex --dangerously-bypass-approvals-and-sandbox",
28        AgentType::Kiro => "kiro-cli",
29        AgentType::Generic => "bash",
30    }
31}
32
33// ---------------------------------------------------------------------------
34// Special command parsing
35// ---------------------------------------------------------------------------
36
37/// Recognized special commands typed at the `you> ` prompt.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum SpecialCommand {
40    Quit,
41    Screen,
42    State,
43    Ping,
44}
45
46/// Try to parse a line of user input as a special command.
47/// Returns `None` if the input is a regular message.
48pub fn parse_special(input: &str) -> Option<SpecialCommand> {
49    match input {
50        ":quit" | ":q" => Some(SpecialCommand::Quit),
51        ":screen" => Some(SpecialCommand::Screen),
52        ":state" => Some(SpecialCommand::State),
53        ":ping" => Some(SpecialCommand::Ping),
54        _ => None,
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Chat entry point
60// ---------------------------------------------------------------------------
61
62pub fn run(agent_type: AgentType, cmd: &str, cwd: &Path, sdk_mode: bool) -> Result<()> {
63    // -- Create socketpair --
64    let (parent_sock, child_sock) = protocol::socketpair().context("socketpair failed")?;
65
66    // -- Find our own binary path (for spawning shim subprocess) --
67    let self_exe = std::env::current_exe().context("cannot determine own executable path")?;
68
69    // -- Spawn shim as child process, passing child_sock as fd 3 --
70    let child_fd = child_sock.as_raw_fd();
71    let child_fd_val = child_fd; // copy the raw fd value
72    let agent_type_str = agent_type.to_string();
73    let cmd_owned = cmd.to_string();
74    let cwd_str = cwd.display().to_string();
75
76    let mut args = vec![
77        "shim".to_string(),
78        "--id".to_string(),
79        "chat-agent".to_string(),
80        "--agent-type".to_string(),
81        agent_type_str.clone(),
82        "--cmd".to_string(),
83        cmd_owned.clone(),
84        "--cwd".to_string(),
85        cwd_str.clone(),
86    ];
87    if sdk_mode {
88        args.push("--sdk-mode".to_string());
89    }
90    let args_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
91
92    let mut child = unsafe {
93        Command::new(&self_exe)
94            .args(&args_refs)
95            .stdin(Stdio::null())
96            .stderr(Stdio::inherit())
97            .pre_exec(move || {
98                // Dup the socketpair fd to fd 3
99                if child_fd_val != 3 {
100                    let ret = libc::dup2(child_fd_val, 3);
101                    if ret < 0 {
102                        return Err(io::Error::last_os_error());
103                    }
104                }
105                Ok(())
106            })
107            .spawn()
108            .context("failed to spawn shim process")?
109    };
110
111    // Close child's end in parent
112    drop(child_sock);
113
114    // -- Set up channel --
115    let mut send_ch = Channel::new(parent_sock);
116    let mut recv_ch = send_ch.try_clone().context("failed to clone channel")?;
117
118    let t_start = Instant::now();
119    eprintln!(
120        "[chat] shim spawned (pid {}), waiting for agent to become ready...",
121        child.id()
122    );
123
124    // -- Wait for Ready event --
125    loop {
126        match recv_ch.recv::<Event>()? {
127            Some(Event::Ready) => {
128                eprintln!(
129                    "[chat] agent is ready ({:.1}s). Type a message and press Enter.",
130                    t_start.elapsed().as_secs_f64()
131                );
132                eprintln!(
133                    "[chat] Type :quit to exit, :screen to capture screen, :state to query state.\n"
134                );
135                break;
136            }
137            Some(Event::StateChanged { from, to, .. }) => {
138                eprintln!(
139                    "[chat] state: {} \u{2192} {} (+{:.1}s)",
140                    from,
141                    to,
142                    t_start.elapsed().as_secs_f64()
143                );
144            }
145            Some(Event::Error { reason, .. }) => {
146                eprintln!("[chat] error during startup: {reason}");
147                child.kill().ok();
148                return Ok(());
149            }
150            Some(Event::Died {
151                exit_code,
152                last_lines,
153            }) => {
154                eprintln!(
155                    "[chat] agent died before becoming ready (exit={:?})\n{}",
156                    exit_code, last_lines
157                );
158                return Ok(());
159            }
160            Some(other) => {
161                eprintln!("[chat] unexpected event during startup: {:?}", other);
162            }
163            None => {
164                eprintln!("[chat] shim disconnected before ready");
165                return Ok(());
166            }
167        }
168    }
169
170    // -- Main chat loop --
171    let (event_tx, event_rx) = std::sync::mpsc::channel::<Event>();
172
173    // Background thread: read events from shim
174    let recv_handle = std::thread::spawn(move || {
175        loop {
176            match recv_ch.recv::<Event>() {
177                Ok(Some(evt)) => {
178                    if event_tx.send(evt).is_err() {
179                        break; // main thread dropped receiver
180                    }
181                }
182                Ok(None) => break, // shim closed
183                Err(_) => break,
184            }
185        }
186    });
187
188    let stdin = io::stdin();
189    let mut stdout = io::stdout();
190
191    loop {
192        print!("you> ");
193        stdout.flush()?;
194
195        let mut line = String::new();
196        let n = stdin.lock().read_line(&mut line)?;
197        if n == 0 {
198            eprintln!("\n[chat] EOF, shutting down...");
199            send_ch.send(&protocol::Command::Shutdown {
200                timeout_secs: 5,
201                reason: protocol::ShutdownReason::Requested,
202            })?;
203            break;
204        }
205
206        let input = line.trim();
207        if input.is_empty() {
208            continue;
209        }
210
211        // -- Special commands --
212        match parse_special(input) {
213            Some(SpecialCommand::Quit) => {
214                send_ch.send(&protocol::Command::Shutdown {
215                    timeout_secs: 5,
216                    reason: protocol::ShutdownReason::Requested,
217                })?;
218                break;
219            }
220            Some(SpecialCommand::Screen) => {
221                send_ch.send(&protocol::Command::CaptureScreen {
222                    last_n_lines: Some(30),
223                })?;
224                if let Ok(Event::ScreenCapture {
225                    content,
226                    cursor_row,
227                    cursor_col,
228                }) = event_rx.recv()
229                {
230                    println!(
231                        "--- screen capture (cursor at {},{}) ---",
232                        cursor_row, cursor_col
233                    );
234                    println!("{content}");
235                    println!("--- end screen capture ---");
236                }
237                continue;
238            }
239            Some(SpecialCommand::State) => {
240                send_ch.send(&protocol::Command::GetState)?;
241                if let Ok(Event::State { state, since_secs }) = event_rx.recv() {
242                    println!("[state: {state}, since: {since_secs}s ago]");
243                }
244                continue;
245            }
246            Some(SpecialCommand::Ping) => {
247                send_ch.send(&protocol::Command::Ping)?;
248                if let Ok(Event::Pong) = event_rx.recv() {
249                    println!("[pong]");
250                }
251                continue;
252            }
253            None => {}
254        }
255
256        // -- Send message to agent --
257        let t_send = Instant::now();
258        send_ch.send(&protocol::Command::SendMessage {
259            from: "user".into(),
260            body: input.to_string(),
261            message_id: None,
262        })?;
263
264        // Wait for completion (or other terminal events)
265        let mut got_completion = false;
266        while !got_completion {
267            match event_rx.recv() {
268                Ok(Event::Completion {
269                    response,
270                    last_lines,
271                    ..
272                }) => {
273                    let elapsed = t_send.elapsed();
274                    let text = if !response.is_empty() {
275                        response.trim_end().to_string()
276                    } else if !last_lines.is_empty() {
277                        last_lines.trim_end().to_string()
278                    } else {
279                        String::new()
280                    };
281                    if text.is_empty() {
282                        eprintln!(
283                            "\n[completed in {:.1}s with no visible output]",
284                            elapsed.as_secs_f64()
285                        );
286                    } else {
287                        println!("\n{text}");
288                        eprintln!("[{:.1}s]", elapsed.as_secs_f64());
289                    }
290                    got_completion = true;
291                }
292                Ok(Event::StateChanged { from, to, .. }) => {
293                    let elapsed = t_send.elapsed();
294                    eprint!("[{from} \u{2192} {to} +{:.1}s] ", elapsed.as_secs_f64());
295                    io::stderr().flush().ok();
296                }
297                Ok(Event::Died {
298                    exit_code,
299                    last_lines,
300                }) => {
301                    eprintln!("\n[chat] agent died (exit={exit_code:?})");
302                    if !last_lines.is_empty() {
303                        println!("{last_lines}");
304                    }
305                    return Ok(());
306                }
307                Ok(Event::ContextExhausted { message, .. }) => {
308                    eprintln!("\n[chat] context exhausted: {message}");
309                    return Ok(());
310                }
311                Ok(Event::Error { command, reason }) => {
312                    eprintln!("\n[chat] error ({command}): {reason}");
313                    got_completion = true; // don't hang
314                }
315                Ok(other) => {
316                    eprintln!("[chat] event: {other:?}");
317                }
318                Err(_) => {
319                    eprintln!("\n[chat] channel closed");
320                    return Ok(());
321                }
322            }
323        }
324    }
325
326    // Cleanup
327    child.wait().ok();
328    recv_handle.join().ok();
329    eprintln!("[chat] done.");
330    Ok(())
331}
332
333// ---------------------------------------------------------------------------
334// Tests
335// ---------------------------------------------------------------------------
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn default_cmd_claude() {
343        assert_eq!(
344            default_cmd(AgentType::Claude),
345            "claude --dangerously-skip-permissions"
346        );
347    }
348
349    #[test]
350    fn default_cmd_codex() {
351        assert_eq!(
352            default_cmd(AgentType::Codex),
353            "codex --dangerously-bypass-approvals-and-sandbox"
354        );
355    }
356
357    #[test]
358    fn default_cmd_kiro() {
359        assert_eq!(default_cmd(AgentType::Kiro), "kiro-cli");
360    }
361
362    #[test]
363    fn default_cmd_generic() {
364        assert_eq!(default_cmd(AgentType::Generic), "bash");
365    }
366
367    #[test]
368    fn parse_special_quit() {
369        assert_eq!(parse_special(":quit"), Some(SpecialCommand::Quit));
370        assert_eq!(parse_special(":q"), Some(SpecialCommand::Quit));
371    }
372
373    #[test]
374    fn parse_special_screen() {
375        assert_eq!(parse_special(":screen"), Some(SpecialCommand::Screen));
376    }
377
378    #[test]
379    fn parse_special_state() {
380        assert_eq!(parse_special(":state"), Some(SpecialCommand::State));
381    }
382
383    #[test]
384    fn parse_special_ping() {
385        assert_eq!(parse_special(":ping"), Some(SpecialCommand::Ping));
386    }
387
388    #[test]
389    fn parse_special_none() {
390        assert_eq!(parse_special("hello world"), None);
391        assert_eq!(parse_special(""), None);
392        assert_eq!(parse_special(":unknown"), None);
393    }
394}