Skip to main content

gritty/
protocol.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_AGENT_FORWARD: u8 = 0x08;
13const TYPE_AGENT_OPEN: u8 = 0x09;
14const TYPE_AGENT_DATA: u8 = 0x0A;
15const TYPE_AGENT_CLOSE: u8 = 0x0B;
16const TYPE_OPEN_FORWARD: u8 = 0x0C;
17const TYPE_OPEN_URL: u8 = 0x0D;
18const TYPE_NEW_SESSION: u8 = 0x10;
19const TYPE_ATTACH: u8 = 0x11;
20const TYPE_LIST_SESSIONS: u8 = 0x12;
21const TYPE_KILL_SESSION: u8 = 0x13;
22const TYPE_KILL_SERVER: u8 = 0x14;
23const TYPE_SESSION_CREATED: u8 = 0x20;
24const TYPE_SESSION_INFO: u8 = 0x21;
25const TYPE_OK: u8 = 0x22;
26const TYPE_ERROR: u8 = 0x23;
27
28const HEADER_LEN: usize = 5; // type(1) + length(4)
29const MAX_FRAME_SIZE: usize = 1 << 20; // 1 MB
30
31/// Metadata for one session, returned in SessionInfo.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct SessionEntry {
34    pub id: String,
35    pub name: String,
36    pub pty_path: String,
37    pub shell_pid: u32,
38    pub created_at: u64,
39    pub attached: bool,
40    pub last_heartbeat: u64,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum Frame {
45    Data(Bytes),
46    Resize {
47        cols: u16,
48        rows: u16,
49    },
50    Exit {
51        code: i32,
52    },
53    /// Sent to a client when another client takes over the session.
54    Detached,
55    /// Heartbeat request (client → server).
56    Ping,
57    /// Heartbeat reply (server → client).
58    Pong,
59    /// Environment variables (client → server, sent before first Resize on new session).
60    Env {
61        vars: Vec<(String, String)>,
62    },
63    /// Client signals it can handle agent forwarding (client → server).
64    AgentForward,
65    /// New agent connection on the remote side (server → client).
66    AgentOpen {
67        channel_id: u32,
68    },
69    /// Agent protocol data (bidirectional).
70    AgentData {
71        channel_id: u32,
72        data: Bytes,
73    },
74    /// Close an agent channel (bidirectional).
75    AgentClose {
76        channel_id: u32,
77    },
78    /// Client signals it can handle URL open forwarding (client → server).
79    OpenForward,
80    /// URL to open on the client machine (server → client).
81    OpenUrl {
82        url: String,
83    },
84    // Control requests
85    NewSession {
86        name: String,
87    },
88    Attach {
89        session: String,
90    },
91    ListSessions,
92    KillSession {
93        session: String,
94    },
95    KillServer,
96    // Control responses
97    SessionCreated {
98        id: String,
99    },
100    SessionInfo {
101        sessions: Vec<SessionEntry>,
102    },
103    Ok,
104    Error {
105        message: String,
106    },
107}
108
109impl Frame {
110    /// Extract a Frame from a `framed.next().await` result, converting
111    /// the common None / Some(Err) cases into descriptive errors.
112    pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
113        match result {
114            Some(Ok(frame)) => Ok(frame),
115            Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
116            None => Err(anyhow::anyhow!("daemon closed connection")),
117        }
118    }
119}
120
121pub struct FrameCodec;
122
123fn encode_empty(dst: &mut BytesMut, ty: u8) {
124    dst.put_u8(ty);
125    dst.put_u32(0);
126}
127
128fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
129    dst.put_u8(ty);
130    dst.put_u32(s.len() as u32);
131    dst.extend_from_slice(s.as_bytes());
132}
133
134fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
135    String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
136}
137
138impl Decoder for FrameCodec {
139    type Item = Frame;
140    type Error = io::Error;
141
142    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
143        if src.len() < HEADER_LEN {
144            return Ok(None);
145        }
146
147        let frame_type = src[0];
148        let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
149
150        if payload_len > MAX_FRAME_SIZE {
151            return Err(io::Error::new(
152                io::ErrorKind::InvalidData,
153                format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
154            ));
155        }
156
157        if src.len() < HEADER_LEN + payload_len {
158            src.reserve(HEADER_LEN + payload_len - src.len());
159            return Ok(None);
160        }
161
162        src.advance(HEADER_LEN);
163        let payload = src.split_to(payload_len);
164
165        match frame_type {
166            TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
167            TYPE_RESIZE => {
168                if payload.len() != 4 {
169                    return Err(io::Error::new(
170                        io::ErrorKind::InvalidData,
171                        "resize frame must be 4 bytes",
172                    ));
173                }
174                let cols = u16::from_be_bytes([payload[0], payload[1]]);
175                let rows = u16::from_be_bytes([payload[2], payload[3]]);
176                Ok(Some(Frame::Resize { cols, rows }))
177            }
178            TYPE_EXIT => {
179                if payload.len() != 4 {
180                    return Err(io::Error::new(
181                        io::ErrorKind::InvalidData,
182                        "exit frame must be 4 bytes",
183                    ));
184                }
185                let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
186                Ok(Some(Frame::Exit { code }))
187            }
188            TYPE_DETACHED => Ok(Some(Frame::Detached)),
189            TYPE_PING => Ok(Some(Frame::Ping)),
190            TYPE_PONG => Ok(Some(Frame::Pong)),
191            TYPE_ENV => {
192                let text = String::from_utf8(payload.to_vec())
193                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
194                let vars = if text.is_empty() {
195                    Vec::new()
196                } else {
197                    text.lines()
198                        .filter_map(|line| {
199                            let (k, v) = line.split_once('=')?;
200                            Some((k.to_string(), v.to_string()))
201                        })
202                        .collect()
203                };
204                Ok(Some(Frame::Env { vars }))
205            }
206            TYPE_AGENT_FORWARD => Ok(Some(Frame::AgentForward)),
207            TYPE_AGENT_OPEN => {
208                if payload.len() != 4 {
209                    return Err(io::Error::new(
210                        io::ErrorKind::InvalidData,
211                        "agent open frame must be 4 bytes",
212                    ));
213                }
214                let channel_id =
215                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
216                Ok(Some(Frame::AgentOpen { channel_id }))
217            }
218            TYPE_AGENT_DATA => {
219                if payload.len() < 4 {
220                    return Err(io::Error::new(
221                        io::ErrorKind::InvalidData,
222                        "agent data frame must be at least 4 bytes",
223                    ));
224                }
225                let channel_id =
226                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
227                let data = payload.freeze().slice(4..);
228                Ok(Some(Frame::AgentData { channel_id, data }))
229            }
230            TYPE_AGENT_CLOSE => {
231                if payload.len() != 4 {
232                    return Err(io::Error::new(
233                        io::ErrorKind::InvalidData,
234                        "agent close frame must be 4 bytes",
235                    ));
236                }
237                let channel_id =
238                    u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
239                Ok(Some(Frame::AgentClose { channel_id }))
240            }
241            TYPE_OPEN_FORWARD => Ok(Some(Frame::OpenForward)),
242            TYPE_OPEN_URL => Ok(Some(Frame::OpenUrl { url: decode_string(payload)? })),
243            TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
244            TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
245            TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
246            TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
247            TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
248            TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
249            TYPE_SESSION_INFO => {
250                let text = String::from_utf8(payload.to_vec())
251                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
252                let sessions = if text.is_empty() {
253                    Vec::new()
254                } else {
255                    text.lines()
256                        .filter_map(|line| {
257                            let parts: Vec<&str> = line.split('\t').collect();
258                            if parts.len() == 7 {
259                                Some(SessionEntry {
260                                    id: parts[0].to_string(),
261                                    name: parts[1].to_string(),
262                                    pty_path: parts[2].to_string(),
263                                    shell_pid: parts[3].parse().unwrap_or(0),
264                                    created_at: parts[4].parse().unwrap_or(0),
265                                    attached: parts[5] == "1",
266                                    last_heartbeat: parts[6].parse().unwrap_or(0),
267                                })
268                            } else {
269                                None
270                            }
271                        })
272                        .collect()
273                };
274                Ok(Some(Frame::SessionInfo { sessions }))
275            }
276            TYPE_OK => Ok(Some(Frame::Ok)),
277            TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
278            _ => Err(io::Error::new(
279                io::ErrorKind::InvalidData,
280                format!("unknown frame type: 0x{frame_type:02x}"),
281            )),
282        }
283    }
284}
285
286impl Encoder<Frame> for FrameCodec {
287    type Error = io::Error;
288
289    fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
290        match frame {
291            Frame::Data(data) => {
292                dst.put_u8(TYPE_DATA);
293                dst.put_u32(data.len() as u32);
294                dst.extend_from_slice(&data);
295            }
296            Frame::Resize { cols, rows } => {
297                dst.put_u8(TYPE_RESIZE);
298                dst.put_u32(4);
299                dst.put_u16(cols);
300                dst.put_u16(rows);
301            }
302            Frame::Exit { code } => {
303                dst.put_u8(TYPE_EXIT);
304                dst.put_u32(4);
305                dst.put_i32(code);
306            }
307            Frame::Detached => encode_empty(dst, TYPE_DETACHED),
308            Frame::Ping => encode_empty(dst, TYPE_PING),
309            Frame::Pong => encode_empty(dst, TYPE_PONG),
310            Frame::Env { vars } => {
311                let text: String =
312                    vars.iter().map(|(k, v)| format!("{k}={v}")).collect::<Vec<_>>().join("\n");
313                dst.put_u8(TYPE_ENV);
314                dst.put_u32(text.len() as u32);
315                dst.extend_from_slice(text.as_bytes());
316            }
317            Frame::AgentForward => encode_empty(dst, TYPE_AGENT_FORWARD),
318            Frame::AgentOpen { channel_id } => {
319                dst.put_u8(TYPE_AGENT_OPEN);
320                dst.put_u32(4);
321                dst.put_u32(channel_id);
322            }
323            Frame::AgentData { channel_id, data } => {
324                dst.put_u8(TYPE_AGENT_DATA);
325                dst.put_u32(4 + data.len() as u32);
326                dst.put_u32(channel_id);
327                dst.extend_from_slice(&data);
328            }
329            Frame::AgentClose { channel_id } => {
330                dst.put_u8(TYPE_AGENT_CLOSE);
331                dst.put_u32(4);
332                dst.put_u32(channel_id);
333            }
334            Frame::OpenForward => encode_empty(dst, TYPE_OPEN_FORWARD),
335            Frame::OpenUrl { url } => encode_str(dst, TYPE_OPEN_URL, &url),
336            Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
337            Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
338            Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
339            Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
340            Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
341            Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
342            Frame::SessionInfo { sessions } => {
343                let text: String = sessions
344                    .iter()
345                    .map(|e| {
346                        format!(
347                            "{}\t{}\t{}\t{}\t{}\t{}\t{}",
348                            e.id,
349                            e.name,
350                            e.pty_path,
351                            e.shell_pid,
352                            e.created_at,
353                            if e.attached { "1" } else { "0" },
354                            e.last_heartbeat
355                        )
356                    })
357                    .collect::<Vec<_>>()
358                    .join("\n");
359                dst.put_u8(TYPE_SESSION_INFO);
360                dst.put_u32(text.len() as u32);
361                dst.extend_from_slice(text.as_bytes());
362            }
363            Frame::Ok => encode_empty(dst, TYPE_OK),
364            Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
365        }
366        Ok(())
367    }
368}