Skip to main content

aft/commands/
bash_write.rs

1use crate::context::AppContext;
2use crate::protocol::{RawRequest, Response};
3use serde::Deserialize;
4use serde_json::json;
5
6const MAX_INPUT_BYTES: usize = 1_048_576;
7
8/// Input payload accepted by `bash_write`.
9///
10/// Two forms:
11/// * `String` — verbatim bytes written to the PTY. Backward-compatible with the
12///   v0.30 phase 1b shape; existing callers see no change.
13/// * `Sequence` — array of items, each either a plain string (text bytes) or a
14///   `{ "key": "<name>" }` object that expands to a known control-byte sequence
15///   (ESC, arrows, Ctrl chords, function keys, …). Items are concatenated into
16///   one atomic write so the PTY sees the whole sequence as one input chunk.
17///
18/// The agent never has to encode escape characters inside the `input` string —
19/// they use named keys instead. The string form remains the right choice when
20/// the agent wants to write literal `\u001b` characters (e.g. source code).
21#[derive(Debug, Deserialize)]
22#[serde(untagged)]
23pub enum BashWriteInput {
24    Text(String),
25    Sequence(Vec<SequenceItem>),
26}
27
28#[derive(Debug, Deserialize)]
29#[serde(untagged)]
30pub enum SequenceItem {
31    Text(String),
32    Key { key: String },
33}
34
35#[derive(Debug, Deserialize)]
36pub struct BashWriteParams {
37    pub task_id: String,
38    pub input: BashWriteInput,
39}
40
41pub fn handle(req: &RawRequest, ctx: &AppContext) -> Response {
42    let raw_params = req
43        .params
44        .get("params")
45        .cloned()
46        .unwrap_or_else(|| req.params.clone());
47    let params = match serde_json::from_value::<BashWriteParams>(raw_params) {
48        Ok(params) => params,
49        Err(e) => {
50            return Response::error(
51                &req.id,
52                "invalid_request",
53                format!("bash_write: invalid params: {e}"),
54            );
55        }
56    };
57
58    let bytes = match expand_input(&params.input) {
59        Ok(bytes) => bytes,
60        Err(message) => {
61            return Response::error(&req.id, "invalid_request", message);
62        }
63    };
64
65    if bytes.len() > MAX_INPUT_BYTES {
66        return Response::error(
67            &req.id,
68            "input_too_large",
69            "bash_write input exceeds 1 MiB limit",
70        );
71    }
72
73    match ctx
74        .bash_background()
75        .write_pty(&params.task_id, req.session(), &bytes)
76    {
77        Ok(bytes_written) => Response::success(&req.id, json!({ "bytes_written": bytes_written })),
78        Err(code) if code == "task_not_found" => Response::error(
79            &req.id,
80            "task_not_found",
81            format!("background task not found: {}", params.task_id),
82        ),
83        Err(code) if code == "task_not_pty" => Response::error(
84            &req.id,
85            "task_not_pty",
86            format!("background task is not a PTY task: {}", params.task_id),
87        ),
88        Err(code) if code == "task_exited" => Response::error(
89            &req.id,
90            "task_exited",
91            format!("PTY task is no longer running: {}", params.task_id),
92        ),
93        Err(message) => Response::error(&req.id, "write_failed", message),
94    }
95}
96
97fn expand_input(input: &BashWriteInput) -> Result<Vec<u8>, String> {
98    match input {
99        BashWriteInput::Text(s) => Ok(s.as_bytes().to_vec()),
100        BashWriteInput::Sequence(items) => {
101            let mut out: Vec<u8> = Vec::with_capacity(items.len() * 4);
102            for item in items {
103                match item {
104                    SequenceItem::Text(s) => out.extend_from_slice(s.as_bytes()),
105                    SequenceItem::Key { key } => {
106                        let bytes = key_to_bytes(key).ok_or_else(|| {
107                            format!(
108                                "bash_write: unknown key '{key}'; allowed keys: {}",
109                                allowed_keys_hint()
110                            )
111                        })?;
112                        out.extend_from_slice(bytes);
113                    }
114                }
115            }
116            Ok(out)
117        }
118    }
119}
120
121/// Map a named key to the byte sequence a terminal sends when that key is pressed.
122///
123/// Implementation notes:
124/// * Names are lowercased and ASCII-only; case-insensitive matching is done by
125///   lowercasing the caller-supplied name before lookup.
126/// * Control chords `ctrl-a` through `ctrl-z` map programmatically to `0x01..=0x1a`
127///   so we don't have to enumerate all 26.
128/// * Function keys use the xterm sequence variant (DECFNK / linux-console hybrid)
129///   that the vast majority of TUI programs accept.
130/// * Arrow / nav keys use the "normal" cursor-key mode sequence (`ESC [ X`)
131///   rather than application-keypad mode (`ESC O X`). Programs that toggle
132///   application mode (vim with `:set keymodel`) handle both; the normal form
133///   is the safer default.
134fn key_to_bytes(name: &str) -> Option<&'static [u8]> {
135    // Lowercased, hyphen-separated lookup. We allocate only when the input
136    // is already non-canonical (rare on the hot path).
137    let canonical: std::borrow::Cow<'_, str> = if name
138        .chars()
139        .all(|c| c.is_ascii_lowercase() || c == '-' || c.is_ascii_digit())
140    {
141        std::borrow::Cow::Borrowed(name)
142    } else {
143        std::borrow::Cow::Owned(name.to_ascii_lowercase())
144    };
145
146    static TABLE: &[(&str, &[u8])] = &[
147        // Line / whitespace
148        //
149        // ENTER maps to CR (\r, 0x0D) — the byte a real terminal sends when
150        // the user presses Enter. Cooked-mode programs (shells, REPLs) have
151        // the line discipline translate CR→LF (`icrnl`), so `\r` works for
152        // them too. Raw-mode TUIs (opencode TUI, vim insert mode, fzf, htop)
153        // see `\r` directly and treat it as submit. LF was wrong for the
154        // raw-mode case — opencode TUI would treat `\n` as multi-line input.
155        ("enter", b"\r"),
156        ("return", b"\r"),
157        ("tab", b"\t"),
158        ("space", b" "),
159        ("backspace", b"\x7f"),
160        // Escape
161        ("esc", b"\x1b"),
162        ("escape", b"\x1b"),
163        // Arrows (normal cursor-key mode)
164        ("up", b"\x1b[A"),
165        ("down", b"\x1b[B"),
166        ("right", b"\x1b[C"),
167        ("left", b"\x1b[D"),
168        // Navigation
169        ("home", b"\x1b[H"),
170        ("end", b"\x1b[F"),
171        ("page-up", b"\x1b[5~"),
172        ("page-down", b"\x1b[6~"),
173        ("delete", b"\x1b[3~"),
174        ("insert", b"\x1b[2~"),
175        // Function keys (xterm-style)
176        ("f1", b"\x1bOP"),
177        ("f2", b"\x1bOQ"),
178        ("f3", b"\x1bOR"),
179        ("f4", b"\x1bOS"),
180        ("f5", b"\x1b[15~"),
181        ("f6", b"\x1b[17~"),
182        ("f7", b"\x1b[18~"),
183        ("f8", b"\x1b[19~"),
184        ("f9", b"\x1b[20~"),
185        ("f10", b"\x1b[21~"),
186        ("f11", b"\x1b[23~"),
187        ("f12", b"\x1b[24~"),
188    ];
189
190    if let Some((_, bytes)) = TABLE.iter().find(|(n, _)| *n == canonical.as_ref()) {
191        return Some(bytes);
192    }
193
194    // Ctrl chords: ctrl-a → 0x01 … ctrl-z → 0x1a.
195    if let Some(rest) = canonical.strip_prefix("ctrl-") {
196        if rest.len() == 1 {
197            let c = rest.chars().next().unwrap();
198            if c.is_ascii_lowercase() {
199                let byte = (c as u8) - b'a' + 1;
200                return Some(CTRL_TABLE[byte as usize - 1]);
201            }
202        }
203    }
204
205    None
206}
207
208// Pre-materialized byte slices for ctrl-a..ctrl-z so key_to_bytes can return
209// `&'static [u8]` without allocating.
210static CTRL_TABLE: [&[u8]; 26] = [
211    b"\x01", b"\x02", b"\x03", b"\x04", b"\x05", b"\x06", b"\x07", b"\x08", b"\x09", b"\x0a",
212    b"\x0b", b"\x0c", b"\x0d", b"\x0e", b"\x0f", b"\x10", b"\x11", b"\x12", b"\x13", b"\x14",
213    b"\x15", b"\x16", b"\x17", b"\x18", b"\x19", b"\x1a",
214];
215
216fn allowed_keys_hint() -> &'static str {
217    "enter, return, tab, space, backspace, esc, escape, up, down, left, right, home, end, \
218     page-up, page-down, delete, insert, f1..f12, ctrl-a..ctrl-z"
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn text_form_passes_bytes_through_verbatim() {
227        let input = BashWriteInput::Text("hello\n".into());
228        let bytes = expand_input(&input).unwrap();
229        assert_eq!(bytes, b"hello\n");
230    }
231
232    #[test]
233    fn text_form_preserves_literal_escape_sequence_chars() {
234        // Critical backward-compat case: the agent wants to write the literal
235        // 6 characters \u001b (e.g. into source code), NOT an ESC byte.
236        let input = BashWriteInput::Text(r"\u001b[31mred\u001b[0m".into());
237        let bytes = expand_input(&input).unwrap();
238        assert_eq!(bytes, br"\u001b[31mred\u001b[0m");
239        // Sanity: byte count matches the literal char count
240        // (6 + 4 + 3 + 6 + 3 = 22), not 8 (the JSON-decoded count) and not
241        // some auto-stripped variant.
242        assert_eq!(bytes.len(), 22);
243    }
244
245    #[test]
246    fn sequence_form_expands_text_items() {
247        let input = BashWriteInput::Sequence(vec![
248            SequenceItem::Text("abc".into()),
249            SequenceItem::Text("def".into()),
250        ]);
251        let bytes = expand_input(&input).unwrap();
252        assert_eq!(bytes, b"abcdef");
253    }
254
255    #[test]
256    fn sequence_form_expands_named_keys_to_byte_sequences() {
257        let input = BashWriteInput::Sequence(vec![
258            SequenceItem::Key { key: "esc".into() },
259            SequenceItem::Key { key: "up".into() },
260            SequenceItem::Key {
261                key: "ctrl-c".into(),
262            },
263        ]);
264        let bytes = expand_input(&input).unwrap();
265        // ESC (\x1b) + arrow-up (\x1b[A) + ctrl-c (\x03)
266        assert_eq!(bytes, b"\x1b\x1b[A\x03");
267    }
268
269    #[test]
270    fn sequence_form_mixes_text_and_keys_in_order() {
271        // The vim "type some text, exit insert, save+quit" idiom.
272        let input = BashWriteInput::Sequence(vec![
273            SequenceItem::Text("iHello".into()),
274            SequenceItem::Key { key: "esc".into() },
275            SequenceItem::Text(":wq".into()),
276            SequenceItem::Key {
277                key: "enter".into(),
278            },
279        ]);
280        let bytes = expand_input(&input).unwrap();
281        // ENTER maps to CR (\r) for raw-mode TUI compatibility, not LF (\n).
282        // See key_to_bytes "Line / whitespace" docs for the rationale.
283        assert_eq!(bytes, b"iHello\x1b:wq\r");
284    }
285
286    #[test]
287    fn sequence_form_accepts_case_insensitive_key_names() {
288        let input = BashWriteInput::Sequence(vec![
289            SequenceItem::Key { key: "ESC".into() },
290            SequenceItem::Key {
291                key: "Ctrl-C".into(),
292            },
293        ]);
294        let bytes = expand_input(&input).unwrap();
295        assert_eq!(bytes, b"\x1b\x03");
296    }
297
298    #[test]
299    fn sequence_form_unknown_key_returns_error_with_hint() {
300        let input = BashWriteInput::Sequence(vec![SequenceItem::Key {
301            key: "windows-key".into(),
302        }]);
303        let err = expand_input(&input).unwrap_err();
304        assert!(err.contains("unknown key 'windows-key'"));
305        assert!(err.contains("allowed keys:"));
306    }
307
308    #[test]
309    fn ctrl_chord_table_covers_all_26_letters() {
310        for (i, letter) in ('a'..='z').enumerate() {
311            let name = format!("ctrl-{letter}");
312            let bytes = key_to_bytes(&name).unwrap_or_else(|| panic!("missing {name}"));
313            assert_eq!(bytes, &[(i as u8) + 1]);
314        }
315    }
316
317    #[test]
318    fn function_keys_use_documented_xterm_sequences() {
319        assert_eq!(key_to_bytes("f1"), Some(b"\x1bOP".as_slice()));
320        assert_eq!(key_to_bytes("f12"), Some(b"\x1b[24~".as_slice()));
321    }
322
323    #[test]
324    fn empty_sequence_produces_zero_bytes() {
325        let input = BashWriteInput::Sequence(vec![]);
326        let bytes = expand_input(&input).unwrap();
327        assert_eq!(bytes, b"");
328    }
329
330    #[test]
331    fn arrows_use_normal_cursor_key_mode_sequence() {
332        // ESC [ A/B/C/D form, not ESC O A/B/C/D (application mode).
333        assert_eq!(key_to_bytes("up"), Some(b"\x1b[A".as_slice()));
334        assert_eq!(key_to_bytes("down"), Some(b"\x1b[B".as_slice()));
335        assert_eq!(key_to_bytes("right"), Some(b"\x1b[C".as_slice()));
336        assert_eq!(key_to_bytes("left"), Some(b"\x1b[D".as_slice()));
337    }
338}