Skip to main content

keepty_protocol/
lib.rs

1//! keepty wire protocol
2//!
3//! Length-prefixed binary frames over Unix sockets.
4//! Frame: [u32 len (BE)][u8 version][u8 kind][payload...]
5//!
6//! Zero external dependencies — implement a client in any language
7//! by following this specification.
8
9use std::io::{self, Read, Write};
10
11pub const PROTOCOL_VERSION: u8 = 1;
12
13/// Deprecated: use socket_dir() for the runtime-resolved directory.
14/// Kept for backward compatibility with code that imports this constant.
15pub const SOCKET_DIR: &str = "/tmp";
16
17/// Maximum payload size (1MB, matches ring buffer capacity).
18pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
19/// Maximum frame body size (payload + version byte + kind byte).
20pub const MAX_FRAME_BODY_SIZE: usize = MAX_PAYLOAD_SIZE + 2;
21
22/// Client role when connecting to a keepty broker.
23///
24/// Roles are negotiated in the Hello handshake — the client declares
25/// what it wants to do, and the broker enforces access accordingly.
26#[repr(u8)]
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Role {
29    /// Exclusive control: can send input and resize the terminal.
30    /// Only one Writer at a time per broker session.
31    Writer = 1,
32    /// Read-only: receives the output stream but cannot send input.
33    /// Multiple Watchers can connect simultaneously.
34    Watcher = 2,
35    /// Server-side observation: like Watcher, but intended for
36    /// programmatic screen analysis (agents, health checks).
37    Monitor = 3,
38}
39
40/// Message types in the keepty protocol.
41#[repr(u8)]
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MsgKind {
44    /// Client handshake. Payload: [role: u8, cols: u16 BE, rows: u16 BE]
45    Hello = 1,
46    /// Broker acknowledgement. Payload: [pty_pid: u32 BE, cols: u16 BE, rows: u16 BE]
47    HelloAck = 2,
48    /// Raw keystroke data from client to PTY. Payload: raw bytes.
49    Input = 3,
50    /// Raw PTY output bytes broadcast to all clients. Payload: raw bytes.
51    Output = 4,
52    /// Terminal resize. Payload: [cols: u16 BE, rows: u16 BE]
53    Resize = 5,
54    /// Broker acknowledgement of resize. Payload: [resize_gen: u32 BE, cols: u16 BE, rows: u16 BE]
55    /// Broadcast to all clients after the broker applies the resize to the PTY.
56    /// Acts as an in-band fence: output before this was old geometry.
57    ResizeAck = 6,
58    /// Process exit. Payload: [exit_code: i32 BE]
59    Exit = 10,
60    /// Request graceful shutdown.
61    Shutdown = 11,
62    /// Connection keepalive (client -> broker).
63    Ping = 12,
64    /// Keepalive response (broker -> client).
65    Pong = 13,
66    /// Error message. Payload: UTF-8 error string.
67    Error = 127,
68}
69
70impl TryFrom<u8> for MsgKind {
71    type Error = u8;
72    fn try_from(v: u8) -> Result<Self, u8> {
73        match v {
74            1 => Ok(Self::Hello),
75            2 => Ok(Self::HelloAck),
76            3 => Ok(Self::Input),
77            4 => Ok(Self::Output),
78            5 => Ok(Self::Resize),
79            6 => Ok(Self::ResizeAck),
80            10 => Ok(Self::Exit),
81            11 => Ok(Self::Shutdown),
82            12 => Ok(Self::Ping),
83            13 => Ok(Self::Pong),
84            127 => Ok(Self::Error),
85            other => Err(other),
86        }
87    }
88}
89
90impl TryFrom<u8> for Role {
91    type Error = u8;
92    fn try_from(v: u8) -> Result<Self, u8> {
93        match v {
94            1 => Ok(Self::Writer),
95            2 => Ok(Self::Watcher),
96            3 => Ok(Self::Monitor),
97            other => Err(other),
98        }
99    }
100}
101
102/// Resolve the runtime socket directory.
103/// Priority: $XDG_RUNTIME_DIR/keepty → $TMPDIR/keepty (via std::env::temp_dir)
104///
105/// On macOS, $TMPDIR is already per-user (/var/folders/.../T/).
106/// On Linux, $XDG_RUNTIME_DIR is per-user (/run/user/$UID).
107/// No external dependencies — uses only std.
108pub fn socket_dir() -> String {
109    if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
110        if !xdg.is_empty() {
111            return format!("{}/keepty", xdg);
112        }
113    }
114    let tmp = std::env::temp_dir();
115    format!("{}/keepty", tmp.to_string_lossy().trim_end_matches('/'))
116}
117
118/// Construct the broker socket path for a session.
119pub fn socket_path(session_id: &str) -> String {
120    format!("{}/keepty-{}.sock", socket_dir(), session_id)
121}
122
123/// A parsed keepty protocol frame.
124#[derive(Debug)]
125pub struct Frame {
126    pub kind: MsgKind,
127    pub payload: Vec<u8>,
128}
129
130impl Frame {
131    pub fn new(kind: MsgKind, payload: Vec<u8>) -> Self {
132        Self { kind, payload }
133    }
134
135    /// Encode frame to wire format: [u32 len BE][u8 version][u8 kind][payload]
136    pub fn encode(&self) -> Vec<u8> {
137        let payload_len = self.payload.len();
138        let frame_len = 2 + payload_len; // version + kind + payload
139        let mut buf = Vec::with_capacity(4 + frame_len);
140        buf.extend_from_slice(&(frame_len as u32).to_be_bytes());
141        buf.push(PROTOCOL_VERSION);
142        buf.push(self.kind as u8);
143        buf.extend_from_slice(&self.payload);
144        buf
145    }
146
147    /// Read one frame from a reader. Returns None on EOF.
148    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Option<Self>> {
149        let mut len_buf = [0u8; 4];
150        match reader.read_exact(&mut len_buf) {
151            Ok(()) => {}
152            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
153            Err(e) => return Err(e),
154        }
155        let frame_len = u32::from_be_bytes(len_buf) as usize;
156        if frame_len < 2 {
157            return Err(io::Error::new(
158                io::ErrorKind::InvalidData,
159                "frame too short",
160            ));
161        }
162        if frame_len > MAX_FRAME_BODY_SIZE {
163            return Err(io::Error::new(
164                io::ErrorKind::InvalidData,
165                format!(
166                    "frame too large: {} bytes (max {})",
167                    frame_len, MAX_FRAME_BODY_SIZE
168                ),
169            ));
170        }
171        let mut data = vec![0u8; frame_len];
172        reader.read_exact(&mut data)?;
173        let version = data[0];
174        if version != PROTOCOL_VERSION {
175            return Err(io::Error::new(
176                io::ErrorKind::InvalidData,
177                format!(
178                    "unsupported protocol version: {} (expected {})",
179                    version, PROTOCOL_VERSION
180                ),
181            ));
182        }
183        let kind = MsgKind::try_from(data[1]).map_err(|v| {
184            io::Error::new(io::ErrorKind::InvalidData, format!("unknown kind: {}", v))
185        })?;
186        let payload = data[2..].to_vec();
187        Ok(Some(Frame { kind, payload }))
188    }
189
190    /// Write this frame to a writer.
191    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
192        writer.write_all(&self.encode())?;
193        writer.flush()
194    }
195}
196
197// --- Hello message helpers ---
198
199pub fn encode_hello(role: Role, cols: u16, rows: u16) -> Vec<u8> {
200    let mut payload = Vec::with_capacity(5);
201    payload.push(role as u8);
202    payload.extend_from_slice(&cols.to_be_bytes());
203    payload.extend_from_slice(&rows.to_be_bytes());
204    payload
205}
206
207pub fn decode_hello(payload: &[u8]) -> Option<(Role, u16, u16)> {
208    if payload.len() < 5 {
209        return None;
210    }
211    let role = Role::try_from(payload[0]).ok()?;
212    let cols = u16::from_be_bytes([payload[1], payload[2]]);
213    let rows = u16::from_be_bytes([payload[3], payload[4]]);
214    Some((role, cols, rows))
215}
216
217// --- HelloAck helpers ---
218
219pub fn encode_hello_ack(pty_pid: u32, cols: u16, rows: u16) -> Vec<u8> {
220    let mut payload = Vec::with_capacity(8);
221    payload.extend_from_slice(&pty_pid.to_be_bytes());
222    payload.extend_from_slice(&cols.to_be_bytes());
223    payload.extend_from_slice(&rows.to_be_bytes());
224    payload
225}
226
227pub fn decode_hello_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
228    if payload.len() < 8 {
229        return None;
230    }
231    let pid = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
232    let cols = u16::from_be_bytes([payload[4], payload[5]]);
233    let rows = u16::from_be_bytes([payload[6], payload[7]]);
234    Some((pid, cols, rows))
235}
236
237// --- Resize helpers ---
238
239pub fn encode_resize(cols: u16, rows: u16) -> Vec<u8> {
240    let mut payload = Vec::with_capacity(4);
241    payload.extend_from_slice(&cols.to_be_bytes());
242    payload.extend_from_slice(&rows.to_be_bytes());
243    payload
244}
245
246pub fn decode_resize(payload: &[u8]) -> Option<(u16, u16)> {
247    if payload.len() < 4 {
248        return None;
249    }
250    let cols = u16::from_be_bytes([payload[0], payload[1]]);
251    let rows = u16::from_be_bytes([payload[2], payload[3]]);
252    Some((cols, rows))
253}
254
255// --- ResizeAck helpers ---
256
257pub fn encode_resize_ack(gen: u32, cols: u16, rows: u16) -> Vec<u8> {
258    let mut payload = Vec::with_capacity(8);
259    payload.extend_from_slice(&gen.to_be_bytes());
260    payload.extend_from_slice(&cols.to_be_bytes());
261    payload.extend_from_slice(&rows.to_be_bytes());
262    payload
263}
264
265pub fn decode_resize_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
266    if payload.len() < 8 {
267        return None;
268    }
269    let gen = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
270    let cols = u16::from_be_bytes([payload[4], payload[5]]);
271    let rows = u16::from_be_bytes([payload[6], payload[7]]);
272    Some((gen, cols, rows))
273}
274
275// --- Exit helpers ---
276
277pub fn encode_exit(code: i32) -> Vec<u8> {
278    code.to_be_bytes().to_vec()
279}
280
281pub fn decode_exit(payload: &[u8]) -> Option<i32> {
282    if payload.len() < 4 {
283        return None;
284    }
285    Some(i32::from_be_bytes([
286        payload[0], payload[1], payload[2], payload[3],
287    ]))
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn frame_roundtrip() {
296        let frame = Frame::new(MsgKind::Output, b"hello world".to_vec());
297        let encoded = frame.encode();
298        let mut cursor = std::io::Cursor::new(encoded);
299        let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
300        assert_eq!(decoded.kind, MsgKind::Output);
301        assert_eq!(decoded.payload, b"hello world");
302    }
303
304    #[test]
305    fn hello_roundtrip() {
306        let payload = encode_hello(Role::Writer, 132, 51);
307        let (role, cols, rows) = decode_hello(&payload).unwrap();
308        assert_eq!(role, Role::Writer);
309        assert_eq!(cols, 132);
310        assert_eq!(rows, 51);
311    }
312
313    #[test]
314    fn hello_ack_roundtrip() {
315        let payload = encode_hello_ack(12345, 80, 24);
316        let (pid, cols, rows) = decode_hello_ack(&payload).unwrap();
317        assert_eq!(pid, 12345);
318        assert_eq!(cols, 80);
319        assert_eq!(rows, 24);
320    }
321
322    #[test]
323    fn resize_roundtrip() {
324        let payload = encode_resize(80, 24);
325        let (cols, rows) = decode_resize(&payload).unwrap();
326        assert_eq!(cols, 80);
327        assert_eq!(rows, 24);
328    }
329
330    #[test]
331    fn exit_roundtrip() {
332        let payload = encode_exit(42);
333        let code = decode_exit(&payload).unwrap();
334        assert_eq!(code, 42);
335    }
336
337    #[test]
338    fn socket_path_format() {
339        let path = socket_path("abc123");
340        assert!(path.ends_with("/keepty-abc123.sock"), "path: {}", path);
341        assert!(
342            path.contains("/keepty"),
343            "path should contain /keepty dir: {}",
344            path
345        );
346    }
347
348    #[test]
349    fn eof_returns_none() {
350        let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
351        let result = Frame::read_from(&mut cursor).unwrap();
352        assert!(result.is_none());
353    }
354
355    #[test]
356    fn all_roles_roundtrip() {
357        for role in [Role::Writer, Role::Watcher, Role::Monitor] {
358            let v = role as u8;
359            assert_eq!(Role::try_from(v).unwrap(), role);
360        }
361    }
362
363    #[test]
364    fn all_msg_kinds_roundtrip() {
365        for kind in [
366            MsgKind::Hello,
367            MsgKind::HelloAck,
368            MsgKind::Input,
369            MsgKind::Output,
370            MsgKind::Resize,
371            MsgKind::ResizeAck,
372            MsgKind::Exit,
373            MsgKind::Shutdown,
374            MsgKind::Ping,
375            MsgKind::Pong,
376            MsgKind::Error,
377        ] {
378            let v = kind as u8;
379            assert_eq!(MsgKind::try_from(v).unwrap(), kind);
380        }
381    }
382
383    #[test]
384    fn invalid_role_returns_err() {
385        assert!(Role::try_from(0).is_err());
386        assert!(Role::try_from(4).is_err());
387        assert!(Role::try_from(255).is_err());
388    }
389
390    #[test]
391    fn resize_ack_roundtrip() {
392        let payload = encode_resize_ack(42, 120, 40);
393        let (gen, cols, rows) = decode_resize_ack(&payload).unwrap();
394        assert_eq!(gen, 42);
395        assert_eq!(cols, 120);
396        assert_eq!(rows, 40);
397    }
398
399    #[test]
400    fn invalid_msg_kind_returns_err() {
401        assert!(MsgKind::try_from(0).is_err());
402        assert!(MsgKind::try_from(7).is_err());
403        assert!(MsgKind::try_from(128).is_err());
404    }
405
406    #[test]
407    fn frame_too_short_is_error() {
408        // Length field says 1 byte, but we need at least 2 (version + kind)
409        let data = vec![0, 0, 0, 1, 0xFF];
410        let mut cursor = std::io::Cursor::new(data);
411        assert!(Frame::read_from(&mut cursor).is_err());
412    }
413
414    #[test]
415    fn empty_payload_frame() {
416        let frame = Frame::new(MsgKind::Ping, vec![]);
417        let encoded = frame.encode();
418        let mut cursor = std::io::Cursor::new(encoded);
419        let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
420        assert_eq!(decoded.kind, MsgKind::Ping);
421        assert!(decoded.payload.is_empty());
422    }
423
424    #[test]
425    fn oversized_frame_rejected() {
426        // Frame claiming to be 2MB — should be rejected without allocating
427        let len = (MAX_FRAME_BODY_SIZE + 1) as u32;
428        let mut data = len.to_be_bytes().to_vec();
429        data.push(PROTOCOL_VERSION);
430        data.push(MsgKind::Output as u8);
431        let mut cursor = std::io::Cursor::new(data);
432        let err = Frame::read_from(&mut cursor).unwrap_err();
433        assert!(err.to_string().contains("too large"));
434    }
435
436    #[test]
437    fn max_allowed_frame_accepted() {
438        // Frame at exactly MAX_FRAME_BODY_SIZE should be accepted
439        let payload = vec![0u8; MAX_PAYLOAD_SIZE];
440        let frame = Frame::new(MsgKind::Output, payload);
441        let encoded = frame.encode();
442        let mut cursor = std::io::Cursor::new(encoded);
443        let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
444        assert_eq!(decoded.kind, MsgKind::Output);
445        assert_eq!(decoded.payload.len(), MAX_PAYLOAD_SIZE);
446    }
447
448    #[test]
449    fn wrong_version_rejected() {
450        // Valid frame structure but wrong protocol version
451        let mut data = Vec::new();
452        let frame_len: u32 = 3; // version + kind + 1 byte payload
453        data.extend_from_slice(&frame_len.to_be_bytes());
454        data.push(99); // wrong version
455        data.push(MsgKind::Ping as u8);
456        data.push(0); // payload byte
457        let mut cursor = std::io::Cursor::new(data);
458        let err = Frame::read_from(&mut cursor).unwrap_err();
459        assert!(err.to_string().contains("unsupported protocol version"));
460    }
461}