1use std::io::{self, Read, Write};
10
11pub const PROTOCOL_VERSION: u8 = 1;
12
13pub const SOCKET_DIR: &str = "/tmp";
16
17pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
19pub const MAX_FRAME_BODY_SIZE: usize = MAX_PAYLOAD_SIZE + 2;
21
22#[repr(u8)]
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum Role {
29 Writer = 1,
32 Watcher = 2,
35 Monitor = 3,
38}
39
40#[repr(u8)]
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MsgKind {
44 Hello = 1,
46 HelloAck = 2,
48 Input = 3,
50 Output = 4,
52 Resize = 5,
54 ResizeAck = 6,
58 Exit = 10,
60 Shutdown = 11,
62 Ping = 12,
64 Pong = 13,
66 Error = 127,
68}
69
70impl TryFrom<u8> for MsgKind {
71 type Error = u8;
72 fn try_from(v: u8) -> Result<Self, u8> {
73 match v {
74 1 => Ok(Self::Hello),
75 2 => Ok(Self::HelloAck),
76 3 => Ok(Self::Input),
77 4 => Ok(Self::Output),
78 5 => Ok(Self::Resize),
79 6 => Ok(Self::ResizeAck),
80 10 => Ok(Self::Exit),
81 11 => Ok(Self::Shutdown),
82 12 => Ok(Self::Ping),
83 13 => Ok(Self::Pong),
84 127 => Ok(Self::Error),
85 other => Err(other),
86 }
87 }
88}
89
90impl TryFrom<u8> for Role {
91 type Error = u8;
92 fn try_from(v: u8) -> Result<Self, u8> {
93 match v {
94 1 => Ok(Self::Writer),
95 2 => Ok(Self::Watcher),
96 3 => Ok(Self::Monitor),
97 other => Err(other),
98 }
99 }
100}
101
102pub fn socket_dir() -> String {
109 if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
110 if !xdg.is_empty() {
111 return format!("{}/keepty", xdg);
112 }
113 }
114 let tmp = std::env::temp_dir();
115 format!("{}/keepty", tmp.to_string_lossy().trim_end_matches('/'))
116}
117
118pub fn socket_path(session_id: &str) -> String {
120 format!("{}/keepty-{}.sock", socket_dir(), session_id)
121}
122
123#[derive(Debug)]
125pub struct Frame {
126 pub kind: MsgKind,
127 pub payload: Vec<u8>,
128}
129
130impl Frame {
131 pub fn new(kind: MsgKind, payload: Vec<u8>) -> Self {
132 Self { kind, payload }
133 }
134
135 pub fn encode(&self) -> Vec<u8> {
137 let payload_len = self.payload.len();
138 let frame_len = 2 + payload_len; let mut buf = Vec::with_capacity(4 + frame_len);
140 buf.extend_from_slice(&(frame_len as u32).to_be_bytes());
141 buf.push(PROTOCOL_VERSION);
142 buf.push(self.kind as u8);
143 buf.extend_from_slice(&self.payload);
144 buf
145 }
146
147 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Option<Self>> {
149 let mut len_buf = [0u8; 4];
150 match reader.read_exact(&mut len_buf) {
151 Ok(()) => {}
152 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
153 Err(e) => return Err(e),
154 }
155 let frame_len = u32::from_be_bytes(len_buf) as usize;
156 if frame_len < 2 {
157 return Err(io::Error::new(
158 io::ErrorKind::InvalidData,
159 "frame too short",
160 ));
161 }
162 if frame_len > MAX_FRAME_BODY_SIZE {
163 return Err(io::Error::new(
164 io::ErrorKind::InvalidData,
165 format!(
166 "frame too large: {} bytes (max {})",
167 frame_len, MAX_FRAME_BODY_SIZE
168 ),
169 ));
170 }
171 let mut data = vec![0u8; frame_len];
172 reader.read_exact(&mut data)?;
173 let version = data[0];
174 if version != PROTOCOL_VERSION {
175 return Err(io::Error::new(
176 io::ErrorKind::InvalidData,
177 format!(
178 "unsupported protocol version: {} (expected {})",
179 version, PROTOCOL_VERSION
180 ),
181 ));
182 }
183 let kind = MsgKind::try_from(data[1]).map_err(|v| {
184 io::Error::new(io::ErrorKind::InvalidData, format!("unknown kind: {}", v))
185 })?;
186 let payload = data[2..].to_vec();
187 Ok(Some(Frame { kind, payload }))
188 }
189
190 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
192 writer.write_all(&self.encode())?;
193 writer.flush()
194 }
195}
196
197pub fn encode_hello(role: Role, cols: u16, rows: u16) -> Vec<u8> {
200 let mut payload = Vec::with_capacity(5);
201 payload.push(role as u8);
202 payload.extend_from_slice(&cols.to_be_bytes());
203 payload.extend_from_slice(&rows.to_be_bytes());
204 payload
205}
206
207pub fn decode_hello(payload: &[u8]) -> Option<(Role, u16, u16)> {
208 if payload.len() < 5 {
209 return None;
210 }
211 let role = Role::try_from(payload[0]).ok()?;
212 let cols = u16::from_be_bytes([payload[1], payload[2]]);
213 let rows = u16::from_be_bytes([payload[3], payload[4]]);
214 Some((role, cols, rows))
215}
216
217pub fn encode_hello_ack(pty_pid: u32, cols: u16, rows: u16) -> Vec<u8> {
220 let mut payload = Vec::with_capacity(8);
221 payload.extend_from_slice(&pty_pid.to_be_bytes());
222 payload.extend_from_slice(&cols.to_be_bytes());
223 payload.extend_from_slice(&rows.to_be_bytes());
224 payload
225}
226
227pub fn decode_hello_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
228 if payload.len() < 8 {
229 return None;
230 }
231 let pid = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
232 let cols = u16::from_be_bytes([payload[4], payload[5]]);
233 let rows = u16::from_be_bytes([payload[6], payload[7]]);
234 Some((pid, cols, rows))
235}
236
237pub fn encode_resize(cols: u16, rows: u16) -> Vec<u8> {
240 let mut payload = Vec::with_capacity(4);
241 payload.extend_from_slice(&cols.to_be_bytes());
242 payload.extend_from_slice(&rows.to_be_bytes());
243 payload
244}
245
246pub fn decode_resize(payload: &[u8]) -> Option<(u16, u16)> {
247 if payload.len() < 4 {
248 return None;
249 }
250 let cols = u16::from_be_bytes([payload[0], payload[1]]);
251 let rows = u16::from_be_bytes([payload[2], payload[3]]);
252 Some((cols, rows))
253}
254
255pub fn encode_resize_ack(gen: u32, cols: u16, rows: u16) -> Vec<u8> {
258 let mut payload = Vec::with_capacity(8);
259 payload.extend_from_slice(&gen.to_be_bytes());
260 payload.extend_from_slice(&cols.to_be_bytes());
261 payload.extend_from_slice(&rows.to_be_bytes());
262 payload
263}
264
265pub fn decode_resize_ack(payload: &[u8]) -> Option<(u32, u16, u16)> {
266 if payload.len() < 8 {
267 return None;
268 }
269 let gen = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
270 let cols = u16::from_be_bytes([payload[4], payload[5]]);
271 let rows = u16::from_be_bytes([payload[6], payload[7]]);
272 Some((gen, cols, rows))
273}
274
275pub fn encode_exit(code: i32) -> Vec<u8> {
278 code.to_be_bytes().to_vec()
279}
280
281pub fn decode_exit(payload: &[u8]) -> Option<i32> {
282 if payload.len() < 4 {
283 return None;
284 }
285 Some(i32::from_be_bytes([
286 payload[0], payload[1], payload[2], payload[3],
287 ]))
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn frame_roundtrip() {
296 let frame = Frame::new(MsgKind::Output, b"hello world".to_vec());
297 let encoded = frame.encode();
298 let mut cursor = std::io::Cursor::new(encoded);
299 let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
300 assert_eq!(decoded.kind, MsgKind::Output);
301 assert_eq!(decoded.payload, b"hello world");
302 }
303
304 #[test]
305 fn hello_roundtrip() {
306 let payload = encode_hello(Role::Writer, 132, 51);
307 let (role, cols, rows) = decode_hello(&payload).unwrap();
308 assert_eq!(role, Role::Writer);
309 assert_eq!(cols, 132);
310 assert_eq!(rows, 51);
311 }
312
313 #[test]
314 fn hello_ack_roundtrip() {
315 let payload = encode_hello_ack(12345, 80, 24);
316 let (pid, cols, rows) = decode_hello_ack(&payload).unwrap();
317 assert_eq!(pid, 12345);
318 assert_eq!(cols, 80);
319 assert_eq!(rows, 24);
320 }
321
322 #[test]
323 fn resize_roundtrip() {
324 let payload = encode_resize(80, 24);
325 let (cols, rows) = decode_resize(&payload).unwrap();
326 assert_eq!(cols, 80);
327 assert_eq!(rows, 24);
328 }
329
330 #[test]
331 fn exit_roundtrip() {
332 let payload = encode_exit(42);
333 let code = decode_exit(&payload).unwrap();
334 assert_eq!(code, 42);
335 }
336
337 #[test]
338 fn socket_path_format() {
339 let path = socket_path("abc123");
340 assert!(path.ends_with("/keepty-abc123.sock"), "path: {}", path);
341 assert!(
342 path.contains("/keepty"),
343 "path should contain /keepty dir: {}",
344 path
345 );
346 }
347
348 #[test]
349 fn eof_returns_none() {
350 let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
351 let result = Frame::read_from(&mut cursor).unwrap();
352 assert!(result.is_none());
353 }
354
355 #[test]
356 fn all_roles_roundtrip() {
357 for role in [Role::Writer, Role::Watcher, Role::Monitor] {
358 let v = role as u8;
359 assert_eq!(Role::try_from(v).unwrap(), role);
360 }
361 }
362
363 #[test]
364 fn all_msg_kinds_roundtrip() {
365 for kind in [
366 MsgKind::Hello,
367 MsgKind::HelloAck,
368 MsgKind::Input,
369 MsgKind::Output,
370 MsgKind::Resize,
371 MsgKind::ResizeAck,
372 MsgKind::Exit,
373 MsgKind::Shutdown,
374 MsgKind::Ping,
375 MsgKind::Pong,
376 MsgKind::Error,
377 ] {
378 let v = kind as u8;
379 assert_eq!(MsgKind::try_from(v).unwrap(), kind);
380 }
381 }
382
383 #[test]
384 fn invalid_role_returns_err() {
385 assert!(Role::try_from(0).is_err());
386 assert!(Role::try_from(4).is_err());
387 assert!(Role::try_from(255).is_err());
388 }
389
390 #[test]
391 fn resize_ack_roundtrip() {
392 let payload = encode_resize_ack(42, 120, 40);
393 let (gen, cols, rows) = decode_resize_ack(&payload).unwrap();
394 assert_eq!(gen, 42);
395 assert_eq!(cols, 120);
396 assert_eq!(rows, 40);
397 }
398
399 #[test]
400 fn invalid_msg_kind_returns_err() {
401 assert!(MsgKind::try_from(0).is_err());
402 assert!(MsgKind::try_from(7).is_err());
403 assert!(MsgKind::try_from(128).is_err());
404 }
405
406 #[test]
407 fn frame_too_short_is_error() {
408 let data = vec![0, 0, 0, 1, 0xFF];
410 let mut cursor = std::io::Cursor::new(data);
411 assert!(Frame::read_from(&mut cursor).is_err());
412 }
413
414 #[test]
415 fn empty_payload_frame() {
416 let frame = Frame::new(MsgKind::Ping, vec![]);
417 let encoded = frame.encode();
418 let mut cursor = std::io::Cursor::new(encoded);
419 let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
420 assert_eq!(decoded.kind, MsgKind::Ping);
421 assert!(decoded.payload.is_empty());
422 }
423
424 #[test]
425 fn oversized_frame_rejected() {
426 let len = (MAX_FRAME_BODY_SIZE + 1) as u32;
428 let mut data = len.to_be_bytes().to_vec();
429 data.push(PROTOCOL_VERSION);
430 data.push(MsgKind::Output as u8);
431 let mut cursor = std::io::Cursor::new(data);
432 let err = Frame::read_from(&mut cursor).unwrap_err();
433 assert!(err.to_string().contains("too large"));
434 }
435
436 #[test]
437 fn max_allowed_frame_accepted() {
438 let payload = vec![0u8; MAX_PAYLOAD_SIZE];
440 let frame = Frame::new(MsgKind::Output, payload);
441 let encoded = frame.encode();
442 let mut cursor = std::io::Cursor::new(encoded);
443 let decoded = Frame::read_from(&mut cursor).unwrap().unwrap();
444 assert_eq!(decoded.kind, MsgKind::Output);
445 assert_eq!(decoded.payload.len(), MAX_PAYLOAD_SIZE);
446 }
447
448 #[test]
449 fn wrong_version_rejected() {
450 let mut data = Vec::new();
452 let frame_len: u32 = 3; data.extend_from_slice(&frame_len.to_be_bytes());
454 data.push(99); data.push(MsgKind::Ping as u8);
456 data.push(0); let mut cursor = std::io::Cursor::new(data);
458 let err = Frame::read_from(&mut cursor).unwrap_err();
459 assert!(err.to_string().contains("unsupported protocol version"));
460 }
461}