borev6_cli/
server.rs

1//! Server implementation for the `bore` service.
2
3use std::{io, net::SocketAddrV6, net::Ipv6Addr, ops::RangeInclusive, sync::Arc, time::Duration};
4
5use anyhow::Result;
6use dashmap::DashMap;
7use tokio::io::AsyncWriteExt;
8use tokio::net::{TcpListener, TcpStream};
9use tokio::time::{sleep, timeout};
10use tracing::{info, info_span, warn, Instrument};
11use uuid::Uuid;
12
13use crate::auth::Authenticator;
14use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
15
16/// State structure for the server.
17pub struct Server {
18    /// Range of TCP ports that can be forwarded.
19    port_range: RangeInclusive<u16>,
20
21    /// Optional secret used to authenticate clients.
22    auth: Option<Authenticator>,
23
24    /// Concurrent map of IDs to incoming connections.
25    conns: Arc<DashMap<Uuid, TcpStream>>,
26}
27
28impl Server {
29    /// Create a new server with a specified minimum port number.
30    pub fn new(port_range: RangeInclusive<u16>, secret: Option<&str>) -> Self {
31        assert!(!port_range.is_empty(), "must provide at least one port");
32        Server {
33            port_range,
34            conns: Arc::new(DashMap::new()),
35            auth: secret.map(Authenticator::new),
36        }
37    }
38
39    /// Start the server, listening for new connections.
40    pub async fn listen(self) -> Result<()> {
41        let this = Arc::new(self);
42        let addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, CONTROL_PORT, 0, 0);
43        let listener = TcpListener::bind(&addr).await?;
44        info!(?addr, "server listening");
45
46        loop {
47            let (stream, addr) = listener.accept().await?;
48            let this = Arc::clone(&this);
49            tokio::spawn(
50                async move {
51                    info!("incoming connection");
52                    if let Err(err) = this.handle_connection(stream).await {
53                        warn!(%err, "connection exited with error");
54                    } else {
55                        info!("connection exited");
56                    }
57                }
58                .instrument(info_span!("control", ?addr)),
59            );
60        }
61    }
62
63    async fn create_listener(&self, port: u16) -> Result<TcpListener, &'static str> {
64        let try_bind = |port: u16| async move {
65            let addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);
66            info!(?addr, "bind address");
67            TcpListener::bind(&addr).await
68                .map_err(|err| match err.kind() {
69                    io::ErrorKind::AddrInUse => "port already in use",
70                    io::ErrorKind::PermissionDenied => "permission denied",
71                    _ => "failed to bind to port",
72                })
73        };
74        if port > 0 {
75            // Client requests a specific port number.
76            if !self.port_range.contains(&port) {
77                return Err("client port number not in allowed range");
78            }
79            try_bind(port).await
80        } else {
81            // Client requests any available port in range.
82            //
83            // In this case, we bind to 150 random port numbers. We choose this value because in
84            // order to find a free port with probability at least 1-δ, when ε proportion of the
85            // ports are currently available, it suffices to check approximately -2 ln(δ) / ε
86            // independently and uniformly chosen ports (up to a second-order term in ε).
87            //
88            // Checking 150 times gives us 99.999% success at utilizing 85% of ports under these
89            // conditions, when ε=0.15 and δ=0.00001.
90            for _ in 0..150 {
91                let port = fastrand::u16(self.port_range.clone());
92                match try_bind(port).await {
93                    Ok(listener) => return Ok(listener),
94                    Err(_) => continue,
95                }
96            }
97            Err("failed to find an available port")
98        }
99    }
100
101    async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
102        let mut stream = Delimited::new(stream);
103        if let Some(auth) = &self.auth {
104            if let Err(err) = auth.server_handshake(&mut stream).await {
105                warn!(%err, "server handshake failed");
106                stream.send(ServerMessage::Error(err.to_string())).await?;
107                return Ok(());
108            }
109        }
110
111        match stream.recv_timeout().await? {
112            Some(ClientMessage::Authenticate(_)) => {
113                warn!("unexpected authenticate");
114                Ok(())
115            }
116            Some(ClientMessage::Hello(port)) => {
117                let listener = match self.create_listener(port).await {
118                    Ok(listener) => listener,
119                    Err(err) => {
120                        stream.send(ServerMessage::Error(err.into())).await?;
121                        return Ok(());
122                    }
123                };
124                let port = listener.local_addr()?.port();
125                info!(?port, "new client");
126                stream.send(ServerMessage::Hello(port)).await?;
127
128                loop {
129                    if stream.send(ServerMessage::Heartbeat).await.is_err() {
130                        // Assume that the TCP connection has been dropped.
131                        return Ok(());
132                    }
133                    const TIMEOUT: Duration = Duration::from_millis(500);
134                    if let Ok(result) = timeout(TIMEOUT, listener.accept()).await {
135                        let (stream2, addr) = result?;
136                        info!(?addr, ?port, "new connection");
137
138                        let id = Uuid::new_v4();
139                        let conns = Arc::clone(&self.conns);
140
141                        conns.insert(id, stream2);
142                        tokio::spawn(async move {
143                            // Remove stale entries to avoid memory leaks.
144                            sleep(Duration::from_secs(10)).await;
145                            if conns.remove(&id).is_some() {
146                                warn!(%id, "removed stale connection");
147                            }
148                        });
149                        stream.send(ServerMessage::Connection(id)).await?;
150                    }
151                }
152            }
153            Some(ClientMessage::Accept(id)) => {
154                info!(%id, "forwarding connection");
155                match self.conns.remove(&id) {
156                    Some((_, mut stream2)) => {
157                        let parts = stream.into_parts();
158                        debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
159                        stream2.write_all(&parts.read_buf).await?;
160                        proxy(parts.io, stream2).await?
161                    }
162                    None => warn!(%id, "missing connection"),
163                }
164                Ok(())
165            }
166            None => Ok(()),
167        }
168    }
169}