codive-relay 0.1.0

Relay server for secure tunneling
Documentation
//! WebSocket endpoint for agent connections

use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        ConnectInfo, State,
    },
    response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{error, info, warn};

use codive_tunnel::{message_type, ControlMessage, WireMessage, PROTOCOL_VERSION};

use crate::state::{AuthResult, RelayState};
use crate::tunnel::{generate_tunnel_id, TunnelConnection, WsMessage, WsSender};

/// WebSocket handler for agent connections
pub async fn agent_ws_handler(
    ws: WebSocketUpgrade,
    State(state): State<Arc<RelayState>>,
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
    ws.on_upgrade(move |socket| handle_agent_connection(socket, state, addr))
}

/// Send an error response and close the connection
async fn send_error_and_close(
    tx: &WsSender,
    code: &str,
    message: &str,
) {
    let error = ControlMessage::Error {
        code: code.to_string(),
        message: message.to_string(),
    };
    let error_json = serde_json::to_string(&error).unwrap_or_default();
    let _ = tx.send(WsMessage::Text(error_json)).await;
}

/// Handle an agent WebSocket connection
async fn handle_agent_connection(socket: WebSocket, state: Arc<RelayState>, addr: SocketAddr) {
    let source_ip = addr.ip().to_string();
    info!(%source_ip, "Agent connecting");

    let (mut ws_sink, mut ws_stream) = socket.split();

    // Create channel for sending messages to the WebSocket
    let (tx, mut rx): (WsSender, _) = mpsc::channel(100);

    // Spawn task to forward channel messages to WebSocket
    let send_task = tokio::spawn(async move {
        while let Some(msg) = rx.recv().await {
            let ws_msg = match msg {
                WsMessage::Text(text) => Message::Text(text.into()),
                WsMessage::Binary(data) => Message::Binary(data.into()),
            };
            if ws_sink.send(ws_msg).await.is_err() {
                break;
            }
        }
    });

    // Wait for Hello message
    let hello = match wait_for_hello(&mut ws_stream).await {
        Some(h) => h,
        None => {
            warn!(%source_ip, "Agent disconnected before Hello");
            send_task.abort();
            return;
        }
    };

    // Validate authentication with rate limiting
    match state.validate_auth(&source_ip, hello.auth_token.as_deref()) {
        AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired => {
            // Auth passed, continue
        }
        AuthResult::Banned { remaining } => {
            warn!(
                %source_ip,
                remaining_secs = remaining.as_secs(),
                "IP is temporarily banned due to too many failed auth attempts"
            );
            send_error_and_close(
                &tx,
                "BANNED",
                &format!(
                    "Too many failed authentication attempts. Try again in {} seconds.",
                    remaining.as_secs()
                ),
            ).await;
            send_task.abort();
            return;
        }
        AuthResult::Invalid(reason) => {
            warn!(%source_ip, %reason, "Authentication failed");
            send_error_and_close(&tx, "AUTH_FAILED", &reason).await;
            send_task.abort();
            return;
        }
    }

    // Check rate limiting
    if !state.can_create_tunnel(&source_ip) {
        let max = state.config.max_tunnels_per_ip;
        warn!(%source_ip, max_tunnels = max, "Rate limit exceeded");
        send_error_and_close(
            &tx,
            "RATE_LIMITED",
            &format!("Maximum tunnels ({}) per IP exceeded", max),
        ).await;
        send_task.abort();
        return;
    }

    // Generate tunnel ID - enforce random IDs if custom IDs are not allowed
    let tunnel_id = if state.config.allow_custom_ids {
        hello.requested_id.unwrap_or_else(generate_tunnel_id)
    } else {
        // Ignore requested_id for public relays (prevent subdomain squatting)
        if hello.requested_id.is_some() {
            tracing::debug!(%source_ip, "Custom tunnel ID requested but not allowed, generating random ID");
        }
        generate_tunnel_id()
    };

    // Create tunnel connection
    let tunnel = TunnelConnection::new(tunnel_id.clone(), tx.clone(), source_ip.clone());
    let tunnel_url = state.tunnel_url(&tunnel_id);

    // Register tunnel
    let tunnel = state.register_tunnel(tunnel);
    info!(
        tunnel_id = %tunnel_id,
        url = %tunnel_url,
        ip_tunnel_count = state.tunnel_count_for_ip(&source_ip),
        "Tunnel registered"
    );

    // Send Welcome message (as text for control messages)
    let welcome = ControlMessage::Welcome {
        tunnel_id: tunnel_id.clone(),
        tunnel_url: tunnel_url.clone(),
    };
    let welcome_json = serde_json::to_string(&welcome).expect("Welcome serialization should not fail");
    if tx.send(WsMessage::Text(welcome_json)).await.is_err() {
        error!(tunnel_id = %tunnel_id, "Failed to send Welcome");
        state.remove_tunnel(&tunnel_id);
        return;
    }

    // Main message loop
    while let Some(result) = ws_stream.next().await {
        match result {
            Ok(Message::Binary(data)) => {
                if let Err(e) = handle_agent_message(&tunnel, &data).await {
                    warn!(tunnel_id = %tunnel_id, error = %e, "Error handling agent message");
                }
            }
            Ok(Message::Text(text)) => {
                // Control messages as JSON text
                if let Err(e) = handle_control_message(&tunnel, &state, text.as_bytes()).await {
                    warn!(tunnel_id = %tunnel_id, error = %e, "Error handling control message");
                }
            }
            Ok(Message::Ping(data)) => {
                // Respond with Pong (binary data)
                if tx.send(WsMessage::Binary(data.to_vec())).await.is_err() {
                    break;
                }
            }
            Ok(Message::Close(_)) => {
                info!(tunnel_id = %tunnel_id, "Agent closed connection");
                break;
            }
            Err(e) => {
                warn!(tunnel_id = %tunnel_id, error = %e, "WebSocket error");
                break;
            }
            _ => {}
        }
    }

    // Cleanup
    info!(tunnel_id = %tunnel_id, "Tunnel disconnected, cleaning up");
    tunnel.cancel_all_requests();
    state.remove_tunnel(&tunnel_id);
    send_task.abort();
}

