Skip to main content

zlayer_tunnel/
protocol.rs

1//! Binary message protocol for tunnel communication
2//!
3//! Message format:
4//! ```text
5//! +----------+----------+----------------------------------+
6//! | Type(1)  | Len(4)   | Payload (variable)               |
7//! +----------+----------+----------------------------------+
8//! ```
9
10use crate::error::{Result, TunnelError};
11use serde::{Deserialize, Serialize};
12use uuid::Uuid;
13
14/// Protocol version
15pub const PROTOCOL_VERSION: u8 = 1;
16
17/// Maximum message size (64KB)
18pub const MAX_MESSAGE_SIZE: usize = 65536;
19
20/// Header size (1 byte type + 4 bytes length)
21pub const HEADER_SIZE: usize = 5;
22
23/// Message type discriminants
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum MessageType {
27    /// Client authentication request
28    Auth = 0x01,
29    /// Server authentication success
30    AuthOk = 0x02,
31    /// Server authentication failure
32    AuthFail = 0x03,
33    /// Client service registration
34    Register = 0x10,
35    /// Server registration success
36    RegisterOk = 0x11,
37    /// Server registration failure
38    RegisterFail = 0x12,
39    /// Server connection announcement
40    Connect = 0x20,
41    /// Client connection acknowledgment
42    ConnectAck = 0x21,
43    /// Client connection failure
44    ConnectFail = 0x22,
45    /// Heartbeat (bidirectional)
46    Heartbeat = 0x30,
47    /// Heartbeat acknowledgment (bidirectional)
48    HeartbeatAck = 0x31,
49    /// Client service unregistration
50    Unregister = 0x40,
51    /// Server disconnect notification
52    Disconnect = 0x41,
53}
54
55impl TryFrom<u8> for MessageType {
56    type Error = TunnelError;
57
58    fn try_from(value: u8) -> Result<Self> {
59        match value {
60            0x01 => Ok(Self::Auth),
61            0x02 => Ok(Self::AuthOk),
62            0x03 => Ok(Self::AuthFail),
63            0x10 => Ok(Self::Register),
64            0x11 => Ok(Self::RegisterOk),
65            0x12 => Ok(Self::RegisterFail),
66            0x20 => Ok(Self::Connect),
67            0x21 => Ok(Self::ConnectAck),
68            0x22 => Ok(Self::ConnectFail),
69            0x30 => Ok(Self::Heartbeat),
70            0x31 => Ok(Self::HeartbeatAck),
71            0x40 => Ok(Self::Unregister),
72            0x41 => Ok(Self::Disconnect),
73            _ => Err(TunnelError::protocol(format!(
74                "unknown message type: 0x{value:02x}"
75            ))),
76        }
77    }
78}
79
80/// Service protocol type
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
82#[serde(rename_all = "lowercase")]
83pub enum ServiceProtocol {
84    /// TCP protocol
85    #[default]
86    Tcp,
87    /// UDP protocol
88    Udp,
89}
90
91impl ServiceProtocol {
92    /// Convert to wire format byte
93    #[must_use]
94    pub const fn to_byte(self) -> u8 {
95        match self {
96            Self::Tcp => 0,
97            Self::Udp => 1,
98        }
99    }
100
101    /// Parse from wire format byte
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the byte is not a valid protocol type (0 or 1).
106    pub fn from_byte(byte: u8) -> Result<Self> {
107        match byte {
108            0 => Ok(Self::Tcp),
109            1 => Ok(Self::Udp),
110            _ => Err(TunnelError::protocol(format!(
111                "unknown protocol type: {byte}"
112            ))),
113        }
114    }
115}
116
117/// Protocol messages
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum Message {
120    /// Client authentication request (C->S)
121    Auth {
122        /// Authentication token
123        token: String,
124        /// Client identifier
125        client_id: Uuid,
126    },
127
128    /// Server authentication success (S->C)
129    AuthOk {
130        /// Assigned tunnel ID
131        tunnel_id: Uuid,
132    },
133
134    /// Server authentication failure (S->C)
135    AuthFail {
136        /// Failure reason
137        reason: String,
138    },
139
140    /// Client service registration (C->S)
141    Register {
142        /// Service name
143        name: String,
144        /// Protocol type
145        protocol: ServiceProtocol,
146        /// Local port on client
147        local_port: u16,
148        /// Requested remote port (0 = auto-assign)
149        remote_port: u16,
150    },
151
152    /// Server registration success (S->C)
153    RegisterOk {
154        /// Assigned service ID
155        service_id: Uuid,
156    },
157
158    /// Server registration failure (S->C)
159    RegisterFail {
160        /// Failure reason
161        reason: String,
162    },
163
164    /// Server connection announcement (S->C)
165    Connect {
166        /// Service ID for this connection
167        service_id: Uuid,
168        /// Unique connection ID
169        connection_id: Uuid,
170        /// Client address (IP:port)
171        client_addr: String,
172    },
173
174    /// Client connection acknowledgment (C->S)
175    ConnectAck {
176        /// Connection ID being acknowledged
177        connection_id: Uuid,
178    },
179
180    /// Client connection failure (C->S)
181    ConnectFail {
182        /// Connection ID that failed
183        connection_id: Uuid,
184        /// Failure reason
185        reason: String,
186    },
187
188    /// Heartbeat (bidirectional)
189    Heartbeat {
190        /// Unix timestamp in milliseconds
191        timestamp: u64,
192    },
193
194    /// Heartbeat acknowledgment (bidirectional)
195    HeartbeatAck {
196        /// Echo of original timestamp
197        timestamp: u64,
198    },
199
200    /// Client service unregistration (C->S)
201    Unregister {
202        /// Service ID to unregister
203        service_id: Uuid,
204    },
205
206    /// Server disconnect notification (S->C)
207    Disconnect {
208        /// Disconnect reason
209        reason: String,
210    },
211}
212
213impl Message {
214    /// Get the message type
215    #[must_use]
216    pub const fn message_type(&self) -> MessageType {
217        match self {
218            Self::Auth { .. } => MessageType::Auth,
219            Self::AuthOk { .. } => MessageType::AuthOk,
220            Self::AuthFail { .. } => MessageType::AuthFail,
221            Self::Register { .. } => MessageType::Register,
222            Self::RegisterOk { .. } => MessageType::RegisterOk,
223            Self::RegisterFail { .. } => MessageType::RegisterFail,
224            Self::Connect { .. } => MessageType::Connect,
225            Self::ConnectAck { .. } => MessageType::ConnectAck,
226            Self::ConnectFail { .. } => MessageType::ConnectFail,
227            Self::Heartbeat { .. } => MessageType::Heartbeat,
228            Self::HeartbeatAck { .. } => MessageType::HeartbeatAck,
229            Self::Unregister { .. } => MessageType::Unregister,
230            Self::Disconnect { .. } => MessageType::Disconnect,
231        }
232    }
233
234    /// Encode the message to binary format
235    #[must_use]
236    #[allow(clippy::cast_possible_truncation, clippy::match_same_arms)]
237    pub fn encode(&self) -> Vec<u8> {
238        let mut payload = Vec::new();
239
240        match self {
241            Self::Auth { token, client_id } => {
242                // token_len(2) + token + client_id(16)
243                let token_bytes = token.as_bytes();
244                payload.extend_from_slice(&(token_bytes.len() as u16).to_be_bytes());
245                payload.extend_from_slice(token_bytes);
246                payload.extend_from_slice(client_id.as_bytes());
247            }
248
249            Self::AuthOk { tunnel_id } => {
250                // tunnel_id(16)
251                payload.extend_from_slice(tunnel_id.as_bytes());
252            }
253
254            Self::AuthFail { reason } => {
255                // reason_len(2) + reason
256                let reason_bytes = reason.as_bytes();
257                payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
258                payload.extend_from_slice(reason_bytes);
259            }
260
261            Self::Register {
262                name,
263                protocol,
264                local_port,
265                remote_port,
266            } => {
267                // name_len(1) + name + protocol(1) + local_port(2) + remote_port(2)
268                let name_bytes = name.as_bytes();
269                payload.push(name_bytes.len() as u8);
270                payload.extend_from_slice(name_bytes);
271                payload.push(protocol.to_byte());
272                payload.extend_from_slice(&local_port.to_be_bytes());
273                payload.extend_from_slice(&remote_port.to_be_bytes());
274            }
275
276            Self::RegisterOk { service_id } => {
277                // service_id(16)
278                payload.extend_from_slice(service_id.as_bytes());
279            }
280
281            Self::RegisterFail { reason } => {
282                // reason_len(2) + reason
283                let reason_bytes = reason.as_bytes();
284                payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
285                payload.extend_from_slice(reason_bytes);
286            }
287
288            Self::Connect {
289                service_id,
290                connection_id,
291                client_addr,
292            } => {
293                // service_id(16) + connection_id(16) + addr_len(2) + client_addr
294                payload.extend_from_slice(service_id.as_bytes());
295                payload.extend_from_slice(connection_id.as_bytes());
296                let addr_bytes = client_addr.as_bytes();
297                payload.extend_from_slice(&(addr_bytes.len() as u16).to_be_bytes());
298                payload.extend_from_slice(addr_bytes);
299            }
300
301            Self::ConnectAck { connection_id } => {
302                // connection_id(16)
303                payload.extend_from_slice(connection_id.as_bytes());
304            }
305
306            Self::ConnectFail {
307                connection_id,
308                reason,
309            } => {
310                // connection_id(16) + reason_len(2) + reason
311                payload.extend_from_slice(connection_id.as_bytes());
312                let reason_bytes = reason.as_bytes();
313                payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
314                payload.extend_from_slice(reason_bytes);
315            }
316
317            Self::Heartbeat { timestamp } | Self::HeartbeatAck { timestamp } => {
318                // timestamp(8)
319                payload.extend_from_slice(&timestamp.to_be_bytes());
320            }
321
322            Self::Unregister { service_id } => {
323                // service_id(16)
324                payload.extend_from_slice(service_id.as_bytes());
325            }
326
327            Self::Disconnect { reason } => {
328                // reason_len(2) + reason
329                let reason_bytes = reason.as_bytes();
330                payload.extend_from_slice(&(reason_bytes.len() as u16).to_be_bytes());
331                payload.extend_from_slice(reason_bytes);
332            }
333        }
334
335        // Build final message: type(1) + len(4) + payload
336        let msg_type = self.message_type() as u8;
337        let payload_len = payload.len() as u32;
338
339        let mut result = Vec::with_capacity(HEADER_SIZE + payload.len());
340        result.push(msg_type);
341        result.extend_from_slice(&payload_len.to_be_bytes());
342        result.extend_from_slice(&payload);
343
344        result
345    }
346
347    /// Decode a message from binary format
348    ///
349    /// Returns the decoded message and the number of bytes consumed.
350    ///
351    /// # Errors
352    ///
353    /// Returns an error if:
354    /// - The buffer is too short for a complete message
355    /// - The message type is unknown
356    /// - The payload is malformed or contains invalid data
357    pub fn decode(bytes: &[u8]) -> Result<(Self, usize)> {
358        if bytes.len() < HEADER_SIZE {
359            return Err(TunnelError::protocol(format!(
360                "message too short: {} bytes, need at least {}",
361                bytes.len(),
362                HEADER_SIZE
363            )));
364        }
365
366        let msg_type = MessageType::try_from(bytes[0])?;
367        let payload_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
368
369        if payload_len > MAX_MESSAGE_SIZE - HEADER_SIZE {
370            return Err(TunnelError::protocol(format!(
371                "payload too large: {payload_len} bytes, max {}",
372                MAX_MESSAGE_SIZE - HEADER_SIZE
373            )));
374        }
375
376        let total_len = HEADER_SIZE + payload_len;
377        if bytes.len() < total_len {
378            return Err(TunnelError::protocol(format!(
379                "incomplete message: have {} bytes, need {}",
380                bytes.len(),
381                total_len
382            )));
383        }
384
385        let payload = &bytes[HEADER_SIZE..total_len];
386        let message = Self::decode_payload(msg_type, payload)?;
387
388        Ok((message, total_len))
389    }
390
391    /// Decode the payload for a given message type
392    #[allow(clippy::too_many_lines)]
393    fn decode_payload(msg_type: MessageType, payload: &[u8]) -> Result<Self> {
394        match msg_type {
395            MessageType::Auth => Self::decode_auth(payload),
396            MessageType::AuthOk => Self::decode_auth_ok(payload),
397            MessageType::AuthFail => Self::decode_auth_fail(payload),
398            MessageType::Register => Self::decode_register(payload),
399            MessageType::RegisterOk => Self::decode_register_ok(payload),
400            MessageType::RegisterFail => Self::decode_register_fail(payload),
401            MessageType::Connect => Self::decode_connect(payload),
402            MessageType::ConnectAck => Self::decode_connect_ack(payload),
403            MessageType::ConnectFail => Self::decode_connect_fail(payload),
404            MessageType::Heartbeat => Self::decode_heartbeat(payload),
405            MessageType::HeartbeatAck => Self::decode_heartbeat_ack(payload),
406            MessageType::Unregister => Self::decode_unregister(payload),
407            MessageType::Disconnect => Self::decode_disconnect(payload),
408        }
409    }
410
411    fn decode_auth(payload: &[u8]) -> Result<Self> {
412        // token_len(2) + token + client_id(16)
413        if payload.len() < 2 {
414            return Err(TunnelError::protocol(
415                "Auth: payload too short for token length",
416            ));
417        }
418        let token_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
419        if payload.len() < 2 + token_len + 16 {
420            return Err(TunnelError::protocol("Auth: payload too short"));
421        }
422        let token = String::from_utf8(payload[2..2 + token_len].to_vec())
423            .map_err(|e| TunnelError::protocol(format!("Auth: invalid token UTF-8: {e}")))?;
424        let client_id = Uuid::from_slice(&payload[2 + token_len..2 + token_len + 16])
425            .map_err(|e| TunnelError::protocol(format!("Auth: invalid client_id: {e}")))?;
426        Ok(Self::Auth { token, client_id })
427    }
428
429    fn decode_auth_ok(payload: &[u8]) -> Result<Self> {
430        // tunnel_id(16)
431        if payload.len() < 16 {
432            return Err(TunnelError::protocol("AuthOk: payload too short"));
433        }
434        let tunnel_id = Uuid::from_slice(&payload[..16])
435            .map_err(|e| TunnelError::protocol(format!("AuthOk: invalid tunnel_id: {e}")))?;
436        Ok(Self::AuthOk { tunnel_id })
437    }
438
439    fn decode_auth_fail(payload: &[u8]) -> Result<Self> {
440        // reason_len(2) + reason
441        if payload.len() < 2 {
442            return Err(TunnelError::protocol(
443                "AuthFail: payload too short for reason length",
444            ));
445        }
446        let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
447        if payload.len() < 2 + reason_len {
448            return Err(TunnelError::protocol(
449                "AuthFail: payload too short for reason",
450            ));
451        }
452        let reason = String::from_utf8(payload[2..2 + reason_len].to_vec())
453            .map_err(|e| TunnelError::protocol(format!("AuthFail: invalid reason UTF-8: {e}")))?;
454        Ok(Self::AuthFail { reason })
455    }
456
457    fn decode_register(payload: &[u8]) -> Result<Self> {
458        // name_len(1) + name + protocol(1) + local_port(2) + remote_port(2)
459        if payload.is_empty() {
460            return Err(TunnelError::protocol(
461                "Register: payload too short for name length",
462            ));
463        }
464        let name_len = payload[0] as usize;
465        if payload.len() < 1 + name_len + 1 + 2 + 2 {
466            return Err(TunnelError::protocol("Register: payload too short"));
467        }
468        let name = String::from_utf8(payload[1..=name_len].to_vec())
469            .map_err(|e| TunnelError::protocol(format!("Register: invalid name UTF-8: {e}")))?;
470        let protocol = ServiceProtocol::from_byte(payload[1 + name_len])?;
471        let local_port = u16::from_be_bytes([payload[2 + name_len], payload[3 + name_len]]);
472        let remote_port = u16::from_be_bytes([payload[4 + name_len], payload[5 + name_len]]);
473        Ok(Self::Register {
474            name,
475            protocol,
476            local_port,
477            remote_port,
478        })
479    }
480
481    fn decode_register_ok(payload: &[u8]) -> Result<Self> {
482        // service_id(16)
483        if payload.len() < 16 {
484            return Err(TunnelError::protocol("RegisterOk: payload too short"));
485        }
486        let service_id = Uuid::from_slice(&payload[..16])
487            .map_err(|e| TunnelError::protocol(format!("RegisterOk: invalid service_id: {e}")))?;
488        Ok(Self::RegisterOk { service_id })
489    }
490
491    fn decode_register_fail(payload: &[u8]) -> Result<Self> {
492        // reason_len(2) + reason
493        if payload.len() < 2 {
494            return Err(TunnelError::protocol(
495                "RegisterFail: payload too short for reason length",
496            ));
497        }
498        let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
499        if payload.len() < 2 + reason_len {
500            return Err(TunnelError::protocol(
501                "RegisterFail: payload too short for reason",
502            ));
503        }
504        let reason = String::from_utf8(payload[2..2 + reason_len].to_vec()).map_err(|e| {
505            TunnelError::protocol(format!("RegisterFail: invalid reason UTF-8: {e}"))
506        })?;
507        Ok(Self::RegisterFail { reason })
508    }
509
510    fn decode_connect(payload: &[u8]) -> Result<Self> {
511        // service_id(16) + connection_id(16) + addr_len(2) + client_addr
512        if payload.len() < 16 + 16 + 2 {
513            return Err(TunnelError::protocol("Connect: payload too short"));
514        }
515        let service_id = Uuid::from_slice(&payload[..16])
516            .map_err(|e| TunnelError::protocol(format!("Connect: invalid service_id: {e}")))?;
517        let connection_id = Uuid::from_slice(&payload[16..32])
518            .map_err(|e| TunnelError::protocol(format!("Connect: invalid connection_id: {e}")))?;
519        let addr_len = u16::from_be_bytes([payload[32], payload[33]]) as usize;
520        if payload.len() < 34 + addr_len {
521            return Err(TunnelError::protocol(
522                "Connect: payload too short for client_addr",
523            ));
524        }
525        let client_addr = String::from_utf8(payload[34..34 + addr_len].to_vec()).map_err(|e| {
526            TunnelError::protocol(format!("Connect: invalid client_addr UTF-8: {e}"))
527        })?;
528        Ok(Self::Connect {
529            service_id,
530            connection_id,
531            client_addr,
532        })
533    }
534
535    fn decode_connect_ack(payload: &[u8]) -> Result<Self> {
536        // connection_id(16)
537        if payload.len() < 16 {
538            return Err(TunnelError::protocol("ConnectAck: payload too short"));
539        }
540        let connection_id = Uuid::from_slice(&payload[..16]).map_err(|e| {
541            TunnelError::protocol(format!("ConnectAck: invalid connection_id: {e}"))
542        })?;
543        Ok(Self::ConnectAck { connection_id })
544    }
545
546    fn decode_connect_fail(payload: &[u8]) -> Result<Self> {
547        // connection_id(16) + reason_len(2) + reason
548        if payload.len() < 16 + 2 {
549            return Err(TunnelError::protocol("ConnectFail: payload too short"));
550        }
551        let connection_id = Uuid::from_slice(&payload[..16]).map_err(|e| {
552            TunnelError::protocol(format!("ConnectFail: invalid connection_id: {e}"))
553        })?;
554        let reason_len = u16::from_be_bytes([payload[16], payload[17]]) as usize;
555        if payload.len() < 18 + reason_len {
556            return Err(TunnelError::protocol(
557                "ConnectFail: payload too short for reason",
558            ));
559        }
560        let reason = String::from_utf8(payload[18..18 + reason_len].to_vec()).map_err(|e| {
561            TunnelError::protocol(format!("ConnectFail: invalid reason UTF-8: {e}"))
562        })?;
563        Ok(Self::ConnectFail {
564            connection_id,
565            reason,
566        })
567    }
568
569    fn decode_heartbeat(payload: &[u8]) -> Result<Self> {
570        // timestamp(8)
571        if payload.len() < 8 {
572            return Err(TunnelError::protocol("Heartbeat: payload too short"));
573        }
574        let timestamp = u64::from_be_bytes([
575            payload[0], payload[1], payload[2], payload[3], payload[4], payload[5], payload[6],
576            payload[7],
577        ]);
578        Ok(Self::Heartbeat { timestamp })
579    }
580
581    fn decode_heartbeat_ack(payload: &[u8]) -> Result<Self> {
582        // timestamp(8)
583        if payload.len() < 8 {
584            return Err(TunnelError::protocol("HeartbeatAck: payload too short"));
585        }
586        let timestamp = u64::from_be_bytes([
587            payload[0], payload[1], payload[2], payload[3], payload[4], payload[5], payload[6],
588            payload[7],
589        ]);
590        Ok(Self::HeartbeatAck { timestamp })
591    }
592
593    fn decode_unregister(payload: &[u8]) -> Result<Self> {
594        // service_id(16)
595        if payload.len() < 16 {
596            return Err(TunnelError::protocol("Unregister: payload too short"));
597        }
598        let service_id = Uuid::from_slice(&payload[..16])
599            .map_err(|e| TunnelError::protocol(format!("Unregister: invalid service_id: {e}")))?;
600        Ok(Self::Unregister { service_id })
601    }
602
603    fn decode_disconnect(payload: &[u8]) -> Result<Self> {
604        // reason_len(2) + reason
605        if payload.len() < 2 {
606            return Err(TunnelError::protocol(
607                "Disconnect: payload too short for reason length",
608            ));
609        }
610        let reason_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
611        if payload.len() < 2 + reason_len {
612            return Err(TunnelError::protocol(
613                "Disconnect: payload too short for reason",
614            ));
615        }
616        let reason = String::from_utf8(payload[2..2 + reason_len].to_vec())
617            .map_err(|e| TunnelError::protocol(format!("Disconnect: invalid reason UTF-8: {e}")))?;
618        Ok(Self::Disconnect { reason })
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    /// Helper to test encode/decode roundtrip
627    fn roundtrip(msg: &Message) {
628        let encoded = msg.encode();
629        let (decoded, consumed) = Message::decode(&encoded).expect("decode failed");
630        assert_eq!(consumed, encoded.len(), "consumed bytes mismatch");
631        assert_eq!(&decoded, msg, "roundtrip mismatch");
632    }
633
634    #[test]
635    fn test_auth_roundtrip() {
636        roundtrip(&Message::Auth {
637            token: "tun_abc123".to_string(),
638            client_id: Uuid::new_v4(),
639        });
640
641        // Test with empty token
642        roundtrip(&Message::Auth {
643            token: String::new(),
644            client_id: Uuid::nil(),
645        });
646
647        // Test with long token
648        roundtrip(&Message::Auth {
649            token: "a".repeat(1000),
650            client_id: Uuid::new_v4(),
651        });
652    }
653
654    #[test]
655    fn test_auth_ok_roundtrip() {
656        roundtrip(&Message::AuthOk {
657            tunnel_id: Uuid::new_v4(),
658        });
659
660        roundtrip(&Message::AuthOk {
661            tunnel_id: Uuid::nil(),
662        });
663    }
664
665    #[test]
666    fn test_auth_fail_roundtrip() {
667        roundtrip(&Message::AuthFail {
668            reason: "invalid token".to_string(),
669        });
670
671        roundtrip(&Message::AuthFail {
672            reason: String::new(),
673        });
674
675        roundtrip(&Message::AuthFail {
676            reason: "x".repeat(500),
677        });
678    }
679
680    #[test]
681    fn test_register_roundtrip() {
682        roundtrip(&Message::Register {
683            name: "ssh".to_string(),
684            protocol: ServiceProtocol::Tcp,
685            local_port: 22,
686            remote_port: 2222,
687        });
688
689        roundtrip(&Message::Register {
690            name: "game".to_string(),
691            protocol: ServiceProtocol::Udp,
692            local_port: 27015,
693            remote_port: 0, // Auto-assign
694        });
695
696        roundtrip(&Message::Register {
697            name: "a".repeat(255), // Max name length with u8
698            protocol: ServiceProtocol::Tcp,
699            local_port: 65535,
700            remote_port: 65535,
701        });
702    }
703
704    #[test]
705    fn test_register_ok_roundtrip() {
706        roundtrip(&Message::RegisterOk {
707            service_id: Uuid::new_v4(),
708        });
709    }
710
711    #[test]
712    fn test_register_fail_roundtrip() {
713        roundtrip(&Message::RegisterFail {
714            reason: "port already in use".to_string(),
715        });
716    }
717
718    #[test]
719    fn test_connect_roundtrip() {
720        roundtrip(&Message::Connect {
721            service_id: Uuid::new_v4(),
722            connection_id: Uuid::new_v4(),
723            client_addr: "192.168.1.100:54321".to_string(),
724        });
725
726        roundtrip(&Message::Connect {
727            service_id: Uuid::new_v4(),
728            connection_id: Uuid::new_v4(),
729            client_addr: "[::1]:8080".to_string(),
730        });
731    }
732
733    #[test]
734    fn test_connect_ack_roundtrip() {
735        roundtrip(&Message::ConnectAck {
736            connection_id: Uuid::new_v4(),
737        });
738    }
739
740    #[test]
741    fn test_connect_fail_roundtrip() {
742        roundtrip(&Message::ConnectFail {
743            connection_id: Uuid::new_v4(),
744            reason: "connection refused".to_string(),
745        });
746    }
747
748    #[test]
749    fn test_heartbeat_roundtrip() {
750        roundtrip(&Message::Heartbeat {
751            timestamp: 1_705_320_000_000,
752        });
753
754        roundtrip(&Message::Heartbeat { timestamp: 0 });
755
756        roundtrip(&Message::Heartbeat {
757            timestamp: u64::MAX,
758        });
759    }
760
761    #[test]
762    fn test_heartbeat_ack_roundtrip() {
763        roundtrip(&Message::HeartbeatAck {
764            timestamp: 1_705_320_000_000,
765        });
766    }
767
768    #[test]
769    fn test_unregister_roundtrip() {
770        roundtrip(&Message::Unregister {
771            service_id: Uuid::new_v4(),
772        });
773    }
774
775    #[test]
776    fn test_disconnect_roundtrip() {
777        roundtrip(&Message::Disconnect {
778            reason: "server shutdown".to_string(),
779        });
780    }
781
782    #[test]
783    fn test_message_type_discriminants() {
784        assert_eq!(
785            Message::Auth {
786                token: String::new(),
787                client_id: Uuid::nil()
788            }
789            .message_type(),
790            MessageType::Auth
791        );
792        assert_eq!(
793            Message::AuthOk {
794                tunnel_id: Uuid::nil()
795            }
796            .message_type(),
797            MessageType::AuthOk
798        );
799        assert_eq!(
800            Message::AuthFail {
801                reason: String::new()
802            }
803            .message_type(),
804            MessageType::AuthFail
805        );
806        assert_eq!(
807            Message::Register {
808                name: String::new(),
809                protocol: ServiceProtocol::Tcp,
810                local_port: 0,
811                remote_port: 0
812            }
813            .message_type(),
814            MessageType::Register
815        );
816        assert_eq!(
817            Message::RegisterOk {
818                service_id: Uuid::nil()
819            }
820            .message_type(),
821            MessageType::RegisterOk
822        );
823        assert_eq!(
824            Message::RegisterFail {
825                reason: String::new()
826            }
827            .message_type(),
828            MessageType::RegisterFail
829        );
830        assert_eq!(
831            Message::Connect {
832                service_id: Uuid::nil(),
833                connection_id: Uuid::nil(),
834                client_addr: String::new()
835            }
836            .message_type(),
837            MessageType::Connect
838        );
839        assert_eq!(
840            Message::ConnectAck {
841                connection_id: Uuid::nil()
842            }
843            .message_type(),
844            MessageType::ConnectAck
845        );
846        assert_eq!(
847            Message::ConnectFail {
848                connection_id: Uuid::nil(),
849                reason: String::new()
850            }
851            .message_type(),
852            MessageType::ConnectFail
853        );
854        assert_eq!(
855            Message::Heartbeat { timestamp: 0 }.message_type(),
856            MessageType::Heartbeat
857        );
858        assert_eq!(
859            Message::HeartbeatAck { timestamp: 0 }.message_type(),
860            MessageType::HeartbeatAck
861        );
862        assert_eq!(
863            Message::Unregister {
864                service_id: Uuid::nil()
865            }
866            .message_type(),
867            MessageType::Unregister
868        );
869        assert_eq!(
870            Message::Disconnect {
871                reason: String::new()
872            }
873            .message_type(),
874            MessageType::Disconnect
875        );
876    }
877
878    #[test]
879    fn test_message_type_from_u8() {
880        assert_eq!(MessageType::try_from(0x01).unwrap(), MessageType::Auth);
881        assert_eq!(MessageType::try_from(0x02).unwrap(), MessageType::AuthOk);
882        assert_eq!(MessageType::try_from(0x03).unwrap(), MessageType::AuthFail);
883        assert_eq!(MessageType::try_from(0x10).unwrap(), MessageType::Register);
884        assert_eq!(
885            MessageType::try_from(0x11).unwrap(),
886            MessageType::RegisterOk
887        );
888        assert_eq!(
889            MessageType::try_from(0x12).unwrap(),
890            MessageType::RegisterFail
891        );
892        assert_eq!(MessageType::try_from(0x20).unwrap(), MessageType::Connect);
893        assert_eq!(
894            MessageType::try_from(0x21).unwrap(),
895            MessageType::ConnectAck
896        );
897        assert_eq!(
898            MessageType::try_from(0x22).unwrap(),
899            MessageType::ConnectFail
900        );
901        assert_eq!(MessageType::try_from(0x30).unwrap(), MessageType::Heartbeat);
902        assert_eq!(
903            MessageType::try_from(0x31).unwrap(),
904            MessageType::HeartbeatAck
905        );
906        assert_eq!(
907            MessageType::try_from(0x40).unwrap(),
908            MessageType::Unregister
909        );
910        assert_eq!(
911            MessageType::try_from(0x41).unwrap(),
912            MessageType::Disconnect
913        );
914
915        // Invalid type
916        assert!(MessageType::try_from(0xFF).is_err());
917        assert!(MessageType::try_from(0x00).is_err());
918    }
919
920    #[test]
921    fn test_service_protocol_roundtrip() {
922        assert_eq!(
923            ServiceProtocol::from_byte(ServiceProtocol::Tcp.to_byte()).unwrap(),
924            ServiceProtocol::Tcp
925        );
926        assert_eq!(
927            ServiceProtocol::from_byte(ServiceProtocol::Udp.to_byte()).unwrap(),
928            ServiceProtocol::Udp
929        );
930        assert!(ServiceProtocol::from_byte(0xFF).is_err());
931    }
932
933    #[test]
934    fn test_decode_too_short() {
935        // Less than header size
936        assert!(Message::decode(&[]).is_err());
937        assert!(Message::decode(&[0x01]).is_err());
938        assert!(Message::decode(&[0x01, 0x00, 0x00, 0x00]).is_err());
939    }
940
941    #[test]
942    fn test_decode_incomplete_payload() {
943        // Valid header but incomplete payload
944        let bytes = [0x01, 0x00, 0x00, 0x00, 0x20, 0x00]; // Says 32 bytes payload, but only 1 byte present
945        assert!(Message::decode(&bytes).is_err());
946    }
947
948    #[test]
949    fn test_decode_invalid_message_type() {
950        let bytes = [0xFF, 0x00, 0x00, 0x00, 0x00]; // Invalid type, zero payload
951        assert!(Message::decode(&bytes).is_err());
952    }
953
954    #[test]
955    fn test_decode_payload_too_large() {
956        // Payload size exceeds MAX_MESSAGE_SIZE
957        let bytes = [0x01, 0xFF, 0xFF, 0xFF, 0xFF]; // ~4GB payload
958        assert!(Message::decode(&bytes).is_err());
959    }
960
961    #[test]
962    fn test_header_size_constant() {
963        // Verify HEADER_SIZE is correct
964        let msg = Message::Heartbeat { timestamp: 0 };
965        let encoded = msg.encode();
966        // Heartbeat has 8-byte payload
967        assert_eq!(encoded.len(), HEADER_SIZE + 8);
968    }
969
970    #[test]
971    fn test_multiple_messages_in_buffer() {
972        let msg1 = Message::Heartbeat { timestamp: 100 };
973        let msg2 = Message::HeartbeatAck { timestamp: 100 };
974
975        let mut buffer = msg1.encode();
976        buffer.extend_from_slice(&msg2.encode());
977
978        // Decode first message
979        let (decoded1, consumed1) = Message::decode(&buffer).unwrap();
980        assert_eq!(decoded1, msg1);
981
982        // Decode second message from remaining buffer
983        let (decoded2, consumed2) = Message::decode(&buffer[consumed1..]).unwrap();
984        assert_eq!(decoded2, msg2);
985
986        assert_eq!(consumed1 + consumed2, buffer.len());
987    }
988}