http_tunnel_common/protocol/
message.rs

1use serde::{Deserialize, Serialize};
2
3use super::{HttpRequest, HttpResponse};
4
5/// All WebSocket messages are wrapped in this typed envelope
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum Message {
9    /// Control plane messages
10    Ping,
11    Pong,
12    Ready, // Sent by forwarder after connection to request connection info
13
14    /// Connection lifecycle
15    ConnectionEstablished {
16        connection_id: String,
17        tunnel_id: String,
18        public_url: String,
19    },
20
21    /// Data plane messages
22    HttpRequest(HttpRequest),
23    HttpResponse(HttpResponse),
24
25    /// Error handling
26    Error {
27        request_id: Option<String>,
28        code: ErrorCode,
29        message: String,
30    },
31}
32
33/// Error codes for tunnel operations
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum ErrorCode {
37    InvalidRequest,
38    Timeout,
39    LocalServiceUnavailable,
40    InternalError,
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46    use std::collections::HashMap;
47
48    #[test]
49    fn test_ping_pong_serialization() {
50        let ping = Message::Ping;
51        let json = serde_json::to_string(&ping).unwrap();
52        assert_eq!(json, r#"{"type":"ping"}"#);
53
54        let pong = Message::Pong;
55        let json = serde_json::to_string(&pong).unwrap();
56        assert_eq!(json, r#"{"type":"pong"}"#);
57
58        let parsed: Message = serde_json::from_str(&json).unwrap();
59        assert!(matches!(parsed, Message::Pong));
60    }
61
62    #[test]
63    fn test_connection_established_serialization() {
64        let msg = Message::ConnectionEstablished {
65            connection_id: "conn_123".to_string(),
66            tunnel_id: "abc123def456".to_string(),
67            public_url: "https://tunnel.example.com/abc123def456".to_string(),
68        };
69
70        let json = serde_json::to_string(&msg).unwrap();
71        assert!(json.contains(r#""type":"connection_established"#));
72        assert!(json.contains(r#""connection_id":"conn_123"#));
73
74        let parsed: Message = serde_json::from_str(&json).unwrap();
75        match parsed {
76            Message::ConnectionEstablished { connection_id, .. } => {
77                assert_eq!(connection_id, "conn_123");
78            }
79            _ => panic!("Expected ConnectionEstablished"),
80        }
81    }
82
83    #[test]
84    fn test_http_request_serialization() {
85        let request = HttpRequest {
86            request_id: "req_123".to_string(),
87            method: "GET".to_string(),
88            uri: "/api/v1/users".to_string(),
89            headers: HashMap::new(),
90            body: String::new(),
91            timestamp: 1234567890,
92        };
93
94        let msg = Message::HttpRequest(request);
95        let json = serde_json::to_string(&msg).unwrap();
96        assert!(json.contains(r#""type":"http_request"#));
97        assert!(json.contains(r#""request_id":"req_123"#));
98
99        let parsed: Message = serde_json::from_str(&json).unwrap();
100        assert!(matches!(parsed, Message::HttpRequest(_)));
101    }
102
103    #[test]
104    fn test_error_serialization() {
105        let msg = Message::Error {
106            request_id: Some("req_123".to_string()),
107            code: ErrorCode::Timeout,
108            message: "Request timed out".to_string(),
109        };
110
111        let json = serde_json::to_string(&msg).unwrap();
112        assert!(json.contains(r#""type":"error"#));
113        assert!(json.contains(r#""code":"timeout"#));
114        assert!(json.contains(r#""message":"Request timed out"#));
115
116        let parsed: Message = serde_json::from_str(&json).unwrap();
117        match parsed {
118            Message::Error { code, .. } => {
119                assert!(matches!(code, ErrorCode::Timeout));
120            }
121            _ => panic!("Expected Error"),
122        }
123    }
124
125    #[test]
126    fn test_error_code_serialization() {
127        let codes = vec![
128            (ErrorCode::InvalidRequest, "invalid_request"),
129            (ErrorCode::Timeout, "timeout"),
130            (
131                ErrorCode::LocalServiceUnavailable,
132                "local_service_unavailable",
133            ),
134            (ErrorCode::InternalError, "internal_error"),
135        ];
136
137        for (code, expected_json) in codes {
138            let json = serde_json::to_string(&code).unwrap();
139            assert_eq!(json, format!(r#""{}""#, expected_json));
140
141            let parsed: ErrorCode = serde_json::from_str(&json).unwrap();
142            assert!(matches!(
143                (code, parsed),
144                (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest)
145                    | (ErrorCode::Timeout, ErrorCode::Timeout)
146                    | (
147                        ErrorCode::LocalServiceUnavailable,
148                        ErrorCode::LocalServiceUnavailable
149                    )
150                    | (ErrorCode::InternalError, ErrorCode::InternalError)
151            ));
152        }
153    }
154}