use log::debug;
use crate::error::AtlsVerificationError;
use crate::policy::Policy;
use crate::verifier::{AsyncByteStream, Report};
use crate::AtlsVerifier;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, RootCertStore};
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
pub use tokio_rustls::client::TlsStream;
#[cfg(not(target_arch = "wasm32"))]
use tokio_rustls::TlsConnector;
#[cfg(target_arch = "wasm32")]
pub use futures_rustls::client::TlsStream;
#[cfg(target_arch = "wasm32")]
use futures_rustls::TlsConnector;
pub async fn tls_handshake<S>(
stream: S,
server_name: &str,
alpn: Option<Vec<String>>,
) -> Result<(TlsStream<S>, Vec<u8>, Vec<u8>), AtlsVerificationError>
where
S: AsyncByteStream + 'static,
{
debug!("Starting TLS handshake to {}", server_name);
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
if let Some(protocols) = alpn {
config.alpn_protocols = protocols.into_iter().map(|s| s.into_bytes()).collect();
}
let connector = TlsConnector::from(Arc::new(config));
let server_name_parsed = ServerName::try_from(server_name.to_owned())
.map_err(|e| AtlsVerificationError::InvalidServerName(e.to_string()))?;
let tls_stream = connector
.connect(server_name_parsed, stream)
.await
.map_err(|e| AtlsVerificationError::TlsHandshake(e.to_string()))?;
let (_, conn) = tls_stream.get_ref();
let peer_cert = conn
.peer_certificates()
.and_then(|certs| certs.first())
.map(|cert| cert.as_ref().to_vec())
.ok_or(AtlsVerificationError::MissingCertificate)?;
debug!(
"TLS handshake complete, certificate received ({} bytes)",
peer_cert.len()
);
let mut session_ekm = vec![0u8; 32];
conn.export_keying_material(&mut session_ekm, b"EXPORTER-Channel-Binding", None)
.map_err(|e| {
AtlsVerificationError::TlsHandshake(format!("Failed to extract session EKM: {}", e))
})?;
debug!("Session EKM extracted ({} bytes)", session_ekm.len());
Ok((tls_stream, peer_cert, session_ekm))
}
pub async fn atls_connect<S>(
stream: S,
server_name: &str,
policy: Policy,
alpn: Option<Vec<String>>,
) -> Result<(TlsStream<S>, Report), AtlsVerificationError>
where
S: AsyncByteStream + 'static,
{
crate::logging::init();
let (mut tls_stream, peer_cert, session_ekm) = tls_handshake(stream, server_name, alpn).await?;
debug!("Starting attestation verification");
let verifier = policy.into_verifier()?;
let report = verifier
.verify(&mut tls_stream, &peer_cert, &session_ekm, server_name)
.await?;
debug!("Attestation verification successful");
Ok((tls_stream, report))
}