use crate::context::AppContext;
use crate::protocol::{RawRequest, Response};
use serde::Deserialize;
use serde_json::json;
const MAX_INPUT_BYTES: usize = 1_048_576;
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum BashWriteInput {
Text(String),
Sequence(Vec<SequenceItem>),
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum SequenceItem {
Text(String),
Key { key: String },
}
#[derive(Debug, Deserialize)]
pub struct BashWriteParams {
pub task_id: String,
pub input: BashWriteInput,
}
pub fn handle(req: &RawRequest, ctx: &AppContext) -> Response {
let raw_params = req
.params
.get("params")
.cloned()
.unwrap_or_else(|| req.params.clone());
let params = match serde_json::from_value::<BashWriteParams>(raw_params) {
Ok(params) => params,
Err(e) => {
return Response::error(
&req.id,
"invalid_request",
format!("bash_write: invalid params: {e}"),
);
}
};
let bytes = match expand_input(¶ms.input) {
Ok(bytes) => bytes,
Err(message) => {
return Response::error(&req.id, "invalid_request", message);
}
};
if bytes.len() > MAX_INPUT_BYTES {
return Response::error(
&req.id,
"input_too_large",
"bash_write input exceeds 1 MiB limit",
);
}
match ctx
.bash_background()
.write_pty(¶ms.task_id, req.session(), &bytes)
{
Ok(bytes_written) => Response::success(&req.id, json!({ "bytes_written": bytes_written })),
Err(code) if code == "task_not_found" => Response::error(
&req.id,
"task_not_found",
format!("background task not found: {}", params.task_id),
),
Err(code) if code == "task_not_pty" => Response::error(
&req.id,
"task_not_pty",
format!("background task is not a PTY task: {}", params.task_id),
),
Err(code) if code == "task_exited" => Response::error(
&req.id,
"task_exited",
format!("PTY task is no longer running: {}", params.task_id),
),
Err(message) => Response::error(&req.id, "write_failed", message),
}
}
fn expand_input(input: &BashWriteInput) -> Result<Vec<u8>, String> {
match input {
BashWriteInput::Text(s) => Ok(s.as_bytes().to_vec()),
BashWriteInput::Sequence(items) => {
let mut out: Vec<u8> = Vec::with_capacity(items.len() * 4);
for item in items {
match item {
SequenceItem::Text(s) => out.extend_from_slice(s.as_bytes()),
SequenceItem::Key { key } => {
let bytes = key_to_bytes(key).ok_or_else(|| {
format!(
"bash_write: unknown key '{key}'; allowed keys: {}",
allowed_keys_hint()
)
})?;
out.extend_from_slice(bytes);
}
}
}
Ok(out)
}
}
}
fn key_to_bytes(name: &str) -> Option<&'static [u8]> {
let canonical: std::borrow::Cow<'_, str> = if name
.chars()
.all(|c| c.is_ascii_lowercase() || c == '-' || c.is_ascii_digit())
{
std::borrow::Cow::Borrowed(name)
} else {
std::borrow::Cow::Owned(name.to_ascii_lowercase())
};
static TABLE: &[(&str, &[u8])] = &[
("enter", b"\r"),
("return", b"\r"),
("tab", b"\t"),
("space", b" "),
("backspace", b"\x7f"),
("esc", b"\x1b"),
("escape", b"\x1b"),
("up", b"\x1b[A"),
("down", b"\x1b[B"),
("right", b"\x1b[C"),
("left", b"\x1b[D"),
("home", b"\x1b[H"),
("end", b"\x1b[F"),
("page-up", b"\x1b[5~"),
("page-down", b"\x1b[6~"),
("delete", b"\x1b[3~"),
("insert", b"\x1b[2~"),
("f1", b"\x1bOP"),
("f2", b"\x1bOQ"),
("f3", b"\x1bOR"),
("f4", b"\x1bOS"),
("f5", b"\x1b[15~"),
("f6", b"\x1b[17~"),
("f7", b"\x1b[18~"),
("f8", b"\x1b[19~"),
("f9", b"\x1b[20~"),
("f10", b"\x1b[21~"),
("f11", b"\x1b[23~"),
("f12", b"\x1b[24~"),
];
if let Some((_, bytes)) = TABLE.iter().find(|(n, _)| *n == canonical.as_ref()) {
return Some(bytes);
}
if let Some(rest) = canonical.strip_prefix("ctrl-") {
if rest.len() == 1 {
let c = rest.chars().next().unwrap();
if c.is_ascii_lowercase() {
let byte = (c as u8) - b'a' + 1;
return Some(CTRL_TABLE[byte as usize - 1]);
}
}
}
None
}
static CTRL_TABLE: [&[u8]; 26] = [
b"\x01", b"\x02", b"\x03", b"\x04", b"\x05", b"\x06", b"\x07", b"\x08", b"\x09", b"\x0a",
b"\x0b", b"\x0c", b"\x0d", b"\x0e", b"\x0f", b"\x10", b"\x11", b"\x12", b"\x13", b"\x14",
b"\x15", b"\x16", b"\x17", b"\x18", b"\x19", b"\x1a",
];
fn allowed_keys_hint() -> &'static str {
"enter, return, tab, space, backspace, esc, escape, up, down, left, right, home, end, \
page-up, page-down, delete, insert, f1..f12, ctrl-a..ctrl-z"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn text_form_passes_bytes_through_verbatim() {
let input = BashWriteInput::Text("hello\n".into());
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"hello\n");
}
#[test]
fn text_form_preserves_literal_escape_sequence_chars() {
let input = BashWriteInput::Text(r"\u001b[31mred\u001b[0m".into());
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, br"\u001b[31mred\u001b[0m");
assert_eq!(bytes.len(), 22);
}
#[test]
fn sequence_form_expands_text_items() {
let input = BashWriteInput::Sequence(vec![
SequenceItem::Text("abc".into()),
SequenceItem::Text("def".into()),
]);
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"abcdef");
}
#[test]
fn sequence_form_expands_named_keys_to_byte_sequences() {
let input = BashWriteInput::Sequence(vec![
SequenceItem::Key { key: "esc".into() },
SequenceItem::Key { key: "up".into() },
SequenceItem::Key {
key: "ctrl-c".into(),
},
]);
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"\x1b\x1b[A\x03");
}
#[test]
fn sequence_form_mixes_text_and_keys_in_order() {
let input = BashWriteInput::Sequence(vec![
SequenceItem::Text("iHello".into()),
SequenceItem::Key { key: "esc".into() },
SequenceItem::Text(":wq".into()),
SequenceItem::Key {
key: "enter".into(),
},
]);
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"iHello\x1b:wq\r");
}
#[test]
fn sequence_form_accepts_case_insensitive_key_names() {
let input = BashWriteInput::Sequence(vec![
SequenceItem::Key { key: "ESC".into() },
SequenceItem::Key {
key: "Ctrl-C".into(),
},
]);
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"\x1b\x03");
}
#[test]
fn sequence_form_unknown_key_returns_error_with_hint() {
let input = BashWriteInput::Sequence(vec![SequenceItem::Key {
key: "windows-key".into(),
}]);
let err = expand_input(&input).unwrap_err();
assert!(err.contains("unknown key 'windows-key'"));
assert!(err.contains("allowed keys:"));
}
#[test]
fn ctrl_chord_table_covers_all_26_letters() {
for (i, letter) in ('a'..='z').enumerate() {
let name = format!("ctrl-{letter}");
let bytes = key_to_bytes(&name).unwrap_or_else(|| panic!("missing {name}"));
assert_eq!(bytes, &[(i as u8) + 1]);
}
}
#[test]
fn function_keys_use_documented_xterm_sequences() {
assert_eq!(key_to_bytes("f1"), Some(b"\x1bOP".as_slice()));
assert_eq!(key_to_bytes("f12"), Some(b"\x1b[24~".as_slice()));
}
#[test]
fn empty_sequence_produces_zero_bytes() {
let input = BashWriteInput::Sequence(vec![]);
let bytes = expand_input(&input).unwrap();
assert_eq!(bytes, b"");
}
#[test]
fn arrows_use_normal_cursor_key_mode_sequence() {
assert_eq!(key_to_bytes("up"), Some(b"\x1b[A".as_slice()));
assert_eq!(key_to_bytes("down"), Some(b"\x1b[B".as_slice()));
assert_eq!(key_to_bytes("right"), Some(b"\x1b[C".as_slice()));
assert_eq!(key_to_bytes("left"), Some(b"\x1b[D".as_slice()));
}
}