1use serde::{Deserialize, Serialize, de::DeserializeOwned};
4
5use crate::error::ProtocolResult;
6
7pub const PROTOCOL_VERSION: u8 = 1;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Message {
23 pub v: u8,
25
26 pub t: MessageType,
28
29 pub id: u32,
32
33 #[serde(with = "serde_bytes")]
35 pub p: Vec<u8>,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum MessageType {
41 Ready,
43
44 Shutdown,
46
47 ExecRequest,
49
50 ExecStarted,
52
53 ExecStdin,
55
56 ExecStdout,
58
59 ExecStderr,
61
62 ExecExited,
64
65 ExecResize,
67
68 ExecSignal,
70
71 FsRequest,
73
74 FsResponse,
76
77 FsData,
79}
80
81impl Message {
86 pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
88 Self {
89 v: PROTOCOL_VERSION,
90 t,
91 id,
92 p,
93 }
94 }
95
96 pub fn with_payload<T: Serialize>(
98 t: MessageType,
99 id: u32,
100 payload: &T,
101 ) -> ProtocolResult<Self> {
102 let mut p = Vec::new();
103 ciborium::into_writer(payload, &mut p)?;
104 Ok(Self {
105 v: PROTOCOL_VERSION,
106 t,
107 id,
108 p,
109 })
110 }
111
112 pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
114 Ok(ciborium::from_reader(&self.p[..])?)
115 }
116}
117
118impl MessageType {
119 pub fn as_str(&self) -> &'static str {
121 match self {
122 Self::Ready => "core.ready",
123 Self::Shutdown => "core.shutdown",
124 Self::ExecRequest => "core.exec.request",
125 Self::ExecStarted => "core.exec.started",
126 Self::ExecStdin => "core.exec.stdin",
127 Self::ExecStdout => "core.exec.stdout",
128 Self::ExecStderr => "core.exec.stderr",
129 Self::ExecExited => "core.exec.exited",
130 Self::ExecResize => "core.exec.resize",
131 Self::ExecSignal => "core.exec.signal",
132 Self::FsRequest => "core.fs.request",
133 Self::FsResponse => "core.fs.response",
134 Self::FsData => "core.fs.data",
135 }
136 }
137
138 pub fn from_wire_str(s: &str) -> Option<Self> {
140 match s {
141 "core.ready" => Some(Self::Ready),
142 "core.shutdown" => Some(Self::Shutdown),
143 "core.exec.request" => Some(Self::ExecRequest),
144 "core.exec.started" => Some(Self::ExecStarted),
145 "core.exec.stdin" => Some(Self::ExecStdin),
146 "core.exec.stdout" => Some(Self::ExecStdout),
147 "core.exec.stderr" => Some(Self::ExecStderr),
148 "core.exec.exited" => Some(Self::ExecExited),
149 "core.exec.resize" => Some(Self::ExecResize),
150 "core.exec.signal" => Some(Self::ExecSignal),
151 "core.fs.request" => Some(Self::FsRequest),
152 "core.fs.response" => Some(Self::FsResponse),
153 "core.fs.data" => Some(Self::FsData),
154 _ => None,
155 }
156 }
157}
158
159impl Serialize for MessageType {
164 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
165 where
166 S: serde::Serializer,
167 {
168 serializer.serialize_str(self.as_str())
169 }
170}
171
172impl<'de> Deserialize<'de> for MessageType {
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: serde::Deserializer<'de>,
176 {
177 let s = String::deserialize(deserializer)?;
178 Self::from_wire_str(&s)
179 .ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
180 }
181}
182
183#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_message_type_roundtrip() {
193 let types = [
194 (MessageType::Ready, "core.ready"),
195 (MessageType::Shutdown, "core.shutdown"),
196 (MessageType::ExecRequest, "core.exec.request"),
197 (MessageType::ExecStarted, "core.exec.started"),
198 (MessageType::ExecStdin, "core.exec.stdin"),
199 (MessageType::ExecStdout, "core.exec.stdout"),
200 (MessageType::ExecStderr, "core.exec.stderr"),
201 (MessageType::ExecExited, "core.exec.exited"),
202 (MessageType::ExecResize, "core.exec.resize"),
203 (MessageType::ExecSignal, "core.exec.signal"),
204 (MessageType::FsRequest, "core.fs.request"),
205 (MessageType::FsResponse, "core.fs.response"),
206 (MessageType::FsData, "core.fs.data"),
207 ];
208
209 for (mt, expected_str) in &types {
210 assert_eq!(mt.as_str(), *expected_str);
211 assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
212 }
213 }
214
215 #[test]
216 fn test_message_type_serde_roundtrip() {
217 let types = [
218 MessageType::Ready,
219 MessageType::Shutdown,
220 MessageType::ExecRequest,
221 MessageType::ExecStarted,
222 MessageType::ExecStdin,
223 MessageType::ExecStdout,
224 MessageType::ExecStderr,
225 MessageType::ExecExited,
226 MessageType::ExecResize,
227 MessageType::ExecSignal,
228 MessageType::FsRequest,
229 MessageType::FsResponse,
230 MessageType::FsData,
231 ];
232
233 for mt in &types {
234 let mut buf = Vec::new();
235 ciborium::into_writer(mt, &mut buf).unwrap();
236 let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
237 assert_eq!(&decoded, mt);
238 }
239 }
240
241 #[test]
242 fn test_unknown_message_type() {
243 assert!(MessageType::from_wire_str("core.unknown").is_none());
244 }
245
246 #[test]
247 fn test_message_with_payload_roundtrip() {
248 use crate::exec::ExecExited;
249
250 let msg =
251 Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
252
253 assert_eq!(msg.t, MessageType::ExecExited);
254 assert_eq!(msg.id, 7);
255
256 let payload: ExecExited = msg.payload().unwrap();
257 assert_eq!(payload.code, 42);
258 }
259}