Skip to main content

codive_relay/routes/
agent.rs

1//! WebSocket endpoint for agent connections
2
3use axum::{
4    extract::{
5        ws::{Message, WebSocket, WebSocketUpgrade},
6        ConnectInfo, State,
7    },
8    response::IntoResponse,
9};
10use futures_util::{SinkExt, StreamExt};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14use tracing::{error, info, warn};
15
16use codive_tunnel::{message_type, ControlMessage, WireMessage, PROTOCOL_VERSION};
17
18use crate::state::{AuthResult, RelayState};
19use crate::tunnel::{generate_tunnel_id, TunnelConnection, WsMessage, WsSender};
20
21/// WebSocket handler for agent connections
22pub async fn agent_ws_handler(
23    ws: WebSocketUpgrade,
24    State(state): State<Arc<RelayState>>,
25    ConnectInfo(addr): ConnectInfo<SocketAddr>,
26) -> impl IntoResponse {
27    ws.on_upgrade(move |socket| handle_agent_connection(socket, state, addr))
28}
29
30/// Send an error response and close the connection
31async fn send_error_and_close(
32    tx: &WsSender,
33    code: &str,
34    message: &str,
35) {
36    let error = ControlMessage::Error {
37        code: code.to_string(),
38        message: message.to_string(),
39    };
40    let error_json = serde_json::to_string(&error).unwrap_or_default();
41    let _ = tx.send(WsMessage::Text(error_json)).await;
42}
43
44/// Handle an agent WebSocket connection
45async fn handle_agent_connection(socket: WebSocket, state: Arc<RelayState>, addr: SocketAddr) {
46    let source_ip = addr.ip().to_string();
47    info!(%source_ip, "Agent connecting");
48
49    let (mut ws_sink, mut ws_stream) = socket.split();
50
51    // Create channel for sending messages to the WebSocket
52    let (tx, mut rx): (WsSender, _) = mpsc::channel(100);
53
54    // Spawn task to forward channel messages to WebSocket
55    let send_task = tokio::spawn(async move {
56        while let Some(msg) = rx.recv().await {
57            let ws_msg = match msg {
58                WsMessage::Text(text) => Message::Text(text.into()),
59                WsMessage::Binary(data) => Message::Binary(data.into()),
60            };
61            if ws_sink.send(ws_msg).await.is_err() {
62                break;
63            }
64        }
65    });
66
67    // Wait for Hello message
68    let hello = match wait_for_hello(&mut ws_stream).await {
69        Some(h) => h,
70        None => {
71            warn!(%source_ip, "Agent disconnected before Hello");
72            send_task.abort();
73            return;
74        }
75    };
76
77    // Validate authentication with rate limiting
78    match state.validate_auth(&source_ip, hello.auth_token.as_deref()) {
79        AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired => {
80            // Auth passed, continue
81        }
82        AuthResult::Banned { remaining } => {
83            warn!(
84                %source_ip,
85                remaining_secs = remaining.as_secs(),
86                "IP is temporarily banned due to too many failed auth attempts"
87            );
88            send_error_and_close(
89                &tx,
90                "BANNED",
91                &format!(
92                    "Too many failed authentication attempts. Try again in {} seconds.",
93                    remaining.as_secs()
94                ),
95            ).await;
96            send_task.abort();
97            return;
98        }
99        AuthResult::Invalid(reason) => {
100            warn!(%source_ip, %reason, "Authentication failed");
101            send_error_and_close(&tx, "AUTH_FAILED", &reason).await;
102            send_task.abort();
103            return;
104        }
105    }
106
107    // Check rate limiting
108    if !state.can_create_tunnel(&source_ip) {
109        let max = state.config.max_tunnels_per_ip;
110        warn!(%source_ip, max_tunnels = max, "Rate limit exceeded");
111        send_error_and_close(
112            &tx,
113            "RATE_LIMITED",
114            &format!("Maximum tunnels ({}) per IP exceeded", max),
115        ).await;
116        send_task.abort();
117        return;
118    }
119
120    // Generate tunnel ID - enforce random IDs if custom IDs are not allowed
121    let tunnel_id = if state.config.allow_custom_ids {
122        hello.requested_id.unwrap_or_else(generate_tunnel_id)
123    } else {
124        // Ignore requested_id for public relays (prevent subdomain squatting)
125        if hello.requested_id.is_some() {
126            tracing::debug!(%source_ip, "Custom tunnel ID requested but not allowed, generating random ID");
127        }
128        generate_tunnel_id()
129    };
130
131    // Create tunnel connection
132    let tunnel = TunnelConnection::new(tunnel_id.clone(), tx.clone(), source_ip.clone());
133    let tunnel_url = state.tunnel_url(&tunnel_id);
134
135    // Register tunnel
136    let tunnel = state.register_tunnel(tunnel);
137    info!(
138        tunnel_id = %tunnel_id,
139        url = %tunnel_url,
140        ip_tunnel_count = state.tunnel_count_for_ip(&source_ip),
141        "Tunnel registered"
142    );
143
144    // Send Welcome message (as text for control messages)
145    let welcome = ControlMessage::Welcome {
146        tunnel_id: tunnel_id.clone(),
147        tunnel_url: tunnel_url.clone(),
148    };
149    let welcome_json = serde_json::to_string(&welcome).expect("Welcome serialization should not fail");
150    if tx.send(WsMessage::Text(welcome_json)).await.is_err() {
151        error!(tunnel_id = %tunnel_id, "Failed to send Welcome");
152        state.remove_tunnel(&tunnel_id);
153        return;
154    }
155
156    // Main message loop
157    while let Some(result) = ws_stream.next().await {
158        match result {
159            Ok(Message::Binary(data)) => {
160                if let Err(e) = handle_agent_message(&tunnel, &data).await {
161                    warn!(tunnel_id = %tunnel_id, error = %e, "Error handling agent message");
162                }
163            }
164            Ok(Message::Text(text)) => {
165                // Control messages as JSON text
166                if let Err(e) = handle_control_message(&tunnel, &state, text.as_bytes()).await {
167                    warn!(tunnel_id = %tunnel_id, error = %e, "Error handling control message");
168                }
169            }
170            Ok(Message::Ping(data)) => {
171                // Respond with Pong (binary data)
172                if tx.send(WsMessage::Binary(data.to_vec())).await.is_err() {
173                    break;
174                }
175            }
176            Ok(Message::Close(_)) => {
177                info!(tunnel_id = %tunnel_id, "Agent closed connection");
178                break;
179            }
180            Err(e) => {
181                warn!(tunnel_id = %tunnel_id, error = %e, "WebSocket error");
182                break;
183            }
184            _ => {}
185        }
186    }
187
188    // Cleanup
189    info!(tunnel_id = %tunnel_id, "Tunnel disconnected, cleaning up");
190    tunnel.cancel_all_requests();
191    state.remove_tunnel(&tunnel_id);
192    send_task.abort();
193}
194
195/// Parsed Hello message fields
196struct HelloMessage {
197    requested_id: Option<String>,
198    auth_token: Option<String>,
199}
200
201/// Wait for the Hello message from the agent
202async fn wait_for_hello(
203    stream: &mut futures_util::stream::SplitStream<WebSocket>,
204) -> Option<HelloMessage> {
205    // Wait for first message with timeout
206    let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await;
207
208    match timeout {
209        Ok(Some(Ok(Message::Text(text)))) => {
210            match WireMessage::decode_control(text.as_bytes()) {
211                Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
212                    if version != PROTOCOL_VERSION {
213                        warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
214                    }
215                    Some(HelloMessage { requested_id, auth_token })
216                }
217                Ok(_) => {
218                    warn!("Expected Hello message, got different control message");
219                    None
220                }
221                Err(e) => {
222                    warn!(error = %e, "Failed to parse Hello message");
223                    None
224                }
225            }
226        }
227        Ok(Some(Ok(Message::Binary(data)))) => {
228            // Try to parse as control message
229            match WireMessage::decode_control(&data) {
230                Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
231                    if version != PROTOCOL_VERSION {
232                        warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
233                    }
234                    Some(HelloMessage { requested_id, auth_token })
235                }
236                _ => {
237                    warn!("Expected Hello message as first message");
238                    None
239                }
240            }
241        }
242        _ => None,
243    }
244}
245
246/// Handle a binary message from the agent (encrypted data)
247async fn handle_agent_message(tunnel: &TunnelConnection, data: &[u8]) -> anyhow::Result<()> {
248    // Try the new format with routing header first
249    let (request_id_from_header, payload) =
250        if let Ok((msg_type, req_id, encrypted_payload)) =
251            WireMessage::decode_encrypted_with_routing(data)
252        {
253            if msg_type != message_type::ENCRYPTED_RESPONSE {
254                tracing::debug!(msg_type, "Ignoring non-response message type");
255                return Ok(());
256            }
257            (Some(req_id.to_string()), encrypted_payload)
258        } else {
259            // Fall back to old format (no routing header)
260            let (msg_type, payload) = WireMessage::decode_encrypted(data)
261                .map_err(|e| anyhow::anyhow!("Invalid wire message: {}", e))?;
262            if msg_type != message_type::ENCRYPTED_RESPONSE {
263                tracing::debug!(msg_type, "Ignoring non-response message type");
264                return Ok(());
265            }
266            (None, payload)
267        };
268
269    tracing::debug!(
270        tunnel_id = %tunnel.tunnel_id,
271        request_id_header = ?request_id_from_header,
272        payload_len = payload.len(),
273        "Received response from tunnel client"
274    );
275
276    // Try to parse the payload as JSON (for unencrypted mode or after decryption)
277    // In full E2E mode with remote CLI client, we'd forward the encrypted blob directly
278    if let Ok(data_msg) = serde_json::from_slice::<codive_tunnel::DataMessage>(payload) {
279        route_data_message(tunnel, data_msg).await;
280    } else if let Some(req_id) = request_id_from_header {
281        // Payload is encrypted - we have request_id from header for routing
282        // For now, create an error response since relay can't decrypt
283        // In full E2E mode, we'd forward encrypted payload to E2E clients
284        tracing::warn!(
285            tunnel_id = %tunnel.tunnel_id,
286            request_id = %req_id,
287            "Received encrypted payload, relay cannot decrypt (E2E mode)"
288        );
289        let error_msg = codive_tunnel::DataMessage::RequestError {
290            request_id: Some(req_id.clone()),
291            code: "E2E_ENCRYPTED".to_string(),
292            message: "Response is end-to-end encrypted. Use E2E client to decrypt.".to_string(),
293        };
294        route_data_message(tunnel, error_msg).await;
295    } else {
296        warn!(
297            tunnel_id = %tunnel.tunnel_id,
298            "Failed to parse response payload and no routing header"
299        );
300    }
301
302    Ok(())
303}
304
305/// Route a parsed data message to the appropriate pending request
306async fn route_data_message(tunnel: &TunnelConnection, data_msg: codive_tunnel::DataMessage) {
307    match data_msg {
308        codive_tunnel::DataMessage::HttpResponse { ref request_id, streaming, .. } => {
309            tracing::info!(
310                tunnel_id = %tunnel.tunnel_id,
311                request_id = %request_id,
312                streaming = %streaming,
313                "Routing response to pending request"
314            );
315            let req_id = request_id.clone();
316            if streaming {
317                // For streaming responses, use send_chunk to keep the request alive
318                // Don't use complete_request as that removes the pending request
319                tunnel.send_chunk(&req_id, data_msg).await;
320            } else if !tunnel.complete_request(&req_id, data_msg.clone()) {
321                // Fallback for edge cases
322                tunnel.send_chunk(&req_id, data_msg).await;
323            }
324        }
325        codive_tunnel::DataMessage::HttpResponseChunk { ref request_id, is_final, .. } => {
326            tracing::debug!(
327                tunnel_id = %tunnel.tunnel_id,
328                request_id = %request_id,
329                is_final = %is_final,
330                "Routing chunk to streaming request"
331            );
332            let req_id = request_id.clone();
333            if !tunnel.send_chunk(&req_id, data_msg).await {
334                warn!(request_id = %req_id, "No streaming request found for chunk");
335            }
336            if is_final {
337                tunnel.complete_streaming_request(&req_id);
338            }
339        }
340        codive_tunnel::DataMessage::RequestError { ref request_id, .. } => {
341            if let Some(ref req_id) = request_id {
342                tracing::debug!(
343                    tunnel_id = %tunnel.tunnel_id,
344                    request_id = %req_id,
345                    "Routing error to pending request"
346                );
347                let req_id = req_id.clone();
348                if !tunnel.complete_request(&req_id, data_msg.clone()) {
349                    tunnel.send_chunk(&req_id, data_msg).await;
350                    tunnel.complete_streaming_request(&req_id);
351                }
352            }
353        }
354        _ => {
355            warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected data message type in response");
356        }
357    }
358}
359
360/// Handle a control message from the agent
361async fn handle_control_message(
362    tunnel: &TunnelConnection,
363    _state: &RelayState,
364    data: &[u8],
365) -> anyhow::Result<()> {
366    let msg = WireMessage::decode_control(data)?;
367
368    match msg {
369        ControlMessage::Ping { timestamp } => {
370            tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received ping, sending pong");
371            // Respond with Pong
372            let pong = ControlMessage::Pong { timestamp };
373            let pong_json = serde_json::to_string(&pong)?;
374            let _ = tunnel.ws_sender.send(WsMessage::Text(pong_json)).await;
375        }
376        ControlMessage::Pong { timestamp } => {
377            tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received pong");
378        }
379        ControlMessage::Close { reason } => {
380            info!(tunnel_id = %tunnel.tunnel_id, reason, "Agent requested close");
381        }
382        _ => {
383            warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected control message from agent");
384        }
385    }
386
387    Ok(())
388}