use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
ConnectInfo, State,
},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use codive_tunnel::{message_type, ControlMessage, WireMessage, PROTOCOL_VERSION};
use crate::state::{AuthResult, RelayState};
use crate::tunnel::{generate_tunnel_id, TunnelConnection, WsMessage, WsSender};
pub async fn agent_ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<RelayState>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_agent_connection(socket, state, addr))
}
async fn send_error_and_close(
tx: &WsSender,
code: &str,
message: &str,
) {
let error = ControlMessage::Error {
code: code.to_string(),
message: message.to_string(),
};
let error_json = serde_json::to_string(&error).unwrap_or_default();
let _ = tx.send(WsMessage::Text(error_json)).await;
}
async fn handle_agent_connection(socket: WebSocket, state: Arc<RelayState>, addr: SocketAddr) {
let source_ip = addr.ip().to_string();
info!(%source_ip, "Agent connecting");
let (mut ws_sink, mut ws_stream) = socket.split();
let (tx, mut rx): (WsSender, _) = mpsc::channel(100);
let send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let ws_msg = match msg {
WsMessage::Text(text) => Message::Text(text.into()),
WsMessage::Binary(data) => Message::Binary(data.into()),
};
if ws_sink.send(ws_msg).await.is_err() {
break;
}
}
});
let hello = match wait_for_hello(&mut ws_stream).await {
Some(h) => h,
None => {
warn!(%source_ip, "Agent disconnected before Hello");
send_task.abort();
return;
}
};
match state.validate_auth(&source_ip, hello.auth_token.as_deref()) {
AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired => {
}
AuthResult::Banned { remaining } => {
warn!(
%source_ip,
remaining_secs = remaining.as_secs(),
"IP is temporarily banned due to too many failed auth attempts"
);
send_error_and_close(
&tx,
"BANNED",
&format!(
"Too many failed authentication attempts. Try again in {} seconds.",
remaining.as_secs()
),
).await;
send_task.abort();
return;
}
AuthResult::Invalid(reason) => {
warn!(%source_ip, %reason, "Authentication failed");
send_error_and_close(&tx, "AUTH_FAILED", &reason).await;
send_task.abort();
return;
}
}
if !state.can_create_tunnel(&source_ip) {
let max = state.config.max_tunnels_per_ip;
warn!(%source_ip, max_tunnels = max, "Rate limit exceeded");
send_error_and_close(
&tx,
"RATE_LIMITED",
&format!("Maximum tunnels ({}) per IP exceeded", max),
).await;
send_task.abort();
return;
}
let tunnel_id = if state.config.allow_custom_ids {
hello.requested_id.unwrap_or_else(generate_tunnel_id)
} else {
if hello.requested_id.is_some() {
tracing::debug!(%source_ip, "Custom tunnel ID requested but not allowed, generating random ID");
}
generate_tunnel_id()
};
let tunnel = TunnelConnection::new(tunnel_id.clone(), tx.clone(), source_ip.clone());
let tunnel_url = state.tunnel_url(&tunnel_id);
let tunnel = state.register_tunnel(tunnel);
info!(
tunnel_id = %tunnel_id,
url = %tunnel_url,
ip_tunnel_count = state.tunnel_count_for_ip(&source_ip),
"Tunnel registered"
);
let welcome = ControlMessage::Welcome {
tunnel_id: tunnel_id.clone(),
tunnel_url: tunnel_url.clone(),
};
let welcome_json = serde_json::to_string(&welcome).expect("Welcome serialization should not fail");
if tx.send(WsMessage::Text(welcome_json)).await.is_err() {
error!(tunnel_id = %tunnel_id, "Failed to send Welcome");
state.remove_tunnel(&tunnel_id);
return;
}
while let Some(result) = ws_stream.next().await {
match result {
Ok(Message::Binary(data)) => {
if let Err(e) = handle_agent_message(&tunnel, &data).await {
warn!(tunnel_id = %tunnel_id, error = %e, "Error handling agent message");
}
}
Ok(Message::Text(text)) => {
if let Err(e) = handle_control_message(&tunnel, &state, text.as_bytes()).await {
warn!(tunnel_id = %tunnel_id, error = %e, "Error handling control message");
}
}
Ok(Message::Ping(data)) => {
if tx.send(WsMessage::Binary(data.to_vec())).await.is_err() {
break;
}
}
Ok(Message::Close(_)) => {
info!(tunnel_id = %tunnel_id, "Agent closed connection");
break;
}
Err(e) => {
warn!(tunnel_id = %tunnel_id, error = %e, "WebSocket error");
break;
}
_ => {}
}
}
info!(tunnel_id = %tunnel_id, "Tunnel disconnected, cleaning up");
tunnel.cancel_all_requests();
state.remove_tunnel(&tunnel_id);
send_task.abort();
}
struct HelloMessage {
requested_id: Option<String>,
auth_token: Option<String>,
}
async fn wait_for_hello(
stream: &mut futures_util::stream::SplitStream<WebSocket>,
) -> Option<HelloMessage> {
let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await;
match timeout {
Ok(Some(Ok(Message::Text(text)))) => {
match WireMessage::decode_control(text.as_bytes()) {
Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
if version != PROTOCOL_VERSION {
warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
}
Some(HelloMessage { requested_id, auth_token })
}
Ok(_) => {
warn!("Expected Hello message, got different control message");
None
}
Err(e) => {
warn!(error = %e, "Failed to parse Hello message");
None
}
}
}
Ok(Some(Ok(Message::Binary(data)))) => {
match WireMessage::decode_control(&data) {
Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
if version != PROTOCOL_VERSION {
warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
}
Some(HelloMessage { requested_id, auth_token })
}
_ => {
warn!("Expected Hello message as first message");
None
}
}
}
_ => None,
}
}
async fn handle_agent_message(tunnel: &TunnelConnection, data: &[u8]) -> anyhow::Result<()> {
let (request_id_from_header, payload) =
if let Ok((msg_type, req_id, encrypted_payload)) =
WireMessage::decode_encrypted_with_routing(data)
{
if msg_type != message_type::ENCRYPTED_RESPONSE {
tracing::debug!(msg_type, "Ignoring non-response message type");
return Ok(());
}
(Some(req_id.to_string()), encrypted_payload)
} else {
let (msg_type, payload) = WireMessage::decode_encrypted(data)
.map_err(|e| anyhow::anyhow!("Invalid wire message: {}", e))?;
if msg_type != message_type::ENCRYPTED_RESPONSE {
tracing::debug!(msg_type, "Ignoring non-response message type");
return Ok(());
}
(None, payload)
};
tracing::debug!(
tunnel_id = %tunnel.tunnel_id,
request_id_header = ?request_id_from_header,
payload_len = payload.len(),
"Received response from tunnel client"
);
if let Ok(data_msg) = serde_json::from_slice::<codive_tunnel::DataMessage>(payload) {
route_data_message(tunnel, data_msg).await;
} else if let Some(req_id) = request_id_from_header {
tracing::warn!(
tunnel_id = %tunnel.tunnel_id,
request_id = %req_id,
"Received encrypted payload, relay cannot decrypt (E2E mode)"
);
let error_msg = codive_tunnel::DataMessage::RequestError {
request_id: Some(req_id.clone()),
code: "E2E_ENCRYPTED".to_string(),
message: "Response is end-to-end encrypted. Use E2E client to decrypt.".to_string(),
};
route_data_message(tunnel, error_msg).await;
} else {
warn!(
tunnel_id = %tunnel.tunnel_id,
"Failed to parse response payload and no routing header"
);
}
Ok(())
}
async fn route_data_message(tunnel: &TunnelConnection, data_msg: codive_tunnel::DataMessage) {
match data_msg {
codive_tunnel::DataMessage::HttpResponse { ref request_id, streaming, .. } => {
tracing::info!(
tunnel_id = %tunnel.tunnel_id,
request_id = %request_id,
streaming = %streaming,
"Routing response to pending request"
);
let req_id = request_id.clone();
if streaming {
tunnel.send_chunk(&req_id, data_msg).await;
} else if !tunnel.complete_request(&req_id, data_msg.clone()) {
tunnel.send_chunk(&req_id, data_msg).await;
}
}
codive_tunnel::DataMessage::HttpResponseChunk { ref request_id, is_final, .. } => {
tracing::debug!(
tunnel_id = %tunnel.tunnel_id,
request_id = %request_id,
is_final = %is_final,
"Routing chunk to streaming request"
);
let req_id = request_id.clone();
if !tunnel.send_chunk(&req_id, data_msg).await {
warn!(request_id = %req_id, "No streaming request found for chunk");
}
if is_final {
tunnel.complete_streaming_request(&req_id);
}
}
codive_tunnel::DataMessage::RequestError { ref request_id, .. } => {
if let Some(ref req_id) = request_id {
tracing::debug!(
tunnel_id = %tunnel.tunnel_id,
request_id = %req_id,
"Routing error to pending request"
);
let req_id = req_id.clone();
if !tunnel.complete_request(&req_id, data_msg.clone()) {
tunnel.send_chunk(&req_id, data_msg).await;
tunnel.complete_streaming_request(&req_id);
}
}
}
_ => {
warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected data message type in response");
}
}
}
async fn handle_control_message(
tunnel: &TunnelConnection,
_state: &RelayState,
data: &[u8],
) -> anyhow::Result<()> {
let msg = WireMessage::decode_control(data)?;
match msg {
ControlMessage::Ping { timestamp } => {
tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received ping, sending pong");
let pong = ControlMessage::Pong { timestamp };
let pong_json = serde_json::to_string(&pong)?;
let _ = tunnel.ws_sender.send(WsMessage::Text(pong_json)).await;
}
ControlMessage::Pong { timestamp } => {
tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received pong");
}
ControlMessage::Close { reason } => {
info!(tunnel_id = %tunnel.tunnel_id, reason, "Agent requested close");
}
_ => {
warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected control message from agent");
}
}
Ok(())
}