Skip to main content

moire_wire/
lib.rs

1use facet::Facet;
2pub use moire_trace_types::{
3    BacktraceRecord, FrameKey as BacktraceFrameKey, ModuleId, RelPc, RuntimeBase,
4};
5use moire_types::{CutAck, CutRequest, ProcessId, PullChangesResponse, Snapshot};
6use std::fmt;
7
8pub const DEFAULT_MAX_FRAME_BYTES: usize = 128 * 1024 * 1024;
9pub const PROTOCOL_MAGIC: u32 = 0x4D4F4952;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum FrameCodecError {
13    PayloadTooLarge { len: usize, max: usize },
14    FrameTooShort { len: usize },
15    FrameTooLarge { len: usize, max: usize },
16    FrameTruncated { expected: usize, actual: usize },
17}
18
19impl fmt::Display for FrameCodecError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            Self::PayloadTooLarge { len, max } => {
23                write!(f, "payload too large: {len} > {max}")
24            }
25            Self::FrameTooShort { len } => write!(f, "frame too short: {len}"),
26            Self::FrameTooLarge { len, max } => write!(f, "frame too large: {len} > {max}"),
27            Self::FrameTruncated { expected, actual } => {
28                write!(
29                    f,
30                    "truncated frame payload: expected {expected}, got {actual}"
31                )
32            }
33        }
34    }
35}
36
37impl std::error::Error for FrameCodecError {}
38
39#[derive(Debug)]
40pub enum WireError {
41    Frame(FrameCodecError),
42    Json(String),
43    MagicMismatch { expected: u32, actual: u32 },
44}
45
46impl fmt::Display for WireError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            Self::Frame(err) => write!(f, "{err}"),
50            Self::Json(err) => write!(f, "{err}"),
51            Self::MagicMismatch { expected, actual } => {
52                write!(
53                    f,
54                    "protocol magic mismatch: expected 0x{expected:08x}, got 0x{actual:08x}"
55                )
56            }
57        }
58    }
59}
60
61impl std::error::Error for WireError {}
62
63impl From<FrameCodecError> for WireError {
64    fn from(value: FrameCodecError) -> Self {
65        Self::Frame(value)
66    }
67}
68
69// r[impl wire.framing]
70pub fn encode_frame(payload: &[u8], max_payload_bytes: usize) -> Result<Vec<u8>, FrameCodecError> {
71    if payload.len() > max_payload_bytes {
72        return Err(FrameCodecError::PayloadTooLarge {
73            len: payload.len(),
74            max: max_payload_bytes,
75        });
76    }
77
78    let payload_len =
79        u32::try_from(payload.len()).map_err(|_| FrameCodecError::PayloadTooLarge {
80            len: payload.len(),
81            max: u32::MAX as usize,
82        })?;
83
84    let mut out = Vec::with_capacity(4 + payload.len());
85    out.extend_from_slice(&payload_len.to_be_bytes());
86    out.extend_from_slice(payload);
87    Ok(out)
88}
89
90pub fn encode_frame_default(payload: &[u8]) -> Result<Vec<u8>, FrameCodecError> {
91    encode_frame(payload, DEFAULT_MAX_FRAME_BYTES)
92}
93
94pub fn decode_frame(frame: &[u8], max_payload_bytes: usize) -> Result<&[u8], FrameCodecError> {
95    if frame.len() < 4 {
96        return Err(FrameCodecError::FrameTooShort { len: frame.len() });
97    }
98
99    let mut prefix = [0u8; 4];
100    prefix.copy_from_slice(&frame[..4]);
101    let payload_len = u32::from_be_bytes(prefix) as usize;
102    if payload_len > max_payload_bytes {
103        return Err(FrameCodecError::FrameTooLarge {
104            len: payload_len,
105            max: max_payload_bytes,
106        });
107    }
108
109    let actual_payload_len = frame.len() - 4;
110    if actual_payload_len != payload_len {
111        return Err(FrameCodecError::FrameTruncated {
112            expected: payload_len,
113            actual: actual_payload_len,
114        });
115    }
116
117    Ok(&frame[4..])
118}
119
120pub fn decode_frame_default(frame: &[u8]) -> Result<&[u8], FrameCodecError> {
121    decode_frame(frame, DEFAULT_MAX_FRAME_BYTES)
122}
123
124#[derive(Facet, Clone)]
125#[repr(u8)]
126#[facet(rename_all = "snake_case")]
127pub enum ModuleIdentity {
128    BuildId(String),
129    DebugId(String),
130}
131
132#[derive(Facet, Clone)]
133// r[impl wire.handshake.module-manifest]
134pub struct ModuleManifestEntry {
135    pub module_id: ModuleId,
136    pub module_path: String,
137    pub runtime_base: RuntimeBase,
138    pub identity: ModuleIdentity,
139    pub arch: String,
140}
141
142#[derive(Facet)]
143pub struct Handshake {
144    pub process_id: ProcessId,
145    pub process_name: String,
146    pub pid: u32,
147    pub args: Vec<String>,
148    pub env: Vec<String>,
149    pub module_manifest: Vec<ModuleManifestEntry>,
150}
151
152// r[impl wire.magic]
153pub fn encode_protocol_magic() -> [u8; 4] {
154    PROTOCOL_MAGIC.to_be_bytes()
155}
156
157pub fn decode_protocol_magic(bytes: [u8; 4]) -> Result<(), WireError> {
158    let actual = u32::from_be_bytes(bytes);
159    if actual != PROTOCOL_MAGIC {
160        return Err(WireError::MagicMismatch {
161            expected: PROTOCOL_MAGIC,
162            actual,
163        });
164    }
165    Ok(())
166}
167
168#[derive(Facet)]
169pub struct SnapshotRequest {
170    pub snapshot_id: i64,
171    pub timeout_ms: i64,
172}
173
174#[derive(Facet)]
175pub struct SnapshotReply {
176    pub snapshot_id: i64,
177    /// Process-relative milliseconds at the moment the process assembled this snapshot.
178    pub ptime_now_ms: u64,
179    #[facet(skip_unless_truthy)]
180    pub snapshot: Option<Snapshot>,
181}
182
183#[derive(Facet)]
184pub struct ClientError {
185    pub process_name: String,
186    pub pid: u32,
187    pub stage: String,
188    pub error: String,
189    #[facet(skip_unless_truthy)]
190    pub last_frame_utf8: Option<String>,
191}
192
193// r[impl wire.client-message]
194#[derive(Facet)]
195#[repr(u8)]
196#[facet(rename_all = "snake_case")]
197pub enum ClientMessage {
198    Handshake(Handshake),
199    // r[impl wire.backtrace-record]
200    BacktraceRecord(BacktraceRecord),
201    SnapshotReply(SnapshotReply),
202    DeltaBatch(PullChangesResponse),
203    CutAck(CutAck),
204    Error(ClientError),
205}
206
207// r[impl wire.server-message]
208#[derive(Facet)]
209#[repr(u8)]
210#[facet(rename_all = "snake_case")]
211pub enum ServerMessage {
212    SnapshotRequest(SnapshotRequest),
213    CutRequest(CutRequest),
214}
215
216pub fn encode_client_message(
217    message: &ClientMessage,
218    max_payload_bytes: usize,
219) -> Result<Vec<u8>, WireError> {
220    let payload = facet_json::to_vec(message).map_err(|e| WireError::Json(e.to_string()))?;
221    Ok(encode_frame(&payload, max_payload_bytes)?)
222}
223
224pub fn encode_client_message_default(message: &ClientMessage) -> Result<Vec<u8>, WireError> {
225    encode_client_message(message, DEFAULT_MAX_FRAME_BYTES)
226}
227
228pub fn decode_client_message(
229    frame: &[u8],
230    max_payload_bytes: usize,
231) -> Result<ClientMessage, WireError> {
232    let payload = decode_frame(frame, max_payload_bytes)?;
233    facet_json::from_slice(payload).map_err(|e| WireError::Json(e.to_string()))
234}
235
236pub fn decode_client_message_default(frame: &[u8]) -> Result<ClientMessage, WireError> {
237    decode_client_message(frame, DEFAULT_MAX_FRAME_BYTES)
238}
239
240pub fn encode_server_message(
241    message: &ServerMessage,
242    max_payload_bytes: usize,
243) -> Result<Vec<u8>, WireError> {
244    let payload = facet_json::to_vec(message).map_err(|e| WireError::Json(e.to_string()))?;
245    Ok(encode_frame(&payload, max_payload_bytes)?)
246}
247
248pub fn encode_server_message_default(message: &ServerMessage) -> Result<Vec<u8>, WireError> {
249    encode_server_message(message, DEFAULT_MAX_FRAME_BYTES)
250}
251
252pub fn decode_server_message(
253    frame: &[u8],
254    max_payload_bytes: usize,
255) -> Result<ServerMessage, WireError> {
256    let payload = decode_frame(frame, max_payload_bytes)?;
257    facet_json::from_slice(payload).map_err(|e| WireError::Json(e.to_string()))
258}
259
260pub fn decode_server_message_default(frame: &[u8]) -> Result<ServerMessage, WireError> {
261    decode_server_message(frame, DEFAULT_MAX_FRAME_BYTES)
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use moire_trace_types::{BacktraceId, ModuleId};
268    use moire_types::{CutId, ProcessId, SeqNo, Snapshot, StreamCursor, StreamId};
269
270    fn client_payload_json(message: &ClientMessage) -> String {
271        let frame = encode_client_message_default(message).expect("client frame should encode");
272        let payload = decode_frame_default(&frame).expect("frame should decode");
273        std::str::from_utf8(payload)
274            .expect("payload should be utf8 json")
275            .to_string()
276    }
277
278    fn server_payload_json(message: &ServerMessage) -> String {
279        let frame = encode_server_message_default(message).expect("server frame should encode");
280        let payload = decode_frame_default(&frame).expect("frame should decode");
281        std::str::from_utf8(payload)
282            .expect("payload should be utf8 json")
283            .to_string()
284    }
285
286    #[test]
287    fn client_handshake_wire_shape() {
288        let module_id = ModuleId::next().expect("valid module id");
289        let json = client_payload_json(&ClientMessage::Handshake(Handshake {
290            process_id: ProcessId::new("0011223344556677"),
291            process_name: "vixenfs-swift".into(),
292            pid: 42,
293            args: vec!["/usr/bin/vixenfs-swift".into(), "--verbose".into()],
294            env: vec!["RUST_LOG=debug".into(), "HOME=/Users/dev".into()],
295            module_manifest: vec![ModuleManifestEntry {
296                module_id,
297                module_path: "/usr/lib/libvixenfs_swift.dylib".into(),
298                runtime_base: RuntimeBase::new(4_294_967_296).expect("valid runtime_base"),
299                identity: ModuleIdentity::DebugId("debugid:def456".into()),
300                arch: "aarch64".into(),
301            }],
302        }));
303        assert!(
304            json.contains(
305                r#""handshake":{"process_id":"0011223344556677","process_name":"vixenfs-swift","pid":42"#
306            )
307        );
308        assert!(json.contains(r#""module_id":"#));
309        assert!(json.contains(r#""module_path":"/usr/lib/libvixenfs_swift.dylib""#));
310        assert!(json.contains(r#""runtime_base":4294967296"#));
311    }
312
313    #[test]
314    fn protocol_magic_roundtrip() {
315        let bytes = encode_protocol_magic();
316        decode_protocol_magic(bytes).expect("protocol magic should decode");
317    }
318
319    #[test]
320    fn client_snapshot_reply_wire_shape() {
321        let json = client_payload_json(&ClientMessage::SnapshotReply(SnapshotReply {
322            snapshot_id: 7,
323            ptime_now_ms: 1234,
324            snapshot: Some(Snapshot {
325                entities: vec![],
326                scopes: vec![],
327                edges: vec![],
328                events: vec![],
329            }),
330        }));
331        assert_eq!(
332            json,
333            r#"{"snapshot_reply":{"snapshot_id":7,"ptime_now_ms":1234,"snapshot":{"entities":[],"scopes":[],"edges":[],"events":[]}}}"#
334        );
335    }
336
337    #[test]
338    fn client_backtrace_record_wire_shape() {
339        let backtrace_id = BacktraceId::next().expect("valid backtrace id");
340        let module_a = ModuleId::next().expect("valid module id");
341        let module_b = ModuleId::next().expect("valid module id");
342        let json = client_payload_json(&ClientMessage::BacktraceRecord(BacktraceRecord {
343            id: backtrace_id,
344            frames: vec![
345                BacktraceFrameKey {
346                    module_id: module_a,
347                    rel_pc: RelPc::new(4096).expect("valid rel_pc"),
348                },
349                BacktraceFrameKey {
350                    module_id: module_b,
351                    rel_pc: RelPc::new(8192).expect("valid rel_pc"),
352                },
353            ],
354        }));
355        assert!(json.contains(r#""backtrace_record":{"id":"#));
356        assert!(json.contains(r#""rel_pc":4096"#));
357        assert!(json.contains(r#""rel_pc":8192"#));
358    }
359
360    #[test]
361    fn client_cut_ack_wire_shape() {
362        let json = client_payload_json(&ClientMessage::CutAck(moire_types::CutAck {
363            cut_id: CutId::new("cut-1"),
364            cursor: StreamCursor {
365                stream_id: StreamId("vixenfs-swift-42".into()),
366                next_seq_no: SeqNo(0),
367            },
368        }));
369        assert_eq!(
370            json,
371            r#"{"cut_ack":{"cut_id":"cut-1","cursor":{"stream_id":"vixenfs-swift-42","next_seq_no":0}}}"#
372        );
373    }
374
375    #[test]
376    fn server_snapshot_request_wire_shape() {
377        let json = server_payload_json(&ServerMessage::SnapshotRequest(SnapshotRequest {
378            snapshot_id: 7,
379            timeout_ms: 5000,
380        }));
381        assert_eq!(
382            json,
383            r#"{"snapshot_request":{"snapshot_id":7,"timeout_ms":5000}}"#
384        );
385    }
386
387    #[test]
388    fn server_cut_request_wire_shape() {
389        let json = server_payload_json(&ServerMessage::CutRequest(moire_types::CutRequest {
390            cut_id: CutId::new("cut-1"),
391        }));
392        assert_eq!(json, r#"{"cut_request":{"cut_id":"cut-1"}}"#);
393    }
394}