bonsaidb_server/server/
tcp.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use rustls::server::ResolvesServerCert;
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::net::TcpListener;
7
8use crate::{Backend, CustomServer, Error};
9
10impl<B: Backend> CustomServer<B> {
11    /// Listens for HTTP traffic on `port`. This port will also receive
12    /// `WebSocket` connections if feature `websockets` is enabled.
13    pub async fn listen_for_tcp_on<S: TcpService, T: tokio::net::ToSocketAddrs + Send + Sync>(
14        &self,
15        addr: T,
16        service: S,
17    ) -> Result<(), Error> {
18        let listener = TcpListener::bind(&addr).await?;
19        let mut shutdown_watcher = self
20            .data
21            .shutdown
22            .watcher()
23            .await
24            .expect("server already shutdown");
25
26        loop {
27            tokio::select! {
28                _ = shutdown_watcher.wait_for_shutdown() => {
29                    break;
30                }
31                incoming = listener.accept() => {
32                    if incoming.is_err() {
33                        continue;
34                    }
35                    let (connection, remote_addr) = incoming.unwrap();
36
37                    let peer = Peer {
38                        address: remote_addr,
39                        protocol: service.available_protocols()[0].clone(),
40                        secure: false,
41                    };
42
43                    let task_self = self.clone();
44                    let task_service = service.clone();
45                    tokio::spawn(async move {
46                        if let Err(err) = task_self.handle_tcp_connection(connection, peer, &task_service).await {
47                            log::error!("[server] closing connection {}: {:?}", remote_addr, err);
48                        }
49                    });
50                }
51            }
52        }
53
54        Ok(())
55    }
56
57    /// Listens for HTTPS traffic on `port`. This port will also receive
58    /// `WebSocket` connections if feature `websockets` is enabled. If feature
59    /// `acme` is enabled, this connection will automatically manage the
60    /// server's private key and certificate, which is also used for the
61    /// QUIC-based protocol.
62    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
63    #[cfg_attr(not(feature = "acme"), allow(unused_mut))]
64    pub async fn listen_for_secure_tcp_on<
65        S: TcpService,
66        T: tokio::net::ToSocketAddrs + Send + Sync,
67    >(
68        &self,
69        addr: T,
70        service: S,
71    ) -> Result<(), Error> {
72        // We may not have a certificate yet, so we ignore any errors.
73        drop(self.refresh_certified_key().await);
74
75        #[cfg(feature = "acme")]
76        {
77            let task_self = self.clone();
78            tokio::task::spawn(async move {
79                if let Err(err) = task_self.update_acme_certificates().await {
80                    log::error!("[server] acme task error: {0}", err);
81                }
82            });
83        }
84
85        let mut config = rustls::ServerConfig::builder()
86            .with_safe_defaults()
87            .with_no_client_auth()
88            .with_cert_resolver(Arc::new(self.clone()));
89        config.alpn_protocols = service
90            .available_protocols()
91            .iter()
92            .map(|proto| proto.alpn_name().to_vec())
93            .collect();
94
95        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
96        let listener = TcpListener::bind(&addr).await?;
97        loop {
98            let (stream, peer_addr) = listener.accept().await?;
99            let acceptor = acceptor.clone();
100
101            let task_self = self.clone();
102            let task_service = service.clone();
103            tokio::task::spawn(async move {
104                let stream = match acceptor.accept(stream).await {
105                    Ok(stream) => stream,
106                    Err(err) => {
107                        log::error!("[server] error during tls handshake: {:?}", err);
108                        return;
109                    }
110                };
111
112                let available_protocols = task_service.available_protocols();
113                let protocol = stream
114                    .get_ref()
115                    .1
116                    .alpn_protocol()
117                    .and_then(|protocol| {
118                        available_protocols
119                            .iter()
120                            .find(|p| p.alpn_name() == protocol)
121                            .cloned()
122                    })
123                    .unwrap_or_else(|| available_protocols[0].clone());
124                let peer = Peer {
125                    address: peer_addr,
126                    secure: true,
127                    protocol,
128                };
129                if let Err(err) = task_self
130                    .handle_tcp_connection(stream, peer, &task_service)
131                    .await
132                {
133                    log::error!("[server] error for client {}: {:?}", peer_addr, err);
134                }
135            });
136        }
137    }
138
139    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
140    async fn handle_tcp_connection<
141        S: TcpService,
142        C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
143    >(
144        &self,
145        connection: C,
146        peer: Peer<S::ApplicationProtocols>,
147        service: &S,
148    ) -> Result<(), Error> {
149        // For ACME, don't send any traffic over the connection.
150        #[cfg(feature = "acme")]
151        if peer.protocol.alpn_name() == async_acme::acme::ACME_TLS_ALPN_NAME {
152            log::info!("received acme challenge connection");
153            return Ok(());
154        }
155
156        if let Err(connection) = service.handle_connection(connection, &peer).await {
157            #[cfg(feature = "websockets")]
158            if let Err(err) = self
159                .handle_raw_websocket_connection(connection, peer.address)
160                .await
161            {
162                log::error!(
163                    "[server] error on websocket for {}: {:?}",
164                    peer.address,
165                    err
166                );
167            }
168        }
169
170        Ok(())
171    }
172}
173
174impl<B: Backend> ResolvesServerCert for CustomServer<B> {
175    #[cfg_attr(not(feature = "acme"), allow(unused_variables))]
176    fn resolve(
177        &self,
178        client_hello: rustls::server::ClientHello<'_>,
179    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
180        #[cfg(feature = "acme")]
181        if client_hello
182            .alpn()
183            .map(|mut iter| iter.any(|n| n == async_acme::acme::ACME_TLS_ALPN_NAME))
184            .unwrap_or_default()
185        {
186            let server_name = client_hello.server_name()?.to_owned();
187            let keys = self.data.alpn_keys.lock();
188            if let Some(key) = keys.get(AsRef::<str>::as_ref(&server_name)) {
189                log::info!("returning acme challenge");
190                return Some(key.clone());
191            }
192
193            log::error!(
194                "acme alpn challenge received with no key for {}",
195                server_name
196            );
197            return None;
198        }
199
200        let cached_key = self.data.primary_tls_key.lock();
201        if let Some(key) = cached_key.as_ref() {
202            Some(key.clone())
203        } else {
204            log::error!("[server] inbound tls connection with no certificate installed");
205            None
206        }
207    }
208}
209
210/// A service that can handle incoming TCP connections.
211#[async_trait]
212pub trait TcpService: Clone + Send + Sync + 'static {
213    /// The application layer protocols that this service supports.
214    type ApplicationProtocols: ApplicationProtocols;
215
216    /// Returns all available protocols for this service. The first will be the
217    /// default used if a connection is made without negotiating the application
218    /// protocol.
219    fn available_protocols(&self) -> &[Self::ApplicationProtocols];
220
221    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
222    /// have BonsaiDb handle the connection internally.
223    async fn handle_connection<
224        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
225    >(
226        &self,
227        connection: S,
228        peer: &Peer<Self::ApplicationProtocols>,
229    ) -> Result<(), S>;
230}
231
232/// A service that can handle incoming HTTP connections. A convenience
233/// implementation of [`TcpService`] that is useful is you are only serving HTTP
234/// and WebSockets over a service.
235#[async_trait]
236pub trait HttpService: Clone + Send + Sync + 'static {
237    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
238    /// have BonsaiDb handle the connection internally.
239    async fn handle_connection<
240        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
241    >(
242        &self,
243        connection: S,
244        peer: &Peer,
245    ) -> Result<(), S>;
246}
247
248#[async_trait]
249impl<T> TcpService for T
250where
251    T: HttpService,
252{
253    type ApplicationProtocols = StandardTcpProtocols;
254
255    fn available_protocols(&self) -> &[Self::ApplicationProtocols] {
256        StandardTcpProtocols::all()
257    }
258
259    async fn handle_connection<
260        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
261    >(
262        &self,
263        connection: S,
264        peer: &Peer<Self::ApplicationProtocols>,
265    ) -> Result<(), S> {
266        HttpService::handle_connection(self, connection, peer).await
267    }
268}
269
270#[async_trait]
271impl HttpService for () {
272    async fn handle_connection<
273        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
274    >(
275        &self,
276        connection: S,
277        _peer: &Peer<StandardTcpProtocols>,
278    ) -> Result<(), S> {
279        Err(connection)
280    }
281}
282
283/// A collection of supported protocols for a network service.
284pub trait ApplicationProtocols: Clone + std::fmt::Debug + Send + Sync {
285    /// Returns the identifier to use in ALPN during TLS negotiation.
286    fn alpn_name(&self) -> &'static [u8];
287}
288
289/// A connected network peer.
290#[derive(Debug, Clone)]
291pub struct Peer<P: ApplicationProtocols = StandardTcpProtocols> {
292    /// The remote address of the peer.
293    pub address: std::net::SocketAddr,
294    /// If true, the connection is secured with TLS.
295    pub secure: bool,
296    /// The application protocol to use for this connection.
297    pub protocol: P,
298}
299
300/// TCP [`ApplicationProtocols`] that BonsaiDb has some knowledge of.
301#[derive(Debug, Clone)]
302#[allow(missing_docs)]
303pub enum StandardTcpProtocols {
304    Http1,
305    #[cfg(feature = "acme")]
306    Acme,
307    Other,
308}
309
310impl StandardTcpProtocols {
311    #[cfg(feature = "acme")]
312    const fn all() -> &'static [Self] {
313        &[Self::Http1, Self::Acme]
314    }
315
316    #[cfg(not(feature = "acme"))]
317    const fn all() -> &'static [Self] {
318        &[Self::Http1]
319    }
320}
321
322impl Default for StandardTcpProtocols {
323    fn default() -> Self {
324        Self::Http1
325    }
326}
327
328impl ApplicationProtocols for StandardTcpProtocols {
329    fn alpn_name(&self) -> &'static [u8] {
330        match self {
331            Self::Http1 => b"http/1.1",
332            #[cfg(feature = "acme")]
333            Self::Acme => async_acme::acme::ACME_TLS_ALPN_NAME,
334            Self::Other => unreachable!(),
335        }
336    }
337}