Skip to main content

compio_rustls/
acceptor.rs

1use std::{
2    io,
3    sync::Arc,
4};
5
6use compio_io::{
7    AsyncRead,
8    AsyncWrite,
9};
10use rustls::{
11    ServerConfig,
12    ServerConnection,
13};
14
15use crate::stream::TlsStream;
16
17/// A wrapper around a [`rustls::ServerConfig`].
18///
19/// **Note:** Clones are cheap.
20#[derive(Clone)]
21pub struct TlsAcceptor {
22    rustls_server_config: Arc<ServerConfig>,
23}
24
25impl TlsAcceptor {
26    pub fn new(rustls_server_config: Arc<ServerConfig>) -> Self {
27        Self {
28            rustls_server_config,
29        }
30    }
31
32    pub async fn accept<S>(&self, stream: S) -> io::Result<TlsStream<S, ServerConnection>>
33    where
34        S: AsyncRead + AsyncWrite,
35    {
36        let session = ServerConnection::new(self.rustls_server_config.clone())
37            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
38
39        let mut tls_stream = TlsStream::new(stream, session);
40        tls_stream.handshake().await?;
41
42        Ok(tls_stream)
43    }
44
45    pub async fn accept_with<S, F>(&self, stream: S, f: F) -> io::Result<TlsStream<S, ServerConnection>>
46    where
47        S: AsyncRead + AsyncWrite,
48        F: FnOnce(&mut ServerConnection),
49    {
50        let mut session = ServerConnection::new(self.rustls_server_config.clone())
51            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
52
53        f(&mut session);
54
55        let mut tls_stream = TlsStream::new(stream, session);
56        tls_stream.handshake().await?;
57
58        Ok(tls_stream)
59    }
60
61    /// Get a read-only reference to underlying config
62    pub fn config(&self) -> &Arc<ServerConfig> {
63        &self.rustls_server_config
64    }
65}