Skip to main content

muxtop_proto/
wire.rs

1use bincode::{config, decode_from_slice, encode_to_vec};
2use serde::{Deserialize, Serialize};
3
4use muxtop_core::system::SystemSnapshot;
5
6use crate::ProtoError;
7use crate::frame::{
8    Frame, MAX_FRAME_SIZE, MSG_ERROR, MSG_HEARTBEAT, MSG_HELLO, MSG_SNAPSHOT, MSG_WELCOME,
9};
10
11/// Wire protocol messages exchanged between muxtop client and server.
12///
13/// Uses a custom `Debug` impl to redact `auth_token` in `Hello` messages,
14/// preventing accidental token leakage in logs or panic messages.
15#[derive(Clone, PartialEq, Serialize, Deserialize)]
16pub enum WireMessage {
17    /// Full system snapshot (server → client).
18    Snapshot(SystemSnapshot),
19
20    /// Keepalive heartbeat (server → client).
21    Heartbeat {
22        server_version: String,
23        uptime_secs: u64,
24    },
25
26    /// Error message (server → client).
27    Error { code: u16, message: String },
28
29    /// Client handshake (client → server).
30    Hello {
31        client_version: String,
32        auth_token: Option<String>,
33    },
34
35    /// Server handshake response (server → client).
36    Welcome {
37        server_version: String,
38        hostname: String,
39        refresh_hz: u32,
40    },
41}
42
43impl std::fmt::Debug for WireMessage {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            WireMessage::Snapshot(s) => f.debug_tuple("Snapshot").field(s).finish(),
47            WireMessage::Heartbeat {
48                server_version,
49                uptime_secs,
50            } => f
51                .debug_struct("Heartbeat")
52                .field("server_version", server_version)
53                .field("uptime_secs", uptime_secs)
54                .finish(),
55            WireMessage::Error { code, message } => f
56                .debug_struct("Error")
57                .field("code", code)
58                .field("message", message)
59                .finish(),
60            WireMessage::Hello {
61                client_version,
62                auth_token,
63            } => f
64                .debug_struct("Hello")
65                .field("client_version", client_version)
66                .field("auth_token", &auth_token.as_ref().map(|_| "[REDACTED]"))
67                .finish(),
68            WireMessage::Welcome {
69                server_version,
70                hostname,
71                refresh_hz,
72            } => f
73                .debug_struct("Welcome")
74                .field("server_version", server_version)
75                .field("hostname", hostname)
76                .field("refresh_hz", refresh_hz)
77                .finish(),
78        }
79    }
80}
81
82/// Bincode configuration shared by encode/decode paths.
83///
84/// `with_limit::<MAX_DECODE_BYTES>()` caps the total bytes the decoder is
85/// willing to allocate while reading a single value, regardless of the
86/// var-int length-prefixes embedded inside the payload (MED-S1, proto-side).
87/// Without this cap a malicious peer can claim a huge collection or `String`
88/// length and force the decoder to pre-allocate hundreds of MiB before the
89/// underlying buffer is exhausted.
90const MAX_DECODE_BYTES: usize = MAX_FRAME_SIZE as usize;
91
92fn bincode_config() -> impl bincode::config::Config {
93    config::standard().with_limit::<MAX_DECODE_BYTES>()
94}
95
96impl WireMessage {
97    /// Encode a `SystemSnapshot` into a `Snapshot` frame **without taking
98    /// ownership** of the snapshot.
99    ///
100    /// Per ADR-30-4 the relay path holds an `Arc<SystemSnapshot>` and needs to
101    /// produce a single encoded `Frame` that it can then broadcast as bytes to
102    /// every client task. Calling `to_frame()` would require constructing a
103    /// `WireMessage::Snapshot` variant that owns the `SystemSnapshot`, which
104    /// defeats the whole point of the `Arc`. This helper bypasses the wrapper
105    /// and encodes directly from a borrow.
106    pub fn encode_snapshot_ref(snap: &SystemSnapshot) -> Result<Frame, ProtoError> {
107        Ok(Frame {
108            msg_type: MSG_SNAPSHOT,
109            payload: encode_to_vec(snap, bincode_config())?,
110        })
111    }
112
113    /// Serialize this message into a [`Frame`].
114    pub fn to_frame(&self) -> Result<Frame, ProtoError> {
115        let (msg_type, payload) = match self {
116            WireMessage::Snapshot(snap) => (MSG_SNAPSHOT, encode_to_vec(snap, bincode_config())?),
117            WireMessage::Heartbeat {
118                server_version,
119                uptime_secs,
120            } => (
121                MSG_HEARTBEAT,
122                encode_to_vec((server_version, uptime_secs), bincode_config())?,
123            ),
124            WireMessage::Error { code, message } => {
125                (MSG_ERROR, encode_to_vec((code, message), bincode_config())?)
126            }
127            WireMessage::Hello {
128                client_version,
129                auth_token,
130            } => (
131                MSG_HELLO,
132                encode_to_vec((client_version, auth_token), bincode_config())?,
133            ),
134            WireMessage::Welcome {
135                server_version,
136                hostname,
137                refresh_hz,
138            } => (
139                MSG_WELCOME,
140                encode_to_vec((server_version, hostname, refresh_hz), bincode_config())?,
141            ),
142        };
143
144        Ok(Frame { msg_type, payload })
145    }
146
147    /// Deserialize a [`Frame`] into a `WireMessage`.
148    pub fn from_frame(frame: &Frame) -> Result<Self, ProtoError> {
149        match frame.msg_type {
150            MSG_SNAPSHOT => {
151                let (snap, _): (SystemSnapshot, _) =
152                    decode_from_slice(&frame.payload, bincode_config())?;
153                Ok(WireMessage::Snapshot(snap))
154            }
155            MSG_HEARTBEAT => {
156                let ((server_version, uptime_secs), _): ((String, u64), _) =
157                    decode_from_slice(&frame.payload, bincode_config())?;
158                Ok(WireMessage::Heartbeat {
159                    server_version,
160                    uptime_secs,
161                })
162            }
163            MSG_ERROR => {
164                let ((code, message), _): ((u16, String), _) =
165                    decode_from_slice(&frame.payload, bincode_config())?;
166                Ok(WireMessage::Error { code, message })
167            }
168            MSG_HELLO => {
169                let ((client_version, auth_token), _): ((String, Option<String>), _) =
170                    decode_from_slice(&frame.payload, bincode_config())?;
171                Ok(WireMessage::Hello {
172                    client_version,
173                    auth_token,
174                })
175            }
176            MSG_WELCOME => {
177                let ((server_version, hostname, refresh_hz), _): ((String, String, u32), _) =
178                    decode_from_slice(&frame.payload, bincode_config())?;
179                Ok(WireMessage::Welcome {
180                    server_version,
181                    hostname,
182                    refresh_hz,
183                })
184            }
185            other => Err(ProtoError::UnknownMessageType(other)),
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use muxtop_core::network::{NetworkInterfaceSnapshot, NetworkSnapshot};
194    use muxtop_core::process::ProcessInfo;
195    use muxtop_core::system::{CoreSnapshot, CpuSnapshot, LoadSnapshot, MemorySnapshot};
196
197    fn make_test_snapshot() -> SystemSnapshot {
198        SystemSnapshot {
199            cpu: CpuSnapshot {
200                global_usage: 45.2,
201                cores: vec![CoreSnapshot {
202                    name: "cpu0".into(),
203                    usage: 45.2,
204                    frequency: 3600,
205                }],
206            },
207            memory: MemorySnapshot {
208                total: 16_000_000_000,
209                used: 8_000_000_000,
210                available: 8_000_000_000,
211                swap_total: 4_000_000_000,
212                swap_used: 1_000_000_000,
213            },
214            load: LoadSnapshot {
215                one: 1.5,
216                five: 1.2,
217                fifteen: 0.8,
218                uptime_secs: 3600,
219            },
220            processes: vec![ProcessInfo {
221                pid: 1,
222                parent_pid: None,
223                name: "init".into(),
224                command: "/sbin/init".into(),
225                user: "root".into(),
226                cpu_percent: 0.1,
227                memory_bytes: 4096,
228                memory_percent: 0.01,
229                status: "Running".into(),
230            }],
231            networks: NetworkSnapshot {
232                interfaces: vec![NetworkInterfaceSnapshot {
233                    name: "lo".into(),
234                    bytes_rx: 1000,
235                    bytes_tx: 1000,
236                    packets_rx: 10,
237                    packets_tx: 10,
238                    errors_rx: 0,
239                    errors_tx: 0,
240                    mac_address: "00:00:00:00:00:00".into(),
241                    is_up: true,
242                }],
243                total_rx: 1000,
244                total_tx: 1000,
245            },
246            containers: None,
247            timestamp_ms: 1_713_200_000_000,
248        }
249    }
250
251    #[test]
252    fn test_wire_snapshot_roundtrip() {
253        let msg = WireMessage::Snapshot(make_test_snapshot());
254        let frame = msg.to_frame().unwrap();
255        assert_eq!(frame.msg_type, MSG_SNAPSHOT);
256        let decoded = WireMessage::from_frame(&frame).unwrap();
257        assert_eq!(msg, decoded);
258    }
259
260    #[test]
261    fn test_encode_snapshot_ref_matches_to_frame() {
262        // PERF-L1 / ADR-30-4: the borrow-only encode helper must produce
263        // the same wire bytes as the owning path so the server can hand
264        // out a single pre-encoded frame to every connected client.
265        let snap = make_test_snapshot();
266        let owning_frame = WireMessage::Snapshot(snap.clone()).to_frame().unwrap();
267        let borrow_frame = WireMessage::encode_snapshot_ref(&snap).unwrap();
268        assert_eq!(owning_frame.msg_type, borrow_frame.msg_type);
269        assert_eq!(owning_frame.payload, borrow_frame.payload);
270        // And the decoded value is bit-for-bit identical.
271        let decoded = WireMessage::from_frame(&borrow_frame).unwrap();
272        assert_eq!(decoded, WireMessage::Snapshot(snap));
273    }
274
275    #[test]
276    fn test_wire_heartbeat_roundtrip() {
277        let msg = WireMessage::Heartbeat {
278            server_version: "0.2.0".into(),
279            uptime_secs: 86400,
280        };
281        let frame = msg.to_frame().unwrap();
282        assert_eq!(frame.msg_type, MSG_HEARTBEAT);
283        let decoded = WireMessage::from_frame(&frame).unwrap();
284        assert_eq!(msg, decoded);
285    }
286
287    #[test]
288    fn test_wire_error_roundtrip() {
289        let msg = WireMessage::Error {
290            code: 503,
291            message: "max clients reached".into(),
292        };
293        let frame = msg.to_frame().unwrap();
294        assert_eq!(frame.msg_type, MSG_ERROR);
295        let decoded = WireMessage::from_frame(&frame).unwrap();
296        assert_eq!(msg, decoded);
297    }
298
299    #[test]
300    fn test_wire_hello_roundtrip() {
301        let msg = WireMessage::Hello {
302            client_version: "0.2.0".into(),
303            auth_token: Some("secret-token".into()),
304        };
305        let frame = msg.to_frame().unwrap();
306        assert_eq!(frame.msg_type, MSG_HELLO);
307        let decoded = WireMessage::from_frame(&frame).unwrap();
308        assert_eq!(msg, decoded);
309    }
310
311    #[test]
312    fn test_wire_hello_no_token_roundtrip() {
313        let msg = WireMessage::Hello {
314            client_version: "0.2.0".into(),
315            auth_token: None,
316        };
317        let frame = msg.to_frame().unwrap();
318        let decoded = WireMessage::from_frame(&frame).unwrap();
319        assert_eq!(msg, decoded);
320    }
321
322    #[test]
323    fn test_wire_welcome_roundtrip() {
324        let msg = WireMessage::Welcome {
325            server_version: "0.2.0".into(),
326            hostname: "prod-server-01".into(),
327            refresh_hz: 1,
328        };
329        let frame = msg.to_frame().unwrap();
330        assert_eq!(frame.msg_type, MSG_WELCOME);
331        let decoded = WireMessage::from_frame(&frame).unwrap();
332        assert_eq!(msg, decoded);
333    }
334
335    #[test]
336    fn test_wire_unknown_message_type() {
337        let frame = Frame {
338            msg_type: 0xFF,
339            payload: vec![1, 2, 3],
340        };
341        let err = WireMessage::from_frame(&frame).unwrap_err();
342        assert!(matches!(err, ProtoError::UnknownMessageType(0xFF)));
343    }
344
345    #[test]
346    fn test_decode_limit_rejects_giant_string_claim() {
347        // MED-S1 (proto-side): a hand-crafted Hello payload that *claims* a
348        // 100 MiB `client_version` String must fail to decode without
349        // allocating that much memory. We bypass `to_frame()` entirely and
350        // build the bincode payload by hand.
351        //
352        // Wire layout for `Hello { client_version, auth_token }` =
353        //   tuple `(String, Option<String>)` (no tuple-length prefix in
354        //   bincode 2). The first bytes are therefore the String length
355        //   var-int directly.
356        //
357        // bincode 2 var-int format (little-endian, default `standard()`):
358        //   - 0..=250    : single byte
359        //   - 0xFB       : marker for u16, followed by 2 LE bytes
360        //   - 0xFC       : marker for u32, followed by 4 LE bytes
361        //   - 0xFD       : marker for u64, followed by 8 LE bytes
362        //   - 0xFE       : marker for u128, followed by 16 LE bytes
363        // 100 MiB fits in a u32 → 0xFC + 4 LE bytes.
364        let claimed_len: u32 = 100 * 1024 * 1024;
365        let mut payload = Vec::new();
366        payload.push(0xFC);
367        payload.extend_from_slice(&claimed_len.to_le_bytes());
368        // ...then NO actual bytes, deliberately. The decoder should refuse
369        // *before* attempting to read 100 MiB of UTF-8.
370
371        let frame = Frame {
372            msg_type: MSG_HELLO,
373            payload,
374        };
375        let err =
376            WireMessage::from_frame(&frame).expect_err("decoder must reject 100 MiB length claim");
377        // The `with_limit` config raises `LimitExceeded`; downstream we
378        // surface it as `ProtoError::Decode`. Either way, no panic and no
379        // 100 MiB allocation.
380        assert!(
381            matches!(err, ProtoError::Decode(_)),
382            "expected Decode error, got {err:?}"
383        );
384    }
385
386    #[test]
387    fn test_hello_token_validation() {
388        let hello = WireMessage::Hello {
389            client_version: "0.2.0".into(),
390            auth_token: Some("wrong-token".into()),
391        };
392        let expected_token = "correct-token";
393
394        // Extract and compare token.
395        if let WireMessage::Hello { auth_token, .. } = &hello {
396            let valid = auth_token.as_deref().is_some_and(|t| t == expected_token);
397            assert!(!valid, "wrong token should not validate");
398        }
399
400        let hello_correct = WireMessage::Hello {
401            client_version: "0.2.0".into(),
402            auth_token: Some("correct-token".into()),
403        };
404        if let WireMessage::Hello { auth_token, .. } = &hello_correct {
405            let valid = auth_token.as_deref().is_some_and(|t| t == expected_token);
406            assert!(valid, "correct token should validate");
407        }
408    }
409}