use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum RelayFrame {
Ack { subdomain: String },
Deny { reason: String },
Req {
id: String,
method: String,
path: String,
headers: HashMap<String, String>,
body: String, },
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientFrame<'a> {
Register { subdomain: &'a str, token: &'a str },
ResHead { id: &'a str, status: u16, headers: &'a HashMap<String, String> },
ResBody { id: &'a str, data: &'a str },
ResEnd { id: &'a str },
ResErr { id: &'a str, message: &'a str },
}
pub async fn run_live(
relay_ws_url: &str,
subdomain: &str,
token: &str,
local_url: &str,
) -> Result<()> {
let mut backoff = Duration::from_secs(2);
loop {
match connect_and_serve(relay_ws_url, subdomain, token, local_url).await {
Ok(()) => {
println!(" · Tunnel closed.");
break;
}
Err(e) => {
let secs = backoff.as_secs();
println!(" · Tunnel disconnected ({e}), reconnecting in {secs}s…");
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(Duration::from_secs(60));
}
}
}
Ok(())
}
async fn connect_and_serve(
relay_ws_url: &str,
subdomain: &str,
token: &str,
local_url: &str,
) -> Result<()> {
let (ws, _) = tokio_tungstenite::connect_async(relay_ws_url)
.await
.context("Failed to connect to relay")?;
let (sink, mut stream) = ws.split();
let sink = Arc::new(Mutex::new(sink));
{
let frame = ClientFrame::Register { subdomain, token };
let text = serde_json::to_string(&frame)?;
sink.lock().await.send(Message::Text(text)).await?;
}
match stream.next().await {
Some(Ok(Message::Text(text))) => {
let frame: RelayFrame = serde_json::from_str(&text)
.context("Invalid relay response")?;
match frame {
RelayFrame::Ack { subdomain: s } => {
println!(" ✓ Tunnel connected: {s}");
}
RelayFrame::Deny { reason } => {
anyhow::bail!("Relay denied registration: {reason}");
}
_ => anyhow::bail!("Unexpected relay response"),
}
}
_ => anyhow::bail!("No response from relay during registration"),
}
let http = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()?;
while let Some(msg) = stream.next().await {
let msg = msg?;
let text = match msg {
Message::Text(t) => t,
Message::Close(_) => break,
Message::Ping(d) => {
let _ = sink.lock().await.send(Message::Pong(d)).await;
continue;
}
_ => continue,
};
let frame: RelayFrame = match serde_json::from_str(&text) {
Ok(f) => f,
Err(_) => continue,
};
if let RelayFrame::Req { id, method, path, headers, body } = frame {
let sink = sink.clone();
let http = http.clone();
let local = local_url.to_owned();
tokio::spawn(async move {
handle_request(id, method, path, headers, body, &local, &http, &sink).await;
});
}
}
Ok(())
}
type WsSink = Arc<Mutex<futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>
>,
Message,
>>>;
async fn handle_request(
id: String,
method: String,
path: String,
headers: HashMap<String, String>,
body_b64: String,
local_url: &str,
http: &reqwest::Client,
sink: &WsSink,
) {
let send_err = |msg: &str| {
let frame = serde_json::to_string(&ClientFrame::ResErr { id: &id, message: msg }).unwrap();
async move {
let _ = sink.lock().await.send(Message::Text(frame)).await;
}
};
let body_bytes = match B64.decode(&body_b64) {
Ok(b) => b,
Err(e) => { send_err(&format!("base64 decode: {e}")).await; return; }
};
let url = format!("{local_url}{path}");
let method = match method.parse::<reqwest::Method>() {
Ok(m) => m,
Err(e) => { send_err(&format!("bad method: {e}")).await; return; }
};
let mut req = http.request(method, &url);
for (k, v) in &headers {
req = req.header(k, v);
}
req = req.body(body_bytes);
let resp = match req.send().await {
Ok(r) => r,
Err(e) => { send_err(&format!("local request failed: {e}")).await; return; }
};
let status = resp.status().as_u16();
let resp_headers: HashMap<String, String> = resp.headers().iter()
.filter_map(|(k, v)| v.to_str().ok().map(|v| (k.as_str().to_owned(), v.to_owned())))
.collect();
{
let frame = serde_json::to_string(&ClientFrame::ResHead {
id: &id, status, headers: &resp_headers,
}).unwrap();
if sink.lock().await.send(Message::Text(frame)).await.is_err() { return; }
}
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let data = B64.encode(&bytes);
let frame = serde_json::to_string(&ClientFrame::ResBody {
id: &id, data: &data,
}).unwrap();
if sink.lock().await.send(Message::Text(frame)).await.is_err() { return; }
}
Err(e) => {
send_err(&format!("body stream error: {e}")).await;
return;
}
}
}
let frame = serde_json::to_string(&ClientFrame::ResEnd { id: &id }).unwrap();
let _ = sink.lock().await.send(Message::Text(frame)).await;
}