Skip to main content

codive_relay/
tunnel.rs

1//! Tunnel connection management
2
3use codive_tunnel::{DataMessage, WireMessage};
4use anyhow::Result;
5use chrono::{DateTime, Utc};
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot, RwLock};
8
9/// WebSocket message wrapper to distinguish between text and binary
10#[derive(Debug, Clone)]
11pub enum WsMessage {
12    /// Text message (for control messages like Welcome, Ping, Pong)
13    Text(String),
14    /// Binary message (for encrypted data)
15    Binary(Vec<u8>),
16}
17
18/// Sender type for WebSocket messages
19pub type WsSender = mpsc::Sender<WsMessage>;
20
21/// Response sender type - either oneshot for single responses or mpsc for streaming
22pub enum ResponseSender {
23    /// Single response (normal HTTP requests)
24    Single(oneshot::Sender<DataMessage>),
25    /// Streaming responses (SSE)
26    Streaming(mpsc::Sender<DataMessage>),
27}
28
29/// A pending HTTP request waiting for a response
30pub struct PendingRequest {
31    /// Channel to send the response back
32    pub response_tx: ResponseSender,
33    /// Request timestamp for timeout tracking
34    pub started_at: DateTime<Utc>,
35    /// Is this a streaming request?
36    pub is_streaming: bool,
37}
38
39/// Represents an active tunnel connection from a local agent
40pub struct TunnelConnection {
41    /// Unique tunnel identifier
42    pub tunnel_id: String,
43    /// WebSocket sender for sending messages to the agent
44    pub ws_sender: WsSender,
45    /// Pending HTTP requests awaiting responses
46    pub pending_requests: DashMap<String, PendingRequest>,
47    /// Creation timestamp
48    pub created_at: DateTime<Utc>,
49    /// Last activity timestamp
50    pub last_activity: RwLock<DateTime<Utc>>,
51    /// Source IP address
52    pub source_ip: String,
53}
54
55impl TunnelConnection {
56    /// Create a new tunnel connection
57    pub fn new(tunnel_id: String, ws_sender: WsSender, source_ip: String) -> Self {
58        let now = Utc::now();
59        Self {
60            tunnel_id,
61            ws_sender,
62            pending_requests: DashMap::new(),
63            created_at: now,
64            last_activity: RwLock::new(now),
65            source_ip,
66        }
67    }
68
69    /// Send a data message through the tunnel (already encrypted)
70    pub async fn send_encrypted(&self, message_type: u8, encrypted: Vec<u8>) -> Result<()> {
71        let wire_msg = WireMessage::encode_encrypted(message_type, encrypted);
72        self.ws_sender
73            .send(WsMessage::Binary(wire_msg))
74            .await
75            .map_err(|_| anyhow::anyhow!("Failed to send to tunnel"))?;
76
77        // Update last activity
78        *self.last_activity.write().await = Utc::now();
79        Ok(())
80    }
81
82    /// Register a pending request and get a receiver for the response
83    pub fn register_request(&self, request_id: String) -> oneshot::Receiver<DataMessage> {
84        let (tx, rx) = oneshot::channel();
85        tracing::debug!(request_id = %request_id, "Registering regular request");
86        self.pending_requests.insert(
87            request_id.clone(),
88            PendingRequest {
89                response_tx: ResponseSender::Single(tx),
90                started_at: Utc::now(),
91                is_streaming: false,
92            },
93        );
94        tracing::debug!(request_id = %request_id, count = self.pending_requests.len(), "Request registered");
95        rx
96    }
97
98    /// Register a streaming request (for SSE) and get a receiver for chunks
99    pub fn register_streaming_request(
100        &self,
101        request_id: String,
102    ) -> mpsc::Receiver<DataMessage> {
103        let (tx, rx) = mpsc::channel(100); // Buffer up to 100 chunks
104        self.pending_requests.insert(
105            request_id,
106            PendingRequest {
107                response_tx: ResponseSender::Streaming(tx),
108                started_at: Utc::now(),
109                is_streaming: true,
110            },
111        );
112        rx
113    }
114
115    /// Complete a pending request with a response
116    pub fn complete_request(&self, request_id: &str, response: DataMessage) -> bool {
117        tracing::debug!(
118            request_id = %request_id,
119            pending_count = self.pending_requests.len(),
120            "Attempting to complete request"
121        );
122        if let Some((_, pending)) = self.pending_requests.remove(request_id) {
123            match pending.response_tx {
124                ResponseSender::Single(tx) => {
125                    tracing::debug!(request_id = %request_id, "Sending response via oneshot");
126                    let _ = tx.send(response);
127                }
128                ResponseSender::Streaming(tx) => {
129                    tracing::debug!(request_id = %request_id, "Sending response via streaming channel");
130                    let _ = tx.try_send(response);
131                }
132            }
133            true
134        } else {
135            tracing::warn!(request_id = %request_id, "Request not found in pending_requests");
136            false
137        }
138    }
139
140    /// Send a chunk to a streaming request (returns false if request not found or not streaming)
141    pub async fn send_chunk(&self, request_id: &str, chunk: DataMessage) -> bool {
142        if let Some(pending) = self.pending_requests.get(request_id) {
143            if let ResponseSender::Streaming(ref tx) = pending.response_tx {
144                tracing::debug!(request_id = %request_id, "Sending chunk to streaming request");
145                return tx.send(chunk).await.is_ok();
146            }
147            tracing::warn!(request_id = %request_id, "Found request but it's not streaming");
148        }
149        false
150    }
151
152    /// Complete a streaming request (removes it from pending)
153    pub fn complete_streaming_request(&self, request_id: &str) {
154        self.pending_requests.remove(request_id);
155    }
156
157    /// Cancel all pending requests (called on disconnect)
158    pub fn cancel_all_requests(&self) {
159        self.pending_requests.clear();
160    }
161}
162
163/// Alphanumeric characters for tunnel IDs
164const ALPHANUMERIC: [char; 62] = [
165    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
166    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
167    'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
168    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
169    'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
170];
171
172/// Generate a unique tunnel ID
173pub fn generate_tunnel_id() -> String {
174    nanoid::nanoid!(8, &ALPHANUMERIC)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::collections::HashMap;
181
182    // ============================================================================
183    // WsMessage Tests
184    // ============================================================================
185
186    #[test]
187    fn test_ws_message_text() {
188        let msg = WsMessage::Text("hello".to_string());
189        match msg {
190            WsMessage::Text(s) => assert_eq!(s, "hello"),
191            _ => panic!("Expected Text message"),
192        }
193    }
194
195    #[test]
196    fn test_ws_message_binary() {
197        let data = vec![1, 2, 3, 4, 5];
198        let msg = WsMessage::Binary(data.clone());
199        match msg {
200            WsMessage::Binary(d) => assert_eq!(d, data),
201            _ => panic!("Expected Binary message"),
202        }
203    }
204
205    #[test]
206    fn test_ws_message_clone() {
207        let text = WsMessage::Text("test".to_string());
208        let text_clone = text.clone();
209        assert!(matches!(text_clone, WsMessage::Text(s) if s == "test"));
210
211        let binary = WsMessage::Binary(vec![1, 2, 3]);
212        let binary_clone = binary.clone();
213        assert!(matches!(binary_clone, WsMessage::Binary(d) if d == vec![1, 2, 3]));
214    }
215
216    // ============================================================================
217    // TunnelConnection Tests
218    // ============================================================================
219
220    fn create_test_tunnel() -> (TunnelConnection, mpsc::Receiver<WsMessage>) {
221        let (tx, rx) = mpsc::channel(100);
222        let tunnel = TunnelConnection::new(
223            "test-tunnel-123".to_string(),
224            tx,
225            "127.0.0.1".to_string(),
226        );
227        (tunnel, rx)
228    }
229
230    #[test]
231    fn test_tunnel_connection_creation() {
232        let (tunnel, _rx) = create_test_tunnel();
233
234        assert_eq!(tunnel.tunnel_id, "test-tunnel-123");
235        assert_eq!(tunnel.source_ip, "127.0.0.1");
236        assert!(tunnel.pending_requests.is_empty());
237    }
238
239    #[tokio::test]
240    async fn test_register_request() {
241        let (tunnel, _rx) = create_test_tunnel();
242
243        let _receiver = tunnel.register_request("req-1".to_string());
244
245        assert_eq!(tunnel.pending_requests.len(), 1);
246        assert!(tunnel.pending_requests.contains_key("req-1"));
247
248        // Check that it's marked as non-streaming
249        let pending = tunnel.pending_requests.get("req-1").unwrap();
250        assert!(!pending.is_streaming);
251    }
252
253    #[tokio::test]
254    async fn test_register_streaming_request() {
255        let (tunnel, _rx) = create_test_tunnel();
256
257        let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
258
259        assert_eq!(tunnel.pending_requests.len(), 1);
260        assert!(tunnel.pending_requests.contains_key("req-sse-1"));
261
262        // Check that it's marked as streaming
263        let pending = tunnel.pending_requests.get("req-sse-1").unwrap();
264        assert!(pending.is_streaming);
265    }
266
267    #[tokio::test]
268    async fn test_complete_request_success() {
269        let (tunnel, _rx) = create_test_tunnel();
270
271        let receiver = tunnel.register_request("req-1".to_string());
272
273        let response = DataMessage::HttpResponse {
274            request_id: "req-1".to_string(),
275            status: 200,
276            headers: HashMap::new(),
277            body: None,
278            streaming: false,
279        };
280
281        let completed = tunnel.complete_request("req-1", response);
282        assert!(completed);
283        assert!(tunnel.pending_requests.is_empty());
284
285        // The receiver should have the response
286        let received = receiver.await.unwrap();
287        match received {
288            DataMessage::HttpResponse { status, .. } => {
289                assert_eq!(status, 200);
290            }
291            _ => panic!("Expected HttpResponse"),
292        }
293    }
294
295    #[tokio::test]
296    async fn test_complete_request_not_found() {
297        let (tunnel, _rx) = create_test_tunnel();
298
299        let response = DataMessage::HttpResponse {
300            request_id: "nonexistent".to_string(),
301            status: 200,
302            headers: HashMap::new(),
303            body: None,
304            streaming: false,
305        };
306
307        let completed = tunnel.complete_request("nonexistent", response);
308        assert!(!completed);
309    }
310
311    #[tokio::test]
312    async fn test_send_chunk_to_streaming_request() {
313        let (tunnel, _rx) = create_test_tunnel();
314
315        let mut receiver = tunnel.register_streaming_request("req-sse-1".to_string());
316
317        // Send initial response
318        let initial = DataMessage::HttpResponse {
319            request_id: "req-sse-1".to_string(),
320            status: 200,
321            headers: HashMap::new(),
322            body: None,
323            streaming: true,
324        };
325
326        let sent = tunnel.send_chunk("req-sse-1", initial).await;
327        assert!(sent);
328
329        // Receive it
330        let received = receiver.recv().await.unwrap();
331        assert!(matches!(received, DataMessage::HttpResponse { streaming: true, .. }));
332
333        // Send chunk
334        let chunk = DataMessage::HttpResponseChunk {
335            request_id: "req-sse-1".to_string(),
336            chunk: "ZGF0YQ==".to_string(),
337            is_final: false,
338        };
339
340        let sent = tunnel.send_chunk("req-sse-1", chunk).await;
341        assert!(sent);
342
343        // Request should still be pending
344        assert!(tunnel.pending_requests.contains_key("req-sse-1"));
345    }
346
347    #[tokio::test]
348    async fn test_send_chunk_to_nonexistent_request() {
349        let (tunnel, _rx) = create_test_tunnel();
350
351        let chunk = DataMessage::HttpResponseChunk {
352            request_id: "nonexistent".to_string(),
353            chunk: "ZGF0YQ==".to_string(),
354            is_final: false,
355        };
356
357        let sent = tunnel.send_chunk("nonexistent", chunk).await;
358        assert!(!sent);
359    }
360
361    #[tokio::test]
362    async fn test_send_chunk_to_non_streaming_request() {
363        let (tunnel, _rx) = create_test_tunnel();
364
365        // Register a regular (non-streaming) request
366        let _receiver = tunnel.register_request("req-regular".to_string());
367
368        let chunk = DataMessage::HttpResponseChunk {
369            request_id: "req-regular".to_string(),
370            chunk: "ZGF0YQ==".to_string(),
371            is_final: false,
372        };
373
374        // This should fail because it's not a streaming request
375        let sent = tunnel.send_chunk("req-regular", chunk).await;
376        assert!(!sent);
377    }
378
379    #[tokio::test]
380    async fn test_complete_streaming_request() {
381        let (tunnel, _rx) = create_test_tunnel();
382
383        let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
384        assert!(tunnel.pending_requests.contains_key("req-sse-1"));
385
386        tunnel.complete_streaming_request("req-sse-1");
387        assert!(!tunnel.pending_requests.contains_key("req-sse-1"));
388    }
389
390    #[tokio::test]
391    async fn test_cancel_all_requests() {
392        let (tunnel, _rx) = create_test_tunnel();
393
394        let _r1 = tunnel.register_request("req-1".to_string());
395        let _r2 = tunnel.register_request("req-2".to_string());
396        let _r3 = tunnel.register_streaming_request("req-sse-1".to_string());
397
398        assert_eq!(tunnel.pending_requests.len(), 3);
399
400        tunnel.cancel_all_requests();
401
402        assert!(tunnel.pending_requests.is_empty());
403    }
404
405    #[tokio::test]
406    async fn test_multiple_concurrent_requests() {
407        let (tunnel, _rx) = create_test_tunnel();
408
409        // Register multiple requests
410        let r1 = tunnel.register_request("req-1".to_string());
411        let r2 = tunnel.register_request("req-2".to_string());
412        let r3 = tunnel.register_streaming_request("req-sse-1".to_string());
413
414        assert_eq!(tunnel.pending_requests.len(), 3);
415
416        // Complete them in different order
417        let response2 = DataMessage::HttpResponse {
418            request_id: "req-2".to_string(),
419            status: 201,
420            headers: HashMap::new(),
421            body: None,
422            streaming: false,
423        };
424        tunnel.complete_request("req-2", response2);
425        assert_eq!(tunnel.pending_requests.len(), 2);
426
427        let response1 = DataMessage::HttpResponse {
428            request_id: "req-1".to_string(),
429            status: 200,
430            headers: HashMap::new(),
431            body: None,
432            streaming: false,
433        };
434        tunnel.complete_request("req-1", response1);
435        assert_eq!(tunnel.pending_requests.len(), 1);
436
437        // Verify responses
438        let received1 = r1.await.unwrap();
439        assert!(matches!(received1, DataMessage::HttpResponse { status: 200, .. }));
440
441        let received2 = r2.await.unwrap();
442        assert!(matches!(received2, DataMessage::HttpResponse { status: 201, .. }));
443
444        // Complete streaming request
445        tunnel.complete_streaming_request("req-sse-1");
446        assert!(tunnel.pending_requests.is_empty());
447        drop(r3);
448    }
449
450    #[tokio::test]
451    async fn test_send_encrypted() {
452        let (tunnel, mut rx) = create_test_tunnel();
453
454        let encrypted = vec![0xAB, 0xCD, 0xEF];
455        let result = tunnel.send_encrypted(0x01, encrypted.clone()).await;
456        assert!(result.is_ok());
457
458        // Check that the message was sent
459        let msg = rx.recv().await.unwrap();
460        match msg {
461            WsMessage::Binary(data) => {
462                assert_eq!(data[0], 0x01); // Message type
463                assert_eq!(&data[1..], &encrypted[..]);
464            }
465            _ => panic!("Expected Binary message"),
466        }
467    }
468
469    // ============================================================================
470    // Tunnel ID Generation Tests
471    // ============================================================================
472
473    #[test]
474    fn test_generate_tunnel_id_length() {
475        let id = generate_tunnel_id();
476        assert_eq!(id.len(), 8);
477    }
478
479    #[test]
480    fn test_generate_tunnel_id_alphanumeric() {
481        let id = generate_tunnel_id();
482        assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
483    }
484
485    #[test]
486    fn test_generate_tunnel_id_uniqueness() {
487        let ids: std::collections::HashSet<String> = (0..100)
488            .map(|_| generate_tunnel_id())
489            .collect();
490
491        // All 100 IDs should be unique
492        assert_eq!(ids.len(), 100);
493    }
494
495    // ============================================================================
496    // Timestamp Tests
497    // ============================================================================
498
499    #[tokio::test]
500    async fn test_tunnel_timestamps() {
501        let (tunnel, _rx) = create_test_tunnel();
502
503        let created = tunnel.created_at;
504        let initial_activity = *tunnel.last_activity.read().await;
505
506        // Timestamps should be approximately equal at creation
507        assert!((created - initial_activity).num_milliseconds().abs() < 100);
508
509        // Sleep briefly to ensure time passes
510        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
511
512        // Send a message to update last_activity
513        let _ = tunnel.send_encrypted(0x01, vec![1, 2, 3]).await;
514
515        let updated_activity = *tunnel.last_activity.read().await;
516        assert!(updated_activity > initial_activity);
517    }
518
519    // ============================================================================
520    // Edge Cases
521    // ============================================================================
522
523    #[tokio::test]
524    async fn test_complete_same_request_twice() {
525        let (tunnel, _rx) = create_test_tunnel();
526
527        let receiver = tunnel.register_request("req-1".to_string());
528
529        let response = DataMessage::HttpResponse {
530            request_id: "req-1".to_string(),
531            status: 200,
532            headers: HashMap::new(),
533            body: None,
534            streaming: false,
535        };
536
537        // First completion should succeed
538        let first = tunnel.complete_request("req-1", response.clone());
539        assert!(first);
540
541        // Second completion should fail (request already removed)
542        let second = tunnel.complete_request("req-1", response);
543        assert!(!second);
544
545        drop(receiver);
546    }
547
548    #[tokio::test]
549    async fn test_request_with_empty_id() {
550        let (tunnel, _rx) = create_test_tunnel();
551
552        let _receiver = tunnel.register_request("".to_string());
553        assert!(tunnel.pending_requests.contains_key(""));
554
555        let response = DataMessage::HttpResponse {
556            request_id: "".to_string(),
557            status: 200,
558            headers: HashMap::new(),
559            body: None,
560            streaming: false,
561        };
562
563        let completed = tunnel.complete_request("", response);
564        assert!(completed);
565    }
566}