Skip to main content

batty_cli/shim/
chat.rs

1//! Chat frontend: spawns a shim subprocess, sends user messages, displays responses.
2//!
3//! This is a simple TTY application that demonstrates the shim protocol.
4//! Under the hood it forks a shim subprocess, communicates via socketpair,
5//! and presents a readline-style prompt.
6
7use std::io::{self, BufRead, Write};
8use std::os::unix::io::AsRawFd;
9use std::os::unix::process::CommandExt;
10use std::path::Path;
11use std::process::{Command, Stdio};
12
13use anyhow::{Context, Result};
14
15use super::classifier::AgentType;
16use super::protocol::{self, Channel, Event};
17
18// ---------------------------------------------------------------------------
19// Default command per agent type
20// ---------------------------------------------------------------------------
21
22/// Returns the default shell command used to launch each agent type.
23pub fn default_cmd(agent_type: AgentType) -> &'static str {
24    match agent_type {
25        AgentType::Claude => "claude --dangerously-skip-permissions",
26        AgentType::Codex => "codex",
27        AgentType::Kiro => "kiro",
28        AgentType::Generic => "bash",
29    }
30}
31
32// ---------------------------------------------------------------------------
33// Special command parsing
34// ---------------------------------------------------------------------------
35
36/// Recognized special commands typed at the `you> ` prompt.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum SpecialCommand {
39    Quit,
40    Screen,
41    State,
42    Ping,
43}
44
45/// Try to parse a line of user input as a special command.
46/// Returns `None` if the input is a regular message.
47pub fn parse_special(input: &str) -> Option<SpecialCommand> {
48    match input {
49        ":quit" | ":q" => Some(SpecialCommand::Quit),
50        ":screen" => Some(SpecialCommand::Screen),
51        ":state" => Some(SpecialCommand::State),
52        ":ping" => Some(SpecialCommand::Ping),
53        _ => None,
54    }
55}
56
57// ---------------------------------------------------------------------------
58// Chat entry point
59// ---------------------------------------------------------------------------
60
61pub fn run(agent_type: AgentType, cmd: &str, cwd: &Path) -> Result<()> {
62    // -- Create socketpair --
63    let (parent_sock, child_sock) = protocol::socketpair().context("socketpair failed")?;
64
65    // -- Find our own binary path (for spawning shim subprocess) --
66    let self_exe = std::env::current_exe().context("cannot determine own executable path")?;
67
68    // -- Spawn shim as child process, passing child_sock as fd 3 --
69    let child_fd = child_sock.as_raw_fd();
70    let child_fd_val = child_fd; // copy the raw fd value
71    let agent_type_str = agent_type.to_string();
72    let cmd_owned = cmd.to_string();
73    let cwd_str = cwd.display().to_string();
74
75    let mut child = unsafe {
76        Command::new(&self_exe)
77            .args([
78                "shim",
79                "--id",
80                "chat-agent",
81                "--agent-type",
82                &agent_type_str,
83                "--cmd",
84                &cmd_owned,
85                "--cwd",
86                &cwd_str,
87            ])
88            .stdin(Stdio::null())
89            .stderr(Stdio::inherit())
90            .pre_exec(move || {
91                // Dup the socketpair fd to fd 3
92                if child_fd_val != 3 {
93                    let ret = libc::dup2(child_fd_val, 3);
94                    if ret < 0 {
95                        return Err(io::Error::last_os_error());
96                    }
97                }
98                Ok(())
99            })
100            .spawn()
101            .context("failed to spawn shim process")?
102    };
103
104    // Close child's end in parent
105    drop(child_sock);
106
107    // -- Set up channel --
108    let mut send_ch = Channel::new(parent_sock);
109    let mut recv_ch = send_ch.try_clone().context("failed to clone channel")?;
110
111    eprintln!(
112        "[chat] shim spawned (pid {}), waiting for agent to become ready...",
113        child.id()
114    );
115
116    // -- Wait for Ready event --
117    loop {
118        match recv_ch.recv::<Event>()? {
119            Some(Event::Ready) => {
120                eprintln!("[chat] agent is ready. Type a message and press Enter.");
121                eprintln!(
122                    "[chat] Type :quit to exit, :screen to capture screen, :state to query state.\n"
123                );
124                break;
125            }
126            Some(Event::StateChanged { from, to, .. }) => {
127                eprintln!("[chat] state: {} \u{2192} {}", from, to);
128            }
129            Some(Event::Error { reason, .. }) => {
130                eprintln!("[chat] error during startup: {reason}");
131                child.kill().ok();
132                return Ok(());
133            }
134            Some(Event::Died {
135                exit_code,
136                last_lines,
137            }) => {
138                eprintln!(
139                    "[chat] agent died before becoming ready (exit={:?})\n{}",
140                    exit_code, last_lines
141                );
142                return Ok(());
143            }
144            Some(other) => {
145                eprintln!("[chat] unexpected event during startup: {:?}", other);
146            }
147            None => {
148                eprintln!("[chat] shim disconnected before ready");
149                return Ok(());
150            }
151        }
152    }
153
154    // -- Main chat loop --
155    let (event_tx, event_rx) = std::sync::mpsc::channel::<Event>();
156
157    // Background thread: read events from shim
158    let recv_handle = std::thread::spawn(move || {
159        loop {
160            match recv_ch.recv::<Event>() {
161                Ok(Some(evt)) => {
162                    if event_tx.send(evt).is_err() {
163                        break; // main thread dropped receiver
164                    }
165                }
166                Ok(None) => break, // shim closed
167                Err(_) => break,
168            }
169        }
170    });
171
172    let stdin = io::stdin();
173    let mut stdout = io::stdout();
174
175    loop {
176        print!("you> ");
177        stdout.flush()?;
178
179        let mut line = String::new();
180        let n = stdin.lock().read_line(&mut line)?;
181        if n == 0 {
182            eprintln!("\n[chat] EOF, shutting down...");
183            send_ch.send(&protocol::Command::Shutdown { timeout_secs: 5 })?;
184            break;
185        }
186
187        let input = line.trim();
188        if input.is_empty() {
189            continue;
190        }
191
192        // -- Special commands --
193        match parse_special(input) {
194            Some(SpecialCommand::Quit) => {
195                send_ch.send(&protocol::Command::Shutdown { timeout_secs: 5 })?;
196                break;
197            }
198            Some(SpecialCommand::Screen) => {
199                send_ch.send(&protocol::Command::CaptureScreen {
200                    last_n_lines: Some(30),
201                })?;
202                if let Ok(Event::ScreenCapture {
203                    content,
204                    cursor_row,
205                    cursor_col,
206                }) = event_rx.recv()
207                {
208                    println!(
209                        "--- screen capture (cursor at {},{}) ---",
210                        cursor_row, cursor_col
211                    );
212                    println!("{content}");
213                    println!("--- end screen capture ---");
214                }
215                continue;
216            }
217            Some(SpecialCommand::State) => {
218                send_ch.send(&protocol::Command::GetState)?;
219                if let Ok(Event::State { state, since_secs }) = event_rx.recv() {
220                    println!("[state: {state}, since: {since_secs}s ago]");
221                }
222                continue;
223            }
224            Some(SpecialCommand::Ping) => {
225                send_ch.send(&protocol::Command::Ping)?;
226                if let Ok(Event::Pong) = event_rx.recv() {
227                    println!("[pong]");
228                }
229                continue;
230            }
231            None => {}
232        }
233
234        // -- Send message to agent --
235        send_ch.send(&protocol::Command::SendMessage {
236            from: "user".into(),
237            body: input.to_string(),
238            message_id: None,
239        })?;
240
241        // Wait for completion (or other terminal events)
242        let mut got_completion = false;
243        while !got_completion {
244            match event_rx.recv() {
245                Ok(Event::Completion {
246                    response,
247                    last_lines,
248                    ..
249                }) => {
250                    if !response.is_empty() {
251                        println!("\n{response}");
252                    } else if !last_lines.is_empty() {
253                        println!("\n{last_lines}");
254                    } else {
255                        println!("\n[agent completed with no visible output]");
256                    }
257                    got_completion = true;
258                }
259                Ok(Event::StateChanged { from, to, .. }) => {
260                    eprint!("[{from} \u{2192} {to}] ");
261                    io::stderr().flush().ok();
262                }
263                Ok(Event::Died {
264                    exit_code,
265                    last_lines,
266                }) => {
267                    eprintln!("\n[chat] agent died (exit={exit_code:?})");
268                    if !last_lines.is_empty() {
269                        println!("{last_lines}");
270                    }
271                    return Ok(());
272                }
273                Ok(Event::ContextExhausted { message, .. }) => {
274                    eprintln!("\n[chat] context exhausted: {message}");
275                    return Ok(());
276                }
277                Ok(Event::Error { command, reason }) => {
278                    eprintln!("\n[chat] error ({command}): {reason}");
279                    got_completion = true; // don't hang
280                }
281                Ok(other) => {
282                    eprintln!("[chat] event: {other:?}");
283                }
284                Err(_) => {
285                    eprintln!("\n[chat] channel closed");
286                    return Ok(());
287                }
288            }
289        }
290    }
291
292    // Cleanup
293    child.wait().ok();
294    recv_handle.join().ok();
295    eprintln!("[chat] done.");
296    Ok(())
297}
298
299// ---------------------------------------------------------------------------
300// Tests
301// ---------------------------------------------------------------------------
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn default_cmd_claude() {
309        assert_eq!(
310            default_cmd(AgentType::Claude),
311            "claude --dangerously-skip-permissions"
312        );
313    }
314
315    #[test]
316    fn default_cmd_codex() {
317        assert_eq!(default_cmd(AgentType::Codex), "codex");
318    }
319
320    #[test]
321    fn default_cmd_kiro() {
322        assert_eq!(default_cmd(AgentType::Kiro), "kiro");
323    }
324
325    #[test]
326    fn default_cmd_generic() {
327        assert_eq!(default_cmd(AgentType::Generic), "bash");
328    }
329
330    #[test]
331    fn parse_special_quit() {
332        assert_eq!(parse_special(":quit"), Some(SpecialCommand::Quit));
333        assert_eq!(parse_special(":q"), Some(SpecialCommand::Quit));
334    }
335
336    #[test]
337    fn parse_special_screen() {
338        assert_eq!(parse_special(":screen"), Some(SpecialCommand::Screen));
339    }
340
341    #[test]
342    fn parse_special_state() {
343        assert_eq!(parse_special(":state"), Some(SpecialCommand::State));
344    }
345
346    #[test]
347    fn parse_special_ping() {
348        assert_eq!(parse_special(":ping"), Some(SpecialCommand::Ping));
349    }
350
351    #[test]
352    fn parse_special_none() {
353        assert_eq!(parse_special("hello world"), None);
354        assert_eq!(parse_special(""), None);
355        assert_eq!(parse_special(":unknown"), None);
356    }
357}