Skip to main content

aster_server/tunnel/
lapstone.rs

1use super::TunnelInfo;
2use anyhow::{Context, Result};
3use futures::{SinkExt, StreamExt};
4use reqwest;
5use serde::{Deserialize, Serialize};
6use socket2::{SockRef, TcpKeepalive};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tokio_tungstenite::{connect_async, tungstenite::Message};
13use tracing::{error, info, warn};
14use url::Url;
15
16/// Constant-time comparison using hash to prevent timing attacks
17fn secure_compare(a: &str, b: &str) -> bool {
18    use std::collections::hash_map::DefaultHasher;
19    use std::hash::{Hash, Hasher};
20
21    let mut hasher_a = DefaultHasher::new();
22    a.hash(&mut hasher_a);
23    let hash_a = hasher_a.finish();
24
25    let mut hasher_b = DefaultHasher::new();
26    b.hash(&mut hasher_b);
27    let hash_b = hasher_b.finish();
28
29    hash_a == hash_b
30}
31
32const WORKER_URL: &str = "https://cloudflare-tunnel-proxy.michael-neale.workers.dev";
33const IDLE_TIMEOUT_SECS: u64 = 300;
34const CONNECTION_TIMEOUT_SECS: u64 = 30;
35const MAX_WS_SIZE: usize = 900_000;
36
37fn get_worker_url() -> String {
38    std::env::var("ASTER_TUNNEL_WORKER_URL")
39        .ok()
40        .unwrap_or_else(|| WORKER_URL.to_string())
41}
42
43type WebSocketSender = Arc<
44    RwLock<
45        Option<
46            futures::stream::SplitSink<
47                tokio_tungstenite::WebSocketStream<
48                    tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49                >,
50                Message,
51            >,
52        >,
53    >,
54>;
55
56#[derive(Debug, Serialize, Deserialize)]
57struct TunnelMessage {
58    #[serde(rename = "requestId")]
59    request_id: String,
60    method: String,
61    path: String,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    headers: Option<HashMap<String, String>>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    body: Option<String>,
66}
67
68#[derive(Debug, Serialize)]
69struct TunnelResponse {
70    #[serde(rename = "requestId")]
71    request_id: String,
72    status: u16,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    headers: Option<HashMap<String, String>>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    body: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    error: Option<String>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    #[serde(rename = "chunkIndex")]
81    chunk_index: Option<usize>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    #[serde(rename = "totalChunks")]
84    total_chunks: Option<usize>,
85    #[serde(rename = "isChunked")]
86    is_chunked: bool,
87    #[serde(rename = "isStreaming")]
88    is_streaming: bool,
89    #[serde(rename = "isFirstChunk")]
90    is_first_chunk: bool,
91    #[serde(rename = "isLastChunk")]
92    is_last_chunk: bool,
93}
94
95fn validate_and_build_request(
96    client: &reqwest::Client,
97    url: &str,
98    message: &TunnelMessage,
99    tunnel_secret: &str,
100    server_secret: &str,
101) -> Result<reqwest::RequestBuilder> {
102    let incoming_secret = message
103        .headers
104        .as_ref()
105        .and_then(|h| {
106            h.iter()
107                .find(|(k, _)| k.eq_ignore_ascii_case("x-secret-key"))
108                .map(|(_, v)| v)
109        })
110        .ok_or_else(|| anyhow::anyhow!("Missing tunnel secret header"))?;
111
112    if !secure_compare(incoming_secret, tunnel_secret) {
113        anyhow::bail!("Invalid tunnel secret");
114    }
115
116    let mut request_builder = match message.method.as_str() {
117        "GET" => client.get(url),
118        "POST" => client.post(url),
119        "PUT" => client.put(url),
120        "DELETE" => client.delete(url),
121        "PATCH" => client.patch(url),
122        _ => client.get(url),
123    };
124
125    if let Some(headers) = &message.headers {
126        for (key, value) in headers {
127            if key.eq_ignore_ascii_case("x-secret-key") {
128                continue;
129            }
130            request_builder = request_builder.header(key, value);
131        }
132    }
133
134    request_builder = request_builder.header("X-Secret-Key", server_secret);
135
136    if let Some(body) = &message.body {
137        if message.method != "GET" && message.method != "HEAD" {
138            request_builder = request_builder.body(body.clone());
139        }
140    }
141
142    Ok(request_builder)
143}
144
145async fn handle_streaming_response(
146    response: reqwest::Response,
147    status: u16,
148    headers_map: HashMap<String, String>,
149    request_id: String,
150    message_path: String,
151    ws_tx: WebSocketSender,
152) -> Result<()> {
153    info!("← {} {} [{}] (streaming)", status, message_path, request_id);
154
155    let mut stream = response.bytes_stream();
156    let mut chunk_index = 0;
157    let mut is_first_chunk = true;
158
159    while let Some(chunk_result) = stream.next().await {
160        match chunk_result {
161            Ok(chunk) => {
162                let chunk_str = String::from_utf8_lossy(&chunk).to_string();
163                let tunnel_response = TunnelResponse {
164                    request_id: request_id.clone(),
165                    status,
166                    headers: if is_first_chunk {
167                        Some(headers_map.clone())
168                    } else {
169                        None
170                    },
171                    body: Some(chunk_str),
172                    error: None,
173                    chunk_index: Some(chunk_index),
174                    total_chunks: None,
175                    is_chunked: false,
176                    is_streaming: true,
177                    is_first_chunk,
178                    is_last_chunk: false,
179                };
180                send_response(ws_tx.clone(), tunnel_response).await?;
181                chunk_index += 1;
182                is_first_chunk = false;
183            }
184            Err(e) => {
185                error!("Error reading stream chunk: {}", e);
186                break;
187            }
188        }
189    }
190
191    let tunnel_response = TunnelResponse {
192        request_id: request_id.clone(),
193        status,
194        headers: None,
195        body: Some(String::new()),
196        error: None,
197        chunk_index: Some(chunk_index),
198        total_chunks: None,
199        is_chunked: false,
200        is_streaming: true,
201        is_first_chunk: false,
202        is_last_chunk: true,
203    };
204    send_response(ws_tx, tunnel_response).await?;
205    info!(
206        "← {} {} [{}] (complete, {} chunks)",
207        status, message_path, request_id, chunk_index
208    );
209    Ok(())
210}
211
212async fn handle_chunked_response(
213    body: String,
214    status: u16,
215    headers_map: HashMap<String, String>,
216    request_id: String,
217    message_path: String,
218    ws_tx: WebSocketSender,
219) -> Result<()> {
220    let total_chunks = body.len().div_ceil(MAX_WS_SIZE);
221    info!(
222        "← {} {} [{}] ({} bytes, {} chunks)",
223        status,
224        message_path,
225        request_id,
226        body.len(),
227        total_chunks
228    );
229
230    for (i, chunk) in body.as_bytes().chunks(MAX_WS_SIZE).enumerate() {
231        let chunk_str = String::from_utf8_lossy(chunk).to_string();
232        let tunnel_response = TunnelResponse {
233            request_id: request_id.clone(),
234            status,
235            headers: if i == 0 {
236                Some(headers_map.clone())
237            } else {
238                None
239            },
240            body: Some(chunk_str),
241            error: None,
242            chunk_index: Some(i),
243            total_chunks: Some(total_chunks),
244            is_chunked: true,
245            is_streaming: false,
246            is_first_chunk: false,
247            is_last_chunk: false,
248        };
249        send_response(ws_tx.clone(), tunnel_response).await?;
250    }
251    Ok(())
252}
253
254async fn handle_request(
255    message: TunnelMessage,
256    port: u16,
257    ws_tx: WebSocketSender,
258    tunnel_secret: String,
259    server_secret: String,
260) -> Result<()> {
261    let request_id = message.request_id.clone();
262
263    let client = reqwest::Client::new();
264    let url = format!("http://127.0.0.1:{}{}", port, message.path);
265
266    let request_builder =
267        match validate_and_build_request(&client, &url, &message, &tunnel_secret, &server_secret) {
268            Ok(builder) => builder,
269            Err(e) => {
270                error!("✗ Authentication error [{}]: {}", request_id, e);
271                let error_response = TunnelResponse {
272                    request_id,
273                    status: 401,
274                    headers: None,
275                    body: None,
276                    error: Some(e.to_string()),
277                    chunk_index: None,
278                    total_chunks: None,
279                    is_chunked: false,
280                    is_streaming: false,
281                    is_first_chunk: false,
282                    is_last_chunk: false,
283                };
284                send_response(ws_tx, error_response).await?;
285                return Ok(());
286            }
287        };
288
289    let response = match request_builder.send().await {
290        Ok(resp) => resp,
291        Err(e) => {
292            error!("✗ Request error [{}]: {}", request_id, e);
293            let error_response = TunnelResponse {
294                request_id,
295                status: 500,
296                headers: None,
297                body: None,
298                error: Some(e.to_string()),
299                chunk_index: None,
300                total_chunks: None,
301                is_chunked: false,
302                is_streaming: false,
303                is_first_chunk: false,
304                is_last_chunk: false,
305            };
306            send_response(ws_tx, error_response).await?;
307            return Ok(());
308        }
309    };
310
311    let status = response.status().as_u16();
312    // Normalize header names to lowercase per RFC 7230 (HTTP headers are case-insensitive)
313    let headers_map: HashMap<String, String> = response
314        .headers()
315        .iter()
316        .map(|(k, v)| {
317            (
318                k.as_str().to_lowercase(),
319                v.to_str().unwrap_or("").to_string(),
320            )
321        })
322        .collect();
323
324    let is_streaming = headers_map
325        .get("content-type")
326        .map(|ct| ct.contains("text/event-stream"))
327        .unwrap_or(false);
328
329    if is_streaming {
330        handle_streaming_response(
331            response,
332            status,
333            headers_map,
334            request_id,
335            message.path,
336            ws_tx,
337        )
338        .await?;
339    } else {
340        let body = response.text().await.unwrap_or_default();
341
342        if body.len() > MAX_WS_SIZE {
343            handle_chunked_response(body, status, headers_map, request_id, message.path, ws_tx)
344                .await?;
345        } else {
346            let tunnel_response = TunnelResponse {
347                request_id: request_id.clone(),
348                status,
349                headers: Some(headers_map),
350                body: Some(body),
351                error: None,
352                chunk_index: None,
353                total_chunks: None,
354                is_chunked: false,
355                is_streaming: false,
356                is_first_chunk: false,
357                is_last_chunk: false,
358            };
359            send_response(ws_tx, tunnel_response).await?;
360        }
361    }
362
363    Ok(())
364}
365
366async fn send_response(ws_tx: WebSocketSender, response: TunnelResponse) -> Result<()> {
367    let json = serde_json::to_string(&response)?;
368    if let Some(tx) = ws_tx.write().await.as_mut() {
369        tx.send(Message::Text(json.into()))
370            .await
371            .context("Failed to send response")?;
372    }
373    Ok(())
374}
375
376fn configure_tcp_keepalive(
377    stream: &tokio_tungstenite::WebSocketStream<
378        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
379    >,
380) {
381    let tcp_stream = stream.get_ref().get_ref();
382    let socket_ref = SockRef::from(tcp_stream);
383
384    let keepalive = TcpKeepalive::new()
385        .with_time(Duration::from_secs(30))
386        .with_interval(Duration::from_secs(30));
387
388    if let Err(e) = socket_ref.set_tcp_keepalive(&keepalive) {
389        warn!("Failed to set TCP keep-alive: {}", e);
390    } else {
391        info!("✓ TCP keep-alive enabled (30s interval)");
392    }
393}
394
395async fn handle_websocket_messages(
396    mut read: futures::stream::SplitStream<
397        tokio_tungstenite::WebSocketStream<
398            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
399        >,
400    >,
401    ws_tx: WebSocketSender,
402    port: u16,
403    tunnel_secret: String,
404    server_secret: String,
405    last_activity: Arc<RwLock<Instant>>,
406    active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
407) {
408    while let Some(msg) = read.next().await {
409        match msg {
410            Ok(Message::Text(text)) => {
411                *last_activity.write().await = Instant::now();
412
413                match serde_json::from_str::<TunnelMessage>(&text) {
414                    Ok(tunnel_msg) => {
415                        let ws_tx_clone = ws_tx.clone();
416                        let tunnel_secret_clone = tunnel_secret.clone();
417                        let server_secret_clone = server_secret.clone();
418                        let task = tokio::spawn(async move {
419                            if let Err(e) = handle_request(
420                                tunnel_msg,
421                                port,
422                                ws_tx_clone,
423                                tunnel_secret_clone,
424                                server_secret_clone,
425                            )
426                            .await
427                            {
428                                error!("Error handling request: {}", e);
429                            }
430                        });
431                        {
432                            let mut tasks = active_tasks.write().await;
433                            tasks.retain(|t| !t.is_finished());
434                            tasks.push(task);
435                        }
436                    }
437                    Err(e) => {
438                        error!("Error parsing tunnel message: {}", e);
439                    }
440                }
441            }
442            Ok(Message::Close(_)) => {
443                info!("✗ Connection closed by server");
444                break;
445            }
446            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
447                *last_activity.write().await = Instant::now();
448            }
449            Err(e) => {
450                error!("✗ WebSocket error: {}", e);
451                break;
452            }
453            _ => {}
454        }
455    }
456}
457
458async fn cleanup_connection(
459    ws_tx: WebSocketSender,
460    active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>>,
461) {
462    if let Some(mut tx) = ws_tx.write().await.take() {
463        let _ = tx.close().await;
464    }
465
466    let tasks = active_tasks.write().await.drain(..).collect::<Vec<_>>();
467    info!("Aborting {} active request tasks", tasks.len());
468    for task in tasks {
469        task.abort();
470    }
471}
472
473async fn run_single_connection(
474    port: u16,
475    agent_id: String,
476    tunnel_secret: String,
477    server_secret: String,
478    restart_tx: mpsc::Sender<()>,
479) {
480    let _ = rustls::crypto::ring::default_provider().install_default();
481
482    let worker_url = get_worker_url();
483    let ws_url = worker_url
484        .replace("https://", "wss://")
485        .replace("http://", "ws://");
486
487    let url = format!("{}/connect?agent_id={}", ws_url, agent_id);
488
489    info!("Connecting to {}...", url);
490
491    let ws_stream = match tokio::time::timeout(
492        Duration::from_secs(CONNECTION_TIMEOUT_SECS),
493        connect_async(url.clone()),
494    )
495    .await
496    {
497        Ok(Ok((stream, _))) => {
498            configure_tcp_keepalive(&stream);
499            stream
500        }
501        Ok(Err(e)) => {
502            error!("✗ WebSocket connection error: {}", e);
503            let _ = restart_tx.send(()).await;
504            return;
505        }
506        Err(_) => {
507            error!(
508                "✗ WebSocket connection timeout after {}s",
509                CONNECTION_TIMEOUT_SECS
510            );
511            let _ = restart_tx.send(()).await;
512            return;
513        }
514    };
515
516    info!("✓ Connected as agent: {}", agent_id);
517    info!("✓ Proxying to: http://127.0.0.1:{}", port);
518    let public_url = format!("{}/tunnel/{}", worker_url, agent_id);
519    info!("✓ Public URL: {}", public_url);
520
521    let (write, read) = ws_stream.split();
522    let ws_tx: WebSocketSender = Arc::new(RwLock::new(Some(write)));
523    let last_activity = Arc::new(RwLock::new(Instant::now()));
524    let active_tasks: Arc<RwLock<Vec<JoinHandle<()>>>> = Arc::new(RwLock::new(Vec::new()));
525
526    let last_activity_clone = last_activity.clone();
527    let idle_task = async move {
528        loop {
529            tokio::time::sleep(Duration::from_secs(60)).await;
530            let elapsed = last_activity_clone.read().await.elapsed();
531            if elapsed > Duration::from_secs(IDLE_TIMEOUT_SECS) {
532                warn!(
533                    "No activity for {} minutes, forcing reconnect",
534                    IDLE_TIMEOUT_SECS / 60
535                );
536                break;
537            }
538        }
539    };
540
541    tokio::select! {
542        _ = idle_task => {
543            info!("✗ Idle timeout triggered");
544        }
545        _ = handle_websocket_messages(
546            read,
547            ws_tx.clone(),
548            port,
549            tunnel_secret.clone(),
550            server_secret.clone(),
551            last_activity,
552            active_tasks.clone()
553        ) => {
554            info!("✗ Connection ended");
555        }
556    }
557
558    cleanup_connection(ws_tx, active_tasks).await;
559
560    let _ = restart_tx.send(()).await;
561}
562
563pub async fn start(
564    port: u16,
565    tunnel_secret: String,
566    server_secret: String,
567    agent_id: String,
568    handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
569    restart_tx: mpsc::Sender<()>,
570) -> Result<TunnelInfo> {
571    let worker_url = get_worker_url();
572
573    let agent_id_clone = agent_id.clone();
574    let tunnel_secret_clone = tunnel_secret.clone();
575    let server_secret_clone = server_secret;
576
577    let task = tokio::spawn(async move {
578        run_single_connection(
579            port,
580            agent_id_clone,
581            tunnel_secret_clone,
582            server_secret_clone,
583            restart_tx,
584        )
585        .await;
586    });
587
588    *handle.write().await = Some(task);
589
590    let public_url = format!("{}/tunnel/{}", worker_url, agent_id);
591    let hostname = Url::parse(&worker_url)?
592        .host_str()
593        .unwrap_or("")
594        .to_string();
595
596    Ok(TunnelInfo {
597        state: super::TunnelState::Running,
598        url: public_url,
599        hostname,
600        secret: tunnel_secret,
601    })
602}
603
604pub async fn stop(handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>) {
605    if let Some(task) = handle.write().await.take() {
606        task.abort();
607        info!("Lapstone tunnel stopped");
608    }
609}