1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 1;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
50
51 pub t: MessageType,
53
54 #[serde(skip)]
59 pub id: u32,
60
61 #[serde(skip)]
65 pub flags: u8,
66
67 #[serde(with = "serde_bytes")]
69 pub p: Vec<u8>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub enum MessageType {
75 Ready,
77
78 Shutdown,
80
81 ExecRequest,
83
84 ExecStarted,
86
87 ExecStdin,
89
90 ExecStdinError,
95
96 ExecStdout,
98
99 ExecStderr,
101
102 ExecExited,
104
105 ExecFailed,
109
110 ExecResize,
112
113 ExecSignal,
115
116 FsRequest,
118
119 FsResponse,
121
122 FsData,
124}
125
126impl Message {
131 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
133 let flags = t.flags();
134 Self {
135 v: PROTOCOL_VERSION,
136 t,
137 id,
138 flags,
139 p,
140 }
141 }
142
143 pub fn with_payload<T: Serialize>(
145 t: MessageType,
146 id: u32,
147 payload: &T,
148 ) -> ProtocolResult<Self> {
149 let mut p = Vec::new();
150 ciborium::into_writer(payload, &mut p)?;
151 let flags = t.flags();
152 Ok(Self {
153 v: PROTOCOL_VERSION,
154 t,
155 id,
156 flags,
157 p,
158 })
159 }
160
161 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
163 Ok(ciborium::from_reader(&self.p[..])?)
164 }
165}
166
167impl MessageType {
168 pub fn flags(&self) -> u8 {
170 match self {
171 Self::ExecExited | Self::ExecFailed | Self::FsResponse => FLAG_TERMINAL,
172 Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
173 Self::Shutdown => FLAG_SHUTDOWN,
174 _ => 0,
175 }
176 }
177
178 pub fn as_str(&self) -> &'static str {
180 match self {
181 Self::Ready => "core.ready",
182 Self::Shutdown => "core.shutdown",
183 Self::ExecRequest => "core.exec.request",
184 Self::ExecStarted => "core.exec.started",
185 Self::ExecStdin => "core.exec.stdin",
186 Self::ExecStdinError => "core.exec.stdin.error",
187 Self::ExecStdout => "core.exec.stdout",
188 Self::ExecStderr => "core.exec.stderr",
189 Self::ExecExited => "core.exec.exited",
190 Self::ExecFailed => "core.exec.failed",
191 Self::ExecResize => "core.exec.resize",
192 Self::ExecSignal => "core.exec.signal",
193 Self::FsRequest => "core.fs.request",
194 Self::FsResponse => "core.fs.response",
195 Self::FsData => "core.fs.data",
196 }
197 }
198
199 pub fn from_wire_str(s: &str) -> Option<Self> {
201 match s {
202 "core.ready" => Some(Self::Ready),
203 "core.shutdown" => Some(Self::Shutdown),
204 "core.exec.request" => Some(Self::ExecRequest),
205 "core.exec.started" => Some(Self::ExecStarted),
206 "core.exec.stdin" => Some(Self::ExecStdin),
207 "core.exec.stdin.error" => Some(Self::ExecStdinError),
208 "core.exec.stdout" => Some(Self::ExecStdout),
209 "core.exec.stderr" => Some(Self::ExecStderr),
210 "core.exec.exited" => Some(Self::ExecExited),
211 "core.exec.failed" => Some(Self::ExecFailed),
212 "core.exec.resize" => Some(Self::ExecResize),
213 "core.exec.signal" => Some(Self::ExecSignal),
214 "core.fs.request" => Some(Self::FsRequest),
215 "core.fs.response" => Some(Self::FsResponse),
216 "core.fs.data" => Some(Self::FsData),
217 _ => None,
218 }
219 }
220}
221
222impl Serialize for MessageType {
227 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
228 where
229 S: serde::Serializer,
230 {
231 serializer.serialize_str(self.as_str())
232 }
233}
234
235impl<'de> Deserialize<'de> for MessageType {
236 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237 where
238 D: serde::Deserializer<'de>,
239 {
240 let s = String::deserialize(deserializer)?;
241 Self::from_wire_str(&s)
242 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_message_type_roundtrip() {
256 let types = [
257 (MessageType::Ready, "core.ready"),
258 (MessageType::Shutdown, "core.shutdown"),
259 (MessageType::ExecRequest, "core.exec.request"),
260 (MessageType::ExecStarted, "core.exec.started"),
261 (MessageType::ExecStdin, "core.exec.stdin"),
262 (MessageType::ExecStdinError, "core.exec.stdin.error"),
263 (MessageType::ExecStdout, "core.exec.stdout"),
264 (MessageType::ExecStderr, "core.exec.stderr"),
265 (MessageType::ExecExited, "core.exec.exited"),
266 (MessageType::ExecFailed, "core.exec.failed"),
267 (MessageType::ExecResize, "core.exec.resize"),
268 (MessageType::ExecSignal, "core.exec.signal"),
269 (MessageType::FsRequest, "core.fs.request"),
270 (MessageType::FsResponse, "core.fs.response"),
271 (MessageType::FsData, "core.fs.data"),
272 ];
273
274 for (mt, expected_str) in &types {
275 assert_eq!(mt.as_str(), *expected_str);
276 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
277 }
278 }
279
280 #[test]
281 fn test_message_type_serde_roundtrip() {
282 let types = [
283 MessageType::Ready,
284 MessageType::Shutdown,
285 MessageType::ExecRequest,
286 MessageType::ExecStarted,
287 MessageType::ExecStdin,
288 MessageType::ExecStdinError,
289 MessageType::ExecStdout,
290 MessageType::ExecStderr,
291 MessageType::ExecExited,
292 MessageType::ExecFailed,
293 MessageType::ExecResize,
294 MessageType::ExecSignal,
295 MessageType::FsRequest,
296 MessageType::FsResponse,
297 MessageType::FsData,
298 ];
299
300 for mt in &types {
301 let mut buf = Vec::new();
302 ciborium::into_writer(mt, &mut buf).unwrap();
303 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
304 assert_eq!(&decoded, mt);
305 }
306 }
307
308 #[test]
309 fn test_unknown_message_type() {
310 assert!(MessageType::from_wire_str("core.unknown").is_none());
311 }
312
313 #[test]
314 fn test_message_with_payload_roundtrip() {
315 use crate::exec::ExecExited;
316
317 let msg =
318 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
319
320 assert_eq!(msg.t, MessageType::ExecExited);
321 assert_eq!(msg.id, 7);
322 assert_eq!(msg.flags, FLAG_TERMINAL);
323
324 let payload: ExecExited = msg.payload().unwrap();
325 assert_eq!(payload.code, 42);
326 }
327
328 #[test]
329 fn test_message_type_flags() {
330 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
331 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
332 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
333 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
334 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
335 assert_eq!(MessageType::Ready.flags(), 0);
336 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
337 assert_eq!(MessageType::ExecStarted.flags(), 0);
338 assert_eq!(MessageType::ExecStdin.flags(), 0);
339 assert_eq!(MessageType::ExecStdout.flags(), 0);
340 assert_eq!(MessageType::ExecStderr.flags(), 0);
341 assert_eq!(MessageType::ExecResize.flags(), 0);
342 assert_eq!(MessageType::ExecSignal.flags(), 0);
343 assert_eq!(MessageType::FsData.flags(), 0);
344 }
345
346 #[test]
347 fn test_message_new_computes_flags() {
348 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
349 assert_eq!(msg.flags, FLAG_SESSION_START);
350
351 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
352 assert_eq!(msg.flags, 0);
353 }
354}