pub mod crypto;
pub mod protocol;
pub use crypto::{TunnelCrypto, TunnelKey, KEY_SIZE, NONCE_SIZE};
pub use protocol::{
message_type, url::TunnelUrl, ControlMessage, DataMessage, WireMessage,
};
pub const PROTOCOL_VERSION: u8 = 1;
pub const DEFAULT_RELAY_URL: &str = "wss://relay.agent.example.com";
pub const RELAY_URL_ENV: &str = "AGENT_RELAY_URL";
#[cfg(test)]
mod integration_tests {
use super::*;
use base64::Engine;
use std::collections::HashMap;
#[test]
fn test_encrypted_handshake_flow() {
let key = TunnelKey::generate();
let tunnel_id = "test-tunnel-abc123";
let base_url = format!("https://{}.relay.example.com", tunnel_id);
let url = TunnelUrl::build(&base_url, &key.to_base64());
assert!(url.starts_with("https://test-tunnel-abc123.relay.example.com"));
let parsed = TunnelUrl::parse(&url).unwrap();
assert_eq!(parsed.tunnel_id, tunnel_id);
let client_key = TunnelKey::from_base64(&parsed.encryption_key).unwrap();
let agent_crypto = TunnelCrypto::new(&key);
let client_crypto = TunnelCrypto::new(&client_key);
let hello = ControlMessage::Hello {
version: PROTOCOL_VERSION,
requested_id: Some(tunnel_id.to_string()),
auth_token: Some("test-token".to_string()),
};
let hello_json = serde_json::to_vec(&hello).unwrap();
let encrypted = agent_crypto.encrypt(&hello_json).unwrap();
let decrypted = client_crypto.decrypt(&encrypted).unwrap();
let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
ControlMessage::Hello {
version,
requested_id,
auth_token,
} => {
assert_eq!(version, PROTOCOL_VERSION);
assert_eq!(requested_id, Some(tunnel_id.to_string()));
assert_eq!(auth_token, Some("test-token".to_string()));
}
_ => panic!("Expected Hello message"),
}
}
#[test]
fn test_encrypted_welcome_response() {
let key = TunnelKey::generate();
let agent_crypto = TunnelCrypto::new(&key);
let relay_crypto = TunnelCrypto::new(&key);
let welcome = ControlMessage::Welcome {
tunnel_id: "final-tunnel-id".to_string(),
tunnel_url: "https://final-tunnel-id.relay.example.com".to_string(),
};
let encrypted = relay_crypto
.encrypt(&serde_json::to_vec(&welcome).unwrap())
.unwrap();
let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
ControlMessage::Welcome {
tunnel_id,
tunnel_url,
} => {
assert_eq!(tunnel_id, "final-tunnel-id");
assert!(tunnel_url.contains("final-tunnel-id"));
}
_ => panic!("Expected Welcome message"),
}
}
#[test]
fn test_encrypted_http_request_response_flow() {
let key = TunnelKey::generate();
let client_crypto = TunnelCrypto::new(&key);
let agent_crypto = TunnelCrypto::new(&key);
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers.insert(
"Authorization".to_string(),
"Bearer secret-token".to_string(),
);
let request = DataMessage::HttpRequest {
request_id: "req-001".to_string(),
client_id: "client-123".to_string(),
method: "POST".to_string(),
path: "/api/sessions".to_string(),
query: None,
headers,
body: Some("eyJtZXNzYWdlIjoiaGVsbG8ifQ".to_string()), };
let request_json = serde_json::to_vec(&request).unwrap();
let encrypted_request = client_crypto.encrypt(&request_json).unwrap();
let wire_data = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, encrypted_request);
assert_eq!(wire_data[0], message_type::ENCRYPTED_REQUEST);
let (msg_type, payload) = WireMessage::decode_encrypted(&wire_data).unwrap();
assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
let decrypted_request = agent_crypto.decrypt(payload).unwrap();
let decoded_request: DataMessage = serde_json::from_slice(&decrypted_request).unwrap();
match decoded_request {
DataMessage::HttpRequest {
request_id,
method,
path,
headers,
body,
..
} => {
assert_eq!(request_id, "req-001");
assert_eq!(method, "POST");
assert_eq!(path, "/api/sessions");
assert_eq!(headers.len(), 2);
assert!(headers.contains_key("Authorization"));
assert!(body.is_some());
}
_ => panic!("Expected HttpRequest"),
}
let mut response_headers = HashMap::new();
response_headers.insert("Content-Type".to_string(), "application/json".to_string());
response_headers.insert("Location".to_string(), "/api/sessions/new-id".to_string());
let response = DataMessage::HttpResponse {
request_id: "req-001".to_string(),
status: 201,
headers: response_headers,
body: Some("eyJpZCI6Im5ldy1pZCJ9".to_string()), streaming: false,
};
let response_json = serde_json::to_vec(&response).unwrap();
let encrypted_response = agent_crypto.encrypt(&response_json).unwrap();
let decrypted_response = client_crypto.decrypt(&encrypted_response).unwrap();
let decoded_response: DataMessage = serde_json::from_slice(&decrypted_response).unwrap();
match decoded_response {
DataMessage::HttpResponse {
request_id,
status,
streaming,
..
} => {
assert_eq!(request_id, "req-001");
assert_eq!(status, 201);
assert!(!streaming);
}
_ => panic!("Expected HttpResponse"),
}
}
#[test]
fn test_encrypted_streaming_response() {
let key = TunnelKey::generate();
let client_crypto = TunnelCrypto::new(&key);
let agent_crypto = TunnelCrypto::new(&key);
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "text/event-stream".to_string());
headers.insert("Cache-Control".to_string(), "no-cache".to_string());
let response = DataMessage::HttpResponse {
request_id: "sse-001".to_string(),
status: 200,
headers,
body: None,
streaming: true,
};
let encrypted = agent_crypto
.encrypt(&serde_json::to_vec(&response).unwrap())
.unwrap();
let decrypted = client_crypto.decrypt(&encrypted).unwrap();
let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
DataMessage::HttpResponse { streaming, .. } => {
assert!(streaming, "Should be a streaming response");
}
_ => panic!("Expected HttpResponse"),
}
let chunks = vec![
("data: event 1\n\n", false),
("data: event 2\n\n", false),
("data: event 3\n\n", true), ];
for (data, is_final) in chunks {
let chunk = DataMessage::HttpResponseChunk {
request_id: "sse-001".to_string(),
chunk: base64::engine::general_purpose::STANDARD.encode(data),
is_final,
};
let encrypted = agent_crypto
.encrypt(&serde_json::to_vec(&chunk).unwrap())
.unwrap();
let decrypted = client_crypto.decrypt(&encrypted).unwrap();
let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
DataMessage::HttpResponseChunk {
request_id,
chunk: chunk_data,
is_final: final_flag,
} => {
assert_eq!(request_id, "sse-001");
let decoded_data =
base64::engine::general_purpose::STANDARD.decode(&chunk_data).unwrap();
assert!(String::from_utf8_lossy(&decoded_data).starts_with("data: event"));
assert_eq!(final_flag, is_final);
}
_ => panic!("Expected HttpResponseChunk"),
}
}
}
#[test]
fn test_encrypted_error_message() {
let key = TunnelKey::generate();
let agent_crypto = TunnelCrypto::new(&key);
let client_crypto = TunnelCrypto::new(&key);
let error = DataMessage::RequestError {
request_id: Some("failed-001".to_string()),
code: "CONNECTION_REFUSED".to_string(),
message: "Connection refused: localhost:3001".to_string(),
};
let encrypted = agent_crypto
.encrypt(&serde_json::to_vec(&error).unwrap())
.unwrap();
let decrypted = client_crypto.decrypt(&encrypted).unwrap();
let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
DataMessage::RequestError {
request_id,
code,
message,
} => {
assert_eq!(request_id, Some("failed-001".to_string()));
assert_eq!(code, "CONNECTION_REFUSED");
assert!(message.contains("Connection refused"));
}
_ => panic!("Expected RequestError"),
}
}
#[test]
fn test_encrypted_control_error() {
let key = TunnelKey::generate();
let relay_crypto = TunnelCrypto::new(&key);
let agent_crypto = TunnelCrypto::new(&key);
let error = ControlMessage::Error {
code: "RATE_LIMITED".to_string(),
message: "Too many requests".to_string(),
};
let encrypted = relay_crypto
.encrypt(&serde_json::to_vec(&error).unwrap())
.unwrap();
let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
ControlMessage::Error { code, message } => {
assert_eq!(code, "RATE_LIMITED");
assert!(message.contains("Too many requests"));
}
_ => panic!("Expected Error message"),
}
}
#[test]
fn test_wire_message_types_through_encryption() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let message_types = vec![
("request", message_type::ENCRYPTED_REQUEST),
("response", message_type::ENCRYPTED_RESPONSE),
("event", message_type::ENCRYPTED_EVENT),
];
for (name, msg_type) in message_types {
let payload = b"test payload data";
let encrypted = crypto.encrypt(payload).unwrap();
let wire_data = WireMessage::encode_encrypted(msg_type, encrypted);
let (decoded_type, decoded_payload) = WireMessage::decode_encrypted(&wire_data).unwrap();
let decrypted = crypto.decrypt(decoded_payload).unwrap();
assert_eq!(decoded_type, msg_type, "Message type mismatch for: {}", name);
assert_eq!(decrypted, payload, "Payload mismatch for: {}", name);
}
}
#[test]
fn test_control_message_encoding() {
let hello = ControlMessage::Hello {
version: 1,
requested_id: Some("test-id".to_string()),
auth_token: None,
};
let encoded = WireMessage::encode_control(&hello);
let decoded = WireMessage::decode_control(&encoded).unwrap();
match decoded {
ControlMessage::Hello {
version,
requested_id,
auth_token,
} => {
assert_eq!(version, 1);
assert_eq!(requested_id, Some("test-id".to_string()));
assert_eq!(auth_token, None);
}
_ => panic!("Expected Hello message"),
}
}
#[test]
fn test_full_url_key_exchange_flow() {
let agent_key = TunnelKey::generate();
let tunnel_id = "secure-tunnel-xyz";
let base_url = format!("https://{}.relay.example.com", tunnel_id);
let public_url = TunnelUrl::build(&base_url, &agent_key.to_base64());
let parsed = TunnelUrl::parse(&public_url).unwrap();
let client_key = TunnelKey::from_base64(&parsed.encryption_key).unwrap();
assert_eq!(agent_key.as_bytes(), client_key.as_bytes());
let agent_crypto = TunnelCrypto::new(&agent_key);
let client_crypto = TunnelCrypto::new(&client_key);
let test_message = b"Secure communication established!";
let encrypted = agent_crypto.encrypt(test_message).unwrap();
let decrypted = client_crypto.decrypt(&encrypted).unwrap();
assert_eq!(test_message.as_slice(), decrypted.as_slice());
}
#[test]
fn test_url_with_different_relay_hosts() {
let key = TunnelKey::generate();
let hosts = vec![
("relay.example.com", "test-id"),
("tunnel.mycompany.io", "tunnel123"),
("localhost:8080", "local"),
];
for (host, tunnel_id) in hosts {
let base_url = format!("https://{}.{}", tunnel_id, host);
let url = TunnelUrl::build(&base_url, &key.to_base64());
let parsed = TunnelUrl::parse(&url).unwrap();
assert_eq!(parsed.tunnel_id, tunnel_id);
assert_eq!(
TunnelKey::from_base64(&parsed.encryption_key)
.unwrap()
.as_bytes(),
key.as_bytes()
);
}
}
#[test]
fn test_multiple_requests_same_tunnel() {
let key = TunnelKey::generate();
let client_crypto = TunnelCrypto::new(&key);
let agent_crypto = TunnelCrypto::new(&key);
let request_ids = vec!["req-1", "req-2", "req-3", "req-4", "req-5"];
let mut encrypted_requests = Vec::new();
for id in &request_ids {
let request = DataMessage::HttpRequest {
request_id: id.to_string(),
client_id: "client-1".to_string(),
method: "GET".to_string(),
path: format!("/api/items/{}", id),
query: None,
headers: HashMap::new(),
body: None,
};
let encrypted = client_crypto
.encrypt(&serde_json::to_vec(&request).unwrap())
.unwrap();
encrypted_requests.push(encrypted);
}
let processing_order = vec![2, 0, 4, 1, 3];
for idx in processing_order {
let decrypted = agent_crypto.decrypt(&encrypted_requests[idx]).unwrap();
let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
DataMessage::HttpRequest { request_id, .. } => {
assert_eq!(request_id, request_ids[idx]);
}
_ => panic!("Expected HttpRequest"),
}
}
}
#[test]
fn test_ping_pong_through_encryption() {
let key = TunnelKey::generate();
let agent_crypto = TunnelCrypto::new(&key);
let relay_crypto = TunnelCrypto::new(&key);
let ping = ControlMessage::Ping {
timestamp: 1234567890,
};
let encrypted = agent_crypto
.encrypt(&serde_json::to_vec(&ping).unwrap())
.unwrap();
let decrypted = relay_crypto.decrypt(&encrypted).unwrap();
let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
let timestamp = match decoded {
ControlMessage::Ping { timestamp } => timestamp,
_ => panic!("Expected Ping"),
};
let pong = ControlMessage::Pong { timestamp };
let encrypted = relay_crypto
.encrypt(&serde_json::to_vec(&pong).unwrap())
.unwrap();
let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
match decoded {
ControlMessage::Pong { timestamp: ts } => {
assert_eq!(ts, 1234567890);
}
_ => panic!("Expected Pong"),
}
}
#[test]
fn test_protocol_version_constant() {
assert_eq!(PROTOCOL_VERSION, 1);
}
#[test]
fn test_hello_with_protocol_version() {
let hello = ControlMessage::Hello {
version: PROTOCOL_VERSION,
requested_id: None,
auth_token: None,
};
let json = serde_json::to_string(&hello).unwrap();
assert!(json.contains(&format!("\"version\":{}", PROTOCOL_VERSION)));
}
#[test]
fn test_different_tunnels_different_keys() {
let key1 = TunnelKey::generate();
let key2 = TunnelKey::generate();
let crypto1 = TunnelCrypto::new(&key1);
let crypto2 = TunnelCrypto::new(&key2);
let message = b"Secret message for tunnel 1";
let encrypted = crypto1.encrypt(message).unwrap();
let result = crypto2.decrypt(&encrypted);
assert!(result.is_err(), "Different keys should not decrypt each other's data");
}
#[test]
fn test_tampered_message_fails() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let message = b"Original message";
let mut encrypted = crypto.encrypt(message).unwrap();
if encrypted.len() > 30 {
encrypted[30] ^= 0xFF;
}
let result = crypto.decrypt(&encrypted);
assert!(result.is_err(), "Tampered message should fail decryption");
}
#[test]
fn test_key_not_exposed_in_debug() {
let key = TunnelKey::generate();
let debug_output = format!("{:?}", key);
assert!(debug_output.contains("REDACTED"), "Key should be redacted in debug output");
assert!(!debug_output.contains(&key.to_base64()), "Key bytes should not be in debug");
}
}