1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_AGENT_FORWARD: u8 = 0x08;
13const TYPE_AGENT_OPEN: u8 = 0x09;
14const TYPE_AGENT_DATA: u8 = 0x0A;
15const TYPE_AGENT_CLOSE: u8 = 0x0B;
16const TYPE_OPEN_FORWARD: u8 = 0x0C;
17const TYPE_OPEN_URL: u8 = 0x0D;
18const TYPE_NEW_SESSION: u8 = 0x10;
19const TYPE_ATTACH: u8 = 0x11;
20const TYPE_LIST_SESSIONS: u8 = 0x12;
21const TYPE_KILL_SESSION: u8 = 0x13;
22const TYPE_KILL_SERVER: u8 = 0x14;
23const TYPE_TAIL: u8 = 0x15;
24const TYPE_HELLO: u8 = 0x16;
25const TYPE_SESSION_CREATED: u8 = 0x20;
26const TYPE_SESSION_INFO: u8 = 0x21;
27const TYPE_OK: u8 = 0x22;
28const TYPE_ERROR: u8 = 0x23;
29const TYPE_HELLO_ACK: u8 = 0x24;
30
31const HEADER_LEN: usize = 5; const MAX_FRAME_SIZE: usize = 1 << 20; pub const PROTOCOL_VERSION: u16 = 1;
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct SessionEntry {
40 pub id: String,
41 pub name: String,
42 pub pty_path: String,
43 pub shell_pid: u32,
44 pub created_at: u64,
45 pub attached: bool,
46 pub last_heartbeat: u64,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum Frame {
51 Data(Bytes),
52 Resize {
53 cols: u16,
54 rows: u16,
55 },
56 Exit {
57 code: i32,
58 },
59 Detached,
61 Ping,
63 Pong,
65 Env {
67 vars: Vec<(String, String)>,
68 },
69 AgentForward,
71 AgentOpen {
73 channel_id: u32,
74 },
75 AgentData {
77 channel_id: u32,
78 data: Bytes,
79 },
80 AgentClose {
82 channel_id: u32,
83 },
84 OpenForward,
86 OpenUrl {
88 url: String,
89 },
90 Hello {
92 version: u16,
93 },
94 HelloAck {
96 version: u16,
97 },
98 NewSession {
100 name: String,
101 },
102 Attach {
103 session: String,
104 },
105 Tail {
107 session: String,
108 },
109 ListSessions,
110 KillSession {
111 session: String,
112 },
113 KillServer,
114 SessionCreated {
116 id: String,
117 },
118 SessionInfo {
119 sessions: Vec<SessionEntry>,
120 },
121 Ok,
122 Error {
123 message: String,
124 },
125}
126
127impl Frame {
128 pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
131 match result {
132 Some(Ok(frame)) => Ok(frame),
133 Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
134 None => Err(anyhow::anyhow!("daemon closed connection")),
135 }
136 }
137}
138
139pub struct FrameCodec;
140
141fn encode_empty(dst: &mut BytesMut, ty: u8) {
142 dst.put_u8(ty);
143 dst.put_u32(0);
144}
145
146fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
147 dst.put_u8(ty);
148 dst.put_u32(s.len() as u32);
149 dst.extend_from_slice(s.as_bytes());
150}
151
152fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
153 String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
154}
155
156impl Decoder for FrameCodec {
157 type Item = Frame;
158 type Error = io::Error;
159
160 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
161 if src.len() < HEADER_LEN {
162 return Ok(None);
163 }
164
165 let frame_type = src[0];
166 let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
167
168 if payload_len > MAX_FRAME_SIZE {
169 return Err(io::Error::new(
170 io::ErrorKind::InvalidData,
171 format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
172 ));
173 }
174
175 if src.len() < HEADER_LEN + payload_len {
176 src.reserve(HEADER_LEN + payload_len - src.len());
177 return Ok(None);
178 }
179
180 src.advance(HEADER_LEN);
181 let payload = src.split_to(payload_len);
182
183 match frame_type {
184 TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
185 TYPE_RESIZE => {
186 if payload.len() != 4 {
187 return Err(io::Error::new(
188 io::ErrorKind::InvalidData,
189 "resize frame must be 4 bytes",
190 ));
191 }
192 let cols = u16::from_be_bytes([payload[0], payload[1]]);
193 let rows = u16::from_be_bytes([payload[2], payload[3]]);
194 Ok(Some(Frame::Resize { cols, rows }))
195 }
196 TYPE_EXIT => {
197 if payload.len() != 4 {
198 return Err(io::Error::new(
199 io::ErrorKind::InvalidData,
200 "exit frame must be 4 bytes",
201 ));
202 }
203 let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
204 Ok(Some(Frame::Exit { code }))
205 }
206 TYPE_DETACHED => Ok(Some(Frame::Detached)),
207 TYPE_PING => Ok(Some(Frame::Ping)),
208 TYPE_PONG => Ok(Some(Frame::Pong)),
209 TYPE_ENV => {
210 let text = String::from_utf8(payload.to_vec())
211 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
212 let vars = if text.is_empty() {
213 Vec::new()
214 } else {
215 text.lines()
216 .filter_map(|line| {
217 let (k, v) = line.split_once('=')?;
218 Some((k.to_string(), v.to_string()))
219 })
220 .collect()
221 };
222 Ok(Some(Frame::Env { vars }))
223 }
224 TYPE_AGENT_FORWARD => Ok(Some(Frame::AgentForward)),
225 TYPE_AGENT_OPEN => {
226 if payload.len() != 4 {
227 return Err(io::Error::new(
228 io::ErrorKind::InvalidData,
229 "agent open frame must be 4 bytes",
230 ));
231 }
232 let channel_id =
233 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
234 Ok(Some(Frame::AgentOpen { channel_id }))
235 }
236 TYPE_AGENT_DATA => {
237 if payload.len() < 4 {
238 return Err(io::Error::new(
239 io::ErrorKind::InvalidData,
240 "agent data frame must be at least 4 bytes",
241 ));
242 }
243 let channel_id =
244 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
245 let data = payload.freeze().slice(4..);
246 Ok(Some(Frame::AgentData { channel_id, data }))
247 }
248 TYPE_AGENT_CLOSE => {
249 if payload.len() != 4 {
250 return Err(io::Error::new(
251 io::ErrorKind::InvalidData,
252 "agent close frame must be 4 bytes",
253 ));
254 }
255 let channel_id =
256 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
257 Ok(Some(Frame::AgentClose { channel_id }))
258 }
259 TYPE_OPEN_FORWARD => Ok(Some(Frame::OpenForward)),
260 TYPE_OPEN_URL => Ok(Some(Frame::OpenUrl { url: decode_string(payload)? })),
261 TYPE_HELLO => {
262 if payload.len() != 2 {
263 return Err(io::Error::new(
264 io::ErrorKind::InvalidData,
265 "hello frame must be 2 bytes",
266 ));
267 }
268 let version = u16::from_be_bytes([payload[0], payload[1]]);
269 Ok(Some(Frame::Hello { version }))
270 }
271 TYPE_HELLO_ACK => {
272 if payload.len() != 2 {
273 return Err(io::Error::new(
274 io::ErrorKind::InvalidData,
275 "hello ack frame must be 2 bytes",
276 ));
277 }
278 let version = u16::from_be_bytes([payload[0], payload[1]]);
279 Ok(Some(Frame::HelloAck { version }))
280 }
281 TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
282 TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
283 TYPE_TAIL => Ok(Some(Frame::Tail { session: decode_string(payload)? })),
284 TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
285 TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
286 TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
287 TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
288 TYPE_SESSION_INFO => {
289 let text = String::from_utf8(payload.to_vec())
290 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
291 let sessions = if text.is_empty() {
292 Vec::new()
293 } else {
294 text.lines()
295 .filter_map(|line| {
296 let parts: Vec<&str> = line.split('\t').collect();
297 if parts.len() == 7 {
298 Some(SessionEntry {
299 id: parts[0].to_string(),
300 name: parts[1].to_string(),
301 pty_path: parts[2].to_string(),
302 shell_pid: parts[3].parse().unwrap_or(0),
303 created_at: parts[4].parse().unwrap_or(0),
304 attached: parts[5] == "1",
305 last_heartbeat: parts[6].parse().unwrap_or(0),
306 })
307 } else {
308 None
309 }
310 })
311 .collect()
312 };
313 Ok(Some(Frame::SessionInfo { sessions }))
314 }
315 TYPE_OK => Ok(Some(Frame::Ok)),
316 TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
317 _ => Err(io::Error::new(
318 io::ErrorKind::InvalidData,
319 format!("unknown frame type: 0x{frame_type:02x}"),
320 )),
321 }
322 }
323}
324
325impl Encoder<Frame> for FrameCodec {
326 type Error = io::Error;
327
328 fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
329 match frame {
330 Frame::Data(data) => {
331 dst.put_u8(TYPE_DATA);
332 dst.put_u32(data.len() as u32);
333 dst.extend_from_slice(&data);
334 }
335 Frame::Resize { cols, rows } => {
336 dst.put_u8(TYPE_RESIZE);
337 dst.put_u32(4);
338 dst.put_u16(cols);
339 dst.put_u16(rows);
340 }
341 Frame::Exit { code } => {
342 dst.put_u8(TYPE_EXIT);
343 dst.put_u32(4);
344 dst.put_i32(code);
345 }
346 Frame::Detached => encode_empty(dst, TYPE_DETACHED),
347 Frame::Ping => encode_empty(dst, TYPE_PING),
348 Frame::Pong => encode_empty(dst, TYPE_PONG),
349 Frame::Env { vars } => {
350 let text: String = vars
353 .iter()
354 .map(|(k, v)| {
355 let k = k.replace('\n', "");
356 let v = v.replace('\n', "");
357 format!("{k}={v}")
358 })
359 .collect::<Vec<_>>()
360 .join("\n");
361 dst.put_u8(TYPE_ENV);
362 dst.put_u32(text.len() as u32);
363 dst.extend_from_slice(text.as_bytes());
364 }
365 Frame::AgentForward => encode_empty(dst, TYPE_AGENT_FORWARD),
366 Frame::AgentOpen { channel_id } => {
367 dst.put_u8(TYPE_AGENT_OPEN);
368 dst.put_u32(4);
369 dst.put_u32(channel_id);
370 }
371 Frame::AgentData { channel_id, data } => {
372 dst.put_u8(TYPE_AGENT_DATA);
373 dst.put_u32(4 + data.len() as u32);
374 dst.put_u32(channel_id);
375 dst.extend_from_slice(&data);
376 }
377 Frame::AgentClose { channel_id } => {
378 dst.put_u8(TYPE_AGENT_CLOSE);
379 dst.put_u32(4);
380 dst.put_u32(channel_id);
381 }
382 Frame::OpenForward => encode_empty(dst, TYPE_OPEN_FORWARD),
383 Frame::OpenUrl { url } => encode_str(dst, TYPE_OPEN_URL, &url),
384 Frame::Hello { version } => {
385 dst.put_u8(TYPE_HELLO);
386 dst.put_u32(2);
387 dst.put_u16(version);
388 }
389 Frame::HelloAck { version } => {
390 dst.put_u8(TYPE_HELLO_ACK);
391 dst.put_u32(2);
392 dst.put_u16(version);
393 }
394 Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
395 Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
396 Frame::Tail { session } => encode_str(dst, TYPE_TAIL, &session),
397 Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
398 Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
399 Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
400 Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
401 Frame::SessionInfo { sessions } => {
402 let text: String = sessions
403 .iter()
404 .map(|e| {
405 let safe_pty = e.pty_path.replace(['\t', '\n'], " ");
406 format!(
407 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
408 e.id,
409 e.name,
410 safe_pty,
411 e.shell_pid,
412 e.created_at,
413 if e.attached { "1" } else { "0" },
414 e.last_heartbeat
415 )
416 })
417 .collect::<Vec<_>>()
418 .join("\n");
419 dst.put_u8(TYPE_SESSION_INFO);
420 dst.put_u32(text.len() as u32);
421 dst.extend_from_slice(text.as_bytes());
422 }
423 Frame::Ok => encode_empty(dst, TYPE_OK),
424 Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
425 }
426 Ok(())
427 }
428}