1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 4;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
56
57 pub t: MessageType,
59
60 #[serde(skip)]
65 pub id: u32,
66
67 #[serde(skip)]
71 pub flags: u8,
72
73 #[serde(with = "serde_bytes")]
75 pub p: Vec<u8>,
76}
77
78#[derive(
86 Debug,
87 Clone,
88 Copy,
89 PartialEq,
90 Eq,
91 Hash,
92 strum::IntoStaticStr,
93 strum::EnumString,
94 strum::EnumIter,
95)]
96pub enum MessageType {
97 #[strum(serialize = "core.ready")]
99 Ready,
100
101 #[strum(serialize = "core.init.resolved")]
103 InitResolved,
104
105 #[strum(serialize = "core.init.ack")]
107 InitAck,
108
109 #[strum(serialize = "core.shutdown")]
111 Shutdown,
112
113 #[strum(serialize = "core.relay.client.disconnected")]
115 RelayClientDisconnected,
116
117 #[strum(serialize = "core.clock.sync")]
119 ClockSync,
120
121 #[strum(serialize = "core.exec.request")]
123 ExecRequest,
124
125 #[strum(serialize = "core.exec.started")]
127 ExecStarted,
128
129 #[strum(serialize = "core.exec.stdin")]
131 ExecStdin,
132
133 #[strum(serialize = "core.exec.stdin.error")]
138 ExecStdinError,
139
140 #[strum(serialize = "core.exec.stdout")]
142 ExecStdout,
143
144 #[strum(serialize = "core.exec.stderr")]
146 ExecStderr,
147
148 #[strum(serialize = "core.exec.exited")]
150 ExecExited,
151
152 #[strum(serialize = "core.exec.failed")]
156 ExecFailed,
157
158 #[strum(serialize = "core.exec.resize")]
160 ExecResize,
161
162 #[strum(serialize = "core.exec.signal")]
164 ExecSignal,
165
166 #[strum(serialize = "core.fs.request")]
168 FsRequest,
169
170 #[strum(serialize = "core.fs.response")]
172 FsResponse,
173
174 #[strum(serialize = "core.fs.data")]
176 FsData,
177
178 #[strum(serialize = "core.tcp.connect")]
180 TcpConnect,
181
182 #[strum(serialize = "core.tcp.connected")]
184 TcpConnected,
185
186 #[strum(serialize = "core.tcp.data")]
188 TcpData,
189
190 #[strum(serialize = "core.tcp.eof")]
192 TcpEof,
193
194 #[strum(serialize = "core.tcp.close")]
196 TcpClose,
197
198 #[strum(serialize = "core.tcp.closed")]
200 TcpClosed,
201
202 #[strum(serialize = "core.tcp.failed")]
204 TcpFailed,
205}
206
207impl Message {
212 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
214 let flags = t.flags();
215 Self {
216 v: PROTOCOL_VERSION,
217 t,
218 id,
219 flags,
220 p,
221 }
222 }
223
224 pub fn with_payload<T: Serialize>(
226 t: MessageType,
227 id: u32,
228 payload: &T,
229 ) -> ProtocolResult<Self> {
230 let mut p = Vec::new();
231 ciborium::into_writer(payload, &mut p)?;
232 let flags = t.flags();
233 Ok(Self {
234 v: PROTOCOL_VERSION,
235 t,
236 id,
237 flags,
238 p,
239 })
240 }
241
242 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
244 Ok(ciborium::from_reader(&self.p[..])?)
245 }
246}
247
248impl MessageType {
249 pub fn flags(&self) -> u8 {
251 match self {
252 Self::ExecExited
253 | Self::ExecFailed
254 | Self::FsResponse
255 | Self::TcpClosed
256 | Self::TcpFailed => FLAG_TERMINAL,
257 Self::ExecRequest | Self::FsRequest | Self::TcpConnect => FLAG_SESSION_START,
258 Self::Shutdown => FLAG_SHUTDOWN,
259 _ => 0,
260 }
261 }
262
263 pub fn min_protocol_version(&self) -> u8 {
281 match self {
282 Self::Ready
283 | Self::InitResolved
284 | Self::InitAck
285 | Self::Shutdown
286 | Self::RelayClientDisconnected
287 | Self::ClockSync
288 | Self::ExecRequest
289 | Self::ExecStarted
290 | Self::ExecStdin
291 | Self::ExecStdinError
292 | Self::ExecStdout
293 | Self::ExecStderr
294 | Self::ExecExited
295 | Self::ExecFailed
296 | Self::ExecResize
297 | Self::ExecSignal => 1,
298 Self::FsRequest | Self::FsResponse | Self::FsData => 2,
299 Self::TcpConnect
300 | Self::TcpConnected
301 | Self::TcpData
302 | Self::TcpEof
303 | Self::TcpClose
304 | Self::TcpClosed
305 | Self::TcpFailed => 4,
306 }
307 }
308
309 pub fn is_available_at(&self, peer_generation: u8) -> bool {
318 self.min_protocol_version() <= peer_generation
319 }
320
321 pub fn as_str(&self) -> &'static str {
326 (*self).into()
327 }
328
329 pub fn from_wire_str(s: &str) -> Option<Self> {
332 s.parse().ok()
333 }
334}
335
336impl Serialize for MessageType {
341 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
342 where
343 S: serde::Serializer,
344 {
345 serializer.serialize_str(self.as_str())
346 }
347}
348
349impl<'de> Deserialize<'de> for MessageType {
350 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
351 where
352 D: serde::Deserializer<'de>,
353 {
354 let s = String::deserialize(deserializer)?;
355 Self::from_wire_str(&s)
356 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
357 }
358}
359
360#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_message_type_roundtrip() {
370 let types = [
371 (MessageType::Ready, "core.ready"),
372 (MessageType::InitResolved, "core.init.resolved"),
373 (MessageType::InitAck, "core.init.ack"),
374 (MessageType::Shutdown, "core.shutdown"),
375 (
376 MessageType::RelayClientDisconnected,
377 "core.relay.client.disconnected",
378 ),
379 (MessageType::ClockSync, "core.clock.sync"),
380 (MessageType::ExecRequest, "core.exec.request"),
381 (MessageType::ExecStarted, "core.exec.started"),
382 (MessageType::ExecStdin, "core.exec.stdin"),
383 (MessageType::ExecStdinError, "core.exec.stdin.error"),
384 (MessageType::ExecStdout, "core.exec.stdout"),
385 (MessageType::ExecStderr, "core.exec.stderr"),
386 (MessageType::ExecExited, "core.exec.exited"),
387 (MessageType::ExecFailed, "core.exec.failed"),
388 (MessageType::ExecResize, "core.exec.resize"),
389 (MessageType::ExecSignal, "core.exec.signal"),
390 (MessageType::FsRequest, "core.fs.request"),
391 (MessageType::FsResponse, "core.fs.response"),
392 (MessageType::FsData, "core.fs.data"),
393 (MessageType::TcpConnect, "core.tcp.connect"),
394 (MessageType::TcpConnected, "core.tcp.connected"),
395 (MessageType::TcpData, "core.tcp.data"),
396 (MessageType::TcpEof, "core.tcp.eof"),
397 (MessageType::TcpClose, "core.tcp.close"),
398 (MessageType::TcpClosed, "core.tcp.closed"),
399 (MessageType::TcpFailed, "core.tcp.failed"),
400 ];
401
402 for (mt, expected_str) in &types {
403 assert_eq!(mt.as_str(), *expected_str);
404 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
405 }
406 }
407
408 #[test]
409 fn test_message_type_serde_roundtrip() {
410 let types = [
411 MessageType::Ready,
412 MessageType::InitResolved,
413 MessageType::InitAck,
414 MessageType::Shutdown,
415 MessageType::RelayClientDisconnected,
416 MessageType::ClockSync,
417 MessageType::ExecRequest,
418 MessageType::ExecStarted,
419 MessageType::ExecStdin,
420 MessageType::ExecStdinError,
421 MessageType::ExecStdout,
422 MessageType::ExecStderr,
423 MessageType::ExecExited,
424 MessageType::ExecFailed,
425 MessageType::ExecResize,
426 MessageType::ExecSignal,
427 MessageType::FsRequest,
428 MessageType::FsResponse,
429 MessageType::FsData,
430 MessageType::TcpConnect,
431 MessageType::TcpConnected,
432 MessageType::TcpData,
433 MessageType::TcpEof,
434 MessageType::TcpClose,
435 MessageType::TcpClosed,
436 MessageType::TcpFailed,
437 ];
438
439 for mt in &types {
440 let mut buf = Vec::new();
441 ciborium::into_writer(mt, &mut buf).unwrap();
442 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
443 assert_eq!(&decoded, mt);
444 }
445 }
446
447 #[test]
448 fn test_unknown_message_type() {
449 assert!(MessageType::from_wire_str("core.unknown").is_none());
450 }
451
452 #[test]
453 fn test_message_with_payload_roundtrip() {
454 use crate::exec::ExecExited;
455
456 let msg =
457 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
458
459 assert_eq!(msg.t, MessageType::ExecExited);
460 assert_eq!(msg.id, 7);
461 assert_eq!(msg.flags, FLAG_TERMINAL);
462
463 let payload: ExecExited = msg.payload().unwrap();
464 assert_eq!(payload.code, 42);
465 }
466
467 #[test]
468 fn test_message_type_flags() {
469 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
470 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
471 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
472 assert_eq!(MessageType::TcpClosed.flags(), FLAG_TERMINAL);
473 assert_eq!(MessageType::TcpFailed.flags(), FLAG_TERMINAL);
474 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
475 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
476 assert_eq!(MessageType::TcpConnect.flags(), FLAG_SESSION_START);
477 assert_eq!(MessageType::Ready.flags(), 0);
478 assert_eq!(MessageType::InitResolved.flags(), 0);
479 assert_eq!(MessageType::InitAck.flags(), 0);
480 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
481 assert_eq!(MessageType::ClockSync.flags(), 0);
482 assert_eq!(MessageType::ExecStarted.flags(), 0);
483 assert_eq!(MessageType::ExecStdin.flags(), 0);
484 assert_eq!(MessageType::ExecStdout.flags(), 0);
485 assert_eq!(MessageType::ExecStderr.flags(), 0);
486 assert_eq!(MessageType::ExecResize.flags(), 0);
487 assert_eq!(MessageType::ExecSignal.flags(), 0);
488 assert_eq!(MessageType::FsData.flags(), 0);
489 assert_eq!(MessageType::TcpConnected.flags(), 0);
490 assert_eq!(MessageType::TcpData.flags(), 0);
491 assert_eq!(MessageType::TcpEof.flags(), 0);
492 assert_eq!(MessageType::TcpClose.flags(), 0);
493 }
494
495 #[test]
496 fn test_additive_fields_keep_old_and_new_compatible() {
497 use serde::{Deserialize, Serialize};
500
501 #[derive(Serialize, Deserialize)]
503 struct Old {
504 a: u32,
505 b: u32,
506 }
507
508 #[derive(Serialize, Deserialize, Debug, PartialEq)]
510 struct New {
511 a: u32,
512 b: u32,
513 #[serde(default)]
514 c: u32,
515 }
516
517 let mut new_bytes = Vec::new();
519 ciborium::into_writer(&New { a: 1, b: 2, c: 3 }, &mut new_bytes).unwrap();
520 let as_old: Old = ciborium::from_reader(&new_bytes[..]).unwrap();
521 assert_eq!((as_old.a, as_old.b), (1, 2));
522
523 let mut old_bytes = Vec::new();
525 ciborium::into_writer(&Old { a: 1, b: 2 }, &mut old_bytes).unwrap();
526 let as_new: New = ciborium::from_reader(&old_bytes[..]).unwrap();
527 assert_eq!(as_new, New { a: 1, b: 2, c: 0 });
528 }
529
530 #[test]
531 fn test_is_available_at() {
532 assert!(MessageType::ExecRequest.is_available_at(1));
534 assert!(MessageType::ExecRequest.is_available_at(2));
535 assert!(MessageType::ExecRequest.is_available_at(PROTOCOL_VERSION));
536 assert!(!MessageType::FsRequest.is_available_at(1));
538 assert!(MessageType::FsRequest.is_available_at(2));
539 assert!(MessageType::FsRequest.is_available_at(PROTOCOL_VERSION));
540 }
541
542 #[test]
543 fn test_min_protocol_version_per_type() {
544 let baseline = [
547 MessageType::Ready,
548 MessageType::InitResolved,
549 MessageType::InitAck,
550 MessageType::Shutdown,
551 MessageType::RelayClientDisconnected,
552 MessageType::ClockSync,
553 MessageType::ExecRequest,
554 MessageType::ExecStarted,
555 MessageType::ExecStdin,
556 MessageType::ExecStdinError,
557 MessageType::ExecStdout,
558 MessageType::ExecStderr,
559 MessageType::ExecExited,
560 MessageType::ExecFailed,
561 MessageType::ExecResize,
562 MessageType::ExecSignal,
563 ];
564 for mt in &baseline {
565 assert_eq!(mt.min_protocol_version(), 1, "{mt:?} should be v1 baseline");
566 }
567
568 for mt in [
571 MessageType::FsRequest,
572 MessageType::FsResponse,
573 MessageType::FsData,
574 ] {
575 assert_eq!(mt.min_protocol_version(), 2, "{mt:?} should require gen 2");
576 }
577
578 assert!(MessageType::FsRequest.min_protocol_version() <= PROTOCOL_VERSION);
580 }
581
582 #[test]
583 fn test_message_new_computes_flags() {
584 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
585 assert_eq!(msg.flags, FLAG_SESSION_START);
586
587 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
588 assert_eq!(msg.flags, 0);
589 }
590}