/// Parsed Hello message fields
struct HelloMessage {
    requested_id: Option<String>,
    auth_token: Option<String>,
}

/// Wait for the Hello message from the agent
async fn wait_for_hello(
    stream: &mut futures_util::stream::SplitStream<WebSocket>,
) -> Option<HelloMessage> {
    // Wait for first message with timeout
    let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await;

    match timeout {
        Ok(Some(Ok(Message::Text(text)))) => {
            match WireMessage::decode_control(text.as_bytes()) {
                Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
                    if version != PROTOCOL_VERSION {
                        warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
                    }
                    Some(HelloMessage { requested_id, auth_token })
                }
                Ok(_) => {
                    warn!("Expected Hello message, got different control message");
                    None
                }
                Err(e) => {
                    warn!(error = %e, "Failed to parse Hello message");
                    None
                }
            }
        }
        Ok(Some(Ok(Message::Binary(data)))) => {
            // Try to parse as control message
            match WireMessage::decode_control(&data) {
                Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
                    if version != PROTOCOL_VERSION {
                        warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
                    }
                    Some(HelloMessage { requested_id, auth_token })
                }
                _ => {
                    warn!("Expected Hello message as first message");
                    None
                }
            }
        }
        _ => None,
    }
}

/// Handle a binary message from the agent (encrypted data)
async fn handle_agent_message(tunnel: &TunnelConnection, data: &[u8]) -> anyhow::Result<()> {
    // Try the new format with routing header first
    let (request_id_from_header, payload) =
        if let Ok((msg_type, req_id, encrypted_payload)) =
            WireMessage::decode_encrypted_with_routing(data)
        {
            if msg_type != message_type::ENCRYPTED_RESPONSE {
                tracing::debug!(msg_type, "Ignoring non-response message type");
                return Ok(());
            }
            (Some(req_id.to_string()), encrypted_payload)
        } else {
            // Fall back to old format (no routing header)
            let (msg_type, payload) = WireMessage::decode_encrypted(data)
                .map_err(|e| anyhow::anyhow!("Invalid wire message: {}", e))?;
            if msg_type != message_type::ENCRYPTED_RESPONSE {
                tracing::debug!(msg_type, "Ignoring non-response message type");
                return Ok(());
            }
            (None, payload)
        };

    tracing::debug!(
        tunnel_id = %tunnel.tunnel_id,
        request_id_header = ?request_id_from_header,
        payload_len = payload.len(),
        "Received response from tunnel client"
    );

    // Try to parse the payload as JSON (for unencrypted mode or after decryption)
    // In full E2E mode with remote CLI client, we'd forward the encrypted blob directly
    if let Ok(data_msg) = serde_json::from_slice::<codive_tunnel::DataMessage>(payload) {
        route_data_message(tunnel, data_msg).await;
    } else if let Some(req_id) = request_id_from_header {
        // Payload is encrypted - we have request_id from header for routing
        // For now, create an error response since relay can't decrypt
        // In full E2E mode, we'd forward encrypted payload to E2E clients
        tracing::warn!(
            tunnel_id = %tunnel.tunnel_id,
            request_id = %req_id,
            "Received encrypted payload, relay cannot decrypt (E2E mode)"
        );
        let error_msg = codive_tunnel::DataMessage::RequestError {
            request_id: Some(req_id.clone()),
            code: "E2E_ENCRYPTED".to_string(),
            message: "Response is end-to-end encrypted. Use E2E client to decrypt.".to_string(),
        };
        route_data_message(tunnel, error_msg).await;
    } else {
        warn!(
            tunnel_id = %tunnel.tunnel_id,
            "Failed to parse response payload and no routing header"
        );
    }

    Ok(())
}

