Skip to main content

gritty/
protocol.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_AGENT_FORWARD: u8 = 0x08;
13const TYPE_AGENT_OPEN: u8 = 0x09;
14const TYPE_AGENT_DATA: u8 = 0x0A;
15const TYPE_AGENT_CLOSE: u8 = 0x0B;
16const TYPE_OPEN_FORWARD: u8 = 0x0C;
17const TYPE_OPEN_URL: u8 = 0x0D;
18const TYPE_NEW_SESSION: u8 = 0x10;
19const TYPE_ATTACH: u8 = 0x11;
20const TYPE_LIST_SESSIONS: u8 = 0x12;
21const TYPE_KILL_SESSION: u8 = 0x13;
22const TYPE_KILL_SERVER: u8 = 0x14;
23const TYPE_TAIL: u8 = 0x15;
24const TYPE_HELLO: u8 = 0x16;
25const TYPE_SESSION_CREATED: u8 = 0x20;
26const TYPE_SESSION_INFO: u8 = 0x21;
27const TYPE_OK: u8 = 0x22;
28const TYPE_ERROR: u8 = 0x23;
29const TYPE_HELLO_ACK: u8 = 0x24;
30
31const HEADER_LEN: usize = 5; // type(1) + length(4)
32const MAX_FRAME_SIZE: usize = 1 << 20; // 1 MB
33
34/// Protocol version for handshake negotiation.
35pub const PROTOCOL_VERSION: u16 = 1;
36
37/// Metadata for one session, returned in SessionInfo.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct SessionEntry {
40    pub id: String,
41    pub name: String,
42    pub pty_path: String,
43    pub shell_pid: u32,
44    pub created_at: u64,
45    pub attached: bool,
46    pub last_heartbeat: u64,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum Frame {
51    Data(Bytes),
52    Resize {
53        cols: u16,
54        rows: u16,
55    },
56    Exit {
57        code: i32,
58    },
59    /// Sent to a client when another client takes over the session.
60    Detached,
61    /// Heartbeat request (client → server).
62    Ping,
63    /// Heartbeat reply (server → client).
64    Pong,
65    /// Environment variables (client → server, sent before first Resize on new session).
66    Env {
67        vars: Vec<(String, String)>,
68    },
69    /// Client signals it can handle agent forwarding (client → server).
70    AgentForward,
71    /// New agent connection on the remote side (server → client).
72    AgentOpen {
73        channel_id: u32,
74    },
75    /// Agent protocol data (bidirectional).
76    AgentData {
77        channel_id: u32,
78        data: Bytes,
79    },
80    /// Close an agent channel (bidirectional).
81    AgentClose {
82        channel_id: u32,
83    },
84    /// Client signals it can handle URL open forwarding (client → server).
85    OpenForward,
86    /// URL to open on the client machine (server → client).
87    OpenUrl {
88        url: String,
89    },
90    /// Protocol version handshake (client → server, first frame on connection).
91    Hello {
92        version: u16,
93    },
94    /// Protocol version acknowledgement (server → client).
95    HelloAck {
96        version: u16,
97    },
98    // Control requests
99    NewSession {
100        name: String,
101    },
102    Attach {
103        session: String,
104    },
105    /// Read-only tail of a session's PTY output (client → server).
106    Tail {
107        session: String,
108    },
109    ListSessions,
110    KillSession {
111        session: String,
112    },
113    KillServer,
114    // Control responses
115    SessionCreated {
116        id: String,
117    },
118    SessionInfo {
119        sessions: Vec<SessionEntry>,
120    },
121    Ok,
122    Error {
123        message: String,
124    },
125}
126
127impl Frame {
128    /// Extract a Frame from a `framed.next().await` result, converting
129    /// the common None / Some(Err) cases into descriptive errors.
130    pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
131        match result {
132            Some(Ok(frame)) => Ok(frame),
133            Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
134            None => Err(anyhow::anyhow!("daemon closed connection")),
135        }
136    }
137}
138
139pub struct FrameCodec;
140
141fn encode_empty(dst: &mut BytesMut, ty: u8) {
142    dst.put_u8(ty);
143    dst.put_u32(0);
144}
145
146fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
147    dst.put_u8(ty);
148    dst.put_u32(s.len() as u32);
149    dst.extend_from_slice(s.as_bytes());
150}
151
152fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
153    String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
154}
155
156impl Decoder for FrameCodec {
157    type Item = Frame;
158    type Error = io::Error;
159
160    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
161        if src.len() < HEADER_LEN {
162            return Ok(None);
163        }
164
165        let frame_type = src[0];
166        let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
167
168        if payload_len > MAX_FRAME_SIZE {
169            return Err(io::Error::new(
170                io::ErrorKind::InvalidData,
171                format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
172            ));
173        }
174
175        if src.len() < HEADER_LEN + payload_len {
176            src.reserve(HEADER_LEN + payload_len - src.len());
177            return Ok(None);
178        }
179
180        src.advance(HEADER_LEN);
181        let payload = src.split_to(payload_len);
182
183        match frame_type {
184            TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
185            TYPE_RESIZE => {
186                if payload.len() != 4 {
187                    return Err(io::Error::new(
188                        io::ErrorKind::InvalidData,
189                        "resize frame must be 4 bytes",
190                    ));
191                }
192                let cols = u16::from_be_bytes([payload[0], payload[1]]);
193                let rows = u16::from_be_bytes([payload[2], payload[3]]);
194                Ok(Some(Frame::Resize { cols, rows }))
195            }
196            TYPE_EXIT => {
197                if payload.len() != 4 {
198                    return Err(io::Error::new(
199                        io::ErrorKind::InvalidData,
200                        "exit frame must be 4 bytes",
201                    ));
202                }
203                let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
204                Ok(Some(Frame::Exit { code }))
205            }
206            TYPE_DETACHED => Ok(Some(Frame::Detached)),
207            TYPE_PING => Ok(Some(Frame::Ping)),
208            TYPE_PONG => Ok(Some(Frame::Pong)),
209            TYPE_ENV => {
210                let text = String::from_utf8(payload.to_vec())
211                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
212                let vars = if text.is_empty() {
213                    Vec::new()
214                } else {
215                    text.lines()
216                        .filter_map(|line| {
217                            let (k, v) = line.split_once('=')?;
218                            Some((k.to_string(), v.to_string()))
219                        })
220                        .collect()
221                };
222                Ok(Some(Frame::Env { vars }))
223            }
224            TYPE_AGENT_FORWARD => Ok(Some(Frame::AgentForward)),
225            TYPE_AGENT_OPEN => {
226                if payload.len() != 4 {
227                    return Err(io::Error::new(
228                        io::ErrorKind::InvalidData,
229                        "agent open frame must be 4 bytes",
230                    ));
231                }
232                let channel_id =
233                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
234                Ok(Some(Frame::AgentOpen { channel_id }))
235            }
236            TYPE_AGENT_DATA => {
237                if payload.len() < 4 {
238                    return Err(io::Error::new(
239                        io::ErrorKind::InvalidData,
240                        "agent data frame must be at least 4 bytes",
241                    ));
242                }
243                let channel_id =
244                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
245                let data = payload.freeze().slice(4..);
246                Ok(Some(Frame::AgentData { channel_id, data }))
247            }
248            TYPE_AGENT_CLOSE => {
249                if payload.len() != 4 {
250                    return Err(io::Error::new(
251                        io::ErrorKind::InvalidData,
252                        "agent close frame must be 4 bytes",
253                    ));
254                }
255                let channel_id =
256                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
257                Ok(Some(Frame::AgentClose { channel_id }))
258            }
259            TYPE_OPEN_FORWARD => Ok(Some(Frame::OpenForward)),
260            TYPE_OPEN_URL => Ok(Some(Frame::OpenUrl { url: decode_string(payload)? })),
261            TYPE_HELLO => {
262                if payload.len() != 2 {
263                    return Err(io::Error::new(
264                        io::ErrorKind::InvalidData,
265                        "hello frame must be 2 bytes",
266                    ));
267                }
268                let version = u16::from_be_bytes([payload[0], payload[1]]);
269                Ok(Some(Frame::Hello { version }))
270            }
271            TYPE_HELLO_ACK => {
272                if payload.len() != 2 {
273                    return Err(io::Error::new(
274                        io::ErrorKind::InvalidData,
275                        "hello ack frame must be 2 bytes",
276                    ));
277                }
278                let version = u16::from_be_bytes([payload[0], payload[1]]);
279                Ok(Some(Frame::HelloAck { version }))
280            }
281            TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
282            TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
283            TYPE_TAIL => Ok(Some(Frame::Tail { session: decode_string(payload)? })),
284            TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
285            TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
286            TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
287            TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
288            TYPE_SESSION_INFO => {
289                let text = String::from_utf8(payload.to_vec())
290                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
291                let sessions = if text.is_empty() {
292                    Vec::new()
293                } else {
294                    text.lines()
295                        .filter_map(|line| {
296                            let parts: Vec<&str> = line.split('\t').collect();
297                            if parts.len() == 7 {
298                                Some(SessionEntry {
299                                    id: parts[0].to_string(),
300                                    name: parts[1].to_string(),
301                                    pty_path: parts[2].to_string(),
302                                    shell_pid: parts[3].parse().unwrap_or(0),
303                                    created_at: parts[4].parse().unwrap_or(0),
304                                    attached: parts[5] == "1",
305                                    last_heartbeat: parts[6].parse().unwrap_or(0),
306                                })
307                            } else {
308                                None
309                            }
310                        })
311                        .collect()
312                };
313                Ok(Some(Frame::SessionInfo { sessions }))
314            }
315            TYPE_OK => Ok(Some(Frame::Ok)),
316            TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
317            _ => Err(io::Error::new(
318                io::ErrorKind::InvalidData,
319                format!("unknown frame type: 0x{frame_type:02x}"),
320            )),
321        }
322    }
323}
324
325impl Encoder<Frame> for FrameCodec {
326    type Error = io::Error;
327
328    fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
329        match frame {
330            Frame::Data(data) => {
331                dst.put_u8(TYPE_DATA);
332                dst.put_u32(data.len() as u32);
333                dst.extend_from_slice(&data);
334            }
335            Frame::Resize { cols, rows } => {
336                dst.put_u8(TYPE_RESIZE);
337                dst.put_u32(4);
338                dst.put_u16(cols);
339                dst.put_u16(rows);
340            }
341            Frame::Exit { code } => {
342                dst.put_u8(TYPE_EXIT);
343                dst.put_u32(4);
344                dst.put_i32(code);
345            }
346            Frame::Detached => encode_empty(dst, TYPE_DETACHED),
347            Frame::Ping => encode_empty(dst, TYPE_PING),
348            Frame::Pong => encode_empty(dst, TYPE_PONG),
349            Frame::Env { vars } => {
350                // Strip newlines from keys/values to prevent injection of extra
351                // key=value pairs via the newline-delimited wire format.
352                let text: String = vars
353                    .iter()
354                    .map(|(k, v)| {
355                        let k = k.replace('\n', "");
356                        let v = v.replace('\n', "");
357                        format!("{k}={v}")
358                    })
359                    .collect::<Vec<_>>()
360                    .join("\n");
361                dst.put_u8(TYPE_ENV);
362                dst.put_u32(text.len() as u32);
363                dst.extend_from_slice(text.as_bytes());
364            }
365            Frame::AgentForward => encode_empty(dst, TYPE_AGENT_FORWARD),
366            Frame::AgentOpen { channel_id } => {
367                dst.put_u8(TYPE_AGENT_OPEN);
368                dst.put_u32(4);
369                dst.put_u32(channel_id);
370            }
371            Frame::AgentData { channel_id, data } => {
372                dst.put_u8(TYPE_AGENT_DATA);
373                dst.put_u32(4 + data.len() as u32);
374                dst.put_u32(channel_id);
375                dst.extend_from_slice(&data);
376            }
377            Frame::AgentClose { channel_id } => {
378                dst.put_u8(TYPE_AGENT_CLOSE);
379                dst.put_u32(4);
380                dst.put_u32(channel_id);
381            }
382            Frame::OpenForward => encode_empty(dst, TYPE_OPEN_FORWARD),
383            Frame::OpenUrl { url } => encode_str(dst, TYPE_OPEN_URL, &url),
384            Frame::Hello { version } => {
385                dst.put_u8(TYPE_HELLO);
386                dst.put_u32(2);
387                dst.put_u16(version);
388            }
389            Frame::HelloAck { version } => {
390                dst.put_u8(TYPE_HELLO_ACK);
391                dst.put_u32(2);
392                dst.put_u16(version);
393            }
394            Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
395            Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
396            Frame::Tail { session } => encode_str(dst, TYPE_TAIL, &session),
397            Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
398            Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
399            Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
400            Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
401            Frame::SessionInfo { sessions } => {
402                let text: String = sessions
403                    .iter()
404                    .map(|e| {
405                        let safe_pty = e.pty_path.replace(['\t', '\n'], " ");
406                        format!(
407                            "{}\t{}\t{}\t{}\t{}\t{}\t{}",
408                            e.id,
409                            e.name,
410                            safe_pty,
411                            e.shell_pid,
412                            e.created_at,
413                            if e.attached { "1" } else { "0" },
414                            e.last_heartbeat
415                        )
416                    })
417                    .collect::<Vec<_>>()
418                    .join("\n");
419                dst.put_u8(TYPE_SESSION_INFO);
420                dst.put_u32(text.len() as u32);
421                dst.extend_from_slice(text.as_bytes());
422            }
423            Frame::Ok => encode_empty(dst, TYPE_OK),
424            Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
425        }
426        Ok(())
427    }
428}