1use std::{io, net::SocketAddrV6, net::Ipv6Addr, ops::RangeInclusive, sync::Arc, time::Duration};
4
5use anyhow::Result;
6use dashmap::DashMap;
7use tokio::io::AsyncWriteExt;
8use tokio::net::{TcpListener, TcpStream};
9use tokio::time::{sleep, timeout};
10use tracing::{info, info_span, warn, Instrument};
11use uuid::Uuid;
12
13use crate::auth::Authenticator;
14use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
15
16pub struct Server {
18 port_range: RangeInclusive<u16>,
20
21 auth: Option<Authenticator>,
23
24 conns: Arc<DashMap<Uuid, TcpStream>>,
26}
27
28impl Server {
29 pub fn new(port_range: RangeInclusive<u16>, secret: Option<&str>) -> Self {
31 assert!(!port_range.is_empty(), "must provide at least one port");
32 Server {
33 port_range,
34 conns: Arc::new(DashMap::new()),
35 auth: secret.map(Authenticator::new),
36 }
37 }
38
39 pub async fn listen(self) -> Result<()> {
41 let this = Arc::new(self);
42 let addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, CONTROL_PORT, 0, 0);
43 let listener = TcpListener::bind(&addr).await?;
44 info!(?addr, "server listening");
45
46 loop {
47 let (stream, addr) = listener.accept().await?;
48 let this = Arc::clone(&this);
49 tokio::spawn(
50 async move {
51 info!("incoming connection");
52 if let Err(err) = this.handle_connection(stream).await {
53 warn!(%err, "connection exited with error");
54 } else {
55 info!("connection exited");
56 }
57 }
58 .instrument(info_span!("control", ?addr)),
59 );
60 }
61 }
62
63 async fn create_listener(&self, port: u16) -> Result<TcpListener, &'static str> {
64 let try_bind = |port: u16| async move {
65 let addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);
66 info!(?addr, "bind address");
67 TcpListener::bind(&addr).await
68 .map_err(|err| match err.kind() {
69 io::ErrorKind::AddrInUse => "port already in use",
70 io::ErrorKind::PermissionDenied => "permission denied",
71 _ => "failed to bind to port",
72 })
73 };
74 if port > 0 {
75 if !self.port_range.contains(&port) {
77 return Err("client port number not in allowed range");
78 }
79 try_bind(port).await
80 } else {
81 for _ in 0..150 {
91 let port = fastrand::u16(self.port_range.clone());
92 match try_bind(port).await {
93 Ok(listener) => return Ok(listener),
94 Err(_) => continue,
95 }
96 }
97 Err("failed to find an available port")
98 }
99 }
100
101 async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
102 let mut stream = Delimited::new(stream);
103 if let Some(auth) = &self.auth {
104 if let Err(err) = auth.server_handshake(&mut stream).await {
105 warn!(%err, "server handshake failed");
106 stream.send(ServerMessage::Error(err.to_string())).await?;
107 return Ok(());
108 }
109 }
110
111 match stream.recv_timeout().await? {
112 Some(ClientMessage::Authenticate(_)) => {
113 warn!("unexpected authenticate");
114 Ok(())
115 }
116 Some(ClientMessage::Hello(port)) => {
117 let listener = match self.create_listener(port).await {
118 Ok(listener) => listener,
119 Err(err) => {
120 stream.send(ServerMessage::Error(err.into())).await?;
121 return Ok(());
122 }
123 };
124 let port = listener.local_addr()?.port();
125 info!(?port, "new client");
126 stream.send(ServerMessage::Hello(port)).await?;
127
128 loop {
129 if stream.send(ServerMessage::Heartbeat).await.is_err() {
130 return Ok(());
132 }
133 const TIMEOUT: Duration = Duration::from_millis(500);
134 if let Ok(result) = timeout(TIMEOUT, listener.accept()).await {
135 let (stream2, addr) = result?;
136 info!(?addr, ?port, "new connection");
137
138 let id = Uuid::new_v4();
139 let conns = Arc::clone(&self.conns);
140
141 conns.insert(id, stream2);
142 tokio::spawn(async move {
143 sleep(Duration::from_secs(10)).await;
145 if conns.remove(&id).is_some() {
146 warn!(%id, "removed stale connection");
147 }
148 });
149 stream.send(ServerMessage::Connection(id)).await?;
150 }
151 }
152 }
153 Some(ClientMessage::Accept(id)) => {
154 info!(%id, "forwarding connection");
155 match self.conns.remove(&id) {
156 Some((_, mut stream2)) => {
157 let parts = stream.into_parts();
158 debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
159 stream2.write_all(&parts.read_buf).await?;
160 proxy(parts.io, stream2).await?
161 }
162 None => warn!(%id, "missing connection"),
163 }
164 Ok(())
165 }
166 None => Ok(()),
167 }
168 }
169}