use std::io::{self, Read, Write};
pub const PROTOCOL_VERSION: u8 = 1;
pub const SOCKET_DIR: &str = "/tmp";
pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
pub const MAX_FRAME_BODY_SIZE: usize = MAX_PAYLOAD_SIZE + 2;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
Writer = 1,
Watcher = 2,
Monitor = 3,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MsgKind {
Hello = 1,
HelloAck = 2,
Input = 3,
Output = 4,
Resize = 5,
ResizeAck = 6,
Exit = 10,
Shutdown = 11,
Ping = 12,
Pong = 13,
Error = 127,
}
impl TryFrom<u8> for MsgKind {
type Error = u8;
fn try_from(v: u8) -> Result<Self, u8> {
match v {
1 => Ok(Self::Hello),
2 => Ok(Self::HelloAck),
3 => Ok(Self::Input),
4 => Ok(Self::Output),
5 => Ok(Self::Resize),
6 => Ok(Self::ResizeAck),
10 => Ok(Self::Exit),
11 => Ok(Self::Shutdown),
12 => Ok(Self::Ping),
13 => Ok(Self::Pong),
127 => Ok(Self::Error),
other => Err(other),
}
}
}
impl TryFrom<u8> for Role {
type Error = u8;
fn try_from(v: u8) -> Result<Self, u8> {
match v {
1 => Ok(Self::Writer),
2 => Ok(Self::Watcher),
3 => Ok(Self::Monitor),
other => Err(other),
}
}
}
pub fn socket_dir() -> String {
if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
if !xdg.is_empty() {
return format!("{}/keepty", xdg);
}
}
let tmp = std::env::temp_dir();
format!("{}/keepty", tmp.to_string_lossy().trim_end_matches('/'))
}
pub fn socket_path(session_id: &str) -> String {
format!("{}/keepty-{}.sock", socket_dir(), session_id)
}
#[derive(Debug)]
pub struct Frame {
pub kind: MsgKind,
pub payload: Vec<u8>,
}
impl Frame {
pub fn new(kind: MsgKind, payload: Vec<u8>) -> Self {
Self { kind, payload }
}
pub fn encode(&self) -> Vec<u8> {
let payload_len = self.payload.len();
let frame_len = 2 + payload_len; let mut buf = Vec::with_capacity(4 + frame_len);
buf.extend_from_slice(&(frame_len as u32).to_be_bytes());
buf.push(PROTOCOL_VERSION);
buf.push(self.kind as u8);
buf.extend_from_slice(&self.payload);
buf
}
pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Option<Self>> {
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e),
}
let frame_len = u32::from_be_bytes(len_buf) as usize;
if frame_len < 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"frame too short",
));
}
if frame_len > MAX_FRAME_BODY_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"frame too large: {} bytes (max {})",
frame_len, MAX_FRAME_BODY_SIZE
),
));
}
let mut data = vec![0u8; frame_len];
reader.read_exact(&mut data)?;
let version = data[0];
if version != PROTOCOL_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"unsupported protocol version: {} (expected {})",
version, PROTOCOL_VERSION
),
));
}
let kind = MsgKind::try_from(data[1]).map_err(|v| {
io::Error::new(io::ErrorKind::InvalidData, format!("unknown kind: {}", v))
})?;
let payload = data[2..].to_vec();
Ok(Some(Frame { kind, payload }))
}
pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.encode())?;
writer.flush()
}
}
pub fn encode_hello(role: Role, cols: u16, rows: u16) -> Vec<u8> {
let mut payload = Vec::with_capacity(5);
payload.push(role as u8);
payload.extend_from_slice(&cols.to_be_bytes());
payload.extend_from_slice(&rows.to_be_bytes());
payload
}
pub fn decode_hello(payload: &[u8]) -> Option<(Role, u16, u16)> {
if payload.len() < 5 {
return None;
}
let role = Role::try_from(payload[0]).ok()?;
let cols = u16::from_be_bytes([payload[1], payload[2]]);
let rows = u16::from_be_bytes([payload[3], payload[4]]);
Some((role, cols, rows))
}
pub fn encode_hello_ack(pty_pid: u32, cols: u16, rows: u16) -> Vec<u8> {
let mut payload = Vec::with_capacity(8);
payload.extend_from_slice(&pty_pid.to_be_bytes());
payload.extend_from_slice(&cols.to_be_bytes());
payload.extend_from_slice(&rows.to_be_bytes());
payload
}
pub fn decode_hello_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
if payload.len() < 8 {
return None;
}
let pid = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let cols = u16::from_be_bytes([payload[4], payload[5]]);
let rows = u16::from_be_bytes([payload[6], payload[7]]);
Some((pid, cols, rows))
}
pub fn encode_resize(cols: u16, rows: u16) -> Vec<u8> {
let mut payload = Vec::with_capacity(4);
payload.extend_from_slice(&cols.to_be_bytes());
payload.extend_from_slice(&rows.to_be_bytes());
payload
}
pub fn decode_resize(payload: &[u8]) -> Option<(u16, u16)> {
if payload.len() < 4 {
return None;
}
let cols = u16::from_be_bytes([payload[0], payload[1]]);
let rows = u16::from_be_bytes([payload[2], payload[3]]);
Some((cols, rows))
}
pub fn encode_resize_ack(gen: u32, cols: u16, rows: u16) -> Vec<u8> {
let mut payload = Vec::with_capacity(8);
payload.extend_from_slice(&gen.to_be_bytes());
payload.extend_from_slice(&cols.to_be_bytes());
payload.extend_from_slice(&rows.to_be_bytes());
payload
}
pub fn decode_resize_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
if payload.len() < 8 {
return None;
}
let gen = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let cols = u16::from_be_bytes([payload[4], payload[5]]);
let rows = u16::from_be_bytes([payload[6], payload[7]]);
Some((gen, cols, rows))
}
pub fn encode_exit(code: i32) -> Vec<u8> {
code.to_be_bytes().to_vec()
}
pub fn decode_exit(payload: &[u8]) -> Option<i32> {
if payload.len() < 4 {
return None;
}
Some(i32::from_be_bytes([
payload[0], payload[1], payload[2], payload[3],
]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn frame_roundtrip() {
let frame = Frame::new(MsgKind::Output, b"hello world".to_vec());
let encoded = frame.encode();
let mut cursor = std::io::Cursor::new(encoded);
let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
assert_eq!(decoded.kind, MsgKind::Output);
assert_eq!(decoded.payload, b"hello world");
}
#[test]
fn hello_roundtrip() {
let payload = encode_hello(Role::Writer, 132, 51);
let (role, cols, rows) = decode_hello(&payload).unwrap();
assert_eq!(role, Role::Writer);
assert_eq!(cols, 132);
assert_eq!(rows, 51);
}
#[test]
fn hello_ack_roundtrip() {
let payload = encode_hello_ack(12345, 80, 24);
let (pid, cols, rows) = decode_hello_ack(&payload).unwrap();
assert_eq!(pid, 12345);
assert_eq!(cols, 80);
assert_eq!(rows, 24);
}
#[test]
fn resize_roundtrip() {
let payload = encode_resize(80, 24);
let (cols, rows) = decode_resize(&payload).unwrap();
assert_eq!(cols, 80);
assert_eq!(rows, 24);
}
#[test]
fn exit_roundtrip() {
let payload = encode_exit(42);
let code = decode_exit(&payload).unwrap();
assert_eq!(code, 42);
}
#[test]
fn socket_path_format() {
let path = socket_path("abc123");
assert!(path.ends_with("/keepty-abc123.sock"), "path: {}", path);
assert!(
path.contains("/keepty"),
"path should contain /keepty dir: {}",
path
);
}
#[test]
fn eof_returns_none() {
let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
let result = Frame::read_from(&mut cursor).unwrap();
assert!(result.is_none());
}
#[test]
fn all_roles_roundtrip() {
for role in [Role::Writer, Role::Watcher, Role::Monitor] {
let v = role as u8;
assert_eq!(Role::try_from(v).unwrap(), role);
}
}
#[test]
fn all_msg_kinds_roundtrip() {
for kind in [
MsgKind::Hello,
MsgKind::HelloAck,
MsgKind::Input,
MsgKind::Output,
MsgKind::Resize,
MsgKind::ResizeAck,
MsgKind::Exit,
MsgKind::Shutdown,
MsgKind::Ping,
MsgKind::Pong,
MsgKind::Error,
] {
let v = kind as u8;
assert_eq!(MsgKind::try_from(v).unwrap(), kind);
}
}
#[test]
fn invalid_role_returns_err() {
assert!(Role::try_from(0).is_err());
assert!(Role::try_from(4).is_err());
assert!(Role::try_from(255).is_err());
}
#[test]
fn resize_ack_roundtrip() {
let payload = encode_resize_ack(42, 120, 40);
let (gen, cols, rows) = decode_resize_ack(&payload).unwrap();
assert_eq!(gen, 42);
assert_eq!(cols, 120);
assert_eq!(rows, 40);
}
#[test]
fn invalid_msg_kind_returns_err() {
assert!(MsgKind::try_from(0).is_err());
assert!(MsgKind::try_from(7).is_err());
assert!(MsgKind::try_from(128).is_err());
}
#[test]
fn frame_too_short_is_error() {
let data = vec![0, 0, 0, 1, 0xFF];
let mut cursor = std::io::Cursor::new(data);
assert!(Frame::read_from(&mut cursor).is_err());
}
#[test]
fn empty_payload_frame() {
let frame = Frame::new(MsgKind::Ping, vec![]);
let encoded = frame.encode();
let mut cursor = std::io::Cursor::new(encoded);
let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
assert_eq!(decoded.kind, MsgKind::Ping);
assert!(decoded.payload.is_empty());
}
#[test]
fn oversized_frame_rejected() {
let len = (MAX_FRAME_BODY_SIZE + 1) as u32;
let mut data = len.to_be_bytes().to_vec();
data.push(PROTOCOL_VERSION);
data.push(MsgKind::Output as u8);
let mut cursor = std::io::Cursor::new(data);
let err = Frame::read_from(&mut cursor).unwrap_err();
assert!(err.to_string().contains("too large"));
}
#[test]
fn max_allowed_frame_accepted() {
let payload = vec![0u8; MAX_PAYLOAD_SIZE];
let frame = Frame::new(MsgKind::Output, payload);
let encoded = frame.encode();
let mut cursor = std::io::Cursor::new(encoded);
let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
assert_eq!(decoded.kind, MsgKind::Output);
assert_eq!(decoded.payload.len(), MAX_PAYLOAD_SIZE);
}
#[test]
fn wrong_version_rejected() {
let mut data = Vec::new();
let frame_len: u32 = 3; data.extend_from_slice(&frame_len.to_be_bytes());
data.push(99); data.push(MsgKind::Ping as u8);
data.push(0); let mut cursor = std::io::Cursor::new(data);
let err = Frame::read_from(&mut cursor).unwrap_err();
assert!(err.to_string().contains("unsupported protocol version"));
}
}