Skip to main content

gritty/
protocol.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_NEW_SESSION: u8 = 0x10;
13const TYPE_ATTACH: u8 = 0x11;
14const TYPE_LIST_SESSIONS: u8 = 0x12;
15const TYPE_KILL_SESSION: u8 = 0x13;
16const TYPE_KILL_SERVER: u8 = 0x14;
17const TYPE_SESSION_CREATED: u8 = 0x20;
18const TYPE_SESSION_INFO: u8 = 0x21;
19const TYPE_OK: u8 = 0x22;
20const TYPE_ERROR: u8 = 0x23;
21
22const HEADER_LEN: usize = 5; // type(1) + length(4)
23const MAX_FRAME_SIZE: usize = 1 << 20; // 1 MB
24
25/// Metadata for one session, returned in SessionInfo.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct SessionEntry {
28    pub id: String,
29    pub name: String,
30    pub pty_path: String,
31    pub shell_pid: u32,
32    pub created_at: u64,
33    pub attached: bool,
34    pub last_heartbeat: u64,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum Frame {
39    Data(Bytes),
40    Resize {
41        cols: u16,
42        rows: u16,
43    },
44    Exit {
45        code: i32,
46    },
47    /// Sent to a client when another client takes over the session.
48    Detached,
49    /// Heartbeat request (client → server).
50    Ping,
51    /// Heartbeat reply (server → client).
52    Pong,
53    /// Environment variables (client → server, sent before first Resize on new session).
54    Env {
55        vars: Vec<(String, String)>,
56    },
57    // Control requests
58    NewSession {
59        name: String,
60    },
61    Attach {
62        session: String,
63    },
64    ListSessions,
65    KillSession {
66        session: String,
67    },
68    KillServer,
69    // Control responses
70    SessionCreated {
71        id: String,
72    },
73    SessionInfo {
74        sessions: Vec<SessionEntry>,
75    },
76    Ok,
77    Error {
78        message: String,
79    },
80}
81
82impl Frame {
83    /// Extract a Frame from a `framed.next().await` result, converting
84    /// the common None / Some(Err) cases into descriptive errors.
85    pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
86        match result {
87            Some(Ok(frame)) => Ok(frame),
88            Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
89            None => Err(anyhow::anyhow!("daemon closed connection")),
90        }
91    }
92}
93
94pub struct FrameCodec;
95
96fn encode_empty(dst: &mut BytesMut, ty: u8) {
97    dst.put_u8(ty);
98    dst.put_u32(0);
99}
100
101fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
102    dst.put_u8(ty);
103    dst.put_u32(s.len() as u32);
104    dst.extend_from_slice(s.as_bytes());
105}
106
107fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
108    String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
109}
110
111impl Decoder for FrameCodec {
112    type Item = Frame;
113    type Error = io::Error;
114
115    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
116        if src.len() < HEADER_LEN {
117            return Ok(None);
118        }
119
120        let frame_type = src[0];
121        let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
122
123        if payload_len > MAX_FRAME_SIZE {
124            return Err(io::Error::new(
125                io::ErrorKind::InvalidData,
126                format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
127            ));
128        }
129
130        if src.len() < HEADER_LEN + payload_len {
131            src.reserve(HEADER_LEN + payload_len - src.len());
132            return Ok(None);
133        }
134
135        src.advance(HEADER_LEN);
136        let payload = src.split_to(payload_len);
137
138        match frame_type {
139            TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
140            TYPE_RESIZE => {
141                if payload.len() != 4 {
142                    return Err(io::Error::new(
143                        io::ErrorKind::InvalidData,
144                        "resize frame must be 4 bytes",
145                    ));
146                }
147                let cols = u16::from_be_bytes([payload[0], payload[1]]);
148                let rows = u16::from_be_bytes([payload[2], payload[3]]);
149                Ok(Some(Frame::Resize { cols, rows }))
150            }
151            TYPE_EXIT => {
152                if payload.len() != 4 {
153                    return Err(io::Error::new(
154                        io::ErrorKind::InvalidData,
155                        "exit frame must be 4 bytes",
156                    ));
157                }
158                let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
159                Ok(Some(Frame::Exit { code }))
160            }
161            TYPE_DETACHED => Ok(Some(Frame::Detached)),
162            TYPE_PING => Ok(Some(Frame::Ping)),
163            TYPE_PONG => Ok(Some(Frame::Pong)),
164            TYPE_ENV => {
165                let text = String::from_utf8(payload.to_vec())
166                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
167                let vars = if text.is_empty() {
168                    Vec::new()
169                } else {
170                    text.lines()
171                        .filter_map(|line| {
172                            let (k, v) = line.split_once('=')?;
173                            Some((k.to_string(), v.to_string()))
174                        })
175                        .collect()
176                };
177                Ok(Some(Frame::Env { vars }))
178            }
179            TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
180            TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
181            TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
182            TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
183            TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
184            TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
185            TYPE_SESSION_INFO => {
186                let text = String::from_utf8(payload.to_vec())
187                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
188                let sessions = if text.is_empty() {
189                    Vec::new()
190                } else {
191                    text.lines()
192                        .filter_map(|line| {
193                            let parts: Vec<&str> = line.split('\t').collect();
194                            if parts.len() == 7 {
195                                Some(SessionEntry {
196                                    id: parts[0].to_string(),
197                                    name: parts[1].to_string(),
198                                    pty_path: parts[2].to_string(),
199                                    shell_pid: parts[3].parse().unwrap_or(0),
200                                    created_at: parts[4].parse().unwrap_or(0),
201                                    attached: parts[5] == "1",
202                                    last_heartbeat: parts[6].parse().unwrap_or(0),
203                                })
204                            } else {
205                                None
206                            }
207                        })
208                        .collect()
209                };
210                Ok(Some(Frame::SessionInfo { sessions }))
211            }
212            TYPE_OK => Ok(Some(Frame::Ok)),
213            TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
214            _ => Err(io::Error::new(
215                io::ErrorKind::InvalidData,
216                format!("unknown frame type: 0x{frame_type:02x}"),
217            )),
218        }
219    }
220}
221
222impl Encoder<Frame> for FrameCodec {
223    type Error = io::Error;
224
225    fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
226        match frame {
227            Frame::Data(data) => {
228                dst.put_u8(TYPE_DATA);
229                dst.put_u32(data.len() as u32);
230                dst.extend_from_slice(&data);
231            }
232            Frame::Resize { cols, rows } => {
233                dst.put_u8(TYPE_RESIZE);
234                dst.put_u32(4);
235                dst.put_u16(cols);
236                dst.put_u16(rows);
237            }
238            Frame::Exit { code } => {
239                dst.put_u8(TYPE_EXIT);
240                dst.put_u32(4);
241                dst.put_i32(code);
242            }
243            Frame::Detached => encode_empty(dst, TYPE_DETACHED),
244            Frame::Ping => encode_empty(dst, TYPE_PING),
245            Frame::Pong => encode_empty(dst, TYPE_PONG),
246            Frame::Env { vars } => {
247                let text: String =
248                    vars.iter().map(|(k, v)| format!("{k}={v}")).collect::<Vec<_>>().join("\n");
249                dst.put_u8(TYPE_ENV);
250                dst.put_u32(text.len() as u32);
251                dst.extend_from_slice(text.as_bytes());
252            }
253            Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
254            Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
255            Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
256            Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
257            Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
258            Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
259            Frame::SessionInfo { sessions } => {
260                let text: String = sessions
261                    .iter()
262                    .map(|e| {
263                        format!(
264                            "{}\t{}\t{}\t{}\t{}\t{}\t{}",
265                            e.id,
266                            e.name,
267                            e.pty_path,
268                            e.shell_pid,
269                            e.created_at,
270                            if e.attached { "1" } else { "0" },
271                            e.last_heartbeat
272                        )
273                    })
274                    .collect::<Vec<_>>()
275                    .join("\n");
276                dst.put_u8(TYPE_SESSION_INFO);
277                dst.put_u32(text.len() as u32);
278                dst.extend_from_slice(text.as_bytes());
279            }
280            Frame::Ok => encode_empty(dst, TYPE_OK),
281            Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
282        }
283        Ok(())
284    }
285}