use futures::{AsyncBufRead, AsyncRead, AsyncWrite};
use futures_rustls::{TlsConnector, rustls};
use std::error::Error;
use std::sync::Arc;
pub async fn connect_rustls<S>(
stream: S,
host_name: Option<String>,
) -> Result<(impl AsyncBufRead + Unpin, impl AsyncWrite + Unpin), Box<dyn Error>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
use futures::io::{BufReader, AsyncReadExt};
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, RootCertStore};
use rustls_native_certs::load_native_certs;
let mut root_store = RootCertStore::empty();
let certs =
load_native_certs().map_err(|e| format!("Failed to load native certificates: {}", e))?;
for cert in certs {
root_store
.add(cert)
.map_err(|e| format!("Failed to add certificate: {}", e))?;
}
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let server_name = if let Some(ref host) = host_name {
ServerName::try_from(host.clone())
.map_err(|e| format!("Invalid server name '{}': {}", host, e))?
} else {
return Err("Host name is required for TLS connection".into());
};
let connector = TlsConnector::from(Arc::new(config));
let tls_stream = connector.connect(server_name, stream).await?;
let (reader, writer) = tls_stream.split();
Ok((BufReader::new(reader), writer))
}