compio_rustls/
acceptor.rs1use std::{
2 io,
3 sync::Arc,
4};
5
6use compio_io::{
7 AsyncRead,
8 AsyncWrite,
9};
10use rustls::{
11 ServerConfig,
12 ServerConnection,
13};
14
15use crate::stream::TlsStream;
16
17#[derive(Clone)]
21pub struct TlsAcceptor {
22 rustls_server_config: Arc<ServerConfig>,
23}
24
25impl TlsAcceptor {
26 pub fn new(rustls_server_config: Arc<ServerConfig>) -> Self {
27 Self {
28 rustls_server_config,
29 }
30 }
31
32 pub async fn accept<S>(&self, stream: S) -> io::Result<TlsStream<S, ServerConnection>>
33 where
34 S: AsyncRead + AsyncWrite,
35 {
36 let session = ServerConnection::new(self.rustls_server_config.clone())
37 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
38
39 let mut tls_stream = TlsStream::new(stream, session);
40 tls_stream.handshake().await?;
41
42 Ok(tls_stream)
43 }
44
45 pub async fn accept_with<S, F>(&self, stream: S, f: F) -> io::Result<TlsStream<S, ServerConnection>>
46 where
47 S: AsyncRead + AsyncWrite,
48 F: FnOnce(&mut ServerConnection),
49 {
50 let mut session = ServerConnection::new(self.rustls_server_config.clone())
51 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
52
53 f(&mut session);
54
55 let mut tls_stream = TlsStream::new(stream, session);
56 tls_stream.handshake().await?;
57
58 Ok(tls_stream)
59 }
60
61 pub fn config(&self) -> &Arc<ServerConfig> {
63 &self.rustls_server_config
64 }
65}