http_tunnel_common/protocol/
message.rs

1use serde::{Deserialize, Serialize};
2
3use super::{HttpRequest, HttpResponse};
4
5/// All WebSocket messages are wrapped in this typed envelope
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(tag = "type", rename_all = "snake_case")]
8pub enum Message {
9    /// Control plane messages
10    Ping,
11    Pong,
12    Ready, // Sent by forwarder after connection to request connection info
13
14    /// Connection lifecycle
15    ConnectionEstablished {
16        connection_id: String,
17        tunnel_id: String,
18        public_url: String,
19        #[serde(skip_serializing_if = "Option::is_none")]
20        subdomain_url: Option<String>,
21        #[serde(skip_serializing_if = "Option::is_none")]
22        path_based_url: Option<String>,
23    },
24
25    /// Data plane messages
26    HttpRequest(HttpRequest),
27    HttpResponse(HttpResponse),
28
29    /// Error handling
30    Error {
31        request_id: Option<String>,
32        code: ErrorCode,
33        message: String,
34    },
35}
36
37/// Error codes for tunnel operations
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum ErrorCode {
41    InvalidRequest,
42    Timeout,
43    LocalServiceUnavailable,
44    InternalError,
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use std::collections::HashMap;
51
52    #[test]
53    fn test_ping_pong_serialization() {
54        let ping = Message::Ping;
55        let json = serde_json::to_string(&ping).unwrap();
56        assert_eq!(json, r#"{"type":"ping"}"#);
57
58        let pong = Message::Pong;
59        let json = serde_json::to_string(&pong).unwrap();
60        assert_eq!(json, r#"{"type":"pong"}"#);
61
62        let parsed: Message = serde_json::from_str(&json).unwrap();
63        assert!(matches!(parsed, Message::Pong));
64    }
65
66    #[test]
67    fn test_connection_established_serialization() {
68        let msg = Message::ConnectionEstablished {
69            connection_id: "conn_123".to_string(),
70            tunnel_id: "abc123def456".to_string(),
71            public_url: "https://abc123def456.tunnel.example.com".to_string(),
72            subdomain_url: Some("https://abc123def456.tunnel.example.com".to_string()),
73            path_based_url: Some("https://tunnel.example.com/abc123def456".to_string()),
74        };
75
76        let json = serde_json::to_string(&msg).unwrap();
77        assert!(json.contains(r#""type":"connection_established"#));
78        assert!(json.contains(r#""connection_id":"conn_123"#));
79
80        let parsed: Message = serde_json::from_str(&json).unwrap();
81        match parsed {
82            Message::ConnectionEstablished { connection_id, .. } => {
83                assert_eq!(connection_id, "conn_123");
84            }
85            _ => panic!("Expected ConnectionEstablished"),
86        }
87    }
88
89    #[test]
90    fn test_connection_established_backward_compat() {
91        // Test backward compatibility - old messages without subdomain_url should still parse
92        let json = r#"{"type":"connection_established","connection_id":"conn_123","tunnel_id":"abc123def456","public_url":"https://tunnel.example.com/abc123def456"}"#;
93
94        let parsed: Message = serde_json::from_str(json).unwrap();
95        match parsed {
96            Message::ConnectionEstablished {
97                connection_id,
98                subdomain_url,
99                path_based_url,
100                ..
101            } => {
102                assert_eq!(connection_id, "conn_123");
103                assert!(subdomain_url.is_none());
104                assert!(path_based_url.is_none());
105            }
106            _ => panic!("Expected ConnectionEstablished"),
107        }
108    }
109
110    #[test]
111    fn test_http_request_serialization() {
112        let request = HttpRequest {
113            request_id: "req_123".to_string(),
114            method: "GET".to_string(),
115            uri: "/api/v1/users".to_string(),
116            headers: HashMap::new(),
117            body: String::new(),
118            timestamp: 1234567890,
119        };
120
121        let msg = Message::HttpRequest(request);
122        let json = serde_json::to_string(&msg).unwrap();
123        assert!(json.contains(r#""type":"http_request"#));
124        assert!(json.contains(r#""request_id":"req_123"#));
125
126        let parsed: Message = serde_json::from_str(&json).unwrap();
127        assert!(matches!(parsed, Message::HttpRequest(_)));
128    }
129
130    #[test]
131    fn test_error_serialization() {
132        let msg = Message::Error {
133            request_id: Some("req_123".to_string()),
134            code: ErrorCode::Timeout,
135            message: "Request timed out".to_string(),
136        };
137
138        let json = serde_json::to_string(&msg).unwrap();
139        assert!(json.contains(r#""type":"error"#));
140        assert!(json.contains(r#""code":"timeout"#));
141        assert!(json.contains(r#""message":"Request timed out"#));
142
143        let parsed: Message = serde_json::from_str(&json).unwrap();
144        match parsed {
145            Message::Error { code, .. } => {
146                assert!(matches!(code, ErrorCode::Timeout));
147            }
148            _ => panic!("Expected Error"),
149        }
150    }
151
152    #[test]
153    fn test_error_code_serialization() {
154        let codes = vec![
155            (ErrorCode::InvalidRequest, "invalid_request"),
156            (ErrorCode::Timeout, "timeout"),
157            (
158                ErrorCode::LocalServiceUnavailable,
159                "local_service_unavailable",
160            ),
161            (ErrorCode::InternalError, "internal_error"),
162        ];
163
164        for (code, expected_json) in codes {
165            let json = serde_json::to_string(&code).unwrap();
166            assert_eq!(json, format!(r#""{}""#, expected_json));
167
168            let parsed: ErrorCode = serde_json::from_str(&json).unwrap();
169            assert!(matches!(
170                (code, parsed),
171                (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest)
172                    | (ErrorCode::Timeout, ErrorCode::Timeout)
173                    | (
174                        ErrorCode::LocalServiceUnavailable,
175                        ErrorCode::LocalServiceUnavailable
176                    )
177                    | (ErrorCode::InternalError, ErrorCode::InternalError)
178            ));
179        }
180    }
181}