1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 1;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
50
51 pub t: MessageType,
53
54 #[serde(skip)]
59 pub id: u32,
60
61 #[serde(skip)]
65 pub flags: u8,
66
67 #[serde(with = "serde_bytes")]
69 pub p: Vec<u8>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub enum MessageType {
75 Ready,
77
78 Shutdown,
80
81 ExecRequest,
83
84 ExecStarted,
86
87 ExecStdin,
89
90 ExecStdout,
92
93 ExecStderr,
95
96 ExecExited,
98
99 ExecFailed,
103
104 ExecResize,
106
107 ExecSignal,
109
110 FsRequest,
112
113 FsResponse,
115
116 FsData,
118}
119
120impl Message {
125 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
127 let flags = t.flags();
128 Self {
129 v: PROTOCOL_VERSION,
130 t,
131 id,
132 flags,
133 p,
134 }
135 }
136
137 pub fn with_payload<T: Serialize>(
139 t: MessageType,
140 id: u32,
141 payload: &T,
142 ) -> ProtocolResult<Self> {
143 let mut p = Vec::new();
144 ciborium::into_writer(payload, &mut p)?;
145 let flags = t.flags();
146 Ok(Self {
147 v: PROTOCOL_VERSION,
148 t,
149 id,
150 flags,
151 p,
152 })
153 }
154
155 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
157 Ok(ciborium::from_reader(&self.p[..])?)
158 }
159}
160
161impl MessageType {
162 pub fn flags(&self) -> u8 {
164 match self {
165 Self::ExecExited | Self::ExecFailed | Self::FsResponse => FLAG_TERMINAL,
166 Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
167 Self::Shutdown => FLAG_SHUTDOWN,
168 _ => 0,
169 }
170 }
171
172 pub fn as_str(&self) -> &'static str {
174 match self {
175 Self::Ready => "core.ready",
176 Self::Shutdown => "core.shutdown",
177 Self::ExecRequest => "core.exec.request",
178 Self::ExecStarted => "core.exec.started",
179 Self::ExecStdin => "core.exec.stdin",
180 Self::ExecStdout => "core.exec.stdout",
181 Self::ExecStderr => "core.exec.stderr",
182 Self::ExecExited => "core.exec.exited",
183 Self::ExecFailed => "core.exec.failed",
184 Self::ExecResize => "core.exec.resize",
185 Self::ExecSignal => "core.exec.signal",
186 Self::FsRequest => "core.fs.request",
187 Self::FsResponse => "core.fs.response",
188 Self::FsData => "core.fs.data",
189 }
190 }
191
192 pub fn from_wire_str(s: &str) -> Option<Self> {
194 match s {
195 "core.ready" => Some(Self::Ready),
196 "core.shutdown" => Some(Self::Shutdown),
197 "core.exec.request" => Some(Self::ExecRequest),
198 "core.exec.started" => Some(Self::ExecStarted),
199 "core.exec.stdin" => Some(Self::ExecStdin),
200 "core.exec.stdout" => Some(Self::ExecStdout),
201 "core.exec.stderr" => Some(Self::ExecStderr),
202 "core.exec.exited" => Some(Self::ExecExited),
203 "core.exec.failed" => Some(Self::ExecFailed),
204 "core.exec.resize" => Some(Self::ExecResize),
205 "core.exec.signal" => Some(Self::ExecSignal),
206 "core.fs.request" => Some(Self::FsRequest),
207 "core.fs.response" => Some(Self::FsResponse),
208 "core.fs.data" => Some(Self::FsData),
209 _ => None,
210 }
211 }
212}
213
214impl Serialize for MessageType {
219 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
220 where
221 S: serde::Serializer,
222 {
223 serializer.serialize_str(self.as_str())
224 }
225}
226
227impl<'de> Deserialize<'de> for MessageType {
228 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229 where
230 D: serde::Deserializer<'de>,
231 {
232 let s = String::deserialize(deserializer)?;
233 Self::from_wire_str(&s)
234 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
235 }
236}
237
238#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_message_type_roundtrip() {
248 let types = [
249 (MessageType::Ready, "core.ready"),
250 (MessageType::Shutdown, "core.shutdown"),
251 (MessageType::ExecRequest, "core.exec.request"),
252 (MessageType::ExecStarted, "core.exec.started"),
253 (MessageType::ExecStdin, "core.exec.stdin"),
254 (MessageType::ExecStdout, "core.exec.stdout"),
255 (MessageType::ExecStderr, "core.exec.stderr"),
256 (MessageType::ExecExited, "core.exec.exited"),
257 (MessageType::ExecFailed, "core.exec.failed"),
258 (MessageType::ExecResize, "core.exec.resize"),
259 (MessageType::ExecSignal, "core.exec.signal"),
260 (MessageType::FsRequest, "core.fs.request"),
261 (MessageType::FsResponse, "core.fs.response"),
262 (MessageType::FsData, "core.fs.data"),
263 ];
264
265 for (mt, expected_str) in &types {
266 assert_eq!(mt.as_str(), *expected_str);
267 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
268 }
269 }
270
271 #[test]
272 fn test_message_type_serde_roundtrip() {
273 let types = [
274 MessageType::Ready,
275 MessageType::Shutdown,
276 MessageType::ExecRequest,
277 MessageType::ExecStarted,
278 MessageType::ExecStdin,
279 MessageType::ExecStdout,
280 MessageType::ExecStderr,
281 MessageType::ExecExited,
282 MessageType::ExecFailed,
283 MessageType::ExecResize,
284 MessageType::ExecSignal,
285 MessageType::FsRequest,
286 MessageType::FsResponse,
287 MessageType::FsData,
288 ];
289
290 for mt in &types {
291 let mut buf = Vec::new();
292 ciborium::into_writer(mt, &mut buf).unwrap();
293 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
294 assert_eq!(&decoded, mt);
295 }
296 }
297
298 #[test]
299 fn test_unknown_message_type() {
300 assert!(MessageType::from_wire_str("core.unknown").is_none());
301 }
302
303 #[test]
304 fn test_message_with_payload_roundtrip() {
305 use crate::exec::ExecExited;
306
307 let msg =
308 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
309
310 assert_eq!(msg.t, MessageType::ExecExited);
311 assert_eq!(msg.id, 7);
312 assert_eq!(msg.flags, FLAG_TERMINAL);
313
314 let payload: ExecExited = msg.payload().unwrap();
315 assert_eq!(payload.code, 42);
316 }
317
318 #[test]
319 fn test_message_type_flags() {
320 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
321 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
322 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
323 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
324 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
325 assert_eq!(MessageType::Ready.flags(), 0);
326 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
327 assert_eq!(MessageType::ExecStarted.flags(), 0);
328 assert_eq!(MessageType::ExecStdin.flags(), 0);
329 assert_eq!(MessageType::ExecStdout.flags(), 0);
330 assert_eq!(MessageType::ExecStderr.flags(), 0);
331 assert_eq!(MessageType::ExecResize.flags(), 0);
332 assert_eq!(MessageType::ExecSignal.flags(), 0);
333 assert_eq!(MessageType::FsData.flags(), 0);
334 }
335
336 #[test]
337 fn test_message_new_computes_flags() {
338 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
339 assert_eq!(msg.flags, FLAG_SESSION_START);
340
341 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
342 assert_eq!(msg.flags, 0);
343 }
344}