use serde::{Deserialize, Serialize};
use std::io;
use std::path::PathBuf;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub const MAX_PAYLOAD_BYTES: usize = 1 << 20;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Request {
Ping,
Shutdown,
Read {
path: PathBuf,
max_bytes: Option<usize>,
},
Write {
path: PathBuf,
content: Vec<u8>,
},
Edit {
path: PathBuf,
old_string: String,
new_string: String,
},
Glob {
pattern: String,
root: PathBuf,
},
Grep {
pattern: String,
root: PathBuf,
include: Option<String>,
},
Stat {
path: PathBuf,
},
GetEnv {
names: Vec<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Response {
Pong,
Read {
content: Vec<u8>,
},
Write {
bytes_written: usize,
},
Edit {
replacements: usize,
},
Glob {
paths: Vec<PathBuf>,
},
Grep {
matches: Vec<GrepMatch>,
},
Stat {
size: u64,
is_dir: bool,
is_symlink: bool,
},
GetEnv {
values: Vec<Option<String>>,
},
Error {
code: ErrorCode,
message: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
Unimplemented,
PolicyDenied,
Io,
Protocol,
Internal,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct GrepMatch {
pub path: PathBuf,
pub line: usize,
pub text: String,
}
pub async fn read_message<R, T>(reader: &mut R) -> io::Result<Option<T>>
where
R: AsyncRead + Unpin + Send,
T: serde::de::DeserializeOwned + Send,
{
let mut len_buf = [0u8; 4];
let mut read_so_far = 0usize;
while read_so_far < 4 {
let n = reader.read(&mut len_buf[read_so_far..]).await?;
if n == 0 {
return if read_so_far == 0 {
Ok(None) } else {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("EOF after {read_so_far} of 4 length-prefix bytes"),
))
};
}
read_so_far += n;
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > MAX_PAYLOAD_BYTES {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("payload size {len} exceeds max {MAX_PAYLOAD_BYTES}"),
));
}
let mut payload = vec![0u8; len];
reader.read_exact(&mut payload).await?;
let parsed = serde_json::from_slice::<T>(&payload).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("malformed JSON payload: {e}"),
)
})?;
Ok(Some(parsed))
}
pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> io::Result<()>
where
W: AsyncWrite + Unpin + Send,
T: serde::Serialize + Sync,
{
let payload = serde_json::to_vec(msg)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("serialize: {e}")))?;
if payload.len() > MAX_PAYLOAD_BYTES {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"outgoing payload size {} exceeds max {}",
payload.len(),
MAX_PAYLOAD_BYTES
),
));
}
let len = (payload.len() as u32).to_be_bytes();
writer.write_all(&len).await?;
writer.write_all(&payload).await?;
writer.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::duplex;
#[tokio::test]
async fn ping_roundtrips_through_framing() {
let (mut a, mut b) = duplex(64);
write_message(&mut a, &Request::Ping).await.unwrap();
let got: Request = read_message(&mut b).await.unwrap().unwrap();
assert_eq!(got, Request::Ping);
}
#[tokio::test]
async fn pong_roundtrips_through_framing() {
let (mut a, mut b) = duplex(64);
write_message(&mut a, &Response::Pong).await.unwrap();
let got: Response = read_message(&mut b).await.unwrap().unwrap();
assert_eq!(got, Response::Pong);
}
#[tokio::test]
async fn error_response_roundtrips() {
let (mut a, mut b) = duplex(256);
let msg = Response::Error {
code: ErrorCode::PolicyDenied,
message: "deny file-write* /etc/passwd".into(),
};
write_message(&mut a, &msg).await.unwrap();
let got: Response = read_message(&mut b).await.unwrap().unwrap();
assert_eq!(got, msg);
}
#[tokio::test]
async fn complex_request_roundtrips() {
let (mut a, mut b) = duplex(512);
let req = Request::Edit {
path: PathBuf::from("/work/src/main.rs"),
old_string: "let x = 1;\nlet y = 2;".into(),
new_string: "let x = 42;".into(),
};
write_message(&mut a, &req).await.unwrap();
let got: Request = read_message(&mut b).await.unwrap().unwrap();
assert_eq!(got, req);
}
#[tokio::test]
async fn multiple_messages_back_to_back_parse_correctly() {
let (mut a, mut b) = duplex(1024);
write_message(&mut a, &Request::Ping).await.unwrap();
write_message(
&mut a,
&Request::Read {
path: "/a".into(),
max_bytes: Some(1024),
},
)
.await
.unwrap();
let m1: Request = read_message(&mut b).await.unwrap().unwrap();
let m2: Request = read_message(&mut b).await.unwrap().unwrap();
assert_eq!(m1, Request::Ping);
assert_eq!(
m2,
Request::Read {
path: "/a".into(),
max_bytes: Some(1024)
}
);
}
#[tokio::test]
async fn clean_eof_before_any_bytes_returns_none() {
let mut empty = Cursor::new(Vec::<u8>::new());
let got: Option<Request> = read_message(&mut empty).await.unwrap();
assert!(got.is_none());
}
#[tokio::test]
async fn eof_mid_length_prefix_is_unexpected_eof() {
let mut partial = Cursor::new(vec![0u8, 0u8]);
let err = read_message::<_, Request>(&mut partial).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn eof_mid_payload_is_unexpected_eof() {
let mut buf = Vec::new();
buf.extend_from_slice(&100u32.to_be_bytes());
buf.extend_from_slice(b"abcd");
let mut cur = Cursor::new(buf);
let err = read_message::<_, Request>(&mut cur).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn payload_size_above_cap_is_rejected() {
let oversize = (MAX_PAYLOAD_BYTES as u32 + 1).to_be_bytes();
let mut cur = Cursor::new(oversize.to_vec());
let err = read_message::<_, Request>(&mut cur).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("exceeds max"));
}
#[tokio::test]
async fn malformed_json_payload_is_invalid_data() {
let mut buf = Vec::new();
let body = b"this is not json";
buf.extend_from_slice(&(body.len() as u32).to_be_bytes());
buf.extend_from_slice(body);
let mut cur = Cursor::new(buf);
let err = read_message::<_, Request>(&mut cur).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("malformed JSON"));
}
#[tokio::test]
async fn write_rejects_oversize_payload_locally() {
let huge = "x".repeat(MAX_PAYLOAD_BYTES + 100);
let req = Request::Write {
path: "/a".into(),
content: huge.into_bytes(),
};
let (mut a, _b) = duplex(MAX_PAYLOAD_BYTES * 2);
let err = write_message(&mut a, &req).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
}