compio_rustls/
connector.rs1use std::{
2 io,
3 sync::Arc,
4};
5
6use compio_io::{
7 AsyncRead,
8 AsyncWrite,
9};
10use rustls::{
11 ClientConfig,
12 ClientConnection,
13 pki_types::ServerName,
14};
15
16use crate::stream::TlsStream;
17
18#[derive(Clone)]
22pub struct TlsConnector {
23 rustls_client_config: Arc<ClientConfig>,
24}
25
26impl TlsConnector {
27 pub fn new(rustls_client_config: Arc<ClientConfig>) -> Self {
28 Self {
29 rustls_client_config,
30 }
31 }
32
33 pub async fn connect<S>(&self, domain: ServerName<'static>, stream: S) -> io::Result<TlsStream<S, ClientConnection>>
34 where
35 S: AsyncRead + AsyncWrite,
36 {
37 self.connect_impl(domain, stream, None, |_| ()).await
38 }
39
40 pub async fn connect_with<S, F>(
41 &self,
42 domain: ServerName<'static>,
43 stream: S,
44 f: F,
45 ) -> io::Result<TlsStream<S, ClientConnection>>
46 where
47 S: AsyncRead + AsyncWrite,
48 F: FnOnce(&mut ClientConnection),
49 {
50 self.connect_impl(domain, stream, None, f).await
51 }
52
53 async fn connect_impl<S, F>(
54 &self,
55 domain: ServerName<'static>,
56 stream: S,
57 alpn_protocols: Option<Vec<Vec<u8>>>,
58 f: F,
59 ) -> io::Result<TlsStream<S, ClientConnection>>
60 where
61 S: AsyncRead + AsyncWrite,
62 F: FnOnce(&mut ClientConnection),
63 {
64 let alpn = alpn_protocols.unwrap_or_else(|| self.rustls_client_config.alpn_protocols.clone());
65 let mut session = ClientConnection::new_with_alpn(self.rustls_client_config.clone(), domain, alpn)
66 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
67
68 f(&mut session);
69
70 let mut tls_stream = TlsStream::new(stream, session);
71 tls_stream.handshake().await?;
72
73 Ok(tls_stream)
74 }
75
76 pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
77 TlsConnectorWithAlpn {
78 inner: self,
79 alpn_protocols,
80 }
81 }
82
83 pub fn config(&self) -> &Arc<ClientConfig> {
85 &self.rustls_client_config
86 }
87}
88
89pub struct TlsConnectorWithAlpn<'c> {
90 inner: &'c TlsConnector,
91 alpn_protocols: Vec<Vec<u8>>,
92}
93
94impl<'c> TlsConnectorWithAlpn<'c> {
95 pub async fn connect<S>(self, domain: ServerName<'static>, stream: S) -> io::Result<TlsStream<S, ClientConnection>>
96 where
97 S: AsyncRead + AsyncWrite,
98 {
99 self.inner
100 .connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
101 .await
102 }
103
104 pub async fn connect_with<S, F>(
105 self,
106 domain: ServerName<'static>,
107 stream: S,
108 f: F,
109 ) -> io::Result<TlsStream<S, ClientConnection>>
110 where
111 S: AsyncRead + AsyncWrite,
112 F: FnOnce(&mut ClientConnection),
113 {
114 self.inner
115 .connect_impl(domain, stream, Some(self.alpn_protocols), f)
116 .await
117 }
118}