1use super::TunnelInfo;
2use anyhow::{Context, Result};
3use futures::{SinkExt, StreamExt};
4use reqwest;
5use serde::{Deserialize, Serialize};
6use socket2::{SockRef, TcpKeepalive};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tokio_tungstenite::{connect_async, tungstenite::Message};
13use tracing::{error, info, warn};
14use url::Url;
15
16fn secure_compare(a: &str, b: &str) -> bool {
18 use std::collections::hash_map::DefaultHasher;
19 use std::hash::{Hash, Hasher};
20
21 let mut hasher_a = DefaultHasher::new();
22 a.hash(&mut hasher_a);
23 let hash_a = hasher_a.finish();
24
25 let mut hasher_b = DefaultHasher::new();
26 b.hash(&mut hasher_b);
27 let hash_b = hasher_b.finish();
28
29 hash_a == hash_b
30}
31
32const WORKER_URL: &str = "https://cloudflare-tunnel-proxy.michael-neale.workers.dev";
33const IDLE_TIMEOUT_SECS: u64 = 300;
34const CONNECTION_TIMEOUT_SECS: u64 = 30;
35const MAX_WS_SIZE: usize = 900_000;
36
37fn get_worker_url() -> String {
38 std::env::var("ASTER_TUNNEL_WORKER_URL")
39 .ok()
40 .unwrap_or_else(|| WORKER_URL.to_string())
41}
42
43type WebSocketSender = Arc<
44 RwLock<
45 Option<
46 futures::stream::SplitSink<
47 tokio_tungstenite::WebSocketStream<
48 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49 >,
50 Message,
51 >,
52 >,
53 >,
54>;
55
56#[derive(Debug, Serialize, Deserialize)]
57struct TunnelMessage {
58 #[serde(rename = "requestId")]
59 request_id: String,
60 method: String,
61 path: String,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 headers: Option<HashMap<String, String>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 body: Option<String>,
66}
67
68#[derive(Debug, Serialize)]
69struct TunnelResponse {
70 #[serde(rename = "requestId")]
71 request_id: String,
72 status: u16,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 headers: Option<HashMap<String, String>>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 body: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 error: Option<String>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 #[serde(rename = "chunkIndex")]
81 chunk_index: Option<usize>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 #[serde(rename = "totalChunks")]
84 total_chunks: Option<usize>,
85 #[serde(rename = "isChunked")]
86 is_chunked: bool,
87 #[serde(rename = "isStreaming")]
88 is_streaming: bool,
89 #[serde(rename = "isFirstChunk")]
90 is_first_chunk: bool,
91 #[serde(rename = "isLastChunk")]
92 is_last_chunk: bool,
93}
94
95fn validate_and_build_request(
96 client: &reqwest::Client,
97 url: &str,
98 message: &TunnelMessage,
99 tunnel_secret: &str,
100 server_secret: &str,
101) -> Result<reqwest::RequestBuilder> {
102 let incoming_secret = message
103 .headers
104 .as_ref()
105 .and_then(|h| {
106 h.iter()
107 .find(|(k, _)| k.eq_ignore_ascii_case("x-secret-key"))
108 .map(|(_, v)| v)
109 })
110 .ok_or_else(|| anyhow::anyhow!("Missing tunnel secret header"))?;
111
112 if !secure_compare(incoming_secret, tunnel_secret) {
113 anyhow::bail!("Invalid tunnel secret");
114 }
115
116 let mut request_builder = match message.method.as_str() {
117 "GET" => client.get(url),
118 "POST" => client.post(url),
119 "PUT" => client.put(url),
120 "DELETE" => client.delete(url),
121 "PATCH" => client.patch(url),
122 _ => client.get(url),
123 };
124
125 if let Some(headers) = &message.headers {
126 for (key, value) in headers {
127 if key.eq_ignore_ascii_case("x-secret-key") {
128 continue;
129 }
130 request_builder = request_builder.header(key, value);
131 }
132 }
133
134 request_builder = request_builder.header("X-Secret-Key", server_secret);
135
136 if let Some(body) = &message.body {
137 if message.method != "GET" && message.method != "HEAD" {
138 request_builder = request_builder.body(body.clone());
139 }
140 }
141
142 Ok(request_builder)
143}
144
145async fn handle_streaming_response(
146 response: reqwest::Response,
147 status: u16,
148 headers_map: HashMap<String, String>,
149 request_id: String,
150 message_path: String,
151 ws_tx: WebSocketSender,
152) -> Result<()> {
153 info!("← {} {} [{}] (streaming)", status, message_path, request_id);
154
155 let mut stream = response.bytes_stream();
156 let mut chunk_index = 0;
157 let mut is_first_chunk = true;
158
159 while let Some(chunk_result) = stream.next().await {
160 match chunk_result {
161 Ok(chunk) => {
162 let chunk_str = String::from_utf8_lossy(&chunk).to_string();
163 let tunnel_response = TunnelResponse {
164 request_id: request_id.clone(),
165 status,
166 headers: if is_first_chunk {
167 Some(headers_map.clone())
168 } else {
169 None
170 },
171 body: Some(chunk_str),
172 error: None,
173 chunk_index: Some(chunk_index),
174 total_chunks: None,
175 is_chunked: false,
176 is_streaming: true,
177 is_first_chunk,
178 is_last_chunk: false,
179 };
180 send_response(ws_tx.clone(), tunnel_response).await?;
181 chunk_index += 1;
182 is_first_chunk = false;
183 }
184 Err(e) => {
185 error!("Error reading stream chunk: {}", e);
186 break;
187 }
188 }
189 }
190
191 let tunnel_response = TunnelResponse {
192 request_id: request_id.clone(),
193 status,
194 headers: None,
195 body: Some(String::new()),
196 error: None,
197 chunk_index: Some(chunk_index),
198 total_chunks: None,
199 is_chunked: false,
200 is_streaming: true,
201 is_first_chunk: false,
202 is_last_chunk: true,
203 };
204 send_response(ws_tx, tunnel_response).await?;
205 info!(
206 "← {} {} [{}] (complete, {} chunks)",
207 status, message_path, request_id, chunk_index
208 );
209 Ok(())
210}
211
212async fn handle_chunked_response(
213 body: String,
214 status: u16,
215 headers_map: HashMap<String, String>,
216 request_id: String,
217 message_path: String,
218 ws_tx: WebSocketSender,
219) -> Result<()> {
220 let total_chunks = body.len().div_ceil(MAX_WS_SIZE);
221 info!(
222 "← {} {} [{}] ({} bytes, {} chunks)",
223 status,
224 message_path,
225 request_id,
226 body.len(),
227 total_chunks
228 );
229
230 for (i, chunk) in body.as_bytes().chunks(MAX_WS_SIZE).enumerate() {
231 let chunk_str = String::from_utf8_lossy(chunk).to_string();
232 let tunnel_response = TunnelResponse {
233 request_id: request_id.clone(),
234 status,
235 headers: if i == 0 {
236 Some(headers_map.clone())
237 } else {
238 None
239 },
240 body: Some(chunk_str),
241 error: None,
242 chunk_index: Some(i),
243 total_chunks: Some(total_chunks),
244 is_chunked: true,
245 is_streaming: false,
246 is_first_chunk: false,
247 is_last_chunk: false,
248 };
249 send_response(ws_tx.clone(), tunnel_response).await?;
250 }
251 Ok(())
252}
253
254async fn handle_request(
255 message: TunnelMessage,
256 port: u16,
257 ws_tx: WebSocketSender,
258 tunnel_secret: String,
259 server_secret: String,
260) -> Result<()> {
261 let request_id = message.request_id.clone();
262
263 let client = reqwest::Client::new();
264 let url = format!("http://127.0.0.1:{}{}", port, message.path);
265
266 let request_builder =
267 match validate_and_build_request(&client, &url, &message, &tunnel_secret, &server_secret) {
268 Ok(builder) => builder,
269 Err(e) => {
270 error!("✗ Authentication error [{}]: {}", request_id, e);
271 let error_response = TunnelResponse {
272 request_id,
273 status: 401,
274 headers: None,
275 body: None,
276 error: Some(e.to_string()),
277 chunk_index: None,
278 total_chunks: None,
279 is_chunked: false,
280 is_streaming: false,
281 is_first_chunk: false,
282 is_last_chunk: false,
283 };
284 send_response(ws_tx, error_response).await?;
285 return Ok(());
286 }
287 };
288
289 let response = match request_builder.send().await {
290 Ok(resp) => resp,
291 Err(e) => {
292 error!("✗ Request error [{}]: {}", request_id, e);
293 let error_response = TunnelResponse {
294 request_id,
295 status: 500,
296 headers: None,
297 body: None,
298 error: Some(e.to_string()),
299 chunk_index: None,
300 total_chunks: None,
301 is_chunked: false,
302 is_streaming: false,
303 is_first_chunk: false,
304 is_last_chunk: false,
305 };
306 send_response(ws_tx, error_response).await?;
307 return Ok(());
308 }
309 };
310
311 let status = response.status().as_u16();
312 let headers_map: HashMap<String, String> = response
314 .headers()
315 .iter()
316 .map(|(k, v)| {
317 (
318 k.as_str().to_lowercase(),
319 v.to_str().unwrap_or("").to_string(),
320 )
321 })
322 .collect();
323
324 let is_streaming = headers_map
325 .get("content-type")
326 .map(|ct| ct.contains("text/event-stream"))
327 .unwrap_or(false);
328
329 if is_streaming {
330 handle_streaming_response(
331 response,
332 status,
333 headers_map,
334 request_id,
335 message.path,
336 ws_tx,
337 )
338 .await?;
339 } else {
340 let body = response.text().await.unwrap_or_default();
341
342 if body.len() > MAX_WS_SIZE {
343 handle_chunked_response(body, status, headers_map, request_id, message.path, ws_tx)
344 .await?;
345 } else {
346 let tunnel_response = TunnelResponse {
347 request_id: request_id.clone(),
348 status,
349 headers: Some(headers_map),
350 body: Some(body),
351 error: None,
352 chunk_index: None,
353 total_chunks: None,
354 is_chunked: false,
355 is_streaming: false,
356 is_first_chunk: false,
357 is_last_chunk: false,
358 };
359 send_response(ws_tx, tunnel_response).await?;
360 }
361 }
362
363 Ok(())
364}
365
366async fn send_response(ws_tx: WebSocketSender, response: TunnelResponse) -> Result<()> {
367 let json = serde_json::to_string(&response)?;
368 if let Some(tx) = ws_tx.write().await.as_mut() {
369 tx.send(Message::Text(json.into()))
370 .await
371 .context("Failed to send response")?;
372 }
373 Ok(())
374}
375
376fn configure_tcp_keepalive(
377 stream: &tokio_tungstenite::WebSocketStream<
378 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
379 >,
380) {
381 let tcp_stream = stream.get_ref().get_ref();
382 let socket_ref = SockRef::from(tcp_stream);
383
384 let keepalive = TcpKeepalive::new()
385 .with_time(Duration::from_secs(30))
386 .with_interval(Duration::from_secs(30));
387
388 if let Err(e) = socket_ref.set_tcp_keepalive(&keepalive) {
389 warn!("Failed to set TCP keep-alive: {}", e);
390 } else {
391 info!("✓ TCP keep-alive enabled (30s interval)");
392 }
393}
394
395async fn handle_websocket_messages(
396 mut read: futures::stream::SplitStream<
397 tokio_tungstenite::WebSocketStream<
398 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
399 >,
400 >,
401 ws_tx: WebSocketSender,
402 port: u16,
403 tunnel_secret: String,
404 server_secret: String,
405 last_activity: Arc<RwLock<Instant>>,
406 active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
407) {
408 while let Some(msg) = read.next().await {
409 match msg {
410 Ok(Message::Text(text)) => {
411 *last_activity.write().await = Instant::now();
412
413 match serde_json::from_str::<TunnelMessage>(&text) {
414 Ok(tunnel_msg) => {
415 let ws_tx_clone = ws_tx.clone();
416 let tunnel_secret_clone = tunnel_secret.clone();
417 let server_secret_clone = server_secret.clone();
418 let task = tokio::spawn(async move {
419 if let Err(e) = handle_request(
420 tunnel_msg,
421 port,
422 ws_tx_clone,
423 tunnel_secret_clone,
424 server_secret_clone,
425 )
426 .await
427 {
428 error!("Error handling request: {}", e);
429 }
430 });
431 {
432 let mut tasks = active_tasks.write().await;
433 tasks.retain(|t| !t.is_finished());
434 tasks.push(task);
435 }
436 }
437 Err(e) => {
438 error!("Error parsing tunnel message: {}", e);
439 }
440 }
441 }
442 Ok(Message::Close(_)) => {
443 info!("✗ Connection closed by server");
444 break;
445 }
446 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
447 *last_activity.write().await = Instant::now();
448 }
449 Err(e) => {
450 error!("✗ WebSocket error: {}", e);
451 break;
452 }
453 _ => {}
454 }
455 }
456}
457
458async fn cleanup_connection(
459 ws_tx: WebSocketSender,
460 active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
461) {
462 if let Some(mut tx) = ws_tx.write().await.take() {
463 let _ = tx.close().await;
464 }
465
466 let tasks = active_tasks.write().await.drain(..).collect::<Vec<_>>();
467 info!("Aborting {} active request tasks", tasks.len());
468 for task in tasks {
469 task.abort();
470 }
471}
472
473async fn run_single_connection(
474 port: u16,
475 agent_id: String,
476 tunnel_secret: String,
477 server_secret: String,
478 restart_tx: mpsc::Sender<()>,
479) {
480 let _ = rustls::crypto::ring::default_provider().install_default();
481
482 let worker_url = get_worker_url();
483 let ws_url = worker_url
484 .replace("https://", "wss://")
485 .replace("http://", "ws://");
486
487 let url = format!("{}/connect?agent_id={}", ws_url, agent_id);
488
489 info!("Connecting to {}...", url);
490
491 let ws_stream = match tokio::time::timeout(
492 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
493 connect_async(url.clone()),
494 )
495 .await
496 {
497 Ok(Ok((stream, _))) => {
498 configure_tcp_keepalive(&stream);
499 stream
500 }
501 Ok(Err(e)) => {
502 error!("✗ WebSocket connection error: {}", e);
503 let _ = restart_tx.send(()).await;
504 return;
505 }
506 Err(_) => {
507 error!(
508 "✗ WebSocket connection timeout after {}s",
509 CONNECTION_TIMEOUT_SECS
510 );
511 let _ = restart_tx.send(()).await;
512 return;
513 }
514 };
515
516 info!("✓ Connected as agent: {}", agent_id);
517 info!("✓ Proxying to: http://127.0.0.1:{}", port);
518 let public_url = format!("{}/tunnel/{}", worker_url, agent_id);
519 info!("✓ Public URL: {}", public_url);
520
521 let (write, read) = ws_stream.split();
522 let ws_tx: WebSocketSender = Arc::new(RwLock::new(Some(write)));
523 let last_activity = Arc::new(RwLock::new(Instant::now()));
524 let active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>> = Arc::new(RwLock::new(Vec::new()));
525
526 let last_activity_clone = last_activity.clone();
527 let idle_task = async move {
528 loop {
529 tokio::time::sleep(Duration::from_secs(60)).await;
530 let elapsed = last_activity_clone.read().await.elapsed();
531 if elapsed > Duration::from_secs(IDLE_TIMEOUT_SECS) {
532 warn!(
533 "No activity for {} minutes, forcing reconnect",
534 IDLE_TIMEOUT_SECS / 60
535 );
536 break;
537 }
538 }
539 };
540
541 tokio::select! {
542 _ = idle_task => {
543 info!("✗ Idle timeout triggered");
544 }
545 _ = handle_websocket_messages(
546 read,
547 ws_tx.clone(),
548 port,
549 tunnel_secret.clone(),
550 server_secret.clone(),
551 last_activity,
552 active_tasks.clone()
553 ) => {
554 info!("✗ Connection ended");
555 }
556 }
557
558 cleanup_connection(ws_tx, active_tasks).await;
559
560 let _ = restart_tx.send(()).await;
561}
562
563pub async fn start(
564 port: u16,
565 tunnel_secret: String,
566 server_secret: String,
567 agent_id: String,
568 handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
569 restart_tx: mpsc::Sender<()>,
570) -> Result<TunnelInfo> {
571 let worker_url = get_worker_url();
572
573 let agent_id_clone = agent_id.clone();
574 let tunnel_secret_clone = tunnel_secret.clone();
575 let server_secret_clone = server_secret;
576
577 let task = tokio::spawn(async move {
578 run_single_connection(
579 port,
580 agent_id_clone,
581 tunnel_secret_clone,
582 server_secret_clone,
583 restart_tx,
584 )
585 .await;
586 });
587
588 *handle.write().await = Some(task);
589
590 let public_url = format!("{}/tunnel/{}", worker_url, agent_id);
591 let hostname = Url::parse(&worker_url)?
592 .host_str()
593 .unwrap_or("")
594 .to_string();
595
596 Ok(TunnelInfo {
597 state: super::TunnelState::Running,
598 url: public_url,
599 hostname,
600 secret: tunnel_secret,
601 })
602}
603
604pub async fn stop(handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>) {
605 if let Some(task) = handle.write().await.take() {
606 task.abort();
607 info!("Lapstone tunnel stopped");
608 }
609}