1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use std::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5const TYPE_DATA: u8 = 0x01;
6const TYPE_RESIZE: u8 = 0x02;
7const TYPE_EXIT: u8 = 0x03;
8const TYPE_DETACHED: u8 = 0x04;
9const TYPE_PING: u8 = 0x05;
10const TYPE_PONG: u8 = 0x06;
11const TYPE_ENV: u8 = 0x07;
12const TYPE_AGENT_FORWARD: u8 = 0x08;
13const TYPE_AGENT_OPEN: u8 = 0x09;
14const TYPE_AGENT_DATA: u8 = 0x0A;
15const TYPE_AGENT_CLOSE: u8 = 0x0B;
16const TYPE_OPEN_FORWARD: u8 = 0x0C;
17const TYPE_OPEN_URL: u8 = 0x0D;
18const TYPE_NEW_SESSION: u8 = 0x10;
19const TYPE_ATTACH: u8 = 0x11;
20const TYPE_LIST_SESSIONS: u8 = 0x12;
21const TYPE_KILL_SESSION: u8 = 0x13;
22const TYPE_KILL_SERVER: u8 = 0x14;
23const TYPE_TAIL: u8 = 0x15;
24const TYPE_SESSION_CREATED: u8 = 0x20;
25const TYPE_SESSION_INFO: u8 = 0x21;
26const TYPE_OK: u8 = 0x22;
27const TYPE_ERROR: u8 = 0x23;
28
29const HEADER_LEN: usize = 5; const MAX_FRAME_SIZE: usize = 1 << 20; #[derive(Debug, Clone, PartialEq, Eq)]
34pub struct SessionEntry {
35 pub id: String,
36 pub name: String,
37 pub pty_path: String,
38 pub shell_pid: u32,
39 pub created_at: u64,
40 pub attached: bool,
41 pub last_heartbeat: u64,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum Frame {
46 Data(Bytes),
47 Resize {
48 cols: u16,
49 rows: u16,
50 },
51 Exit {
52 code: i32,
53 },
54 Detached,
56 Ping,
58 Pong,
60 Env {
62 vars: Vec<(String, String)>,
63 },
64 AgentForward,
66 AgentOpen {
68 channel_id: u32,
69 },
70 AgentData {
72 channel_id: u32,
73 data: Bytes,
74 },
75 AgentClose {
77 channel_id: u32,
78 },
79 OpenForward,
81 OpenUrl {
83 url: String,
84 },
85 NewSession {
87 name: String,
88 },
89 Attach {
90 session: String,
91 },
92 Tail {
94 session: String,
95 },
96 ListSessions,
97 KillSession {
98 session: String,
99 },
100 KillServer,
101 SessionCreated {
103 id: String,
104 },
105 SessionInfo {
106 sessions: Vec<SessionEntry>,
107 },
108 Ok,
109 Error {
110 message: String,
111 },
112}
113
114impl Frame {
115 pub fn expect_from(result: Option<Result<Frame, io::Error>>) -> anyhow::Result<Frame> {
118 match result {
119 Some(Ok(frame)) => Ok(frame),
120 Some(Err(e)) => Err(anyhow::anyhow!("daemon protocol error: {e}")),
121 None => Err(anyhow::anyhow!("daemon closed connection")),
122 }
123 }
124}
125
126pub struct FrameCodec;
127
128fn encode_empty(dst: &mut BytesMut, ty: u8) {
129 dst.put_u8(ty);
130 dst.put_u32(0);
131}
132
133fn encode_str(dst: &mut BytesMut, ty: u8, s: &str) {
134 dst.put_u8(ty);
135 dst.put_u32(s.len() as u32);
136 dst.extend_from_slice(s.as_bytes());
137}
138
139fn decode_string(payload: BytesMut) -> Result<String, io::Error> {
140 String::from_utf8(payload.to_vec()).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
141}
142
143impl Decoder for FrameCodec {
144 type Item = Frame;
145 type Error = io::Error;
146
147 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, io::Error> {
148 if src.len() < HEADER_LEN {
149 return Ok(None);
150 }
151
152 let frame_type = src[0];
153 let payload_len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
154
155 if payload_len > MAX_FRAME_SIZE {
156 return Err(io::Error::new(
157 io::ErrorKind::InvalidData,
158 format!("frame payload too large: {payload_len} bytes (max {MAX_FRAME_SIZE})"),
159 ));
160 }
161
162 if src.len() < HEADER_LEN + payload_len {
163 src.reserve(HEADER_LEN + payload_len - src.len());
164 return Ok(None);
165 }
166
167 src.advance(HEADER_LEN);
168 let payload = src.split_to(payload_len);
169
170 match frame_type {
171 TYPE_DATA => Ok(Some(Frame::Data(payload.freeze()))),
172 TYPE_RESIZE => {
173 if payload.len() != 4 {
174 return Err(io::Error::new(
175 io::ErrorKind::InvalidData,
176 "resize frame must be 4 bytes",
177 ));
178 }
179 let cols = u16::from_be_bytes([payload[0], payload[1]]);
180 let rows = u16::from_be_bytes([payload[2], payload[3]]);
181 Ok(Some(Frame::Resize { cols, rows }))
182 }
183 TYPE_EXIT => {
184 if payload.len() != 4 {
185 return Err(io::Error::new(
186 io::ErrorKind::InvalidData,
187 "exit frame must be 4 bytes",
188 ));
189 }
190 let code = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
191 Ok(Some(Frame::Exit { code }))
192 }
193 TYPE_DETACHED => Ok(Some(Frame::Detached)),
194 TYPE_PING => Ok(Some(Frame::Ping)),
195 TYPE_PONG => Ok(Some(Frame::Pong)),
196 TYPE_ENV => {
197 let text = String::from_utf8(payload.to_vec())
198 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
199 let vars = if text.is_empty() {
200 Vec::new()
201 } else {
202 text.lines()
203 .filter_map(|line| {
204 let (k, v) = line.split_once('=')?;
205 Some((k.to_string(), v.to_string()))
206 })
207 .collect()
208 };
209 Ok(Some(Frame::Env { vars }))
210 }
211 TYPE_AGENT_FORWARD => Ok(Some(Frame::AgentForward)),
212 TYPE_AGENT_OPEN => {
213 if payload.len() != 4 {
214 return Err(io::Error::new(
215 io::ErrorKind::InvalidData,
216 "agent open frame must be 4 bytes",
217 ));
218 }
219 let channel_id =
220 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
221 Ok(Some(Frame::AgentOpen { channel_id }))
222 }
223 TYPE_AGENT_DATA => {
224 if payload.len() < 4 {
225 return Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 "agent data frame must be at least 4 bytes",
228 ));
229 }
230 let channel_id =
231 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
232 let data = payload.freeze().slice(4..);
233 Ok(Some(Frame::AgentData { channel_id, data }))
234 }
235 TYPE_AGENT_CLOSE => {
236 if payload.len() != 4 {
237 return Err(io::Error::new(
238 io::ErrorKind::InvalidData,
239 "agent close frame must be 4 bytes",
240 ));
241 }
242 let channel_id =
243 u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
244 Ok(Some(Frame::AgentClose { channel_id }))
245 }
246 TYPE_OPEN_FORWARD => Ok(Some(Frame::OpenForward)),
247 TYPE_OPEN_URL => Ok(Some(Frame::OpenUrl { url: decode_string(payload)? })),
248 TYPE_NEW_SESSION => Ok(Some(Frame::NewSession { name: decode_string(payload)? })),
249 TYPE_ATTACH => Ok(Some(Frame::Attach { session: decode_string(payload)? })),
250 TYPE_TAIL => Ok(Some(Frame::Tail { session: decode_string(payload)? })),
251 TYPE_LIST_SESSIONS => Ok(Some(Frame::ListSessions)),
252 TYPE_KILL_SESSION => Ok(Some(Frame::KillSession { session: decode_string(payload)? })),
253 TYPE_KILL_SERVER => Ok(Some(Frame::KillServer)),
254 TYPE_SESSION_CREATED => Ok(Some(Frame::SessionCreated { id: decode_string(payload)? })),
255 TYPE_SESSION_INFO => {
256 let text = String::from_utf8(payload.to_vec())
257 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
258 let sessions = if text.is_empty() {
259 Vec::new()
260 } else {
261 text.lines()
262 .filter_map(|line| {
263 let parts: Vec<&str> = line.split('\t').collect();
264 if parts.len() == 7 {
265 Some(SessionEntry {
266 id: parts[0].to_string(),
267 name: parts[1].to_string(),
268 pty_path: parts[2].to_string(),
269 shell_pid: parts[3].parse().unwrap_or(0),
270 created_at: parts[4].parse().unwrap_or(0),
271 attached: parts[5] == "1",
272 last_heartbeat: parts[6].parse().unwrap_or(0),
273 })
274 } else {
275 None
276 }
277 })
278 .collect()
279 };
280 Ok(Some(Frame::SessionInfo { sessions }))
281 }
282 TYPE_OK => Ok(Some(Frame::Ok)),
283 TYPE_ERROR => Ok(Some(Frame::Error { message: decode_string(payload)? })),
284 _ => Err(io::Error::new(
285 io::ErrorKind::InvalidData,
286 format!("unknown frame type: 0x{frame_type:02x}"),
287 )),
288 }
289 }
290}
291
292impl Encoder<Frame> for FrameCodec {
293 type Error = io::Error;
294
295 fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), io::Error> {
296 match frame {
297 Frame::Data(data) => {
298 dst.put_u8(TYPE_DATA);
299 dst.put_u32(data.len() as u32);
300 dst.extend_from_slice(&data);
301 }
302 Frame::Resize { cols, rows } => {
303 dst.put_u8(TYPE_RESIZE);
304 dst.put_u32(4);
305 dst.put_u16(cols);
306 dst.put_u16(rows);
307 }
308 Frame::Exit { code } => {
309 dst.put_u8(TYPE_EXIT);
310 dst.put_u32(4);
311 dst.put_i32(code);
312 }
313 Frame::Detached => encode_empty(dst, TYPE_DETACHED),
314 Frame::Ping => encode_empty(dst, TYPE_PING),
315 Frame::Pong => encode_empty(dst, TYPE_PONG),
316 Frame::Env { vars } => {
317 let text: String = vars
320 .iter()
321 .map(|(k, v)| {
322 let k = k.replace('\n', "");
323 let v = v.replace('\n', "");
324 format!("{k}={v}")
325 })
326 .collect::<Vec<_>>()
327 .join("\n");
328 dst.put_u8(TYPE_ENV);
329 dst.put_u32(text.len() as u32);
330 dst.extend_from_slice(text.as_bytes());
331 }
332 Frame::AgentForward => encode_empty(dst, TYPE_AGENT_FORWARD),
333 Frame::AgentOpen { channel_id } => {
334 dst.put_u8(TYPE_AGENT_OPEN);
335 dst.put_u32(4);
336 dst.put_u32(channel_id);
337 }
338 Frame::AgentData { channel_id, data } => {
339 dst.put_u8(TYPE_AGENT_DATA);
340 dst.put_u32(4 + data.len() as u32);
341 dst.put_u32(channel_id);
342 dst.extend_from_slice(&data);
343 }
344 Frame::AgentClose { channel_id } => {
345 dst.put_u8(TYPE_AGENT_CLOSE);
346 dst.put_u32(4);
347 dst.put_u32(channel_id);
348 }
349 Frame::OpenForward => encode_empty(dst, TYPE_OPEN_FORWARD),
350 Frame::OpenUrl { url } => encode_str(dst, TYPE_OPEN_URL, &url),
351 Frame::NewSession { name } => encode_str(dst, TYPE_NEW_SESSION, &name),
352 Frame::Attach { session } => encode_str(dst, TYPE_ATTACH, &session),
353 Frame::Tail { session } => encode_str(dst, TYPE_TAIL, &session),
354 Frame::ListSessions => encode_empty(dst, TYPE_LIST_SESSIONS),
355 Frame::KillSession { session } => encode_str(dst, TYPE_KILL_SESSION, &session),
356 Frame::KillServer => encode_empty(dst, TYPE_KILL_SERVER),
357 Frame::SessionCreated { id } => encode_str(dst, TYPE_SESSION_CREATED, &id),
358 Frame::SessionInfo { sessions } => {
359 let text: String = sessions
360 .iter()
361 .map(|e| {
362 format!(
363 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
364 e.id,
365 e.name,
366 e.pty_path,
367 e.shell_pid,
368 e.created_at,
369 if e.attached { "1" } else { "0" },
370 e.last_heartbeat
371 )
372 })
373 .collect::<Vec<_>>()
374 .join("\n");
375 dst.put_u8(TYPE_SESSION_INFO);
376 dst.put_u32(text.len() as u32);
377 dst.extend_from_slice(text.as_bytes());
378 }
379 Frame::Ok => encode_empty(dst, TYPE_OK),
380 Frame::Error { message } => encode_str(dst, TYPE_ERROR, &message),
381 }
382 Ok(())
383 }
384}