1use std::io;
17
18use byteorder::{BigEndian, ByteOrder};
19use serde::{Deserialize, Serialize};
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21
22use super::{DaemonError, Result};
23
24pub const PROTOCOL_VERSION: u32 = 1;
26
27pub const MAX_FRAME_BYTES: usize = 64 * 1024 * 1024;
29#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
31pub struct ProtocolVersion(pub u32);
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct Hello {
36 pub version: u32,
38 pub client_pid: i32,
40 pub tty: Option<String>,
42 pub cwd: Option<String>,
44 pub argv0: Option<String>,
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct Welcome {
51 pub version: u32,
53 pub client_id: u64,
55 pub session_id: String,
57 pub daemon_pid: i32,
59 pub daemon_uptime_ms: u64,
61}
62
63#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct ErrPayload {
66 pub code: String,
68 pub msg: String,
70}
71
72impl ErrPayload {
73 pub fn new<C: Into<String>, M: Into<String>>(code: C, msg: M) -> Self {
75 Self {
76 code: code.into(),
77 msg: msg.into(),
78 }
79 }
80}
81
82impl From<rusqlite::Error> for ErrPayload {
83 fn from(e: rusqlite::Error) -> Self {
84 Self::new("sqlite", e.to_string())
85 }
86}
87
88impl From<std::io::Error> for ErrPayload {
89 fn from(e: std::io::Error) -> Self {
90 Self::new("io", e.to_string())
91 }
92}
93
94impl From<super::DaemonError> for ErrPayload {
95 fn from(e: super::DaemonError) -> Self {
96 Self::new("daemon", e.to_string())
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
102#[serde(untagged)]
103pub enum Frame {
104 Hello { hello: Hello },
106 Welcome { welcome: Welcome },
108 WelcomeErr {
110 welcome: serde_json::Value,
111 err: ErrPayload,
112 },
113 Request {
115 id: u64,
116 op: String,
117 #[serde(default)]
118 args: serde_json::Value,
119 },
120 Response {
122 id: u64,
123 ok: bool,
124 #[serde(flatten)]
125 payload: serde_json::Value,
126 },
127 Event {
129 event: String,
130 #[serde(flatten)]
131 payload: serde_json::Value,
132 },
133}
134
135impl Frame {
136 pub fn hello(h: Hello) -> Self {
138 Frame::Hello { hello: h }
139 }
140 pub fn welcome(w: Welcome) -> Self {
142 Frame::Welcome { welcome: w }
143 }
144 pub fn request(id: u64, op: impl Into<String>, args: serde_json::Value) -> Self {
146 Frame::Request {
147 id,
148 op: op.into(),
149 args,
150 }
151 }
152 pub fn ok_response(id: u64, payload: serde_json::Value) -> Self {
154 Frame::Response {
155 id,
156 ok: true,
157 payload,
158 }
159 }
160 pub fn err_response(id: u64, err: ErrPayload) -> Self {
162 let payload = serde_json::json!({ "err": err });
163 Frame::Response {
164 id,
165 ok: false,
166 payload,
167 }
168 }
169 pub fn event(name: impl Into<String>, payload: serde_json::Value) -> Self {
171 Frame::Event {
172 event: name.into(),
173 payload,
174 }
175 }
176}
177
178#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
180#[serde(rename_all = "snake_case")]
181pub enum Event {
182 ShardUpdated,
184 RebuildComplete,
186 CanonicalChanged,
188 Match,
190 CmdExecute,
192 Notify,
194 DaemonShutdown,
196 AskPending,
198 AskDismissed,
200 AskProgress,
202 LongCmdComplete,
204 LongCmdStarted,
206 LongCmdFailed,
208 LongCmdSignaled,
210}
211
212pub async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<Frame> {
216 let mut len_buf = [0u8; 4];
217 reader.read_exact(&mut len_buf).await?;
218 let len = BigEndian::read_u32(&len_buf) as usize;
219 if len == 0 {
220 return Err(DaemonError::other("zero-length frame"));
221 }
222 if len > MAX_FRAME_BYTES {
223 return Err(DaemonError::FrameTooLarge {
224 size: len,
225 max: MAX_FRAME_BYTES,
226 });
227 }
228
229 let mut buf = vec![0u8; len];
230 reader.read_exact(&mut buf).await?;
231
232 let frame: Frame = serde_json::from_slice(&buf)?;
233 Ok(frame)
234}
235
236pub async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, frame: &Frame) -> Result<()> {
238 let body = serde_json::to_vec(frame)?;
239 if body.len() > MAX_FRAME_BYTES {
240 return Err(DaemonError::FrameTooLarge {
241 size: body.len(),
242 max: MAX_FRAME_BYTES,
243 });
244 }
245 let mut header = [0u8; 4];
246 BigEndian::write_u32(&mut header, body.len() as u32);
247 writer.write_all(&header).await?;
248 writer.write_all(&body).await?;
249 writer.flush().await?;
250 Ok(())
251}
252
253pub fn read_frame_sync<R: io::Read>(reader: &mut R) -> Result<Frame> {
255 let mut len_buf = [0u8; 4];
256 reader.read_exact(&mut len_buf)?;
257 let len = BigEndian::read_u32(&len_buf) as usize;
258 if len == 0 {
259 return Err(DaemonError::other("zero-length frame"));
260 }
261 if len > MAX_FRAME_BYTES {
262 return Err(DaemonError::FrameTooLarge {
263 size: len,
264 max: MAX_FRAME_BYTES,
265 });
266 }
267
268 let mut buf = vec![0u8; len];
269 reader.read_exact(&mut buf)?;
270
271 let frame: Frame = serde_json::from_slice(&buf)?;
272 Ok(frame)
273}
274
275pub fn write_frame_sync<W: io::Write>(writer: &mut W, frame: &Frame) -> Result<()> {
277 let body = serde_json::to_vec(frame)?;
278 if body.len() > MAX_FRAME_BYTES {
279 return Err(DaemonError::FrameTooLarge {
280 size: body.len(),
281 max: MAX_FRAME_BYTES,
282 });
283 }
284 let mut header = [0u8; 4];
285 BigEndian::write_u32(&mut header, body.len() as u32);
286 writer.write_all(&header)?;
287 writer.write_all(&body)?;
288 writer.flush()?;
289 Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use std::io::Cursor;
296
297 #[test]
298 fn roundtrip_hello_sync() {
299 let h = Hello {
300 version: PROTOCOL_VERSION,
301 client_pid: 12345,
302 tty: Some("/dev/ttys003".into()),
303 cwd: Some("/home/wizard".into()),
304 argv0: Some("zshrs".into()),
305 };
306 let frame = Frame::hello(h);
307
308 let mut buf = Vec::new();
309 write_frame_sync(&mut buf, &frame).unwrap();
310
311 let mut cur = Cursor::new(buf);
312 let read = read_frame_sync(&mut cur).unwrap();
313
314 match read {
315 Frame::Hello { hello } => {
316 assert_eq!(hello.version, PROTOCOL_VERSION);
317 assert_eq!(hello.client_pid, 12345);
318 assert_eq!(hello.tty.as_deref(), Some("/dev/ttys003"));
319 }
320 _ => panic!("expected Hello, got {:?}", read),
321 }
322 }
323
324 #[test]
325 fn roundtrip_request_sync() {
326 let frame = Frame::request(42, "ping", serde_json::json!({}));
327 let mut buf = Vec::new();
328 write_frame_sync(&mut buf, &frame).unwrap();
329
330 let mut cur = Cursor::new(buf);
331 let read = read_frame_sync(&mut cur).unwrap();
332
333 match read {
334 Frame::Request { id, op, args } => {
335 assert_eq!(id, 42);
336 assert_eq!(op, "ping");
337 assert!(args.is_object());
338 }
339 _ => panic!("expected Request, got {:?}", read),
340 }
341 }
342
343 #[test]
344 fn roundtrip_event_sync() {
345 let frame = Frame::event(
346 "shard_updated",
347 serde_json::json!({"shard":"foo","generation":3}),
348 );
349 let mut buf = Vec::new();
350 write_frame_sync(&mut buf, &frame).unwrap();
351
352 let mut cur = Cursor::new(buf);
353 let read = read_frame_sync(&mut cur).unwrap();
354
355 match read {
356 Frame::Event { event, payload } => {
357 assert_eq!(event, "shard_updated");
358 assert_eq!(payload["shard"], "foo");
359 assert_eq!(payload["generation"], 3);
360 }
361 _ => panic!("expected Event, got {:?}", read),
362 }
363 }
364
365 #[test]
366 fn frame_too_large_rejected_on_write() {
367 let big = "x".repeat(MAX_FRAME_BYTES + 1);
368 let frame = Frame::request(1, "ping", serde_json::json!({"big": big}));
369 let mut buf = Vec::new();
370 let err = write_frame_sync(&mut buf, &frame).unwrap_err();
371 matches!(err, DaemonError::FrameTooLarge { .. });
372 }
373
374 #[test]
375 fn frame_too_large_rejected_on_read() {
376 let mut buf = Vec::new();
377 let bogus_len = (MAX_FRAME_BYTES + 1) as u32;
378 let mut hdr = [0u8; 4];
379 BigEndian::write_u32(&mut hdr, bogus_len);
380 buf.extend_from_slice(&hdr);
381 let mut cur = Cursor::new(buf);
382 let err = read_frame_sync(&mut cur).unwrap_err();
383 matches!(err, DaemonError::FrameTooLarge { .. });
384 }
385
386 #[tokio::test]
387 async fn roundtrip_async() {
388 let frame = Frame::request(7, "info", serde_json::json!({}));
389 let (mut a, mut b) = tokio::io::duplex(64 * 1024);
390 let writer_frame = frame.clone();
391 tokio::spawn(async move {
392 write_frame(&mut a, &writer_frame).await.unwrap();
393 });
394 let read = read_frame(&mut b).await.unwrap();
395 match read {
396 Frame::Request { id, op, .. } => {
397 assert_eq!(id, 7);
398 assert_eq!(op, "info");
399 }
400 _ => panic!("expected Request"),
401 }
402 }
403}