1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
5use futures::AsyncWriteExt;
6use futures::future::poll_fn;
7use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
8use tokio::net::TcpStream;
9use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
10use yamux::{Config, Connection, Mode};
11
12use crate::log as tlog;
13use crate::protocol::{self, ControlMsg};
14use crate::proxy;
15use crate::store;
16use crate::update;
17
18const MAX_BACKOFF: Duration = Duration::from_secs(30);
19const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
20const BODY_CAP: usize = 10 * 1024 * 1024;
21
22pub struct TunnelOpts<'a> {
23 pub local_port: u16,
24 pub local_host: &'a str,
25 pub server_addr: &'a str,
26 pub server_port: u16,
27 pub token: &'a str,
28 pub subdomain: Option<&'a str>,
29 pub auth: Option<&'a str>,
30 pub inspect: bool,
31}
32
33pub async fn run(opts: TunnelOpts<'_>) -> Result<()> {
34 store::init();
35 let expected_auth = opts.auth.map(|a| format!("Basic {}", B64.encode(a)));
36 let mut backoff = INITIAL_BACKOFF;
37
38 loop {
39 tlog::info(&format!(
40 "connecting to {}:{}...",
41 opts.server_addr, opts.server_port
42 ));
43
44 let attempt_start = std::time::Instant::now();
45 match connect_and_tunnel(&opts, expected_auth.as_deref()).await {
46 Ok(()) => {
47 tlog::info("connection closed");
48 break;
49 }
50 Err(e) => {
51 tlog::error(&format!("{e:#}"));
52 tlog::info(&format!("reconnecting in {}s...", backoff.as_secs()));
53 tokio::time::sleep(backoff).await;
54 if attempt_start.elapsed() > Duration::from_secs(5) {
56 backoff = INITIAL_BACKOFF;
57 } else {
58 backoff = (backoff * 2).min(MAX_BACKOFF);
59 }
60 }
61 }
62 }
63
64 Ok(())
65}
66
67async fn connect_and_tunnel(opts: &TunnelOpts<'_>, expected_auth: Option<&str>) -> Result<()> {
68 let TunnelOpts {
69 local_port,
70 local_host,
71 server_addr,
72 server_port,
73 token,
74 subdomain,
75 inspect,
76 ..
77 } = opts;
78 let socket = tokio::time::timeout(
79 Duration::from_secs(10),
80 TcpStream::connect(format!("{server_addr}:{server_port}")),
81 )
82 .await
83 .context("connection timed out")?
84 .context("failed to connect to server")?;
85
86 let mut config = Config::default();
87 config.set_split_send_size(16 * 1024);
88
89 let mut connection = Connection::new(socket.compat(), config, Mode::Client);
90
91 let mut control_stream = poll_fn(|cx| connection.poll_new_outbound(cx))
92 .await
93 .context("failed to open control stream")?;
94
95 let (inbound_tx, mut inbound_rx) = tokio::sync::mpsc::channel::<yamux::Stream>(32);
96 tokio::spawn(async move {
97 loop {
98 match poll_fn(|cx| connection.poll_next_inbound(cx)).await {
99 Some(Ok(stream)) => {
100 if inbound_tx.send(stream).await.is_err() {
101 break;
102 }
103 }
104 Some(Err(e)) => {
105 tlog::error(&format!("yamux: {e}"));
106 break;
107 }
108 None => break,
109 }
110 }
111 });
112
113 let nonce = {
114 use rand::Rng;
115 let bytes: [u8; 32] = rand::rng().random();
116 B64.encode(bytes)
117 };
118 let hmac = if token.is_empty() {
119 None
120 } else {
121 Some(protocol::compute_hmac(token, &nonce))
122 };
123 let auth = ControlMsg::Auth {
124 subdomain: subdomain.map(|s| s.to_string()),
125 nonce,
126 hmac,
127 };
128 control_stream.write_all(&auth.encode()?).await?;
129 control_stream.flush().await?;
130
131 let resp = protocol::read_msg(&mut control_stream).await?;
132 let tunnel_url = match resp {
133 ControlMsg::AuthOk { url, .. } => url,
134 ControlMsg::Error { message } => anyhow::bail!("server error: {message}"),
135 _ => anyhow::bail!("unexpected response from server"),
136 };
137
138 let display_url = if *server_addr == "127.0.0.1" || *server_addr == "localhost" {
139 format!("http://{tunnel_url}")
140 } else {
141 format!("https://{tunnel_url}")
142 };
143 tlog::banner(&display_url, local_host, *local_port, *inspect);
144 update::check_in_background();
145
146 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
147 tokio::spawn(async move {
148 tokio::signal::ctrl_c().await.ok();
149 eprintln!();
150 tlog::info("shutting down...");
151 let _ = shutdown_tx.send(());
152 });
153
154 loop {
155 tokio::select! {
156 biased;
157 _ = &mut shutdown_rx => {
158 control_stream.close().await.ok();
159 break;
160 }
161 stream = inbound_rx.recv() => {
162 match stream {
163 Some(s) => { tokio::spawn(handle_stream(s, *local_port, local_host.to_string(), expected_auth.map(|s| s.to_string()), *inspect)); }
164 None => break,
165 }
166 }
167 }
168 }
169
170 Ok(())
171}
172
173async fn handle_stream(
174 stream: yamux::Stream,
175 local_port: u16,
176 local_host: String,
177 expected_auth: Option<String>,
178 inspect: bool,
179) {
180 if let Err(e) = proxy_to_local(
181 stream,
182 local_port,
183 &local_host,
184 expected_auth.as_deref(),
185 inspect,
186 )
187 .await
188 {
189 tlog::error(&format!("proxy: {e}"));
190 }
191}
192
193async fn proxy_to_local(
194 stream: yamux::Stream,
195 local_port: u16,
196 local_host: &str,
197 expected_auth: Option<&str>,
198 inspect: bool,
199) -> Result<()> {
200 let mut tunnel = stream.compat();
201
202 let req_head = proxy::read_http_head(&mut tunnel).await?;
203 let (method, path) = proxy::parse_request_line(&req_head);
204 let start = std::time::Instant::now();
205 let id = store::next_id();
206
207 let head_end = proxy::headers_end(&req_head).unwrap_or(req_head.len());
208 let req_headers = &req_head[..head_end];
209 let body_prefix = req_head[head_end..].to_vec();
210 let content_length = proxy::parse_content_length(req_headers);
211
212 if let Some(expected) = expected_auth {
213 let provided = proxy::extract_authorization(req_headers);
214 if provided.as_deref() != Some(expected) {
215 proxy::write_401(&mut tunnel).await.ok();
216 tlog::request(&method, &path, 401, start.elapsed().as_millis() as u64, id);
217 return Ok(());
218 }
219 }
220
221 let mut local = match TcpStream::connect(format!("{local_host}:{local_port}")).await {
222 Ok(s) => s,
223 Err(_) => {
224 proxy::write_502(&mut tunnel).await.ok();
225 tlog::request(&method, &path, 502, start.elapsed().as_millis() as u64, id);
226 return Ok(());
227 }
228 };
229
230 if inspect {
231 let req_body = read_body_exact(&mut tunnel, body_prefix, content_length).await;
232
233 local.write_all(req_headers).await?;
234 local.write_all(&req_body).await?;
235 local.flush().await?;
236
237 store::store(store::StoredRequest {
238 id,
239 port: local_port,
240 method: method.clone(),
241 path: path.clone(),
242 raw_headers: String::from_utf8_lossy(req_headers).into_owned(),
243 body_b64: B64.encode(&req_body),
244 });
245
246 let resp_head = proxy::read_http_head(&mut local).await.unwrap_or_default();
247 let status = proxy::parse_response_status(&resp_head);
248 let resp_head_end = proxy::headers_end(&resp_head).unwrap_or(resp_head.len());
249 let resp_headers = &resp_head[..resp_head_end];
250 let resp_already = resp_head[resp_head_end..].to_vec();
251
252 let resp_cl = proxy::parse_content_length(resp_headers);
253 let resp_body = if resp_cl > 0 {
254 read_body_exact(&mut local, resp_already, resp_cl).await
255 } else if proxy::is_chunked(resp_headers) {
256 read_chunked(&mut local, resp_already).await
257 } else {
258 resp_already
259 };
260
261 let out_headers = if proxy::is_chunked(resp_headers) {
262 rebuild_resp_headers(resp_headers, resp_body.len())
263 } else {
264 resp_headers.to_vec()
265 };
266
267 tunnel.write_all(&out_headers).await?;
268 tunnel.write_all(&resp_body).await?;
269 tunnel.flush().await.ok();
270
271 let elapsed = start.elapsed().as_millis() as u64;
272 tlog::request(&method, &path, status, elapsed, id);
273
274 let req_raw = String::from_utf8_lossy(req_headers);
275 let req_body_str = String::from_utf8_lossy(&req_body);
276 tlog::inspect_request(id, &req_raw, &req_body_str);
277
278 let resp_raw = String::from_utf8_lossy(&out_headers);
279 let resp_body_str = String::from_utf8_lossy(&resp_body);
280 tlog::inspect_response(status, &resp_raw, &resp_body_str, id);
281 } else {
282 local.write_all(req_headers).await?;
283 local.write_all(&body_prefix).await?;
284 local.flush().await?;
285
286 store::store(store::StoredRequest {
289 id,
290 port: local_port,
291 method: method.clone(),
292 path: path.clone(),
293 raw_headers: String::from_utf8_lossy(req_headers).into_owned(),
294 body_b64: B64.encode(&body_prefix),
295 });
296
297 let mut peek = [0u8; 512];
298 let n = local.read(&mut peek).await.unwrap_or(0);
299 let status = proxy::parse_response_status(&peek[..n]);
300 tunnel.write_all(&peek[..n]).await?;
301 tokio::io::copy_bidirectional(&mut local, &mut tunnel)
302 .await
303 .ok();
304 tlog::request(
305 &method,
306 &path,
307 status,
308 start.elapsed().as_millis() as u64,
309 id,
310 );
311 }
312
313 Ok(())
314}
315
316async fn read_body_exact<R: tokio::io::AsyncRead + Unpin>(
317 reader: &mut R,
318 mut buf: Vec<u8>,
319 total: usize,
320) -> Vec<u8> {
321 let target = total.min(BODY_CAP);
322 let mut tmp = [0u8; 8192];
323 while buf.len() < target {
324 let want = (target - buf.len()).min(8192);
325 match reader.read(&mut tmp[..want]).await {
326 Ok(0) | Err(_) => break,
327 Ok(n) => buf.extend_from_slice(&tmp[..n]),
328 }
329 }
330 buf
331}
332
333async fn read_chunked<R: tokio::io::AsyncRead + Unpin>(
334 reader: &mut R,
335 initial: Vec<u8>,
336) -> Vec<u8> {
337 let mut raw = initial;
338 let mut body = Vec::new();
339 let mut tmp = [0u8; 8192];
340
341 'outer: loop {
342 let mut pos = 0;
343 loop {
344 let slice = &raw[pos..];
345 let Some(crlf) = slice.windows(2).position(|w| w == b"\r\n") else {
346 break;
347 };
348 let size_str = std::str::from_utf8(&slice[..crlf])
349 .unwrap_or("0")
350 .split(';')
351 .next()
352 .unwrap_or("0")
353 .trim();
354 let chunk_size = usize::from_str_radix(size_str, 16).unwrap_or(0);
355 if chunk_size == 0 {
356 let after_size_line = pos + crlf + 2;
357 if after_size_line + 2 <= raw.len() {
358 break 'outer;
359 }
360 break;
361 }
362 let data_start = pos + crlf + 2;
363 let data_end = data_start + chunk_size;
364 if data_end + 2 > raw.len() {
365 break;
366 }
367 body.extend_from_slice(&raw[data_start..data_end]);
368 pos = data_end + 2;
369 if body.len() >= BODY_CAP {
370 break 'outer;
371 }
372 }
373 raw.drain(..pos);
374 match reader.read(&mut tmp).await {
375 Ok(0) | Err(_) => break,
376 Ok(n) => raw.extend_from_slice(&tmp[..n]),
377 }
378 }
379 body
380}
381
382fn rebuild_resp_headers(headers: &[u8], body_len: usize) -> Vec<u8> {
383 let text = String::from_utf8_lossy(headers);
384 let mut out = Vec::new();
385 for (i, line) in text.split("\r\n").enumerate() {
386 if line.is_empty() {
387 continue;
388 }
389 let lower = line.to_ascii_lowercase();
390 if lower.starts_with("transfer-encoding:") || lower.starts_with("content-length:") {
391 continue;
392 }
393 out.extend_from_slice(line.as_bytes());
394 out.extend_from_slice(b"\r\n");
395 if i == 0 {
396 out.extend_from_slice(format!("Content-Length: {body_len}\r\n").as_bytes());
397 }
398 }
399 out.extend_from_slice(b"\r\n");
400 out
401}
402
403pub async fn replay(id: u64) -> Result<()> {
404 let req = store::find(id).ok_or_else(|| anyhow::anyhow!("request #{id} not found"))?;
405
406 tlog::info(&format!("replaying #{id}: {} {}", req.method, req.path));
407
408 let mut local = TcpStream::connect(format!("127.0.0.1:{}", req.port))
409 .await
410 .with_context(|| format!("failed to connect to localhost:{}", req.port))?;
411
412 local.write_all(req.raw_headers.as_bytes()).await?;
413 let body = B64.decode(&req.body_b64).unwrap_or_default();
414 local.write_all(&body).await?;
415 local.flush().await?;
416
417 let resp_head = proxy::read_http_head(&mut local).await.unwrap_or_default();
418 let resp_head_end = proxy::headers_end(&resp_head).unwrap_or(resp_head.len());
419 let resp_headers = &resp_head[..resp_head_end];
420 let resp_already = resp_head[resp_head_end..].to_vec();
421
422 let resp_cl = proxy::parse_content_length(resp_headers);
423 let resp_body = if resp_cl > 0 {
424 read_body_exact(&mut local, resp_already, resp_cl).await
425 } else if proxy::is_chunked(resp_headers) {
426 read_chunked(&mut local, resp_already).await
427 } else {
428 resp_already
429 };
430
431 let body_str = String::from_utf8_lossy(&resp_body);
432 let status = proxy::parse_response_status(&resp_head);
433 tlog::success(&format!("replayed #{id} → {status}"));
434 if !body_str.trim().is_empty() {
435 eprintln!("{body_str}");
436 }
437 Ok(())
438}