Skip to main content

shunt/
live_relay.rs

1//! Persistent tunnel relay server — `shunt relay serve`.
2//!
3//! Runs on the VPS. Two roles:
4//!   1. Accepts WebSocket tunnel connections from shunt instances (`GET /tunnel`).
5//!   2. Proxies HTTP requests to the right tunnel based on the `Host` header.
6//!
7//! Multi-tenant from day one: all state is keyed by subdomain. Adding users later
8//! is a token-registry change (env var → SQLite), not a protocol change.
9
10use anyhow::Result;
11use axum::{
12    body::Body,
13    extract::{
14        ws::{Message, WebSocket, WebSocketUpgrade},
15        State,
16    },
17    http::{Request, StatusCode},
18    response::{IntoResponse, Response},
19    routing::get,
20    Router,
21};
22use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
23use bytes::Bytes;
24use futures_util::{SinkExt, StreamExt};
25use parking_lot::RwLock;
26use serde::{Deserialize, Serialize};
27use std::{collections::HashMap, sync::Arc};
28use tokio::sync::mpsc;
29use tokio_stream::wrappers::ReceiverStream;
30
31// ---------------------------------------------------------------------------
32// Protocol frames (JSON over WebSocket text messages)
33// ---------------------------------------------------------------------------
34
35/// Client → Relay
36#[derive(Debug, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38enum ClientFrame {
39    Register { subdomain: String, token: String },
40    ResHead { id: String, status: u16, headers: HashMap<String, String> },
41    ResBody { id: String, data: String }, // base64
42    ResEnd  { id: String },
43    ResErr  { id: String, message: String },
44}
45
46/// Relay → Client
47#[derive(Debug, Serialize)]
48#[serde(tag = "type", rename_all = "snake_case")]
49enum RelayFrame<'a> {
50    Ack  { subdomain: &'a str },
51    Deny { reason: &'a str },
52    Req  {
53        id:      &'a str,
54        method:  &'a str,
55        path:    &'a str,
56        headers: &'a HashMap<String, String>,
57        body:    &'a str, // base64
58    },
59}
60
61// ---------------------------------------------------------------------------
62// Relay state
63// ---------------------------------------------------------------------------
64
65#[derive(Clone)]
66pub struct RelayState {
67    tunnels:       Arc<RwLock<HashMap<String, TunnelHandle>>>,
68    allowed_token: Arc<String>,
69    // Future: replace allowed_token with Arc<RwLock<HashMap<token, subdomain>>>
70    //         loaded from SQLite for multi-user.
71}
72
73#[derive(Clone)]
74struct TunnelHandle {
75    tx: mpsc::Sender<TunnelRequest>,
76}
77
78struct TunnelRequest {
79    id:      String,
80    method:  String,
81    path:    String,
82    headers: HashMap<String, String>,
83    body:    Bytes,
84    res_tx:  mpsc::Sender<ResponseChunk>,
85}
86
87#[derive(Debug)]
88enum ResponseChunk {
89    Head   { status: u16, headers: HashMap<String, String> },
90    Body   (Bytes),
91    End,
92    Err    (String),
93}
94
95// ---------------------------------------------------------------------------
96// Entry point
97// ---------------------------------------------------------------------------
98
99pub async fn run_relay_server(port: u16, token: String) -> Result<()> {
100    let state = RelayState {
101        tunnels:       Arc::new(RwLock::new(HashMap::new())),
102        allowed_token: Arc::new(token),
103    };
104
105    let app = Router::new()
106        .route("/tunnel", get(ws_handler))
107        .fallback(proxy_handler)
108        .with_state(state);
109
110    let addr = format!("0.0.0.0:{port}");
111    println!("  ◆ shunt relay  listening on {addr}");
112    let listener = tokio::net::TcpListener::bind(&addr).await?;
113    axum::serve(listener, app).await?;
114    Ok(())
115}
116
117// ---------------------------------------------------------------------------
118// WebSocket tunnel handler
119// ---------------------------------------------------------------------------
120
121async fn ws_handler(
122    ws: WebSocketUpgrade,
123    State(state): State<RelayState>,
124) -> Response {
125    ws.on_upgrade(move |socket| handle_tunnel(socket, state))
126}
127
128async fn handle_tunnel(socket: WebSocket, state: RelayState) {
129    let (mut sink, mut stream) = socket.split();
130
131    // ── Step 1: expect a Register frame ─────────────────────────────────────
132    let subdomain = loop {
133        match stream.next().await {
134            Some(Ok(Message::Text(text))) => {
135                match serde_json::from_str::<ClientFrame>(&text) {
136                    Ok(ClientFrame::Register { subdomain, token }) => {
137                        if token != *state.allowed_token {
138                            let _ = sink.send(Message::Text(
139                                serde_json::to_string(&RelayFrame::Deny { reason: "invalid token" }).unwrap()
140                            )).await;
141                            return;
142                        }
143                        let _ = sink.send(Message::Text(
144                            serde_json::to_string(&RelayFrame::Ack { subdomain: &subdomain }).unwrap()
145                        )).await;
146                        break subdomain;
147                    }
148                    _ => { return; } // unexpected frame before registration
149                }
150            }
151            _ => return,
152        }
153    };
154
155    // ── Step 2: register tunnel ──────────────────────────────────────────────
156    let (tunnel_tx, mut tunnel_rx) = mpsc::channel::<TunnelRequest>(16);
157    state.tunnels.write().insert(subdomain.clone(), TunnelHandle { tx: tunnel_tx });
158    println!("  ◆ tunnel registered: {subdomain}");
159
160    // ── Step 3: sender task (relay → tunnel) ─────────────────────────────────
161    // Pending requests: id → response channel
162    let pending: Arc<RwLock<HashMap<String, mpsc::Sender<ResponseChunk>>>> =
163        Arc::new(RwLock::new(HashMap::new()));
164
165    // Channel for outbound WS messages (from both sender task and request handlers)
166    let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(64);
167
168    // Flush outbound messages to the WS sink
169    let ws_tx_clone = ws_tx.clone();
170    tokio::spawn(async move {
171        while let Some(msg) = ws_rx.recv().await {
172            if sink.send(msg).await.is_err() { break; }
173        }
174    });
175
176    // Forward incoming TunnelRequests to the WS
177    let ws_tx2 = ws_tx_clone.clone();
178    let pending2 = pending.clone();
179    tokio::spawn(async move {
180        while let Some(req) = tunnel_rx.recv().await {
181            pending2.write().insert(req.id.clone(), req.res_tx);
182            let body_b64 = B64.encode(&req.body);
183            let frame = RelayFrame::Req {
184                id:      &req.id,
185                method:  &req.method,
186                path:    &req.path,
187                headers: &req.headers,
188                body:    &body_b64,
189            };
190            let text = serde_json::to_string(&frame).unwrap();
191            if ws_tx2.send(Message::Text(text)).await.is_err() { break; }
192        }
193    });
194
195    // ── Step 4: reader loop (tunnel → relay) ─────────────────────────────────
196    while let Some(Ok(msg)) = stream.next().await {
197        let text = match msg {
198            Message::Text(t) => t,
199            Message::Close(_) => break,
200            _ => continue,
201        };
202        let frame = match serde_json::from_str::<ClientFrame>(&text) {
203            Ok(f) => f,
204            Err(_) => continue,
205        };
206        match frame {
207            ClientFrame::ResHead { id, status, headers } => {
208                // Clone tx before awaiting — parking_lot guard must not cross await points
209                let tx = pending.read().get(&id).cloned();
210                if let Some(tx) = tx {
211                    let _ = tx.send(ResponseChunk::Head { status, headers }).await;
212                }
213            }
214            ClientFrame::ResBody { id, data } => {
215                let tx = pending.read().get(&id).cloned();
216                if let Some(tx) = tx {
217                    if let Ok(bytes) = B64.decode(&data) {
218                        let _ = tx.send(ResponseChunk::Body(Bytes::from(bytes))).await;
219                    }
220                }
221            }
222            ClientFrame::ResEnd { id } => {
223                let tx = pending.write().remove(&id);
224                if let Some(tx) = tx {
225                    let _ = tx.send(ResponseChunk::End).await;
226                }
227            }
228            ClientFrame::ResErr { id, message } => {
229                let tx = pending.write().remove(&id);
230                if let Some(tx) = tx {
231                    let _ = tx.send(ResponseChunk::Err(message)).await;
232                }
233            }
234            ClientFrame::Register { .. } => {} // ignore duplicate registration
235        }
236    }
237
238    // ── Cleanup ───────────────────────────────────────────────────────────────
239    state.tunnels.write().remove(&subdomain);
240    println!("  · tunnel disconnected: {subdomain}");
241}
242
243// ---------------------------------------------------------------------------
244// HTTP proxy handler
245// ---------------------------------------------------------------------------
246
247async fn proxy_handler(
248    State(state): State<RelayState>,
249    req: Request<Body>,
250) -> Response {
251    // Extract subdomain from Host header: "shunt.ramcharan.shop" → "shunt"
252    let subdomain = match extract_subdomain(req.headers()) {
253        Some(s) => s,
254        None => return (StatusCode::BAD_REQUEST, "missing Host header").into_response(),
255    };
256
257    // Find tunnel
258    let handle = state.tunnels.read().get(&subdomain).cloned();
259    let handle = match handle {
260        Some(h) => h,
261        None => return (
262            StatusCode::BAD_GATEWAY,
263            format!("no tunnel connected for '{subdomain}'"),
264        ).into_response(),
265    };
266
267    // Build request fields to send through tunnel
268    let id = uuid::Uuid::new_v4().to_string();
269    let method = req.method().to_string();
270    let path = req.uri().path_and_query()
271        .map(|p| p.as_str().to_owned())
272        .unwrap_or_else(|| "/".to_owned());
273    let headers: HashMap<String, String> = req.headers().iter()
274        .filter_map(|(k, v)| {
275            let key = k.as_str().to_lowercase();
276            // Don't forward hop-by-hop headers
277            if matches!(key.as_str(), "host" | "connection" | "transfer-encoding" | "upgrade") {
278                return None;
279            }
280            v.to_str().ok().map(|v| (key, v.to_owned()))
281        })
282        .collect();
283    let body = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
284        Ok(b) => b,
285        Err(_) => return (StatusCode::BAD_REQUEST, "failed to read body").into_response(),
286    };
287
288    // Send request to tunnel
289    let (res_tx, res_rx) = mpsc::channel::<ResponseChunk>(32);
290    let tunnel_req = TunnelRequest { id, method, path, headers, body, res_tx };
291    if handle.tx.send(tunnel_req).await.is_err() {
292        return (StatusCode::BAD_GATEWAY, "tunnel send failed").into_response();
293    }
294
295    // Wait for response head
296    let mut rx = res_rx;
297    let (status, res_headers) = match rx.recv().await {
298        Some(ResponseChunk::Head { status, headers }) => (status, headers),
299        Some(ResponseChunk::Err(e)) => return (StatusCode::BAD_GATEWAY, e).into_response(),
300        _ => return (StatusCode::BAD_GATEWAY, "no response from tunnel").into_response(),
301    };
302
303    // Build streaming body from remaining chunks
304    let stream = ReceiverStream::new(rx).filter_map(|chunk| async move {
305        match chunk {
306            ResponseChunk::Body(b) => Some(Ok::<_, std::convert::Infallible>(b)),
307            ResponseChunk::End | ResponseChunk::Head { .. } | ResponseChunk::Err(_) => None,
308        }
309    });
310
311    let mut builder = Response::builder()
312        .status(status);
313    for (k, v) in &res_headers {
314        builder = builder.header(k, v);
315    }
316    builder.body(Body::from_stream(stream)).unwrap_or_else(|_| {
317        (StatusCode::INTERNAL_SERVER_ERROR, "response build failed").into_response()
318    })
319}
320
321// ---------------------------------------------------------------------------
322// Helpers
323// ---------------------------------------------------------------------------
324
325fn extract_subdomain(headers: &axum::http::HeaderMap) -> Option<String> {
326    let host = headers.get("host")?.to_str().ok()?;
327    // "shunt.ramcharan.shop" → "shunt"
328    // "shunt.ramcharan.shop:8085" → "shunt"
329    let host = host.split(':').next()?;
330    let subdomain = host.split('.').next()?;
331    if subdomain.is_empty() { return None; }
332    Some(subdomain.to_owned())
333}