Skip to main content

gritty/
protocol.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_AGENT_FORWARD: u8 = 0x08;
13const TYPE_AGENT_OPEN: u8 = 0x09;
14const TYPE_AGENT_DATA: u8 = 0x0A;
15const TYPE_AGENT_CLOSE: u8 = 0x0B;
16const TYPE_OPEN_FORWARD: u8 = 0x0C;
17const TYPE_OPEN_URL: u8 = 0x0D;
18const TYPE_NEW_SESSION: u8 = 0x10;
19const TYPE_ATTACH: u8 = 0x11;
20const TYPE_LIST_SESSIONS: u8 = 0x12;
21const TYPE_KILL_SESSION: u8 = 0x13;
22const TYPE_KILL_SERVER: u8 = 0x14;
23const TYPE_TAIL: u8 = 0x15;
24const TYPE_SESSION_CREATED: u8 = 0x20;
25const TYPE_SESSION_INFO: u8 = 0x21;
26const TYPE_OK: u8 = 0x22;
27const TYPE_ERROR: u8 = 0x23;
28
29const HEADER_LEN: usize = 5; // type(1) + length(4)
30const MAX_FRAME_SIZE: usize = 1 << 20; // 1 MB
31
32/// Metadata for one session, returned in SessionInfo.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct SessionEntry {
35    pub id: String,
36    pub name: String,
37    pub pty_path: String,
38    pub shell_pid: u32,
39    pub created_at: u64,
40    pub attached: bool,
41    pub last_heartbeat: u64,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum Frame {
46    Data(Bytes),
47    Resize {
48        cols: u16,
49        rows: u16,
50    },
51    Exit {
52        code: i32,
53    },
54    /// Sent to a client when another client takes over the session.
55    Detached,
56    /// Heartbeat request (client → server).
57    Ping,
58    /// Heartbeat reply (server → client).
59    Pong,
60    /// Environment variables (client → server, sent before first Resize on new session).
61    Env {
62        vars: Vec<(String, String)>,
63    },
64    /// Client signals it can handle agent forwarding (client → server).
65    AgentForward,
66    /// New agent connection on the remote side (server → client).
67    AgentOpen {
68        channel_id: u32,
69    },
70    /// Agent protocol data (bidirectional).
71    AgentData {
72        channel_id: u32,
73        data: Bytes,
74    },
75    /// Close an agent channel (bidirectional).
76    AgentClose {
77        channel_id: u32,
78    },
79    /// Client signals it can handle URL open forwarding (client → server).
80    OpenForward,
81    /// URL to open on the client machine (server → client).
82    OpenUrl {
83        url: String,
84    },
85    // Control requests
86    NewSession {
87        name: String,
88    },
89    Attach {
90        session: String,
91    },
92    /// Read-only tail of a session's PTY output (client → server).
93    Tail {
94        session: String,
95    },
96    ListSessions,
97    KillSession {
98        session: String,
99    },
100    KillServer,
101    // Control responses
102    SessionCreated {
103        id: String,
104    },
105    SessionInfo {
106        sessions: Vec<SessionEntry>,
107    },
108    Ok,
109    Error {
110        message: String,
111    },
112}
113
114impl Frame {
115    /// Extract a Frame from a `framed.next().await` result, converting
116    /// the common None / Some(Err) cases into descriptive errors.
117    pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
118        match result {
119            Some(Ok(frame)) => Ok(frame),
120            Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
121            None => Err(anyhow::anyhow!("daemon closed connection")),
122        }
123    }
124}
125
126pub struct FrameCodec;
127
128fn encode_empty(dst: &mut BytesMut, ty: u8) {
129    dst.put_u8(ty);
130    dst.put_u32(0);
131}
132
133fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
134    dst.put_u8(ty);
135    dst.put_u32(s.len() as u32);
136    dst.extend_from_slice(s.as_bytes());
137}
138
139fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
140    String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
141}
142
143impl Decoder for FrameCodec {
144    type Item = Frame;
145    type Error = io::Error;
146
147    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
148        if src.len() < HEADER_LEN {
149            return Ok(None);
150        }
151
152        let frame_type = src[0];
153        let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
154
155        if payload_len > MAX_FRAME_SIZE {
156            return Err(io::Error::new(
157                io::ErrorKind::InvalidData,
158                format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
159            ));
160        }
161
162        if src.len() < HEADER_LEN + payload_len {
163            src.reserve(HEADER_LEN + payload_len - src.len());
164            return Ok(None);
165        }
166
167        src.advance(HEADER_LEN);
168        let payload = src.split_to(payload_len);
169
170        match frame_type {
171            TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
172            TYPE_RESIZE => {
173                if payload.len() != 4 {
174                    return Err(io::Error::new(
175                        io::ErrorKind::InvalidData,
176                        "resize frame must be 4 bytes",
177                    ));
178                }
179                let cols = u16::from_be_bytes([payload[0], payload[1]]);
180                let rows = u16::from_be_bytes([payload[2], payload[3]]);
181                Ok(Some(Frame::Resize { cols, rows }))
182            }
183            TYPE_EXIT => {
184                if payload.len() != 4 {
185                    return Err(io::Error::new(
186                        io::ErrorKind::InvalidData,
187                        "exit frame must be 4 bytes",
188                    ));
189                }
190                let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
191                Ok(Some(Frame::Exit { code }))
192            }
193            TYPE_DETACHED => Ok(Some(Frame::Detached)),
194            TYPE_PING => Ok(Some(Frame::Ping)),
195            TYPE_PONG => Ok(Some(Frame::Pong)),
196            TYPE_ENV => {
197                let text = String::from_utf8(payload.to_vec())
198                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
199                let vars = if text.is_empty() {
200                    Vec::new()
201                } else {
202                    text.lines()
203                        .filter_map(|line| {
204                            let (k, v) = line.split_once('=')?;
205                            Some((k.to_string(), v.to_string()))
206                        })
207                        .collect()
208                };
209                Ok(Some(Frame::Env { vars }))
210            }
211            TYPE_AGENT_FORWARD => Ok(Some(Frame::AgentForward)),
212            TYPE_AGENT_OPEN => {
213                if payload.len() != 4 {
214                    return Err(io::Error::new(
215                        io::ErrorKind::InvalidData,
216                        "agent open frame must be 4 bytes",
217                    ));
218                }
219                let channel_id =
220                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
221                Ok(Some(Frame::AgentOpen { channel_id }))
222            }
223            TYPE_AGENT_DATA => {
224                if payload.len() < 4 {
225                    return Err(io::Error::new(
226                        io::ErrorKind::InvalidData,
227                        "agent data frame must be at least 4 bytes",
228                    ));
229                }
230                let channel_id =
231                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
232                let data = payload.freeze().slice(4..);
233                Ok(Some(Frame::AgentData { channel_id, data }))
234            }
235            TYPE_AGENT_CLOSE => {
236                if payload.len() != 4 {
237                    return Err(io::Error::new(
238                        io::ErrorKind::InvalidData,
239                        "agent close frame must be 4 bytes",
240                    ));
241                }
242                let channel_id =
243                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
244                Ok(Some(Frame::AgentClose { channel_id }))
245            }
246            TYPE_OPEN_FORWARD => Ok(Some(Frame::OpenForward)),
247            TYPE_OPEN_URL => Ok(Some(Frame::OpenUrl { url: decode_string(payload)? })),
248            TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
249            TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
250            TYPE_TAIL => Ok(Some(Frame::Tail { session: decode_string(payload)? })),
251            TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
252            TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
253            TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
254            TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
255            TYPE_SESSION_INFO => {
256                let text = String::from_utf8(payload.to_vec())
257                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
258                let sessions = if text.is_empty() {
259                    Vec::new()
260                } else {
261                    text.lines()
262                        .filter_map(|line| {
263                            let parts: Vec<&str> = line.split('\t').collect();
264                            if parts.len() == 7 {
265                                Some(SessionEntry {
266                                    id: parts[0].to_string(),
267                                    name: parts[1].to_string(),
268                                    pty_path: parts[2].to_string(),
269                                    shell_pid: parts[3].parse().unwrap_or(0),
270                                    created_at: parts[4].parse().unwrap_or(0),
271                                    attached: parts[5] == "1",
272                                    last_heartbeat: parts[6].parse().unwrap_or(0),
273                                })
274                            } else {
275                                None
276                            }
277                        })
278                        .collect()
279                };
280                Ok(Some(Frame::SessionInfo { sessions }))
281            }
282            TYPE_OK => Ok(Some(Frame::Ok)),
283            TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
284            _ => Err(io::Error::new(
285                io::ErrorKind::InvalidData,
286                format!("unknown frame type: 0x{frame_type:02x}"),
287            )),
288        }
289    }
290}
291
292impl Encoder<Frame> for FrameCodec {
293    type Error = io::Error;
294
295    fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
296        match frame {
297            Frame::Data(data) => {
298                dst.put_u8(TYPE_DATA);
299                dst.put_u32(data.len() as u32);
300                dst.extend_from_slice(&data);
301            }
302            Frame::Resize { cols, rows } => {
303                dst.put_u8(TYPE_RESIZE);
304                dst.put_u32(4);
305                dst.put_u16(cols);
306                dst.put_u16(rows);
307            }
308            Frame::Exit { code } => {
309                dst.put_u8(TYPE_EXIT);
310                dst.put_u32(4);
311                dst.put_i32(code);
312            }
313            Frame::Detached => encode_empty(dst, TYPE_DETACHED),
314            Frame::Ping => encode_empty(dst, TYPE_PING),
315            Frame::Pong => encode_empty(dst, TYPE_PONG),
316            Frame::Env { vars } => {
317                // Strip newlines from keys/values to prevent injection of extra
318                // key=value pairs via the newline-delimited wire format.
319                let text: String = vars
320                    .iter()
321                    .map(|(k, v)| {
322                        let k = k.replace('\n', "");
323                        let v = v.replace('\n', "");
324                        format!("{k}={v}")
325                    })
326                    .collect::<Vec<_>>()
327                    .join("\n");
328                dst.put_u8(TYPE_ENV);
329                dst.put_u32(text.len() as u32);
330                dst.extend_from_slice(text.as_bytes());
331            }
332            Frame::AgentForward => encode_empty(dst, TYPE_AGENT_FORWARD),
333            Frame::AgentOpen { channel_id } => {
334                dst.put_u8(TYPE_AGENT_OPEN);
335                dst.put_u32(4);
336                dst.put_u32(channel_id);
337            }
338            Frame::AgentData { channel_id, data } => {
339                dst.put_u8(TYPE_AGENT_DATA);
340                dst.put_u32(4 + data.len() as u32);
341                dst.put_u32(channel_id);
342                dst.extend_from_slice(&data);
343            }
344            Frame::AgentClose { channel_id } => {
345                dst.put_u8(TYPE_AGENT_CLOSE);
346                dst.put_u32(4);
347                dst.put_u32(channel_id);
348            }
349            Frame::OpenForward => encode_empty(dst, TYPE_OPEN_FORWARD),
350            Frame::OpenUrl { url } => encode_str(dst, TYPE_OPEN_URL, &url),
351            Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
352            Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
353            Frame::Tail { session } => encode_str(dst, TYPE_TAIL, &session),
354            Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
355            Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
356            Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
357            Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
358            Frame::SessionInfo { sessions } => {
359                let text: String = sessions
360                    .iter()
361                    .map(|e| {
362                        format!(
363                            "{}\t{}\t{}\t{}\t{}\t{}\t{}",
364                            e.id,
365                            e.name,
366                            e.pty_path,
367                            e.shell_pid,
368                            e.created_at,
369                            if e.attached { "1" } else { "0" },
370                            e.last_heartbeat
371                        )
372                    })
373                    .collect::<Vec<_>>()
374                    .join("\n");
375                dst.put_u8(TYPE_SESSION_INFO);
376                dst.put_u32(text.len() as u32);
377                dst.extend_from_slice(text.as_bytes());
378            }
379            Frame::Ok => encode_empty(dst, TYPE_OK),
380            Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
381        }
382        Ok(())
383    }
384}