use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub mod message_type {
pub const ENCRYPTED_REQUEST: u8 = 0x01;
pub const ENCRYPTED_RESPONSE: u8 = 0x02;
pub const ENCRYPTED_EVENT: u8 = 0x03;
pub const PING: u8 = 0x10;
pub const PONG: u8 = 0x11;
pub const CLOSE: u8 = 0x12;
pub const RELAY_ERROR: u8 = 0xFF;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ControlMessage {
Hello {
version: u8,
#[serde(skip_serializing_if = "Option::is_none")]
requested_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
auth_token: Option<String>,
},
Welcome {
tunnel_id: String,
tunnel_url: String,
},
ClientConnected {
client_id: String,
},
ClientDisconnected {
client_id: String,
},
Ping {
timestamp: u64,
},
Pong {
timestamp: u64,
},
Close {
reason: String,
},
Error {
code: String,
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DataMessage {
HttpRequest {
request_id: String,
client_id: String,
method: String,
path: String,
#[serde(skip_serializing_if = "Option::is_none")]
query: Option<String>,
headers: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
body: Option<String>,
},
HttpResponse {
request_id: String,
status: u16,
headers: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
body: Option<String>,
#[serde(default)]
streaming: bool,
},
HttpResponseChunk {
request_id: String,
chunk: String,
#[serde(default)]
is_final: bool,
},
RequestError {
#[serde(skip_serializing_if = "Option::is_none")]
request_id: Option<String>,
code: String,
message: String,
},
}
#[derive(Debug, Clone)]
pub enum WireMessage {
Control(ControlMessage),
EncryptedData {
message_type: u8,
payload: Vec<u8>,
},
}
impl WireMessage {
pub fn encode_control(msg: &ControlMessage) -> Vec<u8> {
serde_json::to_vec(msg).expect("Control message serialization should not fail")
}
pub fn decode_control(data: &[u8]) -> Result<ControlMessage, serde_json::Error> {
serde_json::from_slice(data)
}
pub fn encode_encrypted(message_type: u8, encrypted_payload: Vec<u8>) -> Vec<u8> {
let mut result = Vec::with_capacity(1 + encrypted_payload.len());
result.push(message_type);
result.extend(encrypted_payload);
result
}
pub fn encode_encrypted_with_routing(
message_type: u8,
request_id: &str,
encrypted_payload: Vec<u8>,
) -> Vec<u8> {
let request_id_bytes = request_id.as_bytes();
let id_len = request_id_bytes.len().min(255) as u8;
let mut result = Vec::with_capacity(2 + id_len as usize + encrypted_payload.len());
result.push(message_type);
result.push(id_len);
result.extend_from_slice(&request_id_bytes[..id_len as usize]);
result.extend(encrypted_payload);
result
}
pub fn decode_encrypted(data: &[u8]) -> Result<(u8, &[u8]), &'static str> {
if data.is_empty() {
return Err("Empty message");
}
let message_type = data[0];
let payload = &data[1..];
Ok((message_type, payload))
}
pub fn decode_encrypted_with_routing(data: &[u8]) -> Result<(u8, &str, &[u8]), &'static str> {
if data.len() < 2 {
return Err("Message too short");
}
let message_type = data[0];
let id_len = data[1] as usize;
if data.len() < 2 + id_len {
return Err("Message truncated");
}
let request_id = std::str::from_utf8(&data[2..2 + id_len])
.map_err(|_| "Invalid request_id encoding")?;
let payload = &data[2 + id_len..];
Ok((message_type, request_id, payload))
}
}
pub mod url {
use anyhow::{anyhow, Result};
#[derive(Debug, Clone)]
pub struct TunnelUrl {
pub base_url: String,
pub tunnel_id: String,
pub encryption_key: String,
}
impl TunnelUrl {
pub fn parse(url: &str) -> Result<Self> {
let (base, fragment) = url
.split_once('#')
.ok_or_else(|| anyhow!("Missing encryption key in URL fragment"))?;
if fragment.is_empty() {
return Err(anyhow!("Empty encryption key in URL fragment"));
}
let host = base
.strip_prefix("https://")
.or_else(|| base.strip_prefix("http://"))
.ok_or_else(|| anyhow!("Invalid URL scheme"))?;
let host = host.split('/').next().unwrap_or(host);
let tunnel_id = host
.split('.')
.next()
.ok_or_else(|| anyhow!("Cannot extract tunnel ID from URL"))?;
Ok(Self {
base_url: base.to_string(),
tunnel_id: tunnel_id.to_string(),
encryption_key: fragment.to_string(),
})
}
pub fn build(base_url: &str, encryption_key: &str) -> String {
format!("{}#{}", base_url, encryption_key)
}
}
#[cfg(test)]
mod tests {
use super::TunnelUrl;
#[test]
fn test_parse_tunnel_url() {
let url = "https://abc12345.relay.example.com#K8dX2mPqR7vNzL5hJwYtF3gBcE9sUoAi";
let parsed = TunnelUrl::parse(url).unwrap();
assert_eq!(parsed.base_url, "https://abc12345.relay.example.com");
assert_eq!(parsed.tunnel_id, "abc12345");
assert_eq!(parsed.encryption_key, "K8dX2mPqR7vNzL5hJwYtF3gBcE9sUoAi");
}
#[test]
fn test_parse_tunnel_url_with_path() {
let url = "https://abc12345.relay.example.com/some/path#key123";
let parsed = TunnelUrl::parse(url).unwrap();
assert_eq!(parsed.base_url, "https://abc12345.relay.example.com/some/path");
assert_eq!(parsed.tunnel_id, "abc12345");
assert_eq!(parsed.encryption_key, "key123");
}
#[test]
fn test_parse_missing_fragment() {
let url = "https://abc12345.relay.example.com";
let result = TunnelUrl::parse(url);
assert!(result.is_err());
}
#[test]
fn test_parse_empty_fragment() {
let url = "https://abc12345.relay.example.com#";
let result = TunnelUrl::parse(url);
assert!(result.is_err());
}
#[test]
fn test_build_tunnel_url() {
let url = TunnelUrl::build("https://abc.relay.example.com", "mykey123");
assert_eq!(url, "https://abc.relay.example.com#mykey123");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
#[test]
fn test_control_message_serialization() {
let msg = ControlMessage::Hello {
version: 1,
requested_id: Some("test123".to_string()),
auth_token: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"hello\""));
assert!(json.contains("\"version\":1"));
let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
match decoded {
ControlMessage::Hello { version, requested_id, auth_token } => {
assert_eq!(version, 1);
assert_eq!(requested_id, Some("test123".to_string()));
assert_eq!(auth_token, None);
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_hello_without_requested_id() {
let msg = ControlMessage::Hello {
version: 1,
requested_id: None,
auth_token: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("requested_id"));
let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
match decoded {
ControlMessage::Hello { version, requested_id, auth_token } => {
assert_eq!(version, 1);
assert_eq!(requested_id, None);
assert_eq!(auth_token, None);
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_hello_with_auth_token() {
let msg = ControlMessage::Hello {
version: 1,
requested_id: Some("test123".to_string()),
auth_token: Some("secret-token-abc".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"auth_token\":\"secret-token-abc\""));
let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
match decoded {
ControlMessage::Hello { version, requested_id, auth_token } => {
assert_eq!(version, 1);
assert_eq!(requested_id, Some("test123".to_string()));
assert_eq!(auth_token, Some("secret-token-abc".to_string()));
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_welcome_message() {
let msg = ControlMessage::Welcome {
tunnel_id: "abc123".to_string(),
tunnel_url: "https://abc123.relay.example.com".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"welcome\""));
assert!(json.contains("\"tunnel_id\":\"abc123\""));
let decoded: ControlMessage = serde_json::from_str(&json).unwrap();
match decoded {
ControlMessage::Welcome { tunnel_id, tunnel_url } => {
assert_eq!(tunnel_id, "abc123");
assert_eq!(tunnel_url, "https://abc123.relay.example.com");
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_ping_pong_messages() {
let ping = ControlMessage::Ping { timestamp: 1234567890 };
let ping_json = serde_json::to_string(&ping).unwrap();
assert!(ping_json.contains("\"type\":\"ping\""));
assert!(ping_json.contains("\"timestamp\":1234567890"));
let pong = ControlMessage::Pong { timestamp: 1234567890 };
let pong_json = serde_json::to_string(&pong).unwrap();
assert!(pong_json.contains("\"type\":\"pong\""));
}
#[test]
fn test_close_message() {
let msg = ControlMessage::Close {
reason: "graceful shutdown".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"close\""));
assert!(json.contains("graceful shutdown"));
}
#[test]
fn test_error_message() {
let msg = ControlMessage::Error {
code: "RATE_LIMITED".to_string(),
message: "Too many requests".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"error\""));
assert!(json.contains("RATE_LIMITED"));
}
#[test]
fn test_client_connected_disconnected() {
let connected = ControlMessage::ClientConnected {
client_id: "client-123".to_string(),
};
let connected_json = serde_json::to_string(&connected).unwrap();
assert!(connected_json.contains("\"type\":\"client_connected\""));
let disconnected = ControlMessage::ClientDisconnected {
client_id: "client-123".to_string(),
};
let disconnected_json = serde_json::to_string(&disconnected).unwrap();
assert!(disconnected_json.contains("\"type\":\"client_disconnected\""));
}
#[test]
fn test_data_message_serialization() {
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
let msg = DataMessage::HttpRequest {
request_id: "req-123".to_string(),
client_id: "client-456".to_string(),
method: "POST".to_string(),
path: "/api/sessions".to_string(),
query: Some("foo=bar".to_string()),
headers,
body: Some("eyJoZWxsbyI6IndvcmxkIn0=".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"http_request\""));
assert!(json.contains("\"method\":\"POST\""));
let decoded: DataMessage = serde_json::from_str(&json).unwrap();
match decoded {
DataMessage::HttpRequest { method, path, .. } => {
assert_eq!(method, "POST");
assert_eq!(path, "/api/sessions");
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_http_request_minimal() {
let msg = DataMessage::HttpRequest {
request_id: "req-1".to_string(),
client_id: "client-1".to_string(),
method: "GET".to_string(),
path: "/health".to_string(),
query: None,
headers: HashMap::new(),
body: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("\"query\""));
assert!(!json.contains("\"body\""));
}
#[test]
fn test_http_response() {
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
let msg = DataMessage::HttpResponse {
request_id: "req-123".to_string(),
status: 200,
headers,
body: Some("eyJvayI6dHJ1ZX0=".to_string()),
streaming: false,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"http_response\""));
assert!(json.contains("\"status\":200"));
assert!(json.contains("\"streaming\":false"));
}
#[test]
fn test_http_response_streaming() {
let msg = DataMessage::HttpResponse {
request_id: "req-123".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: true,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"streaming\":true"));
}
#[test]
fn test_http_response_chunk() {
let msg = DataMessage::HttpResponseChunk {
request_id: "req-123".to_string(),
chunk: "ZGF0YTogaGVsbG8K".to_string(), is_final: false,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"http_response_chunk\""));
assert!(json.contains("\"is_final\":false"));
let final_chunk = DataMessage::HttpResponseChunk {
request_id: "req-123".to_string(),
chunk: "".to_string(),
is_final: true,
};
let final_json = serde_json::to_string(&final_chunk).unwrap();
assert!(final_json.contains("\"is_final\":true"));
}
#[test]
fn test_request_error() {
let msg = DataMessage::RequestError {
request_id: Some("req-123".to_string()),
code: "TIMEOUT".to_string(),
message: "Request timed out".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"type\":\"request_error\""));
assert!(json.contains("TIMEOUT"));
let msg_no_id = DataMessage::RequestError {
request_id: None,
code: "INTERNAL_ERROR".to_string(),
message: "Something went wrong".to_string(),
};
let json_no_id = serde_json::to_string(&msg_no_id).unwrap();
assert!(!json_no_id.contains("\"request_id\""));
}
#[test]
fn test_wire_message_encoding() {
let encrypted = vec![1, 2, 3, 4, 5];
let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, encrypted.clone());
assert_eq!(encoded[0], message_type::ENCRYPTED_REQUEST);
assert_eq!(&encoded[1..], &encrypted[..]);
let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
assert_eq!(payload, &encrypted[..]);
}
#[test]
fn test_wire_message_all_types() {
let test_cases = [
message_type::ENCRYPTED_REQUEST,
message_type::ENCRYPTED_RESPONSE,
message_type::ENCRYPTED_EVENT,
message_type::PING,
message_type::PONG,
message_type::CLOSE,
message_type::RELAY_ERROR,
];
for msg_type in test_cases {
let payload = vec![0xAB, 0xCD, 0xEF];
let encoded = WireMessage::encode_encrypted(msg_type, payload.clone());
let (decoded_type, decoded_payload) = WireMessage::decode_encrypted(&encoded).unwrap();
assert_eq!(decoded_type, msg_type, "Message type mismatch for 0x{:02X}", msg_type);
assert_eq!(decoded_payload, &payload[..]);
}
}
#[test]
fn test_wire_message_empty_payload() {
let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, vec![]);
assert_eq!(encoded.len(), 1);
assert_eq!(encoded[0], message_type::ENCRYPTED_REQUEST);
let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
assert!(payload.is_empty());
}
#[test]
fn test_wire_message_large_payload() {
let large_payload: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let encoded = WireMessage::encode_encrypted(message_type::ENCRYPTED_RESPONSE, large_payload.clone());
assert_eq!(encoded.len(), 1 + large_payload.len());
let (msg_type, payload) = WireMessage::decode_encrypted(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
assert_eq!(payload.len(), large_payload.len());
assert_eq!(payload, &large_payload[..]);
}
#[test]
fn test_wire_message_decode_empty() {
let result = WireMessage::decode_encrypted(&[]);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Empty message");
}
#[test]
fn test_wire_message_with_routing_roundtrip() {
let request_id = "req-abc-123";
let payload = vec![1, 2, 3, 4, 5, 6, 7, 8];
let encoded = WireMessage::encode_encrypted_with_routing(
message_type::ENCRYPTED_RESPONSE,
request_id,
payload.clone(),
);
let (msg_type, decoded_id, decoded_payload) =
WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
assert_eq!(decoded_id, request_id);
assert_eq!(decoded_payload, &payload[..]);
}
#[test]
fn test_wire_message_with_routing_empty_payload() {
let request_id = "req-empty";
let encoded = WireMessage::encode_encrypted_with_routing(
message_type::ENCRYPTED_RESPONSE,
request_id,
vec![],
);
let (msg_type, decoded_id, decoded_payload) =
WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
assert_eq!(decoded_id, request_id);
assert!(decoded_payload.is_empty());
}
#[test]
fn test_wire_message_with_routing_uuid_request_id() {
let request_id = "550e8400-e29b-41d4-a716-446655440000";
let payload = b"encrypted data here".to_vec();
let encoded = WireMessage::encode_encrypted_with_routing(
message_type::ENCRYPTED_RESPONSE,
request_id,
payload.clone(),
);
let (msg_type, decoded_id, decoded_payload) =
WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_RESPONSE);
assert_eq!(decoded_id, request_id);
assert_eq!(decoded_payload, &payload[..]);
}
#[test]
fn test_wire_message_with_routing_format() {
let request_id = "test";
let payload = vec![0xAA, 0xBB];
let encoded = WireMessage::encode_encrypted_with_routing(
message_type::ENCRYPTED_RESPONSE,
request_id,
payload,
);
assert_eq!(encoded[0], message_type::ENCRYPTED_RESPONSE); assert_eq!(encoded[1], 4); assert_eq!(&encoded[2..6], b"test"); assert_eq!(&encoded[6..], &[0xAA, 0xBB]); }
#[test]
fn test_wire_message_with_routing_decode_too_short() {
let result = WireMessage::decode_encrypted_with_routing(&[0x02]);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Message too short");
}
#[test]
fn test_wire_message_with_routing_decode_truncated_id() {
let data = vec![0x02, 10, b'a', b'b', b'c'];
let result = WireMessage::decode_encrypted_with_routing(&data);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Message truncated");
}
#[test]
fn test_wire_message_with_routing_long_request_id() {
let long_id: String = "x".repeat(300);
let payload = vec![1, 2, 3];
let encoded = WireMessage::encode_encrypted_with_routing(
message_type::ENCRYPTED_RESPONSE,
&long_id,
payload.clone(),
);
let (_, decoded_id, decoded_payload) =
WireMessage::decode_encrypted_with_routing(&encoded).unwrap();
assert_eq!(decoded_id.len(), 255);
assert_eq!(decoded_payload, &payload[..]);
}
#[test]
fn test_control_message_encode_decode() {
let msg = ControlMessage::Welcome {
tunnel_id: "test123".to_string(),
tunnel_url: "https://test123.relay.example.com".to_string(),
};
let encoded = WireMessage::encode_control(&msg);
let decoded = WireMessage::decode_control(&encoded).unwrap();
match decoded {
ControlMessage::Welcome { tunnel_id, .. } => {
assert_eq!(tunnel_id, "test123");
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_message_type_constants_unique() {
let types = [
message_type::ENCRYPTED_REQUEST,
message_type::ENCRYPTED_RESPONSE,
message_type::ENCRYPTED_EVENT,
message_type::PING,
message_type::PONG,
message_type::CLOSE,
message_type::RELAY_ERROR,
];
let mut seen = std::collections::HashSet::new();
for t in types {
assert!(seen.insert(t), "Duplicate message type: 0x{:02X}", t);
}
}
#[test]
fn test_message_type_ranges() {
assert!(message_type::ENCRYPTED_REQUEST < 0x10);
assert!(message_type::ENCRYPTED_RESPONSE < 0x10);
assert!(message_type::ENCRYPTED_EVENT < 0x10);
assert!(message_type::PING >= 0x10 && message_type::PING < 0xFF);
assert!(message_type::PONG >= 0x10 && message_type::PONG < 0xFF);
assert!(message_type::CLOSE >= 0x10 && message_type::CLOSE < 0xFF);
assert_eq!(message_type::RELAY_ERROR, 0xFF);
}
#[test]
fn test_http_request_response_roundtrip() {
let mut req_headers = HashMap::new();
req_headers.insert("Content-Type".to_string(), "application/json".to_string());
req_headers.insert("Authorization".to_string(), "Bearer token123".to_string());
let request = DataMessage::HttpRequest {
request_id: "req-roundtrip-1".to_string(),
client_id: "client-1".to_string(),
method: "POST".to_string(),
path: "/api/data".to_string(),
query: Some("format=json".to_string()),
headers: req_headers,
body: Some(base64::engine::general_purpose::STANDARD.encode(r#"{"data":"test"}"#)),
};
let req_json = serde_json::to_vec(&request).unwrap();
let req_decoded: DataMessage = serde_json::from_slice(&req_json).unwrap();
let request_id = match req_decoded {
DataMessage::HttpRequest { ref request_id, .. } => request_id.clone(),
_ => panic!("Expected HttpRequest"),
};
let mut resp_headers = HashMap::new();
resp_headers.insert("Content-Type".to_string(), "application/json".to_string());
let response = DataMessage::HttpResponse {
request_id,
status: 201,
headers: resp_headers,
body: Some(base64::engine::general_purpose::STANDARD.encode(r#"{"id":"123"}"#)),
streaming: false,
};
let resp_json = serde_json::to_vec(&response).unwrap();
let resp_decoded: DataMessage = serde_json::from_slice(&resp_json).unwrap();
match resp_decoded {
DataMessage::HttpResponse { request_id, status, .. } => {
assert_eq!(request_id, "req-roundtrip-1");
assert_eq!(status, 201);
}
_ => panic!("Expected HttpResponse"),
}
}
}