1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 1;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
50
51 pub t: MessageType,
53
54 #[serde(skip)]
59 pub id: u32,
60
61 #[serde(skip)]
65 pub flags: u8,
66
67 #[serde(with = "serde_bytes")]
69 pub p: Vec<u8>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub enum MessageType {
75 Ready,
77
78 Shutdown,
80
81 ExecRequest,
83
84 ExecStarted,
86
87 ExecStdin,
89
90 ExecStdout,
92
93 ExecStderr,
95
96 ExecExited,
98
99 ExecResize,
101
102 ExecSignal,
104
105 FsRequest,
107
108 FsResponse,
110
111 FsData,
113}
114
115impl Message {
120 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
122 let flags = t.flags();
123 Self {
124 v: PROTOCOL_VERSION,
125 t,
126 id,
127 flags,
128 p,
129 }
130 }
131
132 pub fn with_payload<T: Serialize>(
134 t: MessageType,
135 id: u32,
136 payload: &T,
137 ) -> ProtocolResult<Self> {
138 let mut p = Vec::new();
139 ciborium::into_writer(payload, &mut p)?;
140 let flags = t.flags();
141 Ok(Self {
142 v: PROTOCOL_VERSION,
143 t,
144 id,
145 flags,
146 p,
147 })
148 }
149
150 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
152 Ok(ciborium::from_reader(&self.p[..])?)
153 }
154}
155
156impl MessageType {
157 pub fn flags(&self) -> u8 {
159 match self {
160 Self::ExecExited | Self::FsResponse => FLAG_TERMINAL,
161 Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
162 Self::Shutdown => FLAG_SHUTDOWN,
163 _ => 0,
164 }
165 }
166
167 pub fn as_str(&self) -> &'static str {
169 match self {
170 Self::Ready => "core.ready",
171 Self::Shutdown => "core.shutdown",
172 Self::ExecRequest => "core.exec.request",
173 Self::ExecStarted => "core.exec.started",
174 Self::ExecStdin => "core.exec.stdin",
175 Self::ExecStdout => "core.exec.stdout",
176 Self::ExecStderr => "core.exec.stderr",
177 Self::ExecExited => "core.exec.exited",
178 Self::ExecResize => "core.exec.resize",
179 Self::ExecSignal => "core.exec.signal",
180 Self::FsRequest => "core.fs.request",
181 Self::FsResponse => "core.fs.response",
182 Self::FsData => "core.fs.data",
183 }
184 }
185
186 pub fn from_wire_str(s: &str) -> Option<Self> {
188 match s {
189 "core.ready" => Some(Self::Ready),
190 "core.shutdown" => Some(Self::Shutdown),
191 "core.exec.request" => Some(Self::ExecRequest),
192 "core.exec.started" => Some(Self::ExecStarted),
193 "core.exec.stdin" => Some(Self::ExecStdin),
194 "core.exec.stdout" => Some(Self::ExecStdout),
195 "core.exec.stderr" => Some(Self::ExecStderr),
196 "core.exec.exited" => Some(Self::ExecExited),
197 "core.exec.resize" => Some(Self::ExecResize),
198 "core.exec.signal" => Some(Self::ExecSignal),
199 "core.fs.request" => Some(Self::FsRequest),
200 "core.fs.response" => Some(Self::FsResponse),
201 "core.fs.data" => Some(Self::FsData),
202 _ => None,
203 }
204 }
205}
206
207impl Serialize for MessageType {
212 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
213 where
214 S: serde::Serializer,
215 {
216 serializer.serialize_str(self.as_str())
217 }
218}
219
220impl<'de> Deserialize<'de> for MessageType {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: serde::Deserializer<'de>,
224 {
225 let s = String::deserialize(deserializer)?;
226 Self::from_wire_str(&s)
227 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
228 }
229}
230
231#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_message_type_roundtrip() {
241 let types = [
242 (MessageType::Ready, "core.ready"),
243 (MessageType::Shutdown, "core.shutdown"),
244 (MessageType::ExecRequest, "core.exec.request"),
245 (MessageType::ExecStarted, "core.exec.started"),
246 (MessageType::ExecStdin, "core.exec.stdin"),
247 (MessageType::ExecStdout, "core.exec.stdout"),
248 (MessageType::ExecStderr, "core.exec.stderr"),
249 (MessageType::ExecExited, "core.exec.exited"),
250 (MessageType::ExecResize, "core.exec.resize"),
251 (MessageType::ExecSignal, "core.exec.signal"),
252 (MessageType::FsRequest, "core.fs.request"),
253 (MessageType::FsResponse, "core.fs.response"),
254 (MessageType::FsData, "core.fs.data"),
255 ];
256
257 for (mt, expected_str) in &types {
258 assert_eq!(mt.as_str(), *expected_str);
259 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
260 }
261 }
262
263 #[test]
264 fn test_message_type_serde_roundtrip() {
265 let types = [
266 MessageType::Ready,
267 MessageType::Shutdown,
268 MessageType::ExecRequest,
269 MessageType::ExecStarted,
270 MessageType::ExecStdin,
271 MessageType::ExecStdout,
272 MessageType::ExecStderr,
273 MessageType::ExecExited,
274 MessageType::ExecResize,
275 MessageType::ExecSignal,
276 MessageType::FsRequest,
277 MessageType::FsResponse,
278 MessageType::FsData,
279 ];
280
281 for mt in &types {
282 let mut buf = Vec::new();
283 ciborium::into_writer(mt, &mut buf).unwrap();
284 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
285 assert_eq!(&decoded, mt);
286 }
287 }
288
289 #[test]
290 fn test_unknown_message_type() {
291 assert!(MessageType::from_wire_str("core.unknown").is_none());
292 }
293
294 #[test]
295 fn test_message_with_payload_roundtrip() {
296 use crate::exec::ExecExited;
297
298 let msg =
299 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
300
301 assert_eq!(msg.t, MessageType::ExecExited);
302 assert_eq!(msg.id, 7);
303 assert_eq!(msg.flags, FLAG_TERMINAL);
304
305 let payload: ExecExited = msg.payload().unwrap();
306 assert_eq!(payload.code, 42);
307 }
308
309 #[test]
310 fn test_message_type_flags() {
311 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
312 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
313 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
314 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
315 assert_eq!(MessageType::Ready.flags(), 0);
316 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
317 assert_eq!(MessageType::ExecStarted.flags(), 0);
318 assert_eq!(MessageType::ExecStdin.flags(), 0);
319 assert_eq!(MessageType::ExecStdout.flags(), 0);
320 assert_eq!(MessageType::ExecStderr.flags(), 0);
321 assert_eq!(MessageType::ExecResize.flags(), 0);
322 assert_eq!(MessageType::ExecSignal.flags(), 0);
323 assert_eq!(MessageType::FsData.flags(), 0);
324 }
325
326 #[test]
327 fn test_message_new_computes_flags() {
328 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
329 assert_eq!(msg.flags, FLAG_SESSION_START);
330
331 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
332 assert_eq!(msg.flags, 0);
333 }
334}