1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 1;
13
14pub const FLAG_TERMINAL: u8 = 0b0000_0001;
18
19pub const FLAG_SESSION_START: u8 = 0b0000_0010;
23
24pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
29
30pub const FRAME_HEADER_SIZE: usize = 5;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub v: u8,
50
51 pub t: MessageType,
53
54 #[serde(skip)]
59 pub id: u32,
60
61 #[serde(skip)]
65 pub flags: u8,
66
67 #[serde(with = "serde_bytes")]
69 pub p: Vec<u8>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub enum MessageType {
75 Ready,
77
78 Shutdown,
80
81 RelayClientDisconnected,
83
84 ExecRequest,
86
87 ExecStarted,
89
90 ExecStdin,
92
93 ExecStdinError,
98
99 ExecStdout,
101
102 ExecStderr,
104
105 ExecExited,
107
108 ExecFailed,
112
113 ExecResize,
115
116 ExecSignal,
118
119 FsRequest,
121
122 FsResponse,
124
125 FsData,
127}
128
129impl Message {
134 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
136 let flags = t.flags();
137 Self {
138 v: PROTOCOL_VERSION,
139 t,
140 id,
141 flags,
142 p,
143 }
144 }
145
146 pub fn with_payload<T: Serialize>(
148 t: MessageType,
149 id: u32,
150 payload: &T,
151 ) -> ProtocolResult<Self> {
152 let mut p = Vec::new();
153 ciborium::into_writer(payload, &mut p)?;
154 let flags = t.flags();
155 Ok(Self {
156 v: PROTOCOL_VERSION,
157 t,
158 id,
159 flags,
160 p,
161 })
162 }
163
164 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
166 Ok(ciborium::from_reader(&self.p[..])?)
167 }
168}
169
170impl MessageType {
171 pub fn flags(&self) -> u8 {
173 match self {
174 Self::ExecExited | Self::ExecFailed | Self::FsResponse => FLAG_TERMINAL,
175 Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
176 Self::Shutdown => FLAG_SHUTDOWN,
177 _ => 0,
178 }
179 }
180
181 pub fn as_str(&self) -> &'static str {
183 match self {
184 Self::Ready => "core.ready",
185 Self::Shutdown => "core.shutdown",
186 Self::RelayClientDisconnected => "core.relay.client.disconnected",
187 Self::ExecRequest => "core.exec.request",
188 Self::ExecStarted => "core.exec.started",
189 Self::ExecStdin => "core.exec.stdin",
190 Self::ExecStdinError => "core.exec.stdin.error",
191 Self::ExecStdout => "core.exec.stdout",
192 Self::ExecStderr => "core.exec.stderr",
193 Self::ExecExited => "core.exec.exited",
194 Self::ExecFailed => "core.exec.failed",
195 Self::ExecResize => "core.exec.resize",
196 Self::ExecSignal => "core.exec.signal",
197 Self::FsRequest => "core.fs.request",
198 Self::FsResponse => "core.fs.response",
199 Self::FsData => "core.fs.data",
200 }
201 }
202
203 pub fn from_wire_str(s: &str) -> Option<Self> {
205 match s {
206 "core.ready" => Some(Self::Ready),
207 "core.shutdown" => Some(Self::Shutdown),
208 "core.relay.client.disconnected" => Some(Self::RelayClientDisconnected),
209 "core.exec.request" => Some(Self::ExecRequest),
210 "core.exec.started" => Some(Self::ExecStarted),
211 "core.exec.stdin" => Some(Self::ExecStdin),
212 "core.exec.stdin.error" => Some(Self::ExecStdinError),
213 "core.exec.stdout" => Some(Self::ExecStdout),
214 "core.exec.stderr" => Some(Self::ExecStderr),
215 "core.exec.exited" => Some(Self::ExecExited),
216 "core.exec.failed" => Some(Self::ExecFailed),
217 "core.exec.resize" => Some(Self::ExecResize),
218 "core.exec.signal" => Some(Self::ExecSignal),
219 "core.fs.request" => Some(Self::FsRequest),
220 "core.fs.response" => Some(Self::FsResponse),
221 "core.fs.data" => Some(Self::FsData),
222 _ => None,
223 }
224 }
225}
226
227impl Serialize for MessageType {
232 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
233 where
234 S: serde::Serializer,
235 {
236 serializer.serialize_str(self.as_str())
237 }
238}
239
240impl<'de> Deserialize<'de> for MessageType {
241 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
242 where
243 D: serde::Deserializer<'de>,
244 {
245 let s = String::deserialize(deserializer)?;
246 Self::from_wire_str(&s)
247 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
248 }
249}
250
251#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_message_type_roundtrip() {
261 let types = [
262 (MessageType::Ready, "core.ready"),
263 (MessageType::Shutdown, "core.shutdown"),
264 (MessageType::ExecRequest, "core.exec.request"),
265 (MessageType::ExecStarted, "core.exec.started"),
266 (MessageType::ExecStdin, "core.exec.stdin"),
267 (MessageType::ExecStdinError, "core.exec.stdin.error"),
268 (MessageType::ExecStdout, "core.exec.stdout"),
269 (MessageType::ExecStderr, "core.exec.stderr"),
270 (MessageType::ExecExited, "core.exec.exited"),
271 (MessageType::ExecFailed, "core.exec.failed"),
272 (MessageType::ExecResize, "core.exec.resize"),
273 (MessageType::ExecSignal, "core.exec.signal"),
274 (MessageType::FsRequest, "core.fs.request"),
275 (MessageType::FsResponse, "core.fs.response"),
276 (MessageType::FsData, "core.fs.data"),
277 ];
278
279 for (mt, expected_str) in &types {
280 assert_eq!(mt.as_str(), *expected_str);
281 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
282 }
283 }
284
285 #[test]
286 fn test_message_type_serde_roundtrip() {
287 let types = [
288 MessageType::Ready,
289 MessageType::Shutdown,
290 MessageType::ExecRequest,
291 MessageType::ExecStarted,
292 MessageType::ExecStdin,
293 MessageType::ExecStdinError,
294 MessageType::ExecStdout,
295 MessageType::ExecStderr,
296 MessageType::ExecExited,
297 MessageType::ExecFailed,
298 MessageType::ExecResize,
299 MessageType::ExecSignal,
300 MessageType::FsRequest,
301 MessageType::FsResponse,
302 MessageType::FsData,
303 ];
304
305 for mt in &types {
306 let mut buf = Vec::new();
307 ciborium::into_writer(mt, &mut buf).unwrap();
308 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
309 assert_eq!(&decoded, mt);
310 }
311 }
312
313 #[test]
314 fn test_unknown_message_type() {
315 assert!(MessageType::from_wire_str("core.unknown").is_none());
316 }
317
318 #[test]
319 fn test_message_with_payload_roundtrip() {
320 use crate::exec::ExecExited;
321
322 let msg =
323 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
324
325 assert_eq!(msg.t, MessageType::ExecExited);
326 assert_eq!(msg.id, 7);
327 assert_eq!(msg.flags, FLAG_TERMINAL);
328
329 let payload: ExecExited = msg.payload().unwrap();
330 assert_eq!(payload.code, 42);
331 }
332
333 #[test]
334 fn test_message_type_flags() {
335 assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
336 assert_eq!(MessageType::ExecFailed.flags(), FLAG_TERMINAL);
337 assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
338 assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
339 assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
340 assert_eq!(MessageType::Ready.flags(), 0);
341 assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
342 assert_eq!(MessageType::ExecStarted.flags(), 0);
343 assert_eq!(MessageType::ExecStdin.flags(), 0);
344 assert_eq!(MessageType::ExecStdout.flags(), 0);
345 assert_eq!(MessageType::ExecStderr.flags(), 0);
346 assert_eq!(MessageType::ExecResize.flags(), 0);
347 assert_eq!(MessageType::ExecSignal.flags(), 0);
348 assert_eq!(MessageType::FsData.flags(), 0);
349 }
350
351 #[test]
352 fn test_message_new_computes_flags() {
353 let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
354 assert_eq!(msg.flags, FLAG_SESSION_START);
355
356 let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
357 assert_eq!(msg.flags, 0);
358 }
359}