1pub mod crypto;
40pub mod protocol;
41
42pub use crypto::{TunnelCrypto, TunnelKey, KEY_SIZE, NONCE_SIZE};
44pub use protocol::{
45 message_type, url::TunnelUrl, ControlMessage, DataMessage, WireMessage,
46};
47
48pub const PROTOCOL_VERSION: u8 = 1;
50
51pub const DEFAULT_RELAY_URL: &str = "wss://relay.agent.example.com";
53
54pub const RELAY_URL_ENV: &str = "AGENT_RELAY_URL";
56
57#[cfg(test)]
58mod integration_tests {
59 use super::*;
60 use base64::Engine;
61 use std::collections::HashMap;
62
63 #[test]
68 fn test_encrypted_handshake_flow() {
69 let key = TunnelKey::generate();
75 let tunnel_id = "test-tunnel-abc123";
76
77 let base_url = format!("https://{}.relay.example.com", tunnel_id);
79 let url = TunnelUrl::build(&base_url, &key.to_base64());
80 assert!(url.starts_with("https://test-tunnel-abc123.relay.example.com"));
81
82 let parsed = TunnelUrl::parse(&url).unwrap();
84 assert_eq!(parsed.tunnel_id, tunnel_id);
85
86 let client_key = TunnelKey::from_base64(&parsed.encryption_key).unwrap();
88
89 let agent_crypto = TunnelCrypto::new(&key);
91 let client_crypto = TunnelCrypto::new(&client_key);
92
93 let hello = ControlMessage::Hello {
95 version: PROTOCOL_VERSION,
96 requested_id: Some(tunnel_id.to_string()),
97 auth_token: Some("test-token".to_string()),
98 };
99 let hello_json = serde_json::to_vec(&hello).unwrap();
100 let encrypted = agent_crypto.encrypt(&hello_json).unwrap();
101
102 let decrypted = client_crypto.decrypt(&encrypted).unwrap();
104 let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
105
106 match decoded {
107 ControlMessage::Hello {
108 version,
109 requested_id,
110 auth_token,
111 } => {
112 assert_eq!(version, PROTOCOL_VERSION);
113 assert_eq!(requested_id, Some(tunnel_id.to_string()));
114 assert_eq!(auth_token, Some("test-token".to_string()));
115 }
116 _ => panic!("Expected Hello message"),
117 }
118 }
119
120 #[test]
121 fn test_encrypted_welcome_response() {
122 let key = TunnelKey::generate();
123 let agent_crypto = TunnelCrypto::new(&key);
124 let relay_crypto = TunnelCrypto::new(&key);
125
126 let welcome = ControlMessage::Welcome {
128 tunnel_id: "final-tunnel-id".to_string(),
129 tunnel_url: "https://final-tunnel-id.relay.example.com".to_string(),
130 };
131
132 let encrypted = relay_crypto
133 .encrypt(&serde_json::to_vec(&welcome).unwrap())
134 .unwrap();
135 let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
136 let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
137
138 match decoded {
139 ControlMessage::Welcome {
140 tunnel_id,
141 tunnel_url,
142 } => {
143 assert_eq!(tunnel_id, "final-tunnel-id");
144 assert!(tunnel_url.contains("final-tunnel-id"));
145 }
146 _ => panic!("Expected Welcome message"),
147 }
148 }
149
150 #[test]
155 fn test_encrypted_http_request_response_flow() {
156 let key = TunnelKey::generate();
157 let client_crypto = TunnelCrypto::new(&key);
158 let agent_crypto = TunnelCrypto::new(&key);
159
160 let mut headers = HashMap::new();
162 headers.insert("Content-Type".to_string(), "application/json".to_string());
163 headers.insert(
164 "Authorization".to_string(),
165 "Bearer secret-token".to_string(),
166 );
167
168 let request = DataMessage::HttpRequest {
169 request_id: "req-001".to_string(),
170 client_id: "client-123".to_string(),
171 method: "POST".to_string(),
172 path: "/api/sessions".to_string(),
173 query: None,
174 headers,
175 body: Some("eyJtZXNzYWdlIjoiaGVsbG8ifQ".to_string()), };
177
178 let request_json = serde_json::to_vec(&request).unwrap();
180 let encrypted_request = client_crypto.encrypt(&request_json).unwrap();
181 let wire_data = WireMessage::encode_encrypted(message_type::ENCRYPTED_REQUEST, encrypted_request);
182
183 assert_eq!(wire_data[0], message_type::ENCRYPTED_REQUEST);
185
186 let (msg_type, payload) = WireMessage::decode_encrypted(&wire_data).unwrap();
188 assert_eq!(msg_type, message_type::ENCRYPTED_REQUEST);
189 let decrypted_request = agent_crypto.decrypt(payload).unwrap();
190 let decoded_request: DataMessage = serde_json::from_slice(&decrypted_request).unwrap();
191
192 match decoded_request {
194 DataMessage::HttpRequest {
195 request_id,
196 method,
197 path,
198 headers,
199 body,
200 ..
201 } => {
202 assert_eq!(request_id, "req-001");
203 assert_eq!(method, "POST");
204 assert_eq!(path, "/api/sessions");
205 assert_eq!(headers.len(), 2);
206 assert!(headers.contains_key("Authorization"));
207 assert!(body.is_some());
208 }
209 _ => panic!("Expected HttpRequest"),
210 }
211
212 let mut response_headers = HashMap::new();
214 response_headers.insert("Content-Type".to_string(), "application/json".to_string());
215 response_headers.insert("Location".to_string(), "/api/sessions/new-id".to_string());
216
217 let response = DataMessage::HttpResponse {
218 request_id: "req-001".to_string(),
219 status: 201,
220 headers: response_headers,
221 body: Some("eyJpZCI6Im5ldy1pZCJ9".to_string()), streaming: false,
223 };
224
225 let response_json = serde_json::to_vec(&response).unwrap();
226 let encrypted_response = agent_crypto.encrypt(&response_json).unwrap();
227
228 let decrypted_response = client_crypto.decrypt(&encrypted_response).unwrap();
230 let decoded_response: DataMessage = serde_json::from_slice(&decrypted_response).unwrap();
231
232 match decoded_response {
233 DataMessage::HttpResponse {
234 request_id,
235 status,
236 streaming,
237 ..
238 } => {
239 assert_eq!(request_id, "req-001");
240 assert_eq!(status, 201);
241 assert!(!streaming);
242 }
243 _ => panic!("Expected HttpResponse"),
244 }
245 }
246
247 #[test]
248 fn test_encrypted_streaming_response() {
249 let key = TunnelKey::generate();
250 let client_crypto = TunnelCrypto::new(&key);
251 let agent_crypto = TunnelCrypto::new(&key);
252
253 let mut headers = HashMap::new();
255 headers.insert("Content-Type".to_string(), "text/event-stream".to_string());
256 headers.insert("Cache-Control".to_string(), "no-cache".to_string());
257
258 let response = DataMessage::HttpResponse {
259 request_id: "sse-001".to_string(),
260 status: 200,
261 headers,
262 body: None,
263 streaming: true,
264 };
265
266 let encrypted = agent_crypto
267 .encrypt(&serde_json::to_vec(&response).unwrap())
268 .unwrap();
269 let decrypted = client_crypto.decrypt(&encrypted).unwrap();
270 let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
271
272 match decoded {
273 DataMessage::HttpResponse { streaming, .. } => {
274 assert!(streaming, "Should be a streaming response");
275 }
276 _ => panic!("Expected HttpResponse"),
277 }
278
279 let chunks = vec![
281 ("data: event 1\n\n", false),
282 ("data: event 2\n\n", false),
283 ("data: event 3\n\n", true), ];
285
286 for (data, is_final) in chunks {
287 let chunk = DataMessage::HttpResponseChunk {
288 request_id: "sse-001".to_string(),
289 chunk: base64::engine::general_purpose::STANDARD.encode(data),
290 is_final,
291 };
292
293 let encrypted = agent_crypto
294 .encrypt(&serde_json::to_vec(&chunk).unwrap())
295 .unwrap();
296 let decrypted = client_crypto.decrypt(&encrypted).unwrap();
297 let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
298
299 match decoded {
300 DataMessage::HttpResponseChunk {
301 request_id,
302 chunk: chunk_data,
303 is_final: final_flag,
304 } => {
305 assert_eq!(request_id, "sse-001");
306 let decoded_data =
307 base64::engine::general_purpose::STANDARD.decode(&chunk_data).unwrap();
308 assert!(String::from_utf8_lossy(&decoded_data).starts_with("data: event"));
309 assert_eq!(final_flag, is_final);
310 }
311 _ => panic!("Expected HttpResponseChunk"),
312 }
313 }
314 }
315
316 #[test]
321 fn test_encrypted_error_message() {
322 let key = TunnelKey::generate();
323 let agent_crypto = TunnelCrypto::new(&key);
324 let client_crypto = TunnelCrypto::new(&key);
325
326 let error = DataMessage::RequestError {
328 request_id: Some("failed-001".to_string()),
329 code: "CONNECTION_REFUSED".to_string(),
330 message: "Connection refused: localhost:3001".to_string(),
331 };
332
333 let encrypted = agent_crypto
334 .encrypt(&serde_json::to_vec(&error).unwrap())
335 .unwrap();
336 let decrypted = client_crypto.decrypt(&encrypted).unwrap();
337 let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
338
339 match decoded {
340 DataMessage::RequestError {
341 request_id,
342 code,
343 message,
344 } => {
345 assert_eq!(request_id, Some("failed-001".to_string()));
346 assert_eq!(code, "CONNECTION_REFUSED");
347 assert!(message.contains("Connection refused"));
348 }
349 _ => panic!("Expected RequestError"),
350 }
351 }
352
353 #[test]
354 fn test_encrypted_control_error() {
355 let key = TunnelKey::generate();
356 let relay_crypto = TunnelCrypto::new(&key);
357 let agent_crypto = TunnelCrypto::new(&key);
358
359 let error = ControlMessage::Error {
360 code: "RATE_LIMITED".to_string(),
361 message: "Too many requests".to_string(),
362 };
363
364 let encrypted = relay_crypto
365 .encrypt(&serde_json::to_vec(&error).unwrap())
366 .unwrap();
367 let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
368 let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
369
370 match decoded {
371 ControlMessage::Error { code, message } => {
372 assert_eq!(code, "RATE_LIMITED");
373 assert!(message.contains("Too many requests"));
374 }
375 _ => panic!("Expected Error message"),
376 }
377 }
378
379 #[test]
384 fn test_wire_message_types_through_encryption() {
385 let key = TunnelKey::generate();
386 let crypto = TunnelCrypto::new(&key);
387
388 let message_types = vec![
390 ("request", message_type::ENCRYPTED_REQUEST),
391 ("response", message_type::ENCRYPTED_RESPONSE),
392 ("event", message_type::ENCRYPTED_EVENT),
393 ];
394
395 for (name, msg_type) in message_types {
396 let payload = b"test payload data";
397
398 let encrypted = crypto.encrypt(payload).unwrap();
400
401 let wire_data = WireMessage::encode_encrypted(msg_type, encrypted);
403
404 let (decoded_type, decoded_payload) = WireMessage::decode_encrypted(&wire_data).unwrap();
406 let decrypted = crypto.decrypt(decoded_payload).unwrap();
407
408 assert_eq!(decoded_type, msg_type, "Message type mismatch for: {}", name);
409 assert_eq!(decrypted, payload, "Payload mismatch for: {}", name);
410 }
411 }
412
413 #[test]
414 fn test_control_message_encoding() {
415 let hello = ControlMessage::Hello {
416 version: 1,
417 requested_id: Some("test-id".to_string()),
418 auth_token: None,
419 };
420
421 let encoded = WireMessage::encode_control(&hello);
423
424 let decoded = WireMessage::decode_control(&encoded).unwrap();
426
427 match decoded {
428 ControlMessage::Hello {
429 version,
430 requested_id,
431 auth_token,
432 } => {
433 assert_eq!(version, 1);
434 assert_eq!(requested_id, Some("test-id".to_string()));
435 assert_eq!(auth_token, None);
436 }
437 _ => panic!("Expected Hello message"),
438 }
439 }
440
441 #[test]
446 fn test_full_url_key_exchange_flow() {
447 let agent_key = TunnelKey::generate();
451 let tunnel_id = "secure-tunnel-xyz";
452 let base_url = format!("https://{}.relay.example.com", tunnel_id);
453
454 let public_url = TunnelUrl::build(&base_url, &agent_key.to_base64());
455
456 let parsed = TunnelUrl::parse(&public_url).unwrap();
461 let client_key = TunnelKey::from_base64(&parsed.encryption_key).unwrap();
462
463 assert_eq!(agent_key.as_bytes(), client_key.as_bytes());
465
466 let agent_crypto = TunnelCrypto::new(&agent_key);
468 let client_crypto = TunnelCrypto::new(&client_key);
469
470 let test_message = b"Secure communication established!";
471 let encrypted = agent_crypto.encrypt(test_message).unwrap();
472 let decrypted = client_crypto.decrypt(&encrypted).unwrap();
473 assert_eq!(test_message.as_slice(), decrypted.as_slice());
474 }
475
476 #[test]
477 fn test_url_with_different_relay_hosts() {
478 let key = TunnelKey::generate();
479 let hosts = vec![
480 ("relay.example.com", "test-id"),
481 ("tunnel.mycompany.io", "tunnel123"),
482 ("localhost:8080", "local"),
483 ];
484
485 for (host, tunnel_id) in hosts {
486 let base_url = format!("https://{}.{}", tunnel_id, host);
487 let url = TunnelUrl::build(&base_url, &key.to_base64());
488 let parsed = TunnelUrl::parse(&url).unwrap();
489
490 assert_eq!(parsed.tunnel_id, tunnel_id);
491 assert_eq!(
492 TunnelKey::from_base64(&parsed.encryption_key)
493 .unwrap()
494 .as_bytes(),
495 key.as_bytes()
496 );
497 }
498 }
499
500 #[test]
505 fn test_multiple_requests_same_tunnel() {
506 let key = TunnelKey::generate();
507 let client_crypto = TunnelCrypto::new(&key);
508 let agent_crypto = TunnelCrypto::new(&key);
509
510 let request_ids = vec!["req-1", "req-2", "req-3", "req-4", "req-5"];
512
513 let mut encrypted_requests = Vec::new();
515 for id in &request_ids {
516 let request = DataMessage::HttpRequest {
517 request_id: id.to_string(),
518 client_id: "client-1".to_string(),
519 method: "GET".to_string(),
520 path: format!("/api/items/{}", id),
521 query: None,
522 headers: HashMap::new(),
523 body: None,
524 };
525
526 let encrypted = client_crypto
527 .encrypt(&serde_json::to_vec(&request).unwrap())
528 .unwrap();
529 encrypted_requests.push(encrypted);
530 }
531
532 let processing_order = vec![2, 0, 4, 1, 3];
534 for idx in processing_order {
535 let decrypted = agent_crypto.decrypt(&encrypted_requests[idx]).unwrap();
536 let decoded: DataMessage = serde_json::from_slice(&decrypted).unwrap();
537
538 match decoded {
539 DataMessage::HttpRequest { request_id, .. } => {
540 assert_eq!(request_id, request_ids[idx]);
541 }
542 _ => panic!("Expected HttpRequest"),
543 }
544 }
545 }
546
547 #[test]
552 fn test_ping_pong_through_encryption() {
553 let key = TunnelKey::generate();
554 let agent_crypto = TunnelCrypto::new(&key);
555 let relay_crypto = TunnelCrypto::new(&key);
556
557 let ping = ControlMessage::Ping {
559 timestamp: 1234567890,
560 };
561 let encrypted = agent_crypto
562 .encrypt(&serde_json::to_vec(&ping).unwrap())
563 .unwrap();
564
565 let decrypted = relay_crypto.decrypt(&encrypted).unwrap();
567 let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
568
569 let timestamp = match decoded {
570 ControlMessage::Ping { timestamp } => timestamp,
571 _ => panic!("Expected Ping"),
572 };
573
574 let pong = ControlMessage::Pong { timestamp };
576 let encrypted = relay_crypto
577 .encrypt(&serde_json::to_vec(&pong).unwrap())
578 .unwrap();
579 let decrypted = agent_crypto.decrypt(&encrypted).unwrap();
580 let decoded: ControlMessage = serde_json::from_slice(&decrypted).unwrap();
581
582 match decoded {
583 ControlMessage::Pong { timestamp: ts } => {
584 assert_eq!(ts, 1234567890);
585 }
586 _ => panic!("Expected Pong"),
587 }
588 }
589
590 #[test]
595 fn test_protocol_version_constant() {
596 assert_eq!(PROTOCOL_VERSION, 1);
597 }
598
599 #[test]
600 fn test_hello_with_protocol_version() {
601 let hello = ControlMessage::Hello {
602 version: PROTOCOL_VERSION,
603 requested_id: None,
604 auth_token: None,
605 };
606
607 let json = serde_json::to_string(&hello).unwrap();
608 assert!(json.contains(&format!("\"version\":{}", PROTOCOL_VERSION)));
609 }
610
611 #[test]
616 fn test_different_tunnels_different_keys() {
617 let key1 = TunnelKey::generate();
619 let key2 = TunnelKey::generate();
620
621 let crypto1 = TunnelCrypto::new(&key1);
622 let crypto2 = TunnelCrypto::new(&key2);
623
624 let message = b"Secret message for tunnel 1";
625 let encrypted = crypto1.encrypt(message).unwrap();
626
627 let result = crypto2.decrypt(&encrypted);
629 assert!(result.is_err(), "Different keys should not decrypt each other's data");
630 }
631
632 #[test]
633 fn test_tampered_message_fails() {
634 let key = TunnelKey::generate();
635 let crypto = TunnelCrypto::new(&key);
636
637 let message = b"Original message";
638 let mut encrypted = crypto.encrypt(message).unwrap();
639
640 if encrypted.len() > 30 {
642 encrypted[30] ^= 0xFF;
643 }
644
645 let result = crypto.decrypt(&encrypted);
646 assert!(result.is_err(), "Tampered message should fail decryption");
647 }
648
649 #[test]
650 fn test_key_not_exposed_in_debug() {
651 let key = TunnelKey::generate();
652 let debug_output = format!("{:?}", key);
653
654 assert!(debug_output.contains("REDACTED"), "Key should be redacted in debug output");
656 assert!(!debug_output.contains(&key.to_base64()), "Key bytes should not be in debug");
657 }
658}