use codive_tunnel::{DataMessage, WireMessage};
use anyhow::Result;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use tokio::sync::{mpsc, oneshot, RwLock};
#[derive(Debug, Clone)]
pub enum WsMessage {
Text(String),
Binary(Vec<u8>),
}
pub type WsSender = mpsc::Sender<WsMessage>;
pub enum ResponseSender {
Single(oneshot::Sender<DataMessage>),
Streaming(mpsc::Sender<DataMessage>),
}
pub struct PendingRequest {
pub response_tx: ResponseSender,
pub started_at: DateTime<Utc>,
pub is_streaming: bool,
}
pub struct TunnelConnection {
pub tunnel_id: String,
pub ws_sender: WsSender,
pub pending_requests: DashMap<String, PendingRequest>,
pub created_at: DateTime<Utc>,
pub last_activity: RwLock<DateTime<Utc>>,
pub source_ip: String,
}
impl TunnelConnection {
pub fn new(tunnel_id: String, ws_sender: WsSender, source_ip: String) -> Self {
let now = Utc::now();
Self {
tunnel_id,
ws_sender,
pending_requests: DashMap::new(),
created_at: now,
last_activity: RwLock::new(now),
source_ip,
}
}
pub async fn send_encrypted(&self, message_type: u8, encrypted: Vec<u8>) -> Result<()> {
let wire_msg = WireMessage::encode_encrypted(message_type, encrypted);
self.ws_sender
.send(WsMessage::Binary(wire_msg))
.await
.map_err(|_| anyhow::anyhow!("Failed to send to tunnel"))?;
*self.last_activity.write().await = Utc::now();
Ok(())
}
pub fn register_request(&self, request_id: String) -> oneshot::Receiver<DataMessage> {
let (tx, rx) = oneshot::channel();
tracing::debug!(request_id = %request_id, "Registering regular request");
self.pending_requests.insert(
request_id.clone(),
PendingRequest {
response_tx: ResponseSender::Single(tx),
started_at: Utc::now(),
is_streaming: false,
},
);
tracing::debug!(request_id = %request_id, count = self.pending_requests.len(), "Request registered");
rx
}
pub fn register_streaming_request(
&self,
request_id: String,
) -> mpsc::Receiver<DataMessage> {
let (tx, rx) = mpsc::channel(100); self.pending_requests.insert(
request_id,
PendingRequest {
response_tx: ResponseSender::Streaming(tx),
started_at: Utc::now(),
is_streaming: true,
},
);
rx
}
pub fn complete_request(&self, request_id: &str, response: DataMessage) -> bool {
tracing::debug!(
request_id = %request_id,
pending_count = self.pending_requests.len(),
"Attempting to complete request"
);
if let Some((_, pending)) = self.pending_requests.remove(request_id) {
match pending.response_tx {
ResponseSender::Single(tx) => {
tracing::debug!(request_id = %request_id, "Sending response via oneshot");
let _ = tx.send(response);
}
ResponseSender::Streaming(tx) => {
tracing::debug!(request_id = %request_id, "Sending response via streaming channel");
let _ = tx.try_send(response);
}
}
true
} else {
tracing::warn!(request_id = %request_id, "Request not found in pending_requests");
false
}
}
pub async fn send_chunk(&self, request_id: &str, chunk: DataMessage) -> bool {
if let Some(pending) = self.pending_requests.get(request_id) {
if let ResponseSender::Streaming(ref tx) = pending.response_tx {
tracing::debug!(request_id = %request_id, "Sending chunk to streaming request");
return tx.send(chunk).await.is_ok();
}
tracing::warn!(request_id = %request_id, "Found request but it's not streaming");
}
false
}
pub fn complete_streaming_request(&self, request_id: &str) {
self.pending_requests.remove(request_id);
}
pub fn cancel_all_requests(&self) {
self.pending_requests.clear();
}
}
const ALPHANUMERIC: [char; 62] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
];
pub fn generate_tunnel_id() -> String {
nanoid::nanoid!(8, &ALPHANUMERIC)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_ws_message_text() {
let msg = WsMessage::Text("hello".to_string());
match msg {
WsMessage::Text(s) => assert_eq!(s, "hello"),
_ => panic!("Expected Text message"),
}
}
#[test]
fn test_ws_message_binary() {
let data = vec![1, 2, 3, 4, 5];
let msg = WsMessage::Binary(data.clone());
match msg {
WsMessage::Binary(d) => assert_eq!(d, data),
_ => panic!("Expected Binary message"),
}
}
#[test]
fn test_ws_message_clone() {
let text = WsMessage::Text("test".to_string());
let text_clone = text.clone();
assert!(matches!(text_clone, WsMessage::Text(s) if s == "test"));
let binary = WsMessage::Binary(vec![1, 2, 3]);
let binary_clone = binary.clone();
assert!(matches!(binary_clone, WsMessage::Binary(d) if d == vec![1, 2, 3]));
}
fn create_test_tunnel() -> (TunnelConnection, mpsc::Receiver<WsMessage>) {
let (tx, rx) = mpsc::channel(100);
let tunnel = TunnelConnection::new(
"test-tunnel-123".to_string(),
tx,
"127.0.0.1".to_string(),
);
(tunnel, rx)
}
#[test]
fn test_tunnel_connection_creation() {
let (tunnel, _rx) = create_test_tunnel();
assert_eq!(tunnel.tunnel_id, "test-tunnel-123");
assert_eq!(tunnel.source_ip, "127.0.0.1");
assert!(tunnel.pending_requests.is_empty());
}
#[tokio::test]
async fn test_register_request() {
let (tunnel, _rx) = create_test_tunnel();
let _receiver = tunnel.register_request("req-1".to_string());
assert_eq!(tunnel.pending_requests.len(), 1);
assert!(tunnel.pending_requests.contains_key("req-1"));
let pending = tunnel.pending_requests.get("req-1").unwrap();
assert!(!pending.is_streaming);
}
#[tokio::test]
async fn test_register_streaming_request() {
let (tunnel, _rx) = create_test_tunnel();
let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
assert_eq!(tunnel.pending_requests.len(), 1);
assert!(tunnel.pending_requests.contains_key("req-sse-1"));
let pending = tunnel.pending_requests.get("req-sse-1").unwrap();
assert!(pending.is_streaming);
}
#[tokio::test]
async fn test_complete_request_success() {
let (tunnel, _rx) = create_test_tunnel();
let receiver = tunnel.register_request("req-1".to_string());
let response = DataMessage::HttpResponse {
request_id: "req-1".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: false,
};
let completed = tunnel.complete_request("req-1", response);
assert!(completed);
assert!(tunnel.pending_requests.is_empty());
let received = receiver.await.unwrap();
match received {
DataMessage::HttpResponse { status, .. } => {
assert_eq!(status, 200);
}
_ => panic!("Expected HttpResponse"),
}
}
#[tokio::test]
async fn test_complete_request_not_found() {
let (tunnel, _rx) = create_test_tunnel();
let response = DataMessage::HttpResponse {
request_id: "nonexistent".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: false,
};
let completed = tunnel.complete_request("nonexistent", response);
assert!(!completed);
}
#[tokio::test]
async fn test_send_chunk_to_streaming_request() {
let (tunnel, _rx) = create_test_tunnel();
let mut receiver = tunnel.register_streaming_request("req-sse-1".to_string());
let initial = DataMessage::HttpResponse {
request_id: "req-sse-1".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: true,
};
let sent = tunnel.send_chunk("req-sse-1", initial).await;
assert!(sent);
let received = receiver.recv().await.unwrap();
assert!(matches!(received, DataMessage::HttpResponse { streaming: true, .. }));
let chunk = DataMessage::HttpResponseChunk {
request_id: "req-sse-1".to_string(),
chunk: "ZGF0YQ==".to_string(),
is_final: false,
};
let sent = tunnel.send_chunk("req-sse-1", chunk).await;
assert!(sent);
assert!(tunnel.pending_requests.contains_key("req-sse-1"));
}
#[tokio::test]
async fn test_send_chunk_to_nonexistent_request() {
let (tunnel, _rx) = create_test_tunnel();
let chunk = DataMessage::HttpResponseChunk {
request_id: "nonexistent".to_string(),
chunk: "ZGF0YQ==".to_string(),
is_final: false,
};
let sent = tunnel.send_chunk("nonexistent", chunk).await;
assert!(!sent);
}
#[tokio::test]
async fn test_send_chunk_to_non_streaming_request() {
let (tunnel, _rx) = create_test_tunnel();
let _receiver = tunnel.register_request("req-regular".to_string());
let chunk = DataMessage::HttpResponseChunk {
request_id: "req-regular".to_string(),
chunk: "ZGF0YQ==".to_string(),
is_final: false,
};
let sent = tunnel.send_chunk("req-regular", chunk).await;
assert!(!sent);
}
#[tokio::test]
async fn test_complete_streaming_request() {
let (tunnel, _rx) = create_test_tunnel();
let _receiver = tunnel.register_streaming_request("req-sse-1".to_string());
assert!(tunnel.pending_requests.contains_key("req-sse-1"));
tunnel.complete_streaming_request("req-sse-1");
assert!(!tunnel.pending_requests.contains_key("req-sse-1"));
}
#[tokio::test]
async fn test_cancel_all_requests() {
let (tunnel, _rx) = create_test_tunnel();
let _r1 = tunnel.register_request("req-1".to_string());
let _r2 = tunnel.register_request("req-2".to_string());
let _r3 = tunnel.register_streaming_request("req-sse-1".to_string());
assert_eq!(tunnel.pending_requests.len(), 3);
tunnel.cancel_all_requests();
assert!(tunnel.pending_requests.is_empty());
}
#[tokio::test]
async fn test_multiple_concurrent_requests() {
let (tunnel, _rx) = create_test_tunnel();
let r1 = tunnel.register_request("req-1".to_string());
let r2 = tunnel.register_request("req-2".to_string());
let r3 = tunnel.register_streaming_request("req-sse-1".to_string());
assert_eq!(tunnel.pending_requests.len(), 3);
let response2 = DataMessage::HttpResponse {
request_id: "req-2".to_string(),
status: 201,
headers: HashMap::new(),
body: None,
streaming: false,
};
tunnel.complete_request("req-2", response2);
assert_eq!(tunnel.pending_requests.len(), 2);
let response1 = DataMessage::HttpResponse {
request_id: "req-1".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: false,
};
tunnel.complete_request("req-1", response1);
assert_eq!(tunnel.pending_requests.len(), 1);
let received1 = r1.await.unwrap();
assert!(matches!(received1, DataMessage::HttpResponse { status: 200, .. }));
let received2 = r2.await.unwrap();
assert!(matches!(received2, DataMessage::HttpResponse { status: 201, .. }));
tunnel.complete_streaming_request("req-sse-1");
assert!(tunnel.pending_requests.is_empty());
drop(r3);
}
#[tokio::test]
async fn test_send_encrypted() {
let (tunnel, mut rx) = create_test_tunnel();
let encrypted = vec![0xAB, 0xCD, 0xEF];
let result = tunnel.send_encrypted(0x01, encrypted.clone()).await;
assert!(result.is_ok());
let msg = rx.recv().await.unwrap();
match msg {
WsMessage::Binary(data) => {
assert_eq!(data[0], 0x01); assert_eq!(&data[1..], &encrypted[..]);
}
_ => panic!("Expected Binary message"),
}
}
#[test]
fn test_generate_tunnel_id_length() {
let id = generate_tunnel_id();
assert_eq!(id.len(), 8);
}
#[test]
fn test_generate_tunnel_id_alphanumeric() {
let id = generate_tunnel_id();
assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[test]
fn test_generate_tunnel_id_uniqueness() {
let ids: std::collections::HashSet<String> = (0..100)
.map(|_| generate_tunnel_id())
.collect();
assert_eq!(ids.len(), 100);
}
#[tokio::test]
async fn test_tunnel_timestamps() {
let (tunnel, _rx) = create_test_tunnel();
let created = tunnel.created_at;
let initial_activity = *tunnel.last_activity.read().await;
assert!((created - initial_activity).num_milliseconds().abs() < 100);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let _ = tunnel.send_encrypted(0x01, vec![1, 2, 3]).await;
let updated_activity = *tunnel.last_activity.read().await;
assert!(updated_activity > initial_activity);
}
#[tokio::test]
async fn test_complete_same_request_twice() {
let (tunnel, _rx) = create_test_tunnel();
let receiver = tunnel.register_request("req-1".to_string());
let response = DataMessage::HttpResponse {
request_id: "req-1".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: false,
};
let first = tunnel.complete_request("req-1", response.clone());
assert!(first);
let second = tunnel.complete_request("req-1", response);
assert!(!second);
drop(receiver);
}
#[tokio::test]
async fn test_request_with_empty_id() {
let (tunnel, _rx) = create_test_tunnel();
let _receiver = tunnel.register_request("".to_string());
assert!(tunnel.pending_requests.contains_key(""));
let response = DataMessage::HttpResponse {
request_id: "".to_string(),
status: 200,
headers: HashMap::new(),
body: None,
streaming: false,
};
let completed = tunnel.complete_request("", response);
assert!(completed);
}
}