1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 5;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
56
57 pub t: MessageType,
59
60 #[serde(skip)]
65 pub id: u32,
66
67 #[serde(skip)]
71 pub flags: u8,
72
73 #[serde(with = "serde_bytes")]
75 pub p: Vec<u8>,
76}
77
78#[derive(
86 Debug,
87 Clone,
88 Copy,
89 PartialEq,
90 Eq,
91 Hash,
92 strum::IntoStaticStr,
93 strum::EnumString,
94 strum::EnumIter,
95)]
96pub enum MessageType {
97 #[strum(serialize = "core.ready")]
99 Ready,
100
101 #[strum(serialize = "core.init.resolved")]
103 InitResolved,
104
105 #[strum(serialize = "core.init.ack")]
107 InitAck,
108
109 #[strum(serialize = "core.shutdown")]
111 Shutdown,
112
113 #[strum(serialize = "core.relay.client.disconnected")]
115 RelayClientDisconnected,
116
117 #[strum(serialize = "core.clock.sync")]
119 ClockSync,
120
121 #[strum(serialize = "core.error")]
123 CoreError,
124
125 #[strum(serialize = "core.exec.request")]
127 ExecRequest,
128
129 #[strum(serialize = "core.exec.started")]
131 ExecStarted,
132
133 #[strum(serialize = "core.exec.stdin")]
135 ExecStdin,
136
137 #[strum(serialize = "core.exec.stdin.error")]
142 ExecStdinError,
143
144 #[strum(serialize = "core.exec.stdout")]
146 ExecStdout,
147
148 #[strum(serialize = "core.exec.stderr")]
150 ExecStderr,
151
152 #[strum(serialize = "core.exec.exited")]
154 ExecExited,
155
156 #[strum(serialize = "core.exec.failed")]
160 ExecFailed,
161
162 #[strum(serialize = "core.exec.resize")]
164 ExecResize,
165
166 #[strum(serialize = "core.exec.signal")]
168 ExecSignal,
169
170 #[strum(serialize = "core.fs.request")]
172 FsRequest,
173
174 #[strum(serialize = "core.fs.response")]
176 FsResponse,
177
178 #[strum(serialize = "core.fs.data")]
180 FsData,
181
182 #[strum(serialize = "core.tcp.connect")]
184 TcpConnect,
185
186 #[strum(serialize = "core.tcp.connected")]
188 TcpConnected,
189
190 #[strum(serialize = "core.tcp.data")]
192 TcpData,
193
194 #[strum(serialize = "core.tcp.eof")]
196 TcpEof,
197
198 #[strum(serialize = "core.tcp.close")]
200 TcpClose,
201
202 #[strum(serialize = "core.tcp.closed")]
204 TcpClosed,
205
206 #[strum(serialize = "core.tcp.failed")]
208 TcpFailed,
209}
210
211impl Message {
216 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
218 let flags = t.flags();
219 Self {
220 v: PROTOCOL_VERSION,
221 t,
222 id,
223 flags,
224 p,
225 }
226 }
227
228 pub fn with_payload<T: Serialize>(
230 t: MessageType,
231 id: u32,
232 payload: &T,
233 ) -> ProtocolResult<Self> {
234 let mut p = Vec::new();
235 ciborium::into_writer(payload, &mut p)?;
236 let flags = t.flags();
237 Ok(Self {
238 v: PROTOCOL_VERSION,
239 t,
240 id,
241 flags,
242 p,
243 })
244 }
245
246 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
248 Ok(ciborium::from_reader(&self.p[..])?)
249 }
250}
251
252impl MessageType {
253 pub fn flags(&self) -> u8 {
255 match self {
256 Self::CoreError
257 | Self::ExecExited
258 | Self::ExecFailed
259 | Self::FsResponse
260 | Self::TcpClosed
261 | Self::TcpFailed => FLAG_TERMINAL,
262 Self::ExecRequest | Self::FsRequest | Self::TcpConnect => FLAG_SESSION_START,
263 Self::Shutdown => FLAG_SHUTDOWN,
264 _ => 0,
265 }
266 }
267
268 pub fn min_protocol_version(&self) -> u8 {
288 match self {
289 Self::Ready
290 | Self::InitResolved
291 | Self::InitAck
292 | Self::Shutdown
293 | Self::RelayClientDisconnected
294 | Self::ClockSync
295 | Self::ExecRequest
296 | Self::ExecStarted
297 | Self::ExecStdin
298 | Self::ExecStdinError
299 | Self::ExecStdout
300 | Self::ExecStderr
301 | Self::ExecExited
302 | Self::ExecFailed
303 | Self::ExecResize
304 | Self::ExecSignal => 1,
305 Self::FsRequest | Self::FsResponse | Self::FsData => 2,
306 Self::CoreError => 5,
307 Self::TcpConnect
308 | Self::TcpConnected
309 | Self::TcpData
310 | Self::TcpEof
311 | Self::TcpClose
312 | Self::TcpClosed
313 | Self::TcpFailed => 4,
314 }
315 }
316
317 pub fn is_available_at(&self, peer_generation: u8) -> bool {
326 self.min_protocol_version() <= peer_generation
327 }
328
329 pub fn as_str(&self) -> &'static str {
334 (*self).into()
335 }
336
337 pub fn from_wire_str(s: &str) -> Option<Self> {
340 s.parse().ok()
341 }
342}
343
344impl Serialize for MessageType {
349 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
350 where
351 S: serde::Serializer,
352 {
353 serializer.serialize_str(self.as_str())
354 }
355}
356
357impl<'de> Deserialize<'de> for MessageType {
358 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
359 where
360 D: serde::Deserializer<'de>,
361 {
362 let s = String::deserialize(deserializer)?;
363 Self::from_wire_str(&s)
364 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
365 }
366}
367
368#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_message_type_roundtrip() {
378 let types = [
379 (MessageType::Ready, "core.ready"),
380 (MessageType::InitResolved, "core.init.resolved"),
381 (MessageType::InitAck, "core.init.ack"),
382 (MessageType::Shutdown, "core.shutdown"),
383 (
384 MessageType::RelayClientDisconnected,
385 "core.relay.client.disconnected",
386 ),
387 (MessageType::ClockSync, "core.clock.sync"),
388 (MessageType::CoreError, "core.error"),
389 (MessageType::ExecRequest, "core.exec.request"),
390 (MessageType::ExecStarted, "core.exec.started"),
391 (MessageType::ExecStdin, "core.exec.stdin"),
392 (MessageType::ExecStdinError, "core.exec.stdin.error"),
393 (MessageType::ExecStdout, "core.exec.stdout"),
394 (MessageType::ExecStderr, "core.exec.stderr"),
395 (MessageType::ExecExited, "core.exec.exited"),
396 (MessageType::ExecFailed, "core.exec.failed"),
397 (MessageType::ExecResize, "core.exec.resize"),
398 (MessageType::ExecSignal, "core.exec.signal"),
399 (MessageType::FsRequest, "core.fs.request"),
400 (MessageType::FsResponse, "core.fs.response"),
401 (MessageType::FsData, "core.fs.data"),
402 (MessageType::TcpConnect, "core.tcp.connect"),
403 (MessageType::TcpConnected, "core.tcp.connected"),
404 (MessageType::TcpData, "core.tcp.data"),
405 (MessageType::TcpEof, "core.tcp.eof"),
406 (MessageType::TcpClose, "core.tcp.close"),
407 (MessageType::TcpClosed, "core.tcp.closed"),
408 (MessageType::TcpFailed, "core.tcp.failed"),
409 ];
410
411 for (mt, expected_str) in &types {
412 assert_eq!(mt.as_str(), *expected_str);
413 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
414 }
415 }
416
417 #[test]
418 fn test_message_type_serde_roundtrip() {
419 let types = [
420 MessageType::Ready,
421 MessageType::InitResolved,
422 MessageType::InitAck,
423 MessageType::Shutdown,
424 MessageType::RelayClientDisconnected,
425 MessageType::ClockSync,
426 MessageType::CoreError,
427 MessageType::ExecRequest,
428 MessageType::ExecStarted,
429 MessageType::ExecStdin,
430 MessageType::ExecStdinError,
431 MessageType::ExecStdout,
432 MessageType::ExecStderr,
433 MessageType::ExecExited,
434 MessageType::ExecFailed,
435 MessageType::ExecResize,
436 MessageType::ExecSignal,
437 MessageType::FsRequest,
438 MessageType::FsResponse,
439 MessageType::FsData,
440 MessageType::TcpConnect,
441 MessageType::TcpConnected,
442 MessageType::TcpData,
443 MessageType::TcpEof,
444 MessageType::TcpClose,
445 MessageType::TcpClosed,
446 MessageType::TcpFailed,
447 ];
448
449 for mt in &types {
450 let mut buf = Vec::new();
451 ciborium::into_writer(mt, &mut buf).unwrap();
452 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
453 assert_eq!(&decoded, mt);
454 }
455 }
456
457 #[test]
458 fn test_unknown_message_type() {
459 assert!(MessageType::from_wire_str("core.unknown").is_none());
460 }
461
462 #[test]
463 fn test_message_with_payload_roundtrip() {
464 use crate::exec::ExecExited;
465
466 let msg =
467 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
468
469 assert_eq!(msg.t, MessageType::ExecExited);
470 assert_eq!(msg.id, 7);
471 assert_eq!(msg.flags, FLAG_TERMINAL);
472
473 let payload: ExecExited = msg.payload().unwrap();
474 assert_eq!(payload.code, 42);
475 }
476
477 #[test]
478 fn test_message_type_flags() {
479 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
480 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
481 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
482 assert_eq!(MessageType::TcpClosed.flags(), FLAG_TERMINAL);
483 assert_eq!(MessageType::TcpFailed.flags(), FLAG_TERMINAL);
484 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
485 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
486 assert_eq!(MessageType::TcpConnect.flags(), FLAG_SESSION_START);
487 assert_eq!(MessageType::Ready.flags(), 0);
488 assert_eq!(MessageType::InitResolved.flags(), 0);
489 assert_eq!(MessageType::InitAck.flags(), 0);
490 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
491 assert_eq!(MessageType::ClockSync.flags(), 0);
492 assert_eq!(MessageType::ExecStarted.flags(), 0);
493 assert_eq!(MessageType::ExecStdin.flags(), 0);
494 assert_eq!(MessageType::ExecStdout.flags(), 0);
495 assert_eq!(MessageType::ExecStderr.flags(), 0);
496 assert_eq!(MessageType::ExecResize.flags(), 0);
497 assert_eq!(MessageType::ExecSignal.flags(), 0);
498 assert_eq!(MessageType::FsData.flags(), 0);
499 assert_eq!(MessageType::TcpConnected.flags(), 0);
500 assert_eq!(MessageType::TcpData.flags(), 0);
501 assert_eq!(MessageType::TcpEof.flags(), 0);
502 assert_eq!(MessageType::TcpClose.flags(), 0);
503 }
504
505 #[test]
506 fn test_additive_fields_keep_old_and_new_compatible() {
507 use serde::{Deserialize, Serialize};
510
511 #[derive(Serialize, Deserialize)]
513 struct Old {
514 a: u32,
515 b: u32,
516 }
517
518 #[derive(Serialize, Deserialize, Debug, PartialEq)]
520 struct New {
521 a: u32,
522 b: u32,
523 #[serde(default)]
524 c: u32,
525 }
526
527 let mut new_bytes = Vec::new();
529 ciborium::into_writer(&New { a: 1, b: 2, c: 3 }, &mut new_bytes).unwrap();
530 let as_old: Old = ciborium::from_reader(&new_bytes[..]).unwrap();
531 assert_eq!((as_old.a, as_old.b), (1, 2));
532
533 let mut old_bytes = Vec::new();
535 ciborium::into_writer(&Old { a: 1, b: 2 }, &mut old_bytes).unwrap();
536 let as_new: New = ciborium::from_reader(&old_bytes[..]).unwrap();
537 assert_eq!(as_new, New { a: 1, b: 2, c: 0 });
538 }
539
540 #[test]
541 fn test_is_available_at() {
542 assert!(MessageType::ExecRequest.is_available_at(1));
544 assert!(MessageType::ExecRequest.is_available_at(2));
545 assert!(MessageType::ExecRequest.is_available_at(PROTOCOL_VERSION));
546 assert!(!MessageType::FsRequest.is_available_at(1));
548 assert!(MessageType::FsRequest.is_available_at(2));
549 assert!(MessageType::FsRequest.is_available_at(PROTOCOL_VERSION));
550 }
551
552 #[test]
553 fn test_min_protocol_version_per_type() {
554 let baseline = [
557 MessageType::Ready,
558 MessageType::InitResolved,
559 MessageType::InitAck,
560 MessageType::Shutdown,
561 MessageType::RelayClientDisconnected,
562 MessageType::ClockSync,
563 MessageType::ExecRequest,
564 MessageType::ExecStarted,
565 MessageType::ExecStdin,
566 MessageType::ExecStdinError,
567 MessageType::ExecStdout,
568 MessageType::ExecStderr,
569 MessageType::ExecExited,
570 MessageType::ExecFailed,
571 MessageType::ExecResize,
572 MessageType::ExecSignal,
573 ];
574 for mt in &baseline {
575 assert_eq!(mt.min_protocol_version(), 1, "{mt:?} should be v1 baseline");
576 }
577
578 for mt in [
581 MessageType::FsRequest,
582 MessageType::FsResponse,
583 MessageType::FsData,
584 ] {
585 assert_eq!(mt.min_protocol_version(), 2, "{mt:?} should require gen 2");
586 }
587
588 assert!(MessageType::FsRequest.min_protocol_version() <= PROTOCOL_VERSION);
590 }
591
592 #[test]
593 fn test_message_new_computes_flags() {
594 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
595 assert_eq!(msg.flags, FLAG_SESSION_START);
596
597 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
598 assert_eq!(msg.flags, 0);
599 }
600}