bore_cli/
server.rs

1//! Server implementation for the `bore` service.
2
3use std::net::{IpAddr, Ipv4Addr};
4use std::{io, ops::RangeInclusive, sync::Arc, time::Duration};
5
6use anyhow::Result;
7use dashmap::DashMap;
8use tokio::io::AsyncWriteExt;
9use tokio::net::{TcpListener, TcpStream};
10use tokio::time::{sleep, timeout};
11use tracing::{info, info_span, warn, Instrument};
12use uuid::Uuid;
13
14use crate::auth::Authenticator;
15use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
16
17/// State structure for the server.
18pub struct Server {
19    /// Range of TCP ports that can be forwarded.
20    port_range: RangeInclusive<u16>,
21
22    /// Optional secret used to authenticate clients.
23    auth: Option<Authenticator>,
24
25    /// Concurrent map of IDs to incoming connections.
26    conns: Arc<DashMap<Uuid, TcpStream>>,
27
28    /// IP address where the control server will bind to.
29    bind_addr: IpAddr,
30
31    /// IP address where tunnels will listen on.
32    bind_tunnels: IpAddr,
33}
34
35impl Server {
36    /// Create a new server with a specified minimum port number.
37    pub fn new(port_range: RangeInclusive<u16>, secret: Option<&str>) -> Self {
38        assert!(!port_range.is_empty(), "must provide at least one port");
39        Server {
40            port_range,
41            conns: Arc::new(DashMap::new()),
42            auth: secret.map(Authenticator::new),
43            bind_addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
44            bind_tunnels: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
45        }
46    }
47
48    /// Set the IP address where tunnels will listen on.
49    pub fn set_bind_addr(&mut self, bind_addr: IpAddr) {
50        self.bind_addr = bind_addr;
51    }
52
53    /// Set the IP address where the control server will bind to.
54    pub fn set_bind_tunnels(&mut self, bind_tunnels: IpAddr) {
55        self.bind_tunnels = bind_tunnels;
56    }
57
58    /// Start the server, listening for new connections.
59    pub async fn listen(self) -> Result<()> {
60        let this = Arc::new(self);
61        let listener = TcpListener::bind((this.bind_addr, CONTROL_PORT)).await?;
62        info!(addr = ?this.bind_addr, "server listening");
63
64        loop {
65            let (stream, addr) = listener.accept().await?;
66            let this = Arc::clone(&this);
67            tokio::spawn(
68                async move {
69                    info!("incoming connection");
70                    if let Err(err) = this.handle_connection(stream).await {
71                        warn!(%err, "connection exited with error");
72                    } else {
73                        info!("connection exited");
74                    }
75                }
76                .instrument(info_span!("control", ?addr)),
77            );
78        }
79    }
80
81    async fn create_listener(&self, port: u16) -> Result<TcpListener, &'static str> {
82        let try_bind = |port: u16| async move {
83            TcpListener::bind((self.bind_tunnels, port))
84                .await
85                .map_err(|err| match err.kind() {
86                    io::ErrorKind::AddrInUse => "port already in use",
87                    io::ErrorKind::PermissionDenied => "permission denied",
88                    _ => "failed to bind to port",
89                })
90        };
91        if port > 0 {
92            // Client requests a specific port number.
93            if !self.port_range.contains(&port) {
94                return Err("client port number not in allowed range");
95            }
96            try_bind(port).await
97        } else {
98            // Client requests any available port in range.
99            //
100            // In this case, we bind to 150 random port numbers. We choose this value because in
101            // order to find a free port with probability at least 1-δ, when ε proportion of the
102            // ports are currently available, it suffices to check approximately -2 ln(δ) / ε
103            // independently and uniformly chosen ports (up to a second-order term in ε).
104            //
105            // Checking 150 times gives us 99.999% success at utilizing 85% of ports under these
106            // conditions, when ε=0.15 and δ=0.00001.
107            for _ in 0..150 {
108                let port = fastrand::u16(self.port_range.clone());
109                match try_bind(port).await {
110                    Ok(listener) => return Ok(listener),
111                    Err(_) => continue,
112                }
113            }
114            Err("failed to find an available port")
115        }
116    }
117
118    async fn handle_connection(&self, stream: TcpStream) -> Result<()> {
119        let mut stream = Delimited::new(stream);
120        if let Some(auth) = &self.auth {
121            if let Err(err) = auth.server_handshake(&mut stream).await {
122                warn!(%err, "server handshake failed");
123                stream.send(ServerMessage::Error(err.to_string())).await?;
124                return Ok(());
125            }
126        }
127
128        match stream.recv_timeout().await? {
129            Some(ClientMessage::Authenticate(_)) => {
130                warn!("unexpected authenticate");
131                Ok(())
132            }
133            Some(ClientMessage::Hello(port)) => {
134                let listener = match self.create_listener(port).await {
135                    Ok(listener) => listener,
136                    Err(err) => {
137                        stream.send(ServerMessage::Error(err.into())).await?;
138                        return Ok(());
139                    }
140                };
141                let host = listener.local_addr()?.ip();
142                let port = listener.local_addr()?.port();
143                info!(?host, ?port, "new client");
144                stream.send(ServerMessage::Hello(port)).await?;
145
146                loop {
147                    if stream.send(ServerMessage::Heartbeat).await.is_err() {
148                        // Assume that the TCP connection has been dropped.
149                        return Ok(());
150                    }
151                    const TIMEOUT: Duration = Duration::from_millis(500);
152                    if let Ok(result) = timeout(TIMEOUT, listener.accept()).await {
153                        let (stream2, addr) = result?;
154                        info!(?addr, ?port, "new connection");
155
156                        let id = Uuid::new_v4();
157                        let conns = Arc::clone(&self.conns);
158
159                        conns.insert(id, stream2);
160                        tokio::spawn(async move {
161                            // Remove stale entries to avoid memory leaks.
162                            sleep(Duration::from_secs(10)).await;
163                            if conns.remove(&id).is_some() {
164                                warn!(%id, "removed stale connection");
165                            }
166                        });
167                        stream.send(ServerMessage::Connection(id)).await?;
168                    }
169                }
170            }
171            Some(ClientMessage::Accept(id)) => {
172                info!(%id, "forwarding connection");
173                match self.conns.remove(&id) {
174                    Some((_, mut stream2)) => {
175                        let mut parts = stream.into_parts();
176                        debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
177                        stream2.write_all(&parts.read_buf).await?;
178                        tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?;
179                    }
180                    None => warn!(%id, "missing connection"),
181                }
182                Ok(())
183            }
184            None => Ok(()),
185        }
186    }
187}