1use codive_tunnel::{DataMessage, WireMessage};
4use anyhow::Result;
5use chrono::{DateTime, Utc};
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot, RwLock};
8
9#[derive(Debug, Clone)]
11pub enum WsMessage {
12 Text(String),
14 Binary(Vec<u8>),
16}
17
18pub type WsSender = mpsc::Sender<WsMessage>;
20
21pub enum ResponseSender {
23 Single(oneshot::Sender<DataMessage>),
25 Streaming(mpsc::Sender<DataMessage>),
27}
28
29pub struct PendingRequest {
31 pub response_tx: ResponseSender,
33 pub started_at: DateTime<Utc>,
35 pub is_streaming: bool,
37}
38
39pub struct TunnelConnection {
41 pub tunnel_id: String,
43 pub ws_sender: WsSender,
45 pub pending_requests: DashMap<String, PendingRequest>,
47 pub created_at: DateTime<Utc>,
49 pub last_activity: RwLock<DateTime<Utc>>,
51 pub source_ip: String,
53}
54
55impl TunnelConnection {
56 pub fn new(tunnel_id: String, ws_sender: WsSender, source_ip: String) -> Self {
58 let now = Utc::now();
59 Self {
60 tunnel_id,
61 ws_sender,
62 pending_requests: DashMap::new(),
63 created_at: now,
64 last_activity: RwLock::new(now),
65 source_ip,
66 }
67 }
68
69 pub async fn send_encrypted(&self, message_type: u8, encrypted: Vec<u8>) -> Result<()> {
71 let wire_msg = WireMessage::encode_encrypted(message_type, encrypted);
72 self.ws_sender
73 .send(WsMessage::Binary(wire_msg))
74 .await
75 .map_err(|_| anyhow::anyhow!("Failed to send to tunnel"))?;
76
77 *self.last_activity.write().await = Utc::now();
79 Ok(())
80 }
81
82 pub fn register_request(&self, request_id: String) -> oneshot::Receiver<DataMessage> {
84 let (tx, rx) = oneshot::channel();
85 tracing::debug!(request_id = %request_id, "Registering regular request");
86 self.pending_requests.insert(
87 request_id.clone(),
88 PendingRequest {
89 response_tx: ResponseSender::Single(tx),
90 started_at: Utc::now(),
91 is_streaming: false,
92 },
93 );
94 tracing::debug!(request_id = %request_id, count = self.pending_requests.len(), "Request registered");
95 rx
96 }
97
98 pub fn register_streaming_request(
100 &self,
101 request_id: String,
102 ) -> mpsc::Receiver<DataMessage> {
103 let (tx, rx) = mpsc::channel(100); self.pending_requests.insert(
105 request_id,
106 PendingRequest {
107 response_tx: ResponseSender::Streaming(tx),
108 started_at: Utc::now(),
109 is_streaming: true,
110 },
111 );
112 rx
113 }
114
115 pub fn complete_request(&self, request_id: &str, response: DataMessage) -> bool {
117 tracing::debug!(
118 request_id = %request_id,
119 pending_count = self.pending_requests.len(),
120 "Attempting to complete request"
121 );
122 if let Some((_, pending)) = self.pending_requests.remove(request_id) {
123 match pending.response_tx {
124 ResponseSender::Single(tx) => {
125 tracing::debug!(request_id = %request_id, "Sending response via oneshot");
126 let _ = tx.send(response);
127 }
128 ResponseSender::Streaming(tx) => {
129 tracing::debug!(request_id = %request_id, "Sending response via streaming channel");
130 let _ = tx.try_send(response);
131 }
132 }
133 true
134 } else {
135 tracing::warn!(request_id = %request_id, "Request not found in pending_requests");
136 false
137 }
138 }
139
140 pub async fn send_chunk(&self, request_id: &str, chunk: DataMessage) -> bool {
142 if let Some(pending) = self.pending_requests.get(request_id) {
143 if let ResponseSender::Streaming(ref tx) = pending.response_tx {
144 tracing::debug!(request_id = %request_id, "Sending chunk to streaming request");
145 return tx.send(chunk).await.is_ok();
146 }
147 tracing::warn!(request_id = %request_id, "Found request but it's not streaming");
148 }
149 false
150 }
151
152 pub fn complete_streaming_request(&self, request_id: &str) {
154 self.pending_requests.remove(request_id);
155 }
156
157 pub fn cancel_all_requests(&self) {
159 self.pending_requests.clear();
160 }
161}
162
163const ALPHANUMERIC: [char; 62] = [
165 '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
166 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
167 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
168 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
169 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
170];
171
172pub fn generate_tunnel_id() -> String {
174 nanoid::nanoid!(8, &ALPHANUMERIC)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::collections::HashMap;
181
182 #[test]
187 fn test_ws_message_text() {
188 let msg = WsMessage::Text("hello".to_string());
189 match msg {
190 WsMessage::Text(s) => assert_eq!(s, "hello"),
191 _ => panic!("Expected Text message"),
192 }
193 }
194
195 #[test]
196 fn test_ws_message_binary() {
197 let data = vec![1, 2, 3, 4, 5];
198 let msg = WsMessage::Binary(data.clone());
199 match msg {
200 WsMessage::Binary(d) => assert_eq!(d, data),
201 _ => panic!("Expected Binary message"),
202 }
203 }
204
205 #[test]
206 fn test_ws_message_clone() {
207 let text = WsMessage::Text("test".to_string());
208 let text_clone = text.clone();
209 assert!(matches!(text_clone, WsMessage::Text(s) if s == "test"));
210
211 let binary = WsMessage::Binary(vec![1, 2, 3]);
212 let binary_clone = binary.clone();
213 assert!(matches!(binary_clone, WsMessage::Binary(d) if d == vec![1, 2, 3]));
214 }
215
216 fn create_test_tunnel() -> (TunnelConnection, mpsc::Receiver<WsMessage>) {
221 let (tx, rx) = mpsc::channel(100);
222 let tunnel = TunnelConnection::new(
223 "test-tunnel-123".to_string(),
224 tx,
225 "127.0.0.1".to_string(),
226 );
227 (tunnel, rx)
228 }
229
230 #[test]
231 fn test_tunnel_connection_creation() {
232 let (tunnel, _rx) = create_test_tunnel();
233
234 assert_eq!(tunnel.tunnel_id, "test-tunnel-123");
235 assert_eq!(tunnel.source_ip, "127.0.0.1");
236 assert!(tunnel.pending_requests.is_empty());
237 }
238
239 #[tokio::test]
240 async fn test_register_request() {
241 let (tunnel, _rx) = create_test_tunnel();
242
243 let _receiver = tunnel.register_request("req-1".to_string());
244
245 assert_eq!(tunnel.pending_requests.len(), 1);
246 assert!(tunnel.pending_requests.contains_key("req-1"));
247
248 let pending = tunnel.pending_requests.get("req-1").unwrap();
250 assert!(!pending.is_streaming);
251 }
252
253 #[tokio::test]
254 async fn test_register_streaming_request() {
255 let (tunnel, _rx) = create_test_tunnel();
256
257 let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
258
259 assert_eq!(tunnel.pending_requests.len(), 1);
260 assert!(tunnel.pending_requests.contains_key("req-sse-1"));
261
262 let pending = tunnel.pending_requests.get("req-sse-1").unwrap();
264 assert!(pending.is_streaming);
265 }
266
267 #[tokio::test]
268 async fn test_complete_request_success() {
269 let (tunnel, _rx) = create_test_tunnel();
270
271 let receiver = tunnel.register_request("req-1".to_string());
272
273 let response = DataMessage::HttpResponse {
274 request_id: "req-1".to_string(),
275 status: 200,
276 headers: HashMap::new(),
277 body: None,
278 streaming: false,
279 };
280
281 let completed = tunnel.complete_request("req-1", response);
282 assert!(completed);
283 assert!(tunnel.pending_requests.is_empty());
284
285 let received = receiver.await.unwrap();
287 match received {
288 DataMessage::HttpResponse { status, .. } => {
289 assert_eq!(status, 200);
290 }
291 _ => panic!("Expected HttpResponse"),
292 }
293 }
294
295 #[tokio::test]
296 async fn test_complete_request_not_found() {
297 let (tunnel, _rx) = create_test_tunnel();
298
299 let response = DataMessage::HttpResponse {
300 request_id: "nonexistent".to_string(),
301 status: 200,
302 headers: HashMap::new(),
303 body: None,
304 streaming: false,
305 };
306
307 let completed = tunnel.complete_request("nonexistent", response);
308 assert!(!completed);
309 }
310
311 #[tokio::test]
312 async fn test_send_chunk_to_streaming_request() {
313 let (tunnel, _rx) = create_test_tunnel();
314
315 let mut receiver = tunnel.register_streaming_request("req-sse-1".to_string());
316
317 let initial = DataMessage::HttpResponse {
319 request_id: "req-sse-1".to_string(),
320 status: 200,
321 headers: HashMap::new(),
322 body: None,
323 streaming: true,
324 };
325
326 let sent = tunnel.send_chunk("req-sse-1", initial).await;
327 assert!(sent);
328
329 let received = receiver.recv().await.unwrap();
331 assert!(matches!(received, DataMessage::HttpResponse { streaming: true, .. }));
332
333 let chunk = DataMessage::HttpResponseChunk {
335 request_id: "req-sse-1".to_string(),
336 chunk: "ZGF0YQ==".to_string(),
337 is_final: false,
338 };
339
340 let sent = tunnel.send_chunk("req-sse-1", chunk).await;
341 assert!(sent);
342
343 assert!(tunnel.pending_requests.contains_key("req-sse-1"));
345 }
346
347 #[tokio::test]
348 async fn test_send_chunk_to_nonexistent_request() {
349 let (tunnel, _rx) = create_test_tunnel();
350
351 let chunk = DataMessage::HttpResponseChunk {
352 request_id: "nonexistent".to_string(),
353 chunk: "ZGF0YQ==".to_string(),
354 is_final: false,
355 };
356
357 let sent = tunnel.send_chunk("nonexistent", chunk).await;
358 assert!(!sent);
359 }
360
361 #[tokio::test]
362 async fn test_send_chunk_to_non_streaming_request() {
363 let (tunnel, _rx) = create_test_tunnel();
364
365 let _receiver = tunnel.register_request("req-regular".to_string());
367
368 let chunk = DataMessage::HttpResponseChunk {
369 request_id: "req-regular".to_string(),
370 chunk: "ZGF0YQ==".to_string(),
371 is_final: false,
372 };
373
374 let sent = tunnel.send_chunk("req-regular", chunk).await;
376 assert!(!sent);
377 }
378
379 #[tokio::test]
380 async fn test_complete_streaming_request() {
381 let (tunnel, _rx) = create_test_tunnel();
382
383 let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
384 assert!(tunnel.pending_requests.contains_key("req-sse-1"));
385
386 tunnel.complete_streaming_request("req-sse-1");
387 assert!(!tunnel.pending_requests.contains_key("req-sse-1"));
388 }
389
390 #[tokio::test]
391 async fn test_cancel_all_requests() {
392 let (tunnel, _rx) = create_test_tunnel();
393
394 let _r1 = tunnel.register_request("req-1".to_string());
395 let _r2 = tunnel.register_request("req-2".to_string());
396 let _r3 = tunnel.register_streaming_request("req-sse-1".to_string());
397
398 assert_eq!(tunnel.pending_requests.len(), 3);
399
400 tunnel.cancel_all_requests();
401
402 assert!(tunnel.pending_requests.is_empty());
403 }
404
405 #[tokio::test]
406 async fn test_multiple_concurrent_requests() {
407 let (tunnel, _rx) = create_test_tunnel();
408
409 let r1 = tunnel.register_request("req-1".to_string());
411 let r2 = tunnel.register_request("req-2".to_string());
412 let r3 = tunnel.register_streaming_request("req-sse-1".to_string());
413
414 assert_eq!(tunnel.pending_requests.len(), 3);
415
416 let response2 = DataMessage::HttpResponse {
418 request_id: "req-2".to_string(),
419 status: 201,
420 headers: HashMap::new(),
421 body: None,
422 streaming: false,
423 };
424 tunnel.complete_request("req-2", response2);
425 assert_eq!(tunnel.pending_requests.len(), 2);
426
427 let response1 = DataMessage::HttpResponse {
428 request_id: "req-1".to_string(),
429 status: 200,
430 headers: HashMap::new(),
431 body: None,
432 streaming: false,
433 };
434 tunnel.complete_request("req-1", response1);
435 assert_eq!(tunnel.pending_requests.len(), 1);
436
437 let received1 = r1.await.unwrap();
439 assert!(matches!(received1, DataMessage::HttpResponse { status: 200, .. }));
440
441 let received2 = r2.await.unwrap();
442 assert!(matches!(received2, DataMessage::HttpResponse { status: 201, .. }));
443
444 tunnel.complete_streaming_request("req-sse-1");
446 assert!(tunnel.pending_requests.is_empty());
447 drop(r3);
448 }
449
450 #[tokio::test]
451 async fn test_send_encrypted() {
452 let (tunnel, mut rx) = create_test_tunnel();
453
454 let encrypted = vec![0xAB, 0xCD, 0xEF];
455 let result = tunnel.send_encrypted(0x01, encrypted.clone()).await;
456 assert!(result.is_ok());
457
458 let msg = rx.recv().await.unwrap();
460 match msg {
461 WsMessage::Binary(data) => {
462 assert_eq!(data[0], 0x01); assert_eq!(&data[1..], &encrypted[..]);
464 }
465 _ => panic!("Expected Binary message"),
466 }
467 }
468
469 #[test]
474 fn test_generate_tunnel_id_length() {
475 let id = generate_tunnel_id();
476 assert_eq!(id.len(), 8);
477 }
478
479 #[test]
480 fn test_generate_tunnel_id_alphanumeric() {
481 let id = generate_tunnel_id();
482 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
483 }
484
485 #[test]
486 fn test_generate_tunnel_id_uniqueness() {
487 let ids: std::collections::HashSet<String> = (0..100)
488 .map(|_| generate_tunnel_id())
489 .collect();
490
491 assert_eq!(ids.len(), 100);
493 }
494
495 #[tokio::test]
500 async fn test_tunnel_timestamps() {
501 let (tunnel, _rx) = create_test_tunnel();
502
503 let created = tunnel.created_at;
504 let initial_activity = *tunnel.last_activity.read().await;
505
506 assert!((created - initial_activity).num_milliseconds().abs() < 100);
508
509 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
511
512 let _ = tunnel.send_encrypted(0x01, vec![1, 2, 3]).await;
514
515 let updated_activity = *tunnel.last_activity.read().await;
516 assert!(updated_activity > initial_activity);
517 }
518
519 #[tokio::test]
524 async fn test_complete_same_request_twice() {
525 let (tunnel, _rx) = create_test_tunnel();
526
527 let receiver = tunnel.register_request("req-1".to_string());
528
529 let response = DataMessage::HttpResponse {
530 request_id: "req-1".to_string(),
531 status: 200,
532 headers: HashMap::new(),
533 body: None,
534 streaming: false,
535 };
536
537 let first = tunnel.complete_request("req-1", response.clone());
539 assert!(first);
540
541 let second = tunnel.complete_request("req-1", response);
543 assert!(!second);
544
545 drop(receiver);
546 }
547
548 #[tokio::test]
549 async fn test_request_with_empty_id() {
550 let (tunnel, _rx) = create_test_tunnel();
551
552 let _receiver = tunnel.register_request("".to_string());
553 assert!(tunnel.pending_requests.contains_key(""));
554
555 let response = DataMessage::HttpResponse {
556 request_id: "".to_string(),
557 status: 200,
558 headers: HashMap::new(),
559 body: None,
560 streaming: false,
561 };
562
563 let completed = tunnel.complete_request("", response);
564 assert!(completed);
565 }
566}