1use std::net::{IpAddr, Ipv4Addr};
4use std::{io, ops::RangeInclusive, sync::Arc, time::Duration};
5
6use anyhow::Result;
7use dashmap::DashMap;
8use tokio::io::AsyncWriteExt;
9use tokio::net::{TcpListener, TcpStream};
10use tokio::time::{sleep, timeout};
11use tracing::{info, info_span, warn, Instrument};
12use uuid::Uuid;
13
14use crate::auth::Authenticator;
15use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
16
17pub struct Server {
19 port_range: RangeInclusive<u16>,
21
22 auth: Option<Authenticator>,
24
25 conns: Arc<DashMap<Uuid, TcpStream>>,
27
28 bind_addr: IpAddr,
30
31 bind_tunnels: IpAddr,
33}
34
35impl Server {
36 pub fn new(port_range: RangeInclusive<u16>, secret: Option<&str>) -> Self {
38 assert!(!port_range.is_empty(), "must provide at least one port");
39 Server {
40 port_range,
41 conns: Arc::new(DashMap::new()),
42 auth: secret.map(Authenticator::new),
43 bind_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
44 bind_tunnels: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
45 }
46 }
47
48 pub fn set_bind_addr(&mut self, bind_addr: IpAddr) {
50 self.bind_addr = bind_addr;
51 }
52
53 pub fn set_bind_tunnels(&mut self, bind_tunnels: IpAddr) {
55 self.bind_tunnels = bind_tunnels;
56 }
57
58 pub async fn listen(self) -> Result<()> {
60 let this = Arc::new(self);
61 let listener = TcpListener::bind((this.bind_addr, CONTROL_PORT)).await?;
62 info!(addr = ?this.bind_addr, "server listening");
63
64 loop {
65 let (stream, addr) = listener.accept().await?;
66 let this = Arc::clone(&this);
67 tokio::spawn(
68 async move {
69 info!("incoming connection");
70 if let Err(err) = this.handle_connection(stream).await {
71 warn!(%err, "connection exited with error");
72 } else {
73 info!("connection exited");
74 }
75 }
76 .instrument(info_span!("control", ?addr)),
77 );
78 }
79 }
80
81 async fn create_listener(&self, port: u16) -> Result<TcpListener, &'static str> {
82 let try_bind = |port: u16| async move {
83 TcpListener::bind((self.bind_tunnels, port))
84 .await
85 .map_err(|err| match err.kind() {
86 io::ErrorKind::AddrInUse => "port already in use",
87 io::ErrorKind::PermissionDenied => "permission denied",
88 _ => "failed to bind to port",
89 })
90 };
91 if port > 0 {
92 if !self.port_range.contains(&port) {
94 return Err("client port number not in allowed range");
95 }
96 try_bind(port).await
97 } else {
98 for _ in 0..150 {
108 let port = fastrand::u16(self.port_range.clone());
109 match try_bind(port).await {
110 Ok(listener) => return Ok(listener),
111 Err(_) => continue,
112 }
113 }
114 Err("failed to find an available port")
115 }
116 }
117
118 async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
119 let mut stream = Delimited::new(stream);
120 if let Some(auth) = &self.auth {
121 if let Err(err) = auth.server_handshake(&mut stream).await {
122 warn!(%err, "server handshake failed");
123 stream.send(ServerMessage::Error(err.to_string())).await?;
124 return Ok(());
125 }
126 }
127
128 match stream.recv_timeout().await? {
129 Some(ClientMessage::Authenticate(_)) => {
130 warn!("unexpected authenticate");
131 Ok(())
132 }
133 Some(ClientMessage::Hello(port)) => {
134 let listener = match self.create_listener(port).await {
135 Ok(listener) => listener,
136 Err(err) => {
137 stream.send(ServerMessage::Error(err.into())).await?;
138 return Ok(());
139 }
140 };
141 let host = listener.local_addr()?.ip();
142 let port = listener.local_addr()?.port();
143 info!(?host, ?port, "new client");
144 stream.send(ServerMessage::Hello(port)).await?;
145
146 loop {
147 if stream.send(ServerMessage::Heartbeat).await.is_err() {
148 return Ok(());
150 }
151 const TIMEOUT: Duration = Duration::from_millis(500);
152 if let Ok(result) = timeout(TIMEOUT, listener.accept()).await {
153 let (stream2, addr) = result?;
154 info!(?addr, ?port, "new connection");
155
156 let id = Uuid::new_v4();
157 let conns = Arc::clone(&self.conns);
158
159 conns.insert(id, stream2);
160 tokio::spawn(async move {
161 sleep(Duration::from_secs(10)).await;
163 if conns.remove(&id).is_some() {
164 warn!(%id, "removed stale connection");
165 }
166 });
167 stream.send(ServerMessage::Connection(id)).await?;
168 }
169 }
170 }
171 Some(ClientMessage::Accept(id)) => {
172 info!(%id, "forwarding connection");
173 match self.conns.remove(&id) {
174 Some((_, mut stream2)) => {
175 let mut parts = stream.into_parts();
176 debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
177 stream2.write_all(&parts.read_buf).await?;
178 tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?;
179 }
180 None => warn!(%id, "missing connection"),
181 }
182 Ok(())
183 }
184 None => Ok(()),
185 }
186 }
187}