1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 3;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
50
51 pub t: MessageType,
53
54 #[serde(skip)]
59 pub id: u32,
60
61 #[serde(skip)]
65 pub flags: u8,
66
67 #[serde(with = "serde_bytes")]
69 pub p: Vec<u8>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub enum MessageType {
75 Ready,
77
78 InitResolved,
80
81 InitAck,
83
84 Shutdown,
86
87 RelayClientDisconnected,
89
90 ClockSync,
92
93 ExecRequest,
95
96 ExecStarted,
98
99 ExecStdin,
101
102 ExecStdinError,
107
108 ExecStdout,
110
111 ExecStderr,
113
114 ExecExited,
116
117 ExecFailed,
121
122 ExecResize,
124
125 ExecSignal,
127
128 FsRequest,
130
131 FsResponse,
133
134 FsData,
136}
137
138impl Message {
143 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
145 let flags = t.flags();
146 Self {
147 v: PROTOCOL_VERSION,
148 t,
149 id,
150 flags,
151 p,
152 }
153 }
154
155 pub fn with_payload<T: Serialize>(
157 t: MessageType,
158 id: u32,
159 payload: &T,
160 ) -> ProtocolResult<Self> {
161 let mut p = Vec::new();
162 ciborium::into_writer(payload, &mut p)?;
163 let flags = t.flags();
164 Ok(Self {
165 v: PROTOCOL_VERSION,
166 t,
167 id,
168 flags,
169 p,
170 })
171 }
172
173 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
175 Ok(ciborium::from_reader(&self.p[..])?)
176 }
177}
178
179impl MessageType {
180 pub fn flags(&self) -> u8 {
182 match self {
183 Self::ExecExited | Self::ExecFailed | Self::FsResponse => FLAG_TERMINAL,
184 Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
185 Self::Shutdown => FLAG_SHUTDOWN,
186 _ => 0,
187 }
188 }
189
190 pub fn as_str(&self) -> &'static str {
192 match self {
193 Self::Ready => "core.ready",
194 Self::InitResolved => "core.init.resolved",
195 Self::InitAck => "core.init.ack",
196 Self::Shutdown => "core.shutdown",
197 Self::RelayClientDisconnected => "core.relay.client.disconnected",
198 Self::ClockSync => "core.clock.sync",
199 Self::ExecRequest => "core.exec.request",
200 Self::ExecStarted => "core.exec.started",
201 Self::ExecStdin => "core.exec.stdin",
202 Self::ExecStdinError => "core.exec.stdin.error",
203 Self::ExecStdout => "core.exec.stdout",
204 Self::ExecStderr => "core.exec.stderr",
205 Self::ExecExited => "core.exec.exited",
206 Self::ExecFailed => "core.exec.failed",
207 Self::ExecResize => "core.exec.resize",
208 Self::ExecSignal => "core.exec.signal",
209 Self::FsRequest => "core.fs.request",
210 Self::FsResponse => "core.fs.response",
211 Self::FsData => "core.fs.data",
212 }
213 }
214
215 pub fn from_wire_str(s: &str) -> Option<Self> {
217 match s {
218 "core.ready" => Some(Self::Ready),
219 "core.init.resolved" => Some(Self::InitResolved),
220 "core.init.ack" => Some(Self::InitAck),
221 "core.shutdown" => Some(Self::Shutdown),
222 "core.relay.client.disconnected" => Some(Self::RelayClientDisconnected),
223 "core.clock.sync" => Some(Self::ClockSync),
224 "core.exec.request" => Some(Self::ExecRequest),
225 "core.exec.started" => Some(Self::ExecStarted),
226 "core.exec.stdin" => Some(Self::ExecStdin),
227 "core.exec.stdin.error" => Some(Self::ExecStdinError),
228 "core.exec.stdout" => Some(Self::ExecStdout),
229 "core.exec.stderr" => Some(Self::ExecStderr),
230 "core.exec.exited" => Some(Self::ExecExited),
231 "core.exec.failed" => Some(Self::ExecFailed),
232 "core.exec.resize" => Some(Self::ExecResize),
233 "core.exec.signal" => Some(Self::ExecSignal),
234 "core.fs.request" => Some(Self::FsRequest),
235 "core.fs.response" => Some(Self::FsResponse),
236 "core.fs.data" => Some(Self::FsData),
237 _ => None,
238 }
239 }
240}
241
242impl Serialize for MessageType {
247 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
248 where
249 S: serde::Serializer,
250 {
251 serializer.serialize_str(self.as_str())
252 }
253}
254
255impl<'de> Deserialize<'de> for MessageType {
256 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
257 where
258 D: serde::Deserializer<'de>,
259 {
260 let s = String::deserialize(deserializer)?;
261 Self::from_wire_str(&s)
262 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
263 }
264}
265
266#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_message_type_roundtrip() {
276 let types = [
277 (MessageType::Ready, "core.ready"),
278 (MessageType::InitResolved, "core.init.resolved"),
279 (MessageType::InitAck, "core.init.ack"),
280 (MessageType::Shutdown, "core.shutdown"),
281 (
282 MessageType::RelayClientDisconnected,
283 "core.relay.client.disconnected",
284 ),
285 (MessageType::ClockSync, "core.clock.sync"),
286 (MessageType::ExecRequest, "core.exec.request"),
287 (MessageType::ExecStarted, "core.exec.started"),
288 (MessageType::ExecStdin, "core.exec.stdin"),
289 (MessageType::ExecStdinError, "core.exec.stdin.error"),
290 (MessageType::ExecStdout, "core.exec.stdout"),
291 (MessageType::ExecStderr, "core.exec.stderr"),
292 (MessageType::ExecExited, "core.exec.exited"),
293 (MessageType::ExecFailed, "core.exec.failed"),
294 (MessageType::ExecResize, "core.exec.resize"),
295 (MessageType::ExecSignal, "core.exec.signal"),
296 (MessageType::FsRequest, "core.fs.request"),
297 (MessageType::FsResponse, "core.fs.response"),
298 (MessageType::FsData, "core.fs.data"),
299 ];
300
301 for (mt, expected_str) in &types {
302 assert_eq!(mt.as_str(), *expected_str);
303 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
304 }
305 }
306
307 #[test]
308 fn test_message_type_serde_roundtrip() {
309 let types = [
310 MessageType::Ready,
311 MessageType::InitResolved,
312 MessageType::InitAck,
313 MessageType::Shutdown,
314 MessageType::RelayClientDisconnected,
315 MessageType::ClockSync,
316 MessageType::ExecRequest,
317 MessageType::ExecStarted,
318 MessageType::ExecStdin,
319 MessageType::ExecStdinError,
320 MessageType::ExecStdout,
321 MessageType::ExecStderr,
322 MessageType::ExecExited,
323 MessageType::ExecFailed,
324 MessageType::ExecResize,
325 MessageType::ExecSignal,
326 MessageType::FsRequest,
327 MessageType::FsResponse,
328 MessageType::FsData,
329 ];
330
331 for mt in &types {
332 let mut buf = Vec::new();
333 ciborium::into_writer(mt, &mut buf).unwrap();
334 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
335 assert_eq!(&decoded, mt);
336 }
337 }
338
339 #[test]
340 fn test_unknown_message_type() {
341 assert!(MessageType::from_wire_str("core.unknown").is_none());
342 }
343
344 #[test]
345 fn test_message_with_payload_roundtrip() {
346 use crate::exec::ExecExited;
347
348 let msg =
349 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
350
351 assert_eq!(msg.t, MessageType::ExecExited);
352 assert_eq!(msg.id, 7);
353 assert_eq!(msg.flags, FLAG_TERMINAL);
354
355 let payload: ExecExited = msg.payload().unwrap();
356 assert_eq!(payload.code, 42);
357 }
358
359 #[test]
360 fn test_message_type_flags() {
361 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
362 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
363 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
364 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
365 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
366 assert_eq!(MessageType::Ready.flags(), 0);
367 assert_eq!(MessageType::InitResolved.flags(), 0);
368 assert_eq!(MessageType::InitAck.flags(), 0);
369 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
370 assert_eq!(MessageType::ClockSync.flags(), 0);
371 assert_eq!(MessageType::ExecStarted.flags(), 0);
372 assert_eq!(MessageType::ExecStdin.flags(), 0);
373 assert_eq!(MessageType::ExecStdout.flags(), 0);
374 assert_eq!(MessageType::ExecStderr.flags(), 0);
375 assert_eq!(MessageType::ExecResize.flags(), 0);
376 assert_eq!(MessageType::ExecSignal.flags(), 0);
377 assert_eq!(MessageType::FsData.flags(), 0);
378 }
379
380 #[test]
381 fn test_message_new_computes_flags() {
382 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
383 assert_eq!(msg.flags, FLAG_SESSION_START);
384
385 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
386 assert_eq!(msg.flags, 0);
387 }
388}