/// Route a parsed data message to the appropriate pending request
async fn route_data_message(tunnel: &TunnelConnection, data_msg: codive_tunnel::DataMessage) {
    match data_msg {
        codive_tunnel::DataMessage::HttpResponse { ref request_id, streaming, .. } => {
            tracing::info!(
                tunnel_id = %tunnel.tunnel_id,
                request_id = %request_id,
                streaming = %streaming,
                "Routing response to pending request"
            );
            let req_id = request_id.clone();
            if streaming {
                // For streaming responses, use send_chunk to keep the request alive
                // Don't use complete_request as that removes the pending request
                tunnel.send_chunk(&req_id, data_msg).await;
            } else if !tunnel.complete_request(&req_id, data_msg.clone()) {
                // Fallback for edge cases
                tunnel.send_chunk(&req_id, data_msg).await;
            }
        }
        codive_tunnel::DataMessage::HttpResponseChunk { ref request_id, is_final, .. } => {
            tracing::debug!(
                tunnel_id = %tunnel.tunnel_id,
                request_id = %request_id,
                is_final = %is_final,
                "Routing chunk to streaming request"
            );
            let req_id = request_id.clone();
            if !tunnel.send_chunk(&req_id, data_msg).await {
                warn!(request_id = %req_id, "No streaming request found for chunk");
            }
            if is_final {
                tunnel.complete_streaming_request(&req_id);
            }
        }
        codive_tunnel::DataMessage::RequestError { ref request_id, .. } => {
            if let Some(ref req_id) = request_id {
                tracing::debug!(
                    tunnel_id = %tunnel.tunnel_id,
                    request_id = %req_id,
                    "Routing error to pending request"
                );
                let req_id = req_id.clone();
                if !tunnel.complete_request(&req_id, data_msg.clone()) {
                    tunnel.send_chunk(&req_id, data_msg).await;
                    tunnel.complete_streaming_request(&req_id);
                }
            }
        }
        _ => {
            warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected data message type in response");
        }
    }
}

/// Handle a control message from the agent
async fn handle_control_message(
    tunnel: &TunnelConnection,
    _state: &RelayState,
    data: &[u8],
) -> anyhow::Result<()> {
    let msg = WireMessage::decode_control(data)?;

    match msg {
        ControlMessage::Ping { timestamp } => {
            tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received ping, sending pong");
            // Respond with Pong
            let pong = ControlMessage::Pong { timestamp };
            let pong_json = serde_json::to_string(&pong)?;
            let _ = tunnel.ws_sender.send(WsMessage::Text(pong_json)).await;
        }
        ControlMessage::Pong { timestamp } => {
            tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received pong");
        }
        ControlMessage::Close { reason } => {
            info!(tunnel_id = %tunnel.tunnel_id, reason, "Agent requested close");
        }
        _ => {
            warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected control message from agent");
        }
    }

    Ok(())
}