1use bincode::{config, decode_from_slice, encode_to_vec};
2use serde::{Deserialize, Serialize};
3
4use muxtop_core::system::SystemSnapshot;
5
6use crate::ProtoError;
7use crate::frame::{
8 Frame, MAX_FRAME_SIZE, MSG_ERROR, MSG_HEARTBEAT, MSG_HELLO, MSG_SNAPSHOT, MSG_WELCOME,
9};
10
11#[derive(Clone, PartialEq, Serialize, Deserialize)]
16pub enum WireMessage {
17 Snapshot(SystemSnapshot),
19
20 Heartbeat {
22 server_version: String,
23 uptime_secs: u64,
24 },
25
26 Error { code: u16, message: String },
28
29 Hello {
31 client_version: String,
32 auth_token: Option<String>,
33 },
34
35 Welcome {
37 server_version: String,
38 hostname: String,
39 refresh_hz: u32,
40 },
41}
42
43impl std::fmt::Debug for WireMessage {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 WireMessage::Snapshot(s) => f.debug_tuple("Snapshot").field(s).finish(),
47 WireMessage::Heartbeat {
48 server_version,
49 uptime_secs,
50 } => f
51 .debug_struct("Heartbeat")
52 .field("server_version", server_version)
53 .field("uptime_secs", uptime_secs)
54 .finish(),
55 WireMessage::Error { code, message } => f
56 .debug_struct("Error")
57 .field("code", code)
58 .field("message", message)
59 .finish(),
60 WireMessage::Hello {
61 client_version,
62 auth_token,
63 } => f
64 .debug_struct("Hello")
65 .field("client_version", client_version)
66 .field("auth_token", &auth_token.as_ref().map(|_| "[REDACTED]"))
67 .finish(),
68 WireMessage::Welcome {
69 server_version,
70 hostname,
71 refresh_hz,
72 } => f
73 .debug_struct("Welcome")
74 .field("server_version", server_version)
75 .field("hostname", hostname)
76 .field("refresh_hz", refresh_hz)
77 .finish(),
78 }
79 }
80}
81
82const MAX_DECODE_BYTES: usize = MAX_FRAME_SIZE as usize;
91
92fn bincode_config() -> impl bincode::config::Config {
93 config::standard().with_limit::<MAX_DECODE_BYTES>()
94}
95
96impl WireMessage {
97 pub fn encode_snapshot_ref(snap: &SystemSnapshot) -> Result<Frame, ProtoError> {
107 Ok(Frame {
108 msg_type: MSG_SNAPSHOT,
109 payload: encode_to_vec(snap, bincode_config())?,
110 })
111 }
112
113 pub fn to_frame(&self) -> Result<Frame, ProtoError> {
115 let (msg_type, payload) = match self {
116 WireMessage::Snapshot(snap) => (MSG_SNAPSHOT, encode_to_vec(snap, bincode_config())?),
117 WireMessage::Heartbeat {
118 server_version,
119 uptime_secs,
120 } => (
121 MSG_HEARTBEAT,
122 encode_to_vec((server_version, uptime_secs), bincode_config())?,
123 ),
124 WireMessage::Error { code, message } => {
125 (MSG_ERROR, encode_to_vec((code, message), bincode_config())?)
126 }
127 WireMessage::Hello {
128 client_version,
129 auth_token,
130 } => (
131 MSG_HELLO,
132 encode_to_vec((client_version, auth_token), bincode_config())?,
133 ),
134 WireMessage::Welcome {
135 server_version,
136 hostname,
137 refresh_hz,
138 } => (
139 MSG_WELCOME,
140 encode_to_vec((server_version, hostname, refresh_hz), bincode_config())?,
141 ),
142 };
143
144 Ok(Frame { msg_type, payload })
145 }
146
147 pub fn from_frame(frame: &Frame) -> Result<Self, ProtoError> {
149 match frame.msg_type {
150 MSG_SNAPSHOT => {
151 let (snap, _): (SystemSnapshot, _) =
152 decode_from_slice(&frame.payload, bincode_config())?;
153 Ok(WireMessage::Snapshot(snap))
154 }
155 MSG_HEARTBEAT => {
156 let ((server_version, uptime_secs), _): ((String, u64), _) =
157 decode_from_slice(&frame.payload, bincode_config())?;
158 Ok(WireMessage::Heartbeat {
159 server_version,
160 uptime_secs,
161 })
162 }
163 MSG_ERROR => {
164 let ((code, message), _): ((u16, String), _) =
165 decode_from_slice(&frame.payload, bincode_config())?;
166 Ok(WireMessage::Error { code, message })
167 }
168 MSG_HELLO => {
169 let ((client_version, auth_token), _): ((String, Option<String>), _) =
170 decode_from_slice(&frame.payload, bincode_config())?;
171 Ok(WireMessage::Hello {
172 client_version,
173 auth_token,
174 })
175 }
176 MSG_WELCOME => {
177 let ((server_version, hostname, refresh_hz), _): ((String, String, u32), _) =
178 decode_from_slice(&frame.payload, bincode_config())?;
179 Ok(WireMessage::Welcome {
180 server_version,
181 hostname,
182 refresh_hz,
183 })
184 }
185 other => Err(ProtoError::UnknownMessageType(other)),
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use muxtop_core::network::{NetworkInterfaceSnapshot, NetworkSnapshot};
194 use muxtop_core::process::ProcessInfo;
195 use muxtop_core::system::{CoreSnapshot, CpuSnapshot, LoadSnapshot, MemorySnapshot};
196
197 fn make_test_snapshot() -> SystemSnapshot {
198 SystemSnapshot {
199 cpu: CpuSnapshot {
200 global_usage: 45.2,
201 cores: vec![CoreSnapshot {
202 name: "cpu0".into(),
203 usage: 45.2,
204 frequency: 3600,
205 }],
206 },
207 memory: MemorySnapshot {
208 total: 16_000_000_000,
209 used: 8_000_000_000,
210 available: 8_000_000_000,
211 swap_total: 4_000_000_000,
212 swap_used: 1_000_000_000,
213 },
214 load: LoadSnapshot {
215 one: 1.5,
216 five: 1.2,
217 fifteen: 0.8,
218 uptime_secs: 3600,
219 },
220 processes: vec![ProcessInfo {
221 pid: 1,
222 parent_pid: None,
223 name: "init".into(),
224 command: "/sbin/init".into(),
225 user: "root".into(),
226 cpu_percent: 0.1,
227 memory_bytes: 4096,
228 memory_percent: 0.01,
229 status: "Running".into(),
230 }],
231 networks: NetworkSnapshot {
232 interfaces: vec![NetworkInterfaceSnapshot {
233 name: "lo".into(),
234 bytes_rx: 1000,
235 bytes_tx: 1000,
236 packets_rx: 10,
237 packets_tx: 10,
238 errors_rx: 0,
239 errors_tx: 0,
240 mac_address: "00:00:00:00:00:00".into(),
241 is_up: true,
242 }],
243 total_rx: 1000,
244 total_tx: 1000,
245 },
246 containers: None,
247 timestamp_ms: 1_713_200_000_000,
248 }
249 }
250
251 #[test]
252 fn test_wire_snapshot_roundtrip() {
253 let msg = WireMessage::Snapshot(make_test_snapshot());
254 let frame = msg.to_frame().unwrap();
255 assert_eq!(frame.msg_type, MSG_SNAPSHOT);
256 let decoded = WireMessage::from_frame(&frame).unwrap();
257 assert_eq!(msg, decoded);
258 }
259
260 #[test]
261 fn test_encode_snapshot_ref_matches_to_frame() {
262 let snap = make_test_snapshot();
266 let owning_frame = WireMessage::Snapshot(snap.clone()).to_frame().unwrap();
267 let borrow_frame = WireMessage::encode_snapshot_ref(&snap).unwrap();
268 assert_eq!(owning_frame.msg_type, borrow_frame.msg_type);
269 assert_eq!(owning_frame.payload, borrow_frame.payload);
270 let decoded = WireMessage::from_frame(&borrow_frame).unwrap();
272 assert_eq!(decoded, WireMessage::Snapshot(snap));
273 }
274
275 #[test]
276 fn test_wire_heartbeat_roundtrip() {
277 let msg = WireMessage::Heartbeat {
278 server_version: "0.2.0".into(),
279 uptime_secs: 86400,
280 };
281 let frame = msg.to_frame().unwrap();
282 assert_eq!(frame.msg_type, MSG_HEARTBEAT);
283 let decoded = WireMessage::from_frame(&frame).unwrap();
284 assert_eq!(msg, decoded);
285 }
286
287 #[test]
288 fn test_wire_error_roundtrip() {
289 let msg = WireMessage::Error {
290 code: 503,
291 message: "max clients reached".into(),
292 };
293 let frame = msg.to_frame().unwrap();
294 assert_eq!(frame.msg_type, MSG_ERROR);
295 let decoded = WireMessage::from_frame(&frame).unwrap();
296 assert_eq!(msg, decoded);
297 }
298
299 #[test]
300 fn test_wire_hello_roundtrip() {
301 let msg = WireMessage::Hello {
302 client_version: "0.2.0".into(),
303 auth_token: Some("secret-token".into()),
304 };
305 let frame = msg.to_frame().unwrap();
306 assert_eq!(frame.msg_type, MSG_HELLO);
307 let decoded = WireMessage::from_frame(&frame).unwrap();
308 assert_eq!(msg, decoded);
309 }
310
311 #[test]
312 fn test_wire_hello_no_token_roundtrip() {
313 let msg = WireMessage::Hello {
314 client_version: "0.2.0".into(),
315 auth_token: None,
316 };
317 let frame = msg.to_frame().unwrap();
318 let decoded = WireMessage::from_frame(&frame).unwrap();
319 assert_eq!(msg, decoded);
320 }
321
322 #[test]
323 fn test_wire_welcome_roundtrip() {
324 let msg = WireMessage::Welcome {
325 server_version: "0.2.0".into(),
326 hostname: "prod-server-01".into(),
327 refresh_hz: 1,
328 };
329 let frame = msg.to_frame().unwrap();
330 assert_eq!(frame.msg_type, MSG_WELCOME);
331 let decoded = WireMessage::from_frame(&frame).unwrap();
332 assert_eq!(msg, decoded);
333 }
334
335 #[test]
336 fn test_wire_unknown_message_type() {
337 let frame = Frame {
338 msg_type: 0xFF,
339 payload: vec![1, 2, 3],
340 };
341 let err = WireMessage::from_frame(&frame).unwrap_err();
342 assert!(matches!(err, ProtoError::UnknownMessageType(0xFF)));
343 }
344
345 #[test]
346 fn test_decode_limit_rejects_giant_string_claim() {
347 let claimed_len: u32 = 100 * 1024 * 1024;
365 let mut payload = Vec::new();
366 payload.push(0xFC);
367 payload.extend_from_slice(&claimed_len.to_le_bytes());
368 let frame = Frame {
372 msg_type: MSG_HELLO,
373 payload,
374 };
375 let err =
376 WireMessage::from_frame(&frame).expect_err("decoder must reject 100 MiB length claim");
377 assert!(
381 matches!(err, ProtoError::Decode(_)),
382 "expected Decode error, got {err:?}"
383 );
384 }
385
386 #[test]
387 fn test_hello_token_validation() {
388 let hello = WireMessage::Hello {
389 client_version: "0.2.0".into(),
390 auth_token: Some("wrong-token".into()),
391 };
392 let expected_token = "correct-token";
393
394 if let WireMessage::Hello { auth_token, .. } = &hello {
396 let valid = auth_token.as_deref().is_some_and(|t| t == expected_token);
397 assert!(!valid, "wrong token should not validate");
398 }
399
400 let hello_correct = WireMessage::Hello {
401 client_version: "0.2.0".into(),
402 auth_token: Some("correct-token".into()),
403 };
404 if let WireMessage::Hello { auth_token, .. } = &hello_correct {
405 let valid = auth_token.as_deref().is_some_and(|t| t == expected_token);
406 assert!(valid, "correct token should validate");
407 }
408 }
409}