1use std::collections::VecDeque;
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use dashmap::DashMap;
8use futures::future::poll_fn;
9use futures::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpListener;
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
13use yamux::{Config, Connection, Mode};
14
15use crate::log as tlog;
16use crate::protocol::{self, ControlMsg};
17use crate::proxy;
18
19const MAX_TUNNELS_TOTAL: usize = 200;
20const MAX_TUNNELS_PER_IP: usize = 5;
21const MAX_CONNECTS_PER_MINUTE: usize = 15;
22const RATE_WINDOW: Duration = Duration::from_secs(60);
23
24type OpenStreamReply = oneshot::Sender<Result<yamux::Stream>>;
25
26struct ClientHandle {
27 stream_tx: mpsc::Sender<OpenStreamReply>,
28}
29
30type Registry = Arc<DashMap<String, ClientHandle>>;
31
32struct IpState {
33 recent_connects: VecDeque<Instant>,
34 active_tunnels: usize,
35}
36
37type IpTracker = Arc<DashMap<IpAddr, IpState>>;
38
39pub async fn run(
40 control_port: u16,
41 http_port: u16,
42 domain: &str,
43 token: Option<&str>,
44) -> Result<()> {
45 let registry: Registry = Arc::new(DashMap::new());
46 let ip_tracker: IpTracker = Arc::new(DashMap::new());
47 let domain = domain.to_string();
48 let token = token.map(|t| t.to_string());
49
50 let control_listener = TcpListener::bind(format!("0.0.0.0:{control_port}"))
51 .await
52 .context(format!("failed to bind control port {control_port}"))?;
53 let http_listener = TcpListener::bind(format!("0.0.0.0:{http_port}"))
54 .await
55 .context(format!("failed to bind http port {http_port}"))?;
56
57 tlog::info(&format!("control listening on 0.0.0.0:{control_port}"));
58 tlog::info(&format!("http listening on 0.0.0.0:{http_port}"));
59 tlog::info(&format!("domain: *.{domain}"));
60 if token.is_none() {
61 tlog::info("token auth disabled — open server");
62 }
63
64 let reg = registry.clone();
65 let ipt = ip_tracker.clone();
66 let tok = token.clone();
67 let dom = domain.clone();
68 let control_task = tokio::spawn(async move {
69 loop {
70 match control_listener.accept().await {
71 Ok((socket, addr)) => {
72 let ip = addr.ip();
73
74 if !check_rate_limit(&ipt, ip) {
75 tlog::info(&format!(
76 "rate limited {ip} (>{MAX_CONNECTS_PER_MINUTE}/min)"
77 ));
78 drop(socket);
79 continue;
80 }
81
82 tlog::info(&format!("client connected from {addr}"));
83 let reg = reg.clone();
84 let ipt = ipt.clone();
85 let tok = tok.clone();
86 let dom = dom.clone();
87 tokio::spawn(async move {
88 if let Err(e) =
89 handle_client(socket, reg, ipt, ip, tok.as_deref(), &dom).await
90 {
91 tlog::error(&format!("client {addr}: {e}"));
92 }
93 });
94 }
95 Err(e) => tlog::error(&format!("accept error: {e}")),
96 }
97 }
98 });
99
100 let reg = registry.clone();
101 let dom = domain.clone();
102 let http_task = tokio::spawn(async move {
103 loop {
104 match http_listener.accept().await {
105 Ok((socket, _)) => {
106 let reg = reg.clone();
107 let dom = dom.clone();
108 tokio::spawn(async move {
109 if let Err(e) = handle_http(socket, reg, &dom).await {
110 tlog::error(&format!("http: {e}"));
111 }
112 });
113 }
114 Err(e) => tlog::error(&format!("http accept error: {e}")),
115 }
116 }
117 });
118
119 tokio::select! {
120 r = control_task => r?,
121 r = http_task => r?,
122 }
123
124 Ok(())
125}
126
127fn check_rate_limit(ip_tracker: &IpTracker, ip: IpAddr) -> bool {
130 let now = Instant::now();
131 let cutoff = now - RATE_WINDOW;
132 let mut entry = ip_tracker.entry(ip).or_insert_with(|| IpState {
133 recent_connects: VecDeque::new(),
134 active_tunnels: 0,
135 });
136 while entry.recent_connects.front().is_some_and(|t| *t < cutoff) {
137 entry.recent_connects.pop_front();
138 }
139 if entry.recent_connects.len() >= MAX_CONNECTS_PER_MINUTE {
140 return false;
141 }
142 entry.recent_connects.push_back(now);
143 true
144}
145
146async fn handle_client(
147 socket: tokio::net::TcpStream,
148 registry: Registry,
149 ip_tracker: IpTracker,
150 peer_ip: IpAddr,
151 expected_token: Option<&str>,
152 domain: &str,
153) -> Result<()> {
154 let mut config = Config::default();
155 config.set_split_send_size(16 * 1024);
156
157 let mut connection = Connection::new(socket.compat(), config, Mode::Server);
158
159 let mut control_stream = poll_fn(|cx| connection.poll_next_inbound(cx))
160 .await
161 .context("no control stream")??;
162
163 let (stream_tx, mut stream_rx) = mpsc::channel::<OpenStreamReply>(32);
164 let conn_task = tokio::spawn(async move {
165 loop {
166 tokio::select! {
167 biased;
168 Some(reply_tx) = stream_rx.recv() => {
169 let result = poll_fn(|cx| connection.poll_new_outbound(cx)).await;
170 let _ = reply_tx.send(result.map_err(|e| anyhow::anyhow!("{e}")));
171 }
172 inbound = poll_fn(|cx| connection.poll_next_inbound(cx)) => {
173 match inbound {
174 Some(Ok(_)) => {}
175 Some(Err(e)) => {
176 tlog::error(&format!("yamux error: {e}"));
177 break;
178 }
179 None => break,
180 }
181 }
182 }
183 }
184 });
185
186 let msg = protocol::read_msg(&mut control_stream).await?;
187 let (requested_subdomain, nonce, provided_hmac) = match msg {
188 ControlMsg::Auth {
189 subdomain,
190 nonce,
191 hmac,
192 } => (subdomain, nonce, hmac),
193 _ => anyhow::bail!("expected Auth message"),
194 };
195
196 if let Some(secret) = expected_token {
197 let expected_hmac = protocol::compute_hmac(secret, &nonce);
198 if provided_hmac.as_deref() != Some(expected_hmac.as_str()) {
199 let encoded = ControlMsg::Error {
200 message: "invalid secret".into(),
201 }
202 .encode()?;
203 control_stream.write_all(&encoded).await.ok();
204 control_stream.close().await.ok();
205 anyhow::bail!("invalid secret from {peer_ip}");
206 }
207 }
208
209 let ip_active = ip_tracker
210 .get(&peer_ip)
211 .map(|s| s.active_tunnels)
212 .unwrap_or(0);
213 if ip_active >= MAX_TUNNELS_PER_IP {
214 let encoded = ControlMsg::Error {
215 message: format!("max {MAX_TUNNELS_PER_IP} tunnels per IP"),
216 }
217 .encode()?;
218 control_stream.write_all(&encoded).await.ok();
219 control_stream.close().await.ok();
220 anyhow::bail!("{peer_ip} hit per-IP tunnel limit");
221 }
222
223 if registry.len() >= MAX_TUNNELS_TOTAL {
224 let encoded = ControlMsg::Error {
225 message: "server at capacity, try again later".into(),
226 }
227 .encode()?;
228 control_stream.write_all(&encoded).await.ok();
229 control_stream.close().await.ok();
230 anyhow::bail!("global tunnel limit reached");
231 }
232
233 let subdomain = requested_subdomain
234 .filter(|s| {
235 !s.is_empty() && s.len() <= 63 && s.chars().all(|c| c.is_alphanumeric() || c == '-')
236 })
237 .unwrap_or_else(|| {
238 use rand::Rng;
239 let mut rng = rand::rng();
240 format!("{:08x}", rng.random::<u32>())
241 });
242
243 if registry.contains_key(&subdomain) {
244 let encoded = ControlMsg::Error {
245 message: format!("subdomain '{subdomain}' already in use"),
246 }
247 .encode()?;
248 control_stream.write_all(&encoded).await.ok();
249 control_stream.close().await.ok();
250 anyhow::bail!("subdomain collision: {subdomain}");
251 }
252
253 let full_domain = format!("{subdomain}.{domain}");
254 let ok = ControlMsg::AuthOk {
255 subdomain: subdomain.clone(),
256 url: full_domain.clone(),
257 };
258 control_stream.write_all(&ok.encode()?).await?;
259 control_stream.flush().await?;
260
261 registry.insert(subdomain.clone(), ClientHandle { stream_tx });
262 ip_tracker
263 .entry(peer_ip)
264 .and_modify(|s| s.active_tunnels += 1);
265
266 tlog::success(&format!(
267 "tunnel live: {full_domain} (ip={peer_ip}, active={}/{MAX_TUNNELS_PER_IP})",
268 ip_active + 1
269 ));
270
271 let mut buf = [0u8; 1024];
272 loop {
273 match control_stream.read(&mut buf).await {
274 Ok(0) | Err(_) => break,
275 Ok(_) => {}
276 }
277 }
278
279 registry.remove(&subdomain);
280 ip_tracker.entry(peer_ip).and_modify(|s| {
281 if s.active_tunnels > 0 {
282 s.active_tunnels -= 1;
283 }
284 });
285 conn_task.abort();
286
287 tlog::info(&format!("client disconnected, removed {full_domain}"));
288
289 Ok(())
290}
291
292const INSTALL_SH: &str = include_str!("../install.sh");
293
294async fn handle_http(
295 mut socket: tokio::net::TcpStream,
296 registry: Registry,
297 domain: &str,
298) -> Result<()> {
299 let head = proxy::read_http_head(&mut socket).await?;
300 let host = proxy::extract_host(&head).context("no Host header")?;
301
302 if host == domain {
303 let (_, path) = proxy::parse_request_line(&head);
304 let response: Vec<u8> = if path == "/install.sh" {
305 format!(
306 "HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
307 INSTALL_SH.len(),
308 INSTALL_SH
309 )
310 .into_bytes()
311 } else {
312 b"HTTP/1.1 301 Moved Permanently\r\nLocation: https://github.com/jbingen/tnnl\r\nContent-Length: 0\r\nConnection: close\r\n\r\n".to_vec()
313 };
314 tokio::io::AsyncWriteExt::write_all(&mut socket, &response)
315 .await
316 .ok();
317 return Ok(());
318 }
319
320 let subdomain = host
321 .strip_suffix(&format!(".{domain}"))
322 .context(format!("host '{host}' not a subdomain of {domain}"))?
323 .to_string();
324
325 let stream_tx = match registry.get(&subdomain) {
326 Some(entry) => entry.stream_tx.clone(),
327 None => {
328 proxy::write_404(&mut socket).await.ok();
329 return Ok(());
330 }
331 };
332
333 let (reply_tx, reply_rx) = oneshot::channel();
334 stream_tx
335 .send(reply_tx)
336 .await
337 .map_err(|_| anyhow::anyhow!("client disconnected"))?;
338
339 let tunnel_stream = reply_rx
340 .await
341 .map_err(|_| anyhow::anyhow!("client disconnected"))??;
342
343 let mut tunnel_compat = tunnel_stream.compat();
344
345 tokio::io::AsyncWriteExt::write_all(&mut tunnel_compat, &head).await?;
346
347 tokio::io::copy_bidirectional(&mut socket, &mut tunnel_compat)
348 .await
349 .ok();
350
351 Ok(())
352}