use std::{fs, io, path::PathBuf, sync::Arc};
use rustls_pki_types::ServerName;
use tokio_rustls::rustls;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::{AvesterraError, CallError, HGTPError, String255};
use super::message::*;
#[derive(Debug)]
pub struct HGTPStream {
tlsstream: TlsStream<TcpStream>,
closed: bool,
}
impl HGTPStream {
pub async fn new(pem_filepath: PathBuf, address: &str, port: u16) -> anyhow::Result<Self> {
let mut root_store = rustls::RootCertStore::empty();
let pem_objects = get_pem_objects_content(&pem_filepath)?;
for pem in pem_objects {
root_store.add(pem)?;
}
let config = {
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
{
let mut dangerous = config.dangerous();
if cfg!(feature = "dangerous_enable_ssl_keylog_file") {
println!(
"Using SSLKEYLOGFILE, value: {}",
std::env::var("SSLKEYLOGFILE").unwrap()
);
dangerous.cfg.key_log = Arc::new(rustls::KeyLogFile::new());
}
dangerous.set_certificate_verifier(Arc::new(danger::NoCertificateVerification {}));
}
config
};
let dnsname = ServerName::try_from(address.to_string())?;
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(format!("{}:{}", address, port)).await?;
stream.set_nodelay(true)?;
let tlsstream = connector.connect(dnsname, stream).await?;
Ok(Self {
tlsstream,
closed: false,
})
}
pub async fn send(&mut self, msg: &HGTPMessage) -> Result<(), std::io::Error> {
match self._send(msg).await {
Ok(_) => Ok(()),
Err(e) => {
self.close().await;
Err(e)
}
}
}
pub async fn _send(&mut self, msg: &HGTPMessage) -> Result<(), std::io::Error> {
self.tlsstream.write_all(&msg.frame).await?;
if msg.unpack_extension() > 0 {
let packsize = 1024;
for from in (0..).step_by(packsize) {
if from > msg.unbounded.len() {
break;
}
let mut to = from + packsize;
if to > msg.unbounded.len() {
to = msg.unbounded.len();
}
self.tlsstream.write_all(&msg.unbounded[from..to]).await?;
self.tlsstream.flush().await?;
}
}
Ok(())
}
pub async fn recv(&mut self, msg: &mut HGTPMessage) -> Result<(), CallError> {
if let Err(e) = self._recv(msg).await {
self.send_bye().await;
self.close().await;
return Err(e.into());
};
validate(msg)?;
Ok(())
}
pub async fn _recv(&mut self, msg: &mut HGTPMessage) -> Result<(), std::io::Error> {
self.tlsstream.flush().await?;
self.tlsstream.read_exact(&mut msg.frame).await?;
if msg.unpack_extension() > 0 {
msg.unbounded.resize(msg.unpack_extension() as usize, 0);
self.tlsstream
.read_exact(msg.unbounded.as_mut_slice())
.await?;
}
Ok(())
}
pub async fn close(&mut self) {
self.tlsstream.shutdown().await.ok();
self.closed = true;
}
pub fn closed(&self) -> bool {
self.closed
}
pub async fn send_bye(&mut self) {
let mut msg = HGTPMessage::default();
msg.pack_command(crate::Command::Bye);
self.send(&msg).await.ok();
}
}
fn get_pem_objects_content(
pem_filepath: &PathBuf,
) -> io::Result<Vec<rustls_pki_types::CertificateDer>> {
let pem_file = fs::File::open(pem_filepath)?;
let mut rd = std::io::BufReader::new(pem_file);
rustls_pemfile::certs(&mut rd).collect()
}
fn validate(msg: &HGTPMessage) -> Result<(), CallError> {
match msg.unpack_error_code()? {
HGTPError::Ok => Ok(()),
errcode => Err(AvesterraError {
errcode,
message: String::from_utf8_lossy(msg.unpack_bytes())
.to_string()
.try_into()
.unwrap_or(String255::unchecked(
"<Error message too long, should never happen. Did the server return error message as frame extension??>",
)),
}
.into()),
}
}
mod danger {
use rustls::client::danger::HandshakeSignatureValid;
use rustls::crypto::{verify_tls12_signature, verify_tls13_signature};
use rustls::DigitallySignedStruct;
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
use tokio_rustls::rustls;
#[derive(Debug)]
pub struct NoCertificateVerification {}
impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
}