Skip to main content

codive_tunnel/
protocol.rs

1//! Tunnel protocol message definitions
2//!
3//! This module defines the wire protocol for communication between:
4//! - Local agent and relay server (control + encrypted data)
5//! - Remote client and relay server (encrypted data only)
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Message type bytes for the wire format
11pub mod message_type {
12    /// Encrypted HTTP request (client -> agent via relay)
13    pub const ENCRYPTED_REQUEST: u8 = 0x01;
14    /// Encrypted HTTP response (agent -> client via relay)
15    pub const ENCRYPTED_RESPONSE: u8 = 0x02;
16    /// Encrypted SSE event (agent -> client via relay)
17    pub const ENCRYPTED_EVENT: u8 = 0x03;
18    /// Control: Ping
19    pub const PING: u8 = 0x10;
20    /// Control: Pong
21    pub const PONG: u8 = 0x11;
22    /// Control: Close
23    pub const CLOSE: u8 = 0x12;
24    /// Relay error (unencrypted)
25    pub const RELAY_ERROR: u8 = 0xFF;
26}
27
28/// Control messages for tunnel management (sent in plaintext over WSS)
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum ControlMessage {
32    /// Agent -> Relay: Initial handshake to establish tunnel
33    Hello {
34        /// Protocol version
35        version: u8,
36        /// Requested tunnel ID (optional, relay may assign random)
37        #[serde(skip_serializing_if = "Option::is_none")]
38        requested_id: Option<String>,
39        /// Authentication token (optional, required if relay enforces auth)
40        #[serde(skip_serializing_if = "Option::is_none")]
41        auth_token: Option<String>,
42    },
43
44    /// Relay -> Agent: Handshake response with assigned tunnel info
45    Welcome {
46        /// Assigned tunnel ID
47        tunnel_id: String,
48        /// Base URL for the tunnel (without fragment)
49        /// e.g., "https://abc12345.relay.example.com"
50        tunnel_url: String,
51    },
52
53    /// Relay -> Agent: A client has connected to the tunnel
54    ClientConnected {
55        /// Unique identifier for this client connection
56        client_id: String,
57    },
58
59    /// Relay -> Agent: A client has disconnected
60    ClientDisconnected {
61        /// Client identifier
62        client_id: String,
63    },
64
65    /// Bidirectional: Keep-alive ping
66    Ping {
67        /// Timestamp for latency measurement (milliseconds)
68        timestamp: u64,
69    },
70
71    /// Bidirectional: Keep-alive pong
72    Pong {
73        /// Echo of the ping timestamp
74        timestamp: u64,
75    },
76
77    /// Either side: Graceful close
78    Close {
79        /// Human-readable reason for closing
80        reason: String,
81    },
82
83    /// Relay -> Agent/Client: Error occurred
84    Error {
85        /// Error code
86        code: String,
87        /// Human-readable error message
88        message: String,
89    },
90}
91
92/// Data messages for HTTP traffic (encrypted with XChaCha20-Poly1305)
93#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(tag = "type", rename_all = "snake_case")]
95pub enum DataMessage {
96    /// HTTP request from client to agent
97    HttpRequest {
98        /// Unique ID for correlating response
99        request_id: String,
100        /// Client connection ID
101        client_id: String,
102        /// HTTP method (GET, POST, etc.)
103        method: String,
104        /// Request path (e.g., "/api/sessions")
105        path: String,
106        /// Query string (without leading ?)
107        #[serde(skip_serializing_if = "Option::is_none")]
108        query: Option<String>,
109        /// HTTP headers
110        headers: HashMap<String, String>,
111        /// Request body (base64 encoded for binary safety)
112        #[serde(skip_serializing_if = "Option::is_none")]
113        body: Option<String>,
114    },
115
116    /// HTTP response from agent to client
117    HttpResponse {
118        /// Correlates to request_id
119        request_id: String,
120        /// HTTP status code
121        status: u16,
122        /// Response headers
123        headers: HashMap<String, String>,
124        /// Response body (base64 encoded)
125        #[serde(skip_serializing_if = "Option::is_none")]
126        body: Option<String>,
127        /// Is this a streaming response? (SSE)
128        #[serde(default)]
129        streaming: bool,
130    },
131
132    /// Streaming response chunk (for SSE support)
133    HttpResponseChunk {
134        /// Correlates to request_id
135        request_id: String,
136        /// Chunk data (base64 encoded)
137        chunk: String,
138        /// Is this the final chunk?
139        #[serde(default)]
140        is_final: bool,
141    },
142
143    /// Error processing request
144    RequestError {
145        /// Correlates to request_id (if known)
146        #[serde(skip_serializing_if = "Option::is_none")]
147        request_id: Option<String>,
148        /// Error code
149        code: String,
150        /// Human-readable error message
151        message: String,
152    },
153}
154
155/// Wrapper for wire messages with type prefix
156#[derive(Debug, Clone)]
157pub enum WireMessage {
158    /// Control message (JSON)
159    Control(ControlMessage),
160    /// Encrypted data (binary)
161    EncryptedData {
162        message_type: u8,
163        payload: Vec<u8>,
164    },
165}
166
167impl WireMessage {
168    /// Encode a control message to bytes
169    pub fn encode_control(msg: &ControlMessage) -> Vec<u8> {
170        serde_json::to_vec(msg).expect("Control message serialization should not fail")
171    }
172
173    /// Decode a control message from bytes
174    pub fn decode_control(data: &[u8]) -> Result<ControlMessage, serde_json::Error> {
175        serde_json::from_slice(data)
176    }
177
178    /// Encode an encrypted data message to bytes (with type prefix)
179    /// Format: [message_type: 1 byte][payload]
180    pub fn encode_encrypted(message_type: u8, encrypted_payload: Vec<u8>) -> Vec<u8> {
181        let mut result = Vec::with_capacity(1 + encrypted_payload.len());
182        result.push(message_type);
183        result.extend(encrypted_payload);
184        result
185    }
186
187    /// Encode an encrypted data message with routing header
188    /// Format: [message_type: 1 byte][request_id_len: 1 byte][request_id][encrypted_payload]
189    /// This allows the relay to route responses without decrypting the payload
190    pub fn encode_encrypted_with_routing(
191        message_type: u8,
192        request_id: &str,
193        encrypted_payload: Vec<u8>,
194    ) -> Vec<u8> {
195        let request_id_bytes = request_id.as_bytes();
196        let id_len = request_id_bytes.len().min(255) as u8;
197        let mut result = Vec::with_capacity(2 + id_len as usize + encrypted_payload.len());
198        result.push(message_type);
199        result.push(id_len);
200        result.extend_from_slice(&request_id_bytes[..id_len as usize]);
201        result.extend(encrypted_payload);
202        result
203    }
204
205    /// Decode an encrypted data message from bytes
206    pub fn decode_encrypted(data: &[u8]) -> Result<(u8, &[u8]), &'static str> {
207        if data.is_empty() {
208            return Err("Empty message");
209        }
210        let message_type = data[0];
211        let payload = &data[1..];
212        Ok((message_type, payload))
213    }
214
215    /// Decode an encrypted data message with routing header
216    /// Returns (message_type, request_id, encrypted_payload)
217    pub fn decode_encrypted_with_routing(data: &[u8]) -> Result<(u8, &str, &[u8]), &'static str> {
218        if data.len() < 2 {
219            return Err("Message too short");
220        }
221        let message_type = data[0];
222        let id_len = data[1] as usize;
223        if data.len() < 2 + id_len {
224            return Err("Message truncated");
225        }
226        let request_id = std::str::from_utf8(&data[2..2 + id_len])
227            .map_err(|_| "Invalid request_id encoding")?;
228        let payload = &data[2 + id_len..];
229        Ok((message_type, request_id, payload))
230    }
231}
232
233/// URL utilities for tunnel URLs with embedded encryption keys
234pub mod url {
235    use anyhow::{anyhow, Result};
236
237    /// Components extracted from a tunnel URL
238    #[derive(Debug, Clone)]
239    pub struct TunnelUrl {
240        /// Full URL without fragment (for WebSocket connection)
241        pub base_url: String,
242        /// The tunnel ID extracted from subdomain
243        pub tunnel_id: String,
244        /// The encryption key from the URL fragment
245        pub encryption_key: String,
246    }
247
248    impl TunnelUrl {
249        /// Parse a tunnel URL like "https://abc123.relay.example.com#key"
250        pub fn parse(url: &str) -> Result<Self> {
251            // Split on fragment
252            let (base, fragment) = url
253                .split_once('#')
254                .ok_or_else(|| anyhow!("Missing encryption key in URL fragment"))?;
255
256            if fragment.is_empty() {
257                return Err(anyhow!("Empty encryption key in URL fragment"));
258            }
259
260            // Extract tunnel_id from subdomain
261            // URL format: https://{tunnel_id}.relay.example.com
262            let host = base
263                .strip_prefix("https://")
264                .or_else(|| base.strip_prefix("http://"))
265                .ok_or_else(|| anyhow!("Invalid URL scheme"))?;
266
267            let host = host.split('/').next().unwrap_or(host);
268            let tunnel_id = host
269                .split('.')
270                .next()
271                .ok_or_else(|| anyhow!("Cannot extract tunnel ID from URL"))?;
272
273            Ok(Self {
274                base_url: base.to_string(),
275                tunnel_id: tunnel_id.to_string(),
276                encryption_key: fragment.to_string(),
277            })
278        }
279
280        /// Construct a tunnel URL from components
281        pub fn build(base_url: &str, encryption_key: &str) -> String {
282            format!("{}#{}", base_url, encryption_key)
283        }
284    }
285
286    #[cfg(test)]
287    mod tests {
288        use super::TunnelUrl;
289
290        #[test]
291        fn test_parse_tunnel_url() {
292            let url = "https://abc12345.relay.example.com#K8dX2mPqR7vNzL5hJwYtF3gBcE9sUoAi";
293            let parsed = TunnelUrl::parse(url).unwrap();
294
295            assert_eq!(parsed.base_url, "https://abc12345.relay.example.com");
296            assert_eq!(parsed.tunnel_id, "abc12345");
297            assert_eq!(parsed.encryption_key, "K8dX2mPqR7vNzL5hJwYtF3gBcE9sUoAi");
298        }
299
300        #[test]
301        fn test_parse_tunnel_url_with_path() {
302            let url = "https://abc12345.relay.example.com/some/path#key123";
303            let parsed = TunnelUrl::parse(url).unwrap();
304
305            assert_eq!(parsed.base_url, "https://abc12345.relay.example.com/some/path");
306            assert_eq!(parsed.tunnel_id, "abc12345");
307            assert_eq!(parsed.encryption_key, "key123");
308        }
309
310        #[test]
311        fn test_parse_missing_fragment() {
312            let url = "https://abc12345.relay.example.com";
313            let result = TunnelUrl::parse(url);
314            assert!(result.is_err());
315        }
316
317        #[test]
318        fn test_parse_empty_fragment() {
319            let url = "https://abc12345.relay.example.com#";
320            let result = TunnelUrl::parse(url);
321            assert!(result.is_err());
322        }
323
324        #[test]
325        fn test_build_tunnel_url() {
326            let url = TunnelUrl::build("https://abc.relay.example.com", "mykey123");
327            assert_eq!(url, "https://abc.relay.example.com#mykey123");
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use base64::Engine;
336
337    // ============================================================================
338    // Control Message Tests
339    // ============================================================================
340
341    #[test]
342    fn test_control_message_serialization() {
343        let msg = ControlMessage::Hello {
344            version: 1,
345            requested_id: Some("test123".to_string()),
346            auth_token: None,
347        };
348
349        let json = serde_json::to_string(&msg).unwrap();
350        assert!(json.contains("\"type\":\"hello\""));
351        assert!(json.contains("\"version\":1"));
352
353        let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
354        match decoded {
355            ControlMessage::Hello { version, requested_id, auth_token } => {
356                assert_eq!(version, 1);
357                assert_eq!(requested_id, Some("test123".to_string()));
358                assert_eq!(auth_token, None);
359            }
360            _ => panic!("Wrong message type"),
361        }
362    }
363
364    #[test]
365    fn test_hello_without_requested_id() {
366        let msg = ControlMessage::Hello {
367            version: 1,
368            requested_id: None,
369            auth_token: None,
370        };
371
372        let json = serde_json::to_string(&msg).unwrap();
373        // requested_id should be omitted when None
374        assert!(!json.contains("requested_id"));
375
376        let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
377        match decoded {
378            ControlMessage::Hello { version, requested_id, auth_token } => {
379                assert_eq!(version, 1);
380                assert_eq!(requested_id, None);
381                assert_eq!(auth_token, None);
382            }
383            _ => panic!("Wrong message type"),
384        }
385    }
386
387    #[test]
388    fn test_hello_with_auth_token() {
389        let msg = ControlMessage::Hello {
390            version: 1,
391            requested_id: Some("test123".to_string()),
392            auth_token: Some("secret-token-abc".to_string()),
393        };
394
395        let json = serde_json::to_string(&msg).unwrap();
396        assert!(json.contains("\"auth_token\":\"secret-token-abc\""));
397
398        let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
399        match decoded {
400            ControlMessage::Hello { version, requested_id, auth_token } => {
401                assert_eq!(version, 1);
402                assert_eq!(requested_id, Some("test123".to_string()));
403                assert_eq!(auth_token, Some("secret-token-abc".to_string()));
404            }
405            _ => panic!("Wrong message type"),
406        }
407    }
408
409    #[test]
410    fn test_welcome_message() {
411        let msg = ControlMessage::Welcome {
412            tunnel_id: "abc123".to_string(),
413            tunnel_url: "https://abc123.relay.example.com".to_string(),
414        };
415
416        let json = serde_json::to_string(&msg).unwrap();
417        assert!(json.contains("\"type\":\"welcome\""));
418        assert!(json.contains("\"tunnel_id\":\"abc123\""));
419
420        let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
421        match decoded {
422            ControlMessage::Welcome { tunnel_id, tunnel_url } => {
423                assert_eq!(tunnel_id, "abc123");
424                assert_eq!(tunnel_url, "https://abc123.relay.example.com");
425            }
426            _ => panic!("Wrong message type"),
427        }
428    }
429
430    #[test]
431    fn test_ping_pong_messages() {
432        let ping = ControlMessage::Ping { timestamp: 1234567890 };
433        let ping_json = serde_json::to_string(&ping).unwrap();
434        assert!(ping_json.contains("\"type\":\"ping\""));
435        assert!(ping_json.contains("\"timestamp\":1234567890"));
436
437        let pong = ControlMessage::Pong { timestamp: 1234567890 };
438        let pong_json = serde_json::to_string(&pong).unwrap();
439        assert!(pong_json.contains("\"type\":\"pong\""));
440    }
441
442    #[test]
443    fn test_close_message() {
444        let msg = ControlMessage::Close {
445            reason: "graceful shutdown".to_string(),
446        };
447
448        let json = serde_json::to_string(&msg).unwrap();
449        assert!(json.contains("\"type\":\"close\""));
450        assert!(json.contains("graceful shutdown"));
451    }
452
453    #[test]
454    fn test_error_message() {
455        let msg = ControlMessage::Error {
456            code: "RATE_LIMITED".to_string(),
457            message: "Too many requests".to_string(),
458        };
459
460        let json = serde_json::to_string(&msg).unwrap();
461        assert!(json.contains("\"type\":\"error\""));
462        assert!(json.contains("RATE_LIMITED"));
463    }
464
465    #[test]
466    fn test_client_connected_disconnected() {
467        let connected = ControlMessage::ClientConnected {
468            client_id: "client-123".to_string(),
469        };
470        let connected_json = serde_json::to_string(&connected).unwrap();
471        assert!(connected_json.contains("\"type\":\"client_connected\""));
472
473        let disconnected = ControlMessage::ClientDisconnected {
474            client_id: "client-123".to_string(),
475        };
476        let disconnected_json = serde_json::to_string(&disconnected).unwrap();
477        assert!(disconnected_json.contains("\"type\":\"client_disconnected\""));
478    }
479
480    // ============================================================================
481    // Data Message Tests
482    // ============================================================================
483
484    #[test]
485    fn test_data_message_serialization() {
486        let mut headers = HashMap::new();
487        headers.insert("Content-Type".to_string(), "application/json".to_string());
488
489        let msg = DataMessage::HttpRequest {
490            request_id: "req-123".to_string(),
491            client_id: "client-456".to_string(),
492            method: "POST".to_string(),
493            path: "/api/sessions".to_string(),
494            query: Some("foo=bar".to_string()),
495            headers,
496            body: Some("eyJoZWxsbyI6IndvcmxkIn0=".to_string()),
497        };
498
499        let json = serde_json::to_string(&msg).unwrap();
500        assert!(json.contains("\"type\":\"http_request\""));
501        assert!(json.contains("\"method\":\"POST\""));
502
503        let decoded: DataMessage = serde_json::from_str(&json).unwrap();
504        match decoded {
505            DataMessage::HttpRequest { method, path, .. } => {
506                assert_eq!(method, "POST");
507                assert_eq!(path, "/api/sessions");
508            }
509            _ => panic!("Wrong message type"),
510        }
511    }
512
513    #[test]
514    fn test_http_request_minimal() {
515        let msg = DataMessage::HttpRequest {
516            request_id: "req-1".to_string(),
517            client_id: "client-1".to_string(),
518            method: "GET".to_string(),
519            path: "/health".to_string(),
520            query: None,
521            headers: HashMap::new(),
522            body: None,
523        };
524
525        let json = serde_json::to_string(&msg).unwrap();
526        // Optional fields should be omitted
527        assert!(!json.contains("\"query\""));
528        assert!(!json.contains("\"body\""));
529    }
530
531    #[test]
532    fn test_http_response() {
533        let mut headers = HashMap::new();
534        headers.insert("Content-Type".to_string(), "application/json".to_string());
535
536        let msg = DataMessage::HttpResponse {
537            request_id: "req-123".to_string(),
538            status: 200,
539            headers,
540            body: Some("eyJvayI6dHJ1ZX0=".to_string()),
541            streaming: false,
542        };
543
544        let json = serde_json::to_string(&msg).unwrap();
545        assert!(json.contains("\"type\":\"http_response\""));
546        assert!(json.contains("\"status\":200"));
547        assert!(json.contains("\"streaming\":false"));
548    }
549
550    #[test]
551    fn test_http_response_streaming() {
552        let msg = DataMessage::HttpResponse {
553            request_id: "req-123".to_string(),
554            status: 200,
555            headers: HashMap::new(),
556            body: None,
557            streaming: true,
558        };
559
560        let json = serde_json::to_string(&msg).unwrap();
561        assert!(json.contains("\"streaming\":true"));
562    }
563
564    #[test]
565    fn test_http_response_chunk() {
566        let msg = DataMessage::HttpResponseChunk {
567            request_id: "req-123".to_string(),
568            chunk: "ZGF0YTogaGVsbG8K".to_string(), // "data: hello\n" base64
569            is_final: false,
570        };
571
572        let json = serde_json::to_string(&msg).unwrap();
573        assert!(json.contains("\"type\":\"http_response_chunk\""));
574        assert!(json.contains("\"is_final\":false"));
575
576        let final_chunk = DataMessage::HttpResponseChunk {
577            request_id: "req-123".to_string(),
578            chunk: "".to_string(),
579            is_final: true,
580        };
581
582        let final_json = serde_json::to_string(&final_chunk).unwrap();
583        assert!(final_json.contains("\"is_final\":true"));
584    }
585
586    #[test]
587    fn test_request_error() {
588        let msg = DataMessage::RequestError {
589            request_id: Some("req-123".to_string()),
590            code: "TIMEOUT".to_string(),
591            message: "Request timed out".to_string(),
592        };
593
594        let json = serde_json::to_string(&msg).unwrap();
595        assert!(json.contains("\"type\":\"request_error\""));
596        assert!(json.contains("TIMEOUT"));
597
598        // Test without request_id
599        let msg_no_id = DataMessage::RequestError {
600            request_id: None,
601            code: "INTERNAL_ERROR".to_string(),
602            message: "Something went wrong".to_string(),
603        };
604
605        let json_no_id = serde_json::to_string(&msg_no_id).unwrap();
606        assert!(!json_no_id.contains("\"request_id\""));
607    }
608
609    // ============================================================================
610    // Wire Message Tests
611    // ============================================================================
612
613    #[test]
614    fn test_wire_message_encoding() {
615        let encrypted = vec![1, 2, 3, 4, 5];
616        let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, encrypted.clone());
617
618        assert_eq!(encoded[0], message_type::ENCRYPTED_REQUEST);
619        assert_eq!(&encoded[1..], &encrypted[..]);
620
621        let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
622        assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
623        assert_eq!(payload, &encrypted[..]);
624    }
625
626    #[test]
627    fn test_wire_message_all_types() {
628        let test_cases = [
629            message_type::ENCRYPTED_REQUEST,
630            message_type::ENCRYPTED_RESPONSE,
631            message_type::ENCRYPTED_EVENT,
632            message_type::PING,
633            message_type::PONG,
634            message_type::CLOSE,
635            message_type::RELAY_ERROR,
636        ];
637
638        for msg_type in test_cases {
639            let payload = vec![0xAB, 0xCD, 0xEF];
640            let encoded = WireMessage::encode_encrypted(msg_type, payload.clone());
641            let (decoded_type, decoded_payload) = WireMessage::decode_encrypted(&encoded).unwrap();
642            assert_eq!(decoded_type, msg_type, "Message type mismatch for 0x{:02X}", msg_type);
643            assert_eq!(decoded_payload, &payload[..]);
644        }
645    }
646
647    #[test]
648    fn test_wire_message_empty_payload() {
649        let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, vec![]);
650        assert_eq!(encoded.len(), 1);
651        assert_eq!(encoded[0], message_type::ENCRYPTED_REQUEST);
652
653        let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
654        assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
655        assert!(payload.is_empty());
656    }
657
658    #[test]
659    fn test_wire_message_large_payload() {
660        // 1MB payload
661        let large_payload: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
662        let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_RESPONSE, large_payload.clone());
663
664        assert_eq!(encoded.len(), 1 + large_payload.len());
665        let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
666        assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
667        assert_eq!(payload.len(), large_payload.len());
668        assert_eq!(payload, &large_payload[..]);
669    }
670
671    #[test]
672    fn test_wire_message_decode_empty() {
673        let result = WireMessage::decode_encrypted(&[]);
674        assert!(result.is_err());
675        assert_eq!(result.unwrap_err(), "Empty message");
676    }
677
678    // ============================================================================
679    // Wire Message with Routing Header Tests
680    // ============================================================================
681
682    #[test]
683    fn test_wire_message_with_routing_roundtrip() {
684        let request_id = "req-abc-123";
685        let payload = vec![1, 2, 3, 4, 5, 6, 7, 8];
686
687        let encoded = WireMessage::encode_encrypted_with_routing(
688            message_type::ENCRYPTED_RESPONSE,
689            request_id,
690            payload.clone(),
691        );
692
693        let (msg_type, decoded_id, decoded_payload) =
694            WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
695
696        assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
697        assert_eq!(decoded_id, request_id);
698        assert_eq!(decoded_payload, &payload[..]);
699    }
700
701    #[test]
702    fn test_wire_message_with_routing_empty_payload() {
703        let request_id = "req-empty";
704        let encoded = WireMessage::encode_encrypted_with_routing(
705            message_type::ENCRYPTED_RESPONSE,
706            request_id,
707            vec![],
708        );
709
710        let (msg_type, decoded_id, decoded_payload) =
711            WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
712
713        assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
714        assert_eq!(decoded_id, request_id);
715        assert!(decoded_payload.is_empty());
716    }
717
718    #[test]
719    fn test_wire_message_with_routing_uuid_request_id() {
720        // UUID format request IDs are common
721        let request_id = "550e8400-e29b-41d4-a716-446655440000";
722        let payload = b"encrypted data here".to_vec();
723
724        let encoded = WireMessage::encode_encrypted_with_routing(
725            message_type::ENCRYPTED_RESPONSE,
726            request_id,
727            payload.clone(),
728        );
729
730        let (msg_type, decoded_id, decoded_payload) =
731            WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
732
733        assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
734        assert_eq!(decoded_id, request_id);
735        assert_eq!(decoded_payload, &payload[..]);
736    }
737
738    #[test]
739    fn test_wire_message_with_routing_format() {
740        // Verify the wire format: [msg_type][id_len][id][payload]
741        let request_id = "test";
742        let payload = vec![0xAA, 0xBB];
743
744        let encoded = WireMessage::encode_encrypted_with_routing(
745            message_type::ENCRYPTED_RESPONSE,
746            request_id,
747            payload,
748        );
749
750        assert_eq!(encoded[0], message_type::ENCRYPTED_RESPONSE); // msg_type
751        assert_eq!(encoded[1], 4); // id_len = "test".len()
752        assert_eq!(&encoded[2..6], b"test"); // id
753        assert_eq!(&encoded[6..], &[0xAA, 0xBB]); // payload
754    }
755
756    #[test]
757    fn test_wire_message_with_routing_decode_too_short() {
758        // Only message type byte
759        let result = WireMessage::decode_encrypted_with_routing(&[0x02]);
760        assert!(result.is_err());
761        assert_eq!(result.unwrap_err(), "Message too short");
762    }
763
764    #[test]
765    fn test_wire_message_with_routing_decode_truncated_id() {
766        // Says id_len=10 but only has 3 bytes of id
767        let data = vec![0x02, 10, b'a', b'b', b'c'];
768        let result = WireMessage::decode_encrypted_with_routing(&data);
769        assert!(result.is_err());
770        assert_eq!(result.unwrap_err(), "Message truncated");
771    }
772
773    #[test]
774    fn test_wire_message_with_routing_long_request_id() {
775        // Request IDs longer than 255 bytes should be truncated
776        let long_id: String = "x".repeat(300);
777        let payload = vec![1, 2, 3];
778
779        let encoded = WireMessage::encode_encrypted_with_routing(
780            message_type::ENCRYPTED_RESPONSE,
781            &long_id,
782            payload.clone(),
783        );
784
785        let (_, decoded_id, decoded_payload) =
786            WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
787
788        // ID should be truncated to 255 chars
789        assert_eq!(decoded_id.len(), 255);
790        assert_eq!(decoded_payload, &payload[..]);
791    }
792
793    #[test]
794    fn test_control_message_encode_decode() {
795        let msg = ControlMessage::Welcome {
796            tunnel_id: "test123".to_string(),
797            tunnel_url: "https://test123.relay.example.com".to_string(),
798        };
799
800        let encoded = WireMessage::encode_control(&msg);
801        let decoded = WireMessage::decode_control(&encoded).unwrap();
802
803        match decoded {
804            ControlMessage::Welcome { tunnel_id, .. } => {
805                assert_eq!(tunnel_id, "test123");
806            }
807            _ => panic!("Wrong message type"),
808        }
809    }
810
811    // ============================================================================
812    // Message Type Constants Tests
813    // ============================================================================
814
815    #[test]
816    fn test_message_type_constants_unique() {
817        let types = [
818            message_type::ENCRYPTED_REQUEST,
819            message_type::ENCRYPTED_RESPONSE,
820            message_type::ENCRYPTED_EVENT,
821            message_type::PING,
822            message_type::PONG,
823            message_type::CLOSE,
824            message_type::RELAY_ERROR,
825        ];
826
827        // Check all types are unique
828        let mut seen = std::collections::HashSet::new();
829        for t in types {
830            assert!(seen.insert(t), "Duplicate message type: 0x{:02X}", t);
831        }
832    }
833
834    #[test]
835    fn test_message_type_ranges() {
836        // Data messages should be 0x01-0x0F
837        assert!(message_type::ENCRYPTED_REQUEST < 0x10);
838        assert!(message_type::ENCRYPTED_RESPONSE < 0x10);
839        assert!(message_type::ENCRYPTED_EVENT < 0x10);
840
841        // Control messages should be 0x10-0xFE
842        assert!(message_type::PING >= 0x10 && message_type::PING < 0xFF);
843        assert!(message_type::PONG >= 0x10 && message_type::PONG < 0xFF);
844        assert!(message_type::CLOSE >= 0x10 && message_type::CLOSE < 0xFF);
845
846        // Error is special 0xFF
847        assert_eq!(message_type::RELAY_ERROR, 0xFF);
848    }
849
850    // ============================================================================
851    // Roundtrip Tests
852    // ============================================================================
853
854    #[test]
855    fn test_http_request_response_roundtrip() {
856        // Simulate a complete request-response cycle
857        let mut req_headers = HashMap::new();
858        req_headers.insert("Content-Type".to_string(), "application/json".to_string());
859        req_headers.insert("Authorization".to_string(), "Bearer token123".to_string());
860
861        let request = DataMessage::HttpRequest {
862            request_id: "req-roundtrip-1".to_string(),
863            client_id: "client-1".to_string(),
864            method: "POST".to_string(),
865            path: "/api/data".to_string(),
866            query: Some("format=json".to_string()),
867            headers: req_headers,
868            body: Some(base64::engine::general_purpose::STANDARD.encode(r#"{"data":"test"}"#)),
869        };
870
871        // Serialize and deserialize request
872        let req_json = serde_json::to_vec(&request).unwrap();
873        let req_decoded: DataMessage = serde_json::from_slice(&req_json).unwrap();
874
875        // Extract request_id for response
876        let request_id = match req_decoded {
877            DataMessage::HttpRequest { ref request_id, .. } => request_id.clone(),
878            _ => panic!("Expected HttpRequest"),
879        };
880
881        // Create response
882        let mut resp_headers = HashMap::new();
883        resp_headers.insert("Content-Type".to_string(), "application/json".to_string());
884
885        let response = DataMessage::HttpResponse {
886            request_id,
887            status: 201,
888            headers: resp_headers,
889            body: Some(base64::engine::general_purpose::STANDARD.encode(r#"{"id":"123"}"#)),
890            streaming: false,
891        };
892
893        let resp_json = serde_json::to_vec(&response).unwrap();
894        let resp_decoded: DataMessage = serde_json::from_slice(&resp_json).unwrap();
895
896        match resp_decoded {
897            DataMessage::HttpResponse { request_id, status, .. } => {
898                assert_eq!(request_id, "req-roundtrip-1");
899                assert_eq!(status, 201);
900            }
901            _ => panic!("Expected HttpResponse"),
902        }
903    }
904}