1use axum::{
4 extract::{
5 ws::{Message, WebSocket, WebSocketUpgrade},
6 ConnectInfo, State,
7 },
8 response::IntoResponse,
9};
10use futures_util::{SinkExt, StreamExt};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14use tracing::{error, info, warn};
15
16use codive_tunnel::{message_type, ControlMessage, WireMessage, PROTOCOL_VERSION};
17
18use crate::state::{AuthResult, RelayState};
19use crate::tunnel::{generate_tunnel_id, TunnelConnection, WsMessage, WsSender};
20
21pub async fn agent_ws_handler(
23 ws: WebSocketUpgrade,
24 State(state): State<Arc<RelayState>>,
25 ConnectInfo(addr): ConnectInfo<SocketAddr>,
26) -> impl IntoResponse {
27 ws.on_upgrade(move |socket| handle_agent_connection(socket, state, addr))
28}
29
30async fn send_error_and_close(
32 tx: &WsSender,
33 code: &str,
34 message: &str,
35) {
36 let error = ControlMessage::Error {
37 code: code.to_string(),
38 message: message.to_string(),
39 };
40 let error_json = serde_json::to_string(&error).unwrap_or_default();
41 let _ = tx.send(WsMessage::Text(error_json)).await;
42}
43
44async fn handle_agent_connection(socket: WebSocket, state: Arc<RelayState>, addr: SocketAddr) {
46 let source_ip = addr.ip().to_string();
47 info!(%source_ip, "Agent connecting");
48
49 let (mut ws_sink, mut ws_stream) = socket.split();
50
51 let (tx, mut rx): (WsSender, _) = mpsc::channel(100);
53
54 let send_task = tokio::spawn(async move {
56 while let Some(msg) = rx.recv().await {
57 let ws_msg = match msg {
58 WsMessage::Text(text) => Message::Text(text.into()),
59 WsMessage::Binary(data) => Message::Binary(data.into()),
60 };
61 if ws_sink.send(ws_msg).await.is_err() {
62 break;
63 }
64 }
65 });
66
67 let hello = match wait_for_hello(&mut ws_stream).await {
69 Some(h) => h,
70 None => {
71 warn!(%source_ip, "Agent disconnected before Hello");
72 send_task.abort();
73 return;
74 }
75 };
76
77 match state.validate_auth(&source_ip, hello.auth_token.as_deref()) {
79 AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired => {
80 }
82 AuthResult::Banned { remaining } => {
83 warn!(
84 %source_ip,
85 remaining_secs = remaining.as_secs(),
86 "IP is temporarily banned due to too many failed auth attempts"
87 );
88 send_error_and_close(
89 &tx,
90 "BANNED",
91 &format!(
92 "Too many failed authentication attempts. Try again in {} seconds.",
93 remaining.as_secs()
94 ),
95 ).await;
96 send_task.abort();
97 return;
98 }
99 AuthResult::Invalid(reason) => {
100 warn!(%source_ip, %reason, "Authentication failed");
101 send_error_and_close(&tx, "AUTH_FAILED", &reason).await;
102 send_task.abort();
103 return;
104 }
105 }
106
107 if !state.can_create_tunnel(&source_ip) {
109 let max = state.config.max_tunnels_per_ip;
110 warn!(%source_ip, max_tunnels = max, "Rate limit exceeded");
111 send_error_and_close(
112 &tx,
113 "RATE_LIMITED",
114 &format!("Maximum tunnels ({}) per IP exceeded", max),
115 ).await;
116 send_task.abort();
117 return;
118 }
119
120 let tunnel_id = if state.config.allow_custom_ids {
122 hello.requested_id.unwrap_or_else(generate_tunnel_id)
123 } else {
124 if hello.requested_id.is_some() {
126 tracing::debug!(%source_ip, "Custom tunnel ID requested but not allowed, generating random ID");
127 }
128 generate_tunnel_id()
129 };
130
131 let tunnel = TunnelConnection::new(tunnel_id.clone(), tx.clone(), source_ip.clone());
133 let tunnel_url = state.tunnel_url(&tunnel_id);
134
135 let tunnel = state.register_tunnel(tunnel);
137 info!(
138 tunnel_id = %tunnel_id,
139 url = %tunnel_url,
140 ip_tunnel_count = state.tunnel_count_for_ip(&source_ip),
141 "Tunnel registered"
142 );
143
144 let welcome = ControlMessage::Welcome {
146 tunnel_id: tunnel_id.clone(),
147 tunnel_url: tunnel_url.clone(),
148 };
149 let welcome_json = serde_json::to_string(&welcome).expect("Welcome serialization should not fail");
150 if tx.send(WsMessage::Text(welcome_json)).await.is_err() {
151 error!(tunnel_id = %tunnel_id, "Failed to send Welcome");
152 state.remove_tunnel(&tunnel_id);
153 return;
154 }
155
156 while let Some(result) = ws_stream.next().await {
158 match result {
159 Ok(Message::Binary(data)) => {
160 if let Err(e) = handle_agent_message(&tunnel, &data).await {
161 warn!(tunnel_id = %tunnel_id, error = %e, "Error handling agent message");
162 }
163 }
164 Ok(Message::Text(text)) => {
165 if let Err(e) = handle_control_message(&tunnel, &state, text.as_bytes()).await {
167 warn!(tunnel_id = %tunnel_id, error = %e, "Error handling control message");
168 }
169 }
170 Ok(Message::Ping(data)) => {
171 if tx.send(WsMessage::Binary(data.to_vec())).await.is_err() {
173 break;
174 }
175 }
176 Ok(Message::Close(_)) => {
177 info!(tunnel_id = %tunnel_id, "Agent closed connection");
178 break;
179 }
180 Err(e) => {
181 warn!(tunnel_id = %tunnel_id, error = %e, "WebSocket error");
182 break;
183 }
184 _ => {}
185 }
186 }
187
188 info!(tunnel_id = %tunnel_id, "Tunnel disconnected, cleaning up");
190 tunnel.cancel_all_requests();
191 state.remove_tunnel(&tunnel_id);
192 send_task.abort();
193}
194
195struct HelloMessage {
197 requested_id: Option<String>,
198 auth_token: Option<String>,
199}
200
201async fn wait_for_hello(
203 stream: &mut futures_util::stream::SplitStream<WebSocket>,
204) -> Option<HelloMessage> {
205 let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), stream.next()).await;
207
208 match timeout {
209 Ok(Some(Ok(Message::Text(text)))) => {
210 match WireMessage::decode_control(text.as_bytes()) {
211 Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
212 if version != PROTOCOL_VERSION {
213 warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
214 }
215 Some(HelloMessage { requested_id, auth_token })
216 }
217 Ok(_) => {
218 warn!("Expected Hello message, got different control message");
219 None
220 }
221 Err(e) => {
222 warn!(error = %e, "Failed to parse Hello message");
223 None
224 }
225 }
226 }
227 Ok(Some(Ok(Message::Binary(data)))) => {
228 match WireMessage::decode_control(&data) {
230 Ok(ControlMessage::Hello { version, requested_id, auth_token }) => {
231 if version != PROTOCOL_VERSION {
232 warn!(version, expected = PROTOCOL_VERSION, "Protocol version mismatch");
233 }
234 Some(HelloMessage { requested_id, auth_token })
235 }
236 _ => {
237 warn!("Expected Hello message as first message");
238 None
239 }
240 }
241 }
242 _ => None,
243 }
244}
245
246async fn handle_agent_message(tunnel: &TunnelConnection, data: &[u8]) -> anyhow::Result<()> {
248 let (request_id_from_header, payload) =
250 if let Ok((msg_type, req_id, encrypted_payload)) =
251 WireMessage::decode_encrypted_with_routing(data)
252 {
253 if msg_type != message_type::ENCRYPTED_RESPONSE {
254 tracing::debug!(msg_type, "Ignoring non-response message type");
255 return Ok(());
256 }
257 (Some(req_id.to_string()), encrypted_payload)
258 } else {
259 let (msg_type, payload) = WireMessage::decode_encrypted(data)
261 .map_err(|e| anyhow::anyhow!("Invalid wire message: {}", e))?;
262 if msg_type != message_type::ENCRYPTED_RESPONSE {
263 tracing::debug!(msg_type, "Ignoring non-response message type");
264 return Ok(());
265 }
266 (None, payload)
267 };
268
269 tracing::debug!(
270 tunnel_id = %tunnel.tunnel_id,
271 request_id_header = ?request_id_from_header,
272 payload_len = payload.len(),
273 "Received response from tunnel client"
274 );
275
276 if let Ok(data_msg) = serde_json::from_slice::<codive_tunnel::DataMessage>(payload) {
279 route_data_message(tunnel, data_msg).await;
280 } else if let Some(req_id) = request_id_from_header {
281 tracing::warn!(
285 tunnel_id = %tunnel.tunnel_id,
286 request_id = %req_id,
287 "Received encrypted payload, relay cannot decrypt (E2E mode)"
288 );
289 let error_msg = codive_tunnel::DataMessage::RequestError {
290 request_id: Some(req_id.clone()),
291 code: "E2E_ENCRYPTED".to_string(),
292 message: "Response is end-to-end encrypted. Use E2E client to decrypt.".to_string(),
293 };
294 route_data_message(tunnel, error_msg).await;
295 } else {
296 warn!(
297 tunnel_id = %tunnel.tunnel_id,
298 "Failed to parse response payload and no routing header"
299 );
300 }
301
302 Ok(())
303}
304
305async fn route_data_message(tunnel: &TunnelConnection, data_msg: codive_tunnel::DataMessage) {
307 match data_msg {
308 codive_tunnel::DataMessage::HttpResponse { ref request_id, streaming, .. } => {
309 tracing::info!(
310 tunnel_id = %tunnel.tunnel_id,
311 request_id = %request_id,
312 streaming = %streaming,
313 "Routing response to pending request"
314 );
315 let req_id = request_id.clone();
316 if streaming {
317 tunnel.send_chunk(&req_id, data_msg).await;
320 } else if !tunnel.complete_request(&req_id, data_msg.clone()) {
321 tunnel.send_chunk(&req_id, data_msg).await;
323 }
324 }
325 codive_tunnel::DataMessage::HttpResponseChunk { ref request_id, is_final, .. } => {
326 tracing::debug!(
327 tunnel_id = %tunnel.tunnel_id,
328 request_id = %request_id,
329 is_final = %is_final,
330 "Routing chunk to streaming request"
331 );
332 let req_id = request_id.clone();
333 if !tunnel.send_chunk(&req_id, data_msg).await {
334 warn!(request_id = %req_id, "No streaming request found for chunk");
335 }
336 if is_final {
337 tunnel.complete_streaming_request(&req_id);
338 }
339 }
340 codive_tunnel::DataMessage::RequestError { ref request_id, .. } => {
341 if let Some(ref req_id) = request_id {
342 tracing::debug!(
343 tunnel_id = %tunnel.tunnel_id,
344 request_id = %req_id,
345 "Routing error to pending request"
346 );
347 let req_id = req_id.clone();
348 if !tunnel.complete_request(&req_id, data_msg.clone()) {
349 tunnel.send_chunk(&req_id, data_msg).await;
350 tunnel.complete_streaming_request(&req_id);
351 }
352 }
353 }
354 _ => {
355 warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected data message type in response");
356 }
357 }
358}
359
360async fn handle_control_message(
362 tunnel: &TunnelConnection,
363 _state: &RelayState,
364 data: &[u8],
365) -> anyhow::Result<()> {
366 let msg = WireMessage::decode_control(data)?;
367
368 match msg {
369 ControlMessage::Ping { timestamp } => {
370 tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received ping, sending pong");
371 let pong = ControlMessage::Pong { timestamp };
373 let pong_json = serde_json::to_string(&pong)?;
374 let _ = tunnel.ws_sender.send(WsMessage::Text(pong_json)).await;
375 }
376 ControlMessage::Pong { timestamp } => {
377 tracing::debug!(tunnel_id = %tunnel.tunnel_id, timestamp, "Received pong");
378 }
379 ControlMessage::Close { reason } => {
380 info!(tunnel_id = %tunnel.tunnel_id, reason, "Agent requested close");
381 }
382 _ => {
383 warn!(tunnel_id = %tunnel.tunnel_id, "Unexpected control message from agent");
384 }
385 }
386
387 Ok(())
388}