use core::{net::SocketAddr, time::Duration};
use rustls::{
ClientConfig, Error as Terror, RootCertStore, SignatureScheme,
client::{
danger::{ServerCertVerified, ServerCertVerifier},
verify_server_cert_signed_by_trust_anchor, verify_server_name,
},
crypto::{CryptoProvider, verify_tls12_signature, verify_tls13_signature},
pki_types::ServerName,
server::ParsedCertificate,
};
use std::{io::ErrorKind, sync::Arc, time::Instant};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::TlsConnector;
use x509_parser::oid_registry::OID_X509_COMMON_NAME;
macro_rules! debug {
($($arg:tt)+) => {{
#[cfg(feature = "logging")]
::log::debug!(target: "gload::tls", $($arg)+)
}}
}
macro_rules! warn {
($($arg:tt)+) => {{
#[cfg(feature = "logging")]
::log::warn!(target: "gload::tls", $($arg)+)
}}
}
pub(crate) type Stream = tokio_rustls::client::TlsStream<TcpStream>;
pub(crate) async fn open_stream(
addr: &SocketAddr,
authority: ServerName<'static>,
) -> Result<Stream, TlsError> {
let verifier = Arc::new(Verifier::new());
let versions = [&rustls::version::TLS12, &rustls::version::TLS13];
let cfg = ClientConfig::builder_with_protocol_versions(&versions)
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth();
debug!("* ALPN: gload offers none"); let connector = TlsConnector::from(Arc::new(cfg));
let timeout = Duration::from_secs(15);
let start = Instant::now();
let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Err(_) => {
return Err(TlsError::TimedOut(authority, addr.port(), start.elapsed()));
}
Ok(Err(err)) if err.kind() == ErrorKind::TimedOut => {
return Err(TlsError::TimedOut(authority, addr.port(), start.elapsed()));
}
Ok(Err(err)) => return Err(TlsError::Open(err)),
Ok(Ok(s)) => s,
};
let stream = match connector.connect(authority.clone(), stream).await {
Err(err) => match err.downcast::<Terror>() {
Ok(Terror::InvalidCertificate(err)) => {
return Err(TlsError::ServerCertificate(err));
}
Ok(Terror::InappropriateHandshakeMessage { .. }) => {
return Err(TlsError::InappropriateHandshakeMessage);
}
Ok(err) => panic!("something went wrong: {err}"), Err(err) => return Err(TlsError::InitialConnect(err, authority, addr.port())),
},
Ok(c) => c,
};
Ok(stream)
}
pub(crate) async fn send(
addr: &SocketAddr,
authority: &ServerName<'static>,
mut tcp_stream: Stream,
payload: Vec<u8>,
allow_truncation: bool,
) -> Result<Vec<u8>, TlsError> {
debug!(
"* Connected to {} ({}) port {}",
authority.to_str(),
addr.ip(),
addr.port()
);
debug!("* using Gemini/0.24.x");
let start = Instant::now();
match tcp_stream.write_all(&payload).await {
Err(err) if err.kind() == ErrorKind::TimedOut => {
return Err(TlsError::TimedOut(
authority.to_owned(),
addr.port(),
start.elapsed(),
));
}
Err(err) => match err.downcast::<Terror>() {
Ok(Terror::InvalidCertificate(err)) => {
return Err(TlsError::ServerCertificate(err));
}
Ok(Terror::InappropriateHandshakeMessage { .. }) => {
return Err(TlsError::InappropriateHandshakeMessage);
}
Ok(err) => panic!("something went wrong: {err}"), Err(err) => return Err(TlsError::Send(err)),
},
Ok(()) => {
debug!(
"> {}",
str::from_utf8(&payload)
.expect("the payload came from a utf-8 string")
.replace('\r', "")
.replace('\n', "\n> ") );
}
}
debug!("* Request completely sent off");
let mut buffer: Vec<u8> = Vec::with_capacity(512);
let timeout = Duration::from_secs(15);
let start = Instant::now();
match tokio::time::timeout(timeout, tcp_stream.read_to_end(&mut buffer)).await {
Err(_) => {
return Err(TlsError::TimedOut(
authority.clone(),
addr.port(),
start.elapsed(),
));
}
Ok(Err(e)) if e.kind() == ErrorKind::UnexpectedEof => {
warn!("* rustls: server closed abruptly (missing close_notify)");
if allow_truncation {
} else {
return Err(TlsError::ClosedWithoutNotify);
}
}
Ok(Err(e)) => return Err(TlsError::Receive(e)),
Ok(Ok(_)) => {} }
tcp_stream.get_mut().1.send_close_notify();
let timeout = Duration::from_secs(3);
let start = Instant::now();
match tokio::time::timeout(timeout, tcp_stream.write(&[])).await {
Err(_) => {
return Err(TlsError::TimedOut(
authority.clone(),
addr.port(),
start.elapsed(),
));
}
Ok(Err(err)) => {
return Err(TlsError::ShutdownFailed(err));
}
Ok(Ok(_)) => {} }
if buffer.is_empty() {
return Err(TlsError::NoResponse);
}
Ok(buffer)
}
#[derive(Debug)]
struct Verifier(CryptoProvider);
impl Verifier {
fn new() -> Self {
let provider = rustls::crypto::aws_lc_rs::default_provider();
Self(provider)
}
}
impl ServerCertVerifier for Verifier {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
intermediates: &[rustls::pki_types::CertificateDer<'_>],
server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, Terror> {
use rustls::{CertificateError, OtherError};
let cert = match x509_parser::parse_x509_certificate(end_entity) {
Err(_) => return Err(Terror::InvalidCertificate(CertificateError::BadEncoding)),
Ok((_, cert)) => cert,
};
match u64::try_from(cert.validity.not_before.timestamp()) {
Err(_) => {} Ok(not_before) => {
if now.as_secs() < not_before {
return Err(Terror::InvalidCertificate(CertificateError::NotValidYet));
}
}
}
match u64::try_from(cert.validity.not_after.timestamp()) {
Err(_) => return Err(Terror::InvalidCertificate(CertificateError::Expired)), Ok(not_after) => {
if now.as_secs() > not_after {
return Err(Terror::InvalidCertificate(CertificateError::Expired));
}
}
}
debug!("* Server certificate:");
debug!("* subject: {}", cert.subject);
debug!("* start date: {}", cert.validity.not_before);
debug!("* expire date: {}", cert.validity.not_after);
let parsed_cert = ParsedCertificate::try_from(end_entity)?;
let subject_alternative_names = match cert.subject_alternative_name() {
Err(err) => {
return Err(Terror::InvalidCertificate(CertificateError::Other(
OtherError(Arc::new(err)),
)));
}
Ok(None) => None,
Ok(Some(name)) => Some(name.value.general_names.to_owned()),
};
if let Some(names) = subject_alternative_names {
match verify_server_name(&parsed_cert, server_name) {
Err(err) => return Err(err),
Ok(()) => {
debug!(
r#"* subjectAltName: host "{}" matched cert's "{}""#,
server_name.to_str(),
names
.iter()
.map(|n| n.to_string())
.map(|n| n.strip_prefix("DNSName(").unwrap_or(&n).to_owned())
.map(|n| n.strip_suffix(")").unwrap_or(&n).to_owned())
.collect::<Vec<_>>()
.join(",")
);
}
}
} else if let Some(attr) = cert
.subject
.iter()
.flat_map(|n| n.iter())
.find(|a| a.attr_type() == &OID_X509_COMMON_NAME)
&& let Ok(common_name) = attr.as_str()
&& common_name == server_name.to_str()
{
debug!("* subjectAltName: (none)");
} else {
return Err(Terror::InvalidCertificate(
CertificateError::NotValidForName,
));
}
debug!("* issuer: {}", cert.issuer);
if cert.subject == cert.issuer {
debug!("* Cert is self-signed");
} else {
let system_ca = rustls_native_certs::load_native_certs().certs; let mut roots = RootCertStore {
roots: if system_ca.is_empty() {
webpki_roots::TLS_SERVER_ROOTS.to_vec()
} else {
Vec::with_capacity(system_ca.len())
},
};
if !system_ca.is_empty() {
for ca in system_ca {
roots.add(ca).unwrap();
}
}
let supported_algs = self.0.signature_verification_algorithms.all;
let is_verified_by_ca = verify_server_cert_signed_by_trust_anchor(
&parsed_cert,
&roots,
intermediates,
now,
supported_algs,
)
.is_ok();
if is_verified_by_ca {
debug!("* Cert is signed by a trusted central authority");
} else {
return Err(Terror::InvalidCertificate(CertificateError::UnknownIssuer));
}
}
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, Terror> {
let supported_schemes = self.0.signature_verification_algorithms;
match verify_tls12_signature(message, cert, dss, &supported_schemes) {
Ok(valid) => {
debug!("* Valid TLSv1.2 signature");
Ok(valid)
}
Err(err) => {
debug!("* Invalid TLSv1.2 signature!");
Err(err)
}
}
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, Terror> {
let supported_schemes = self.0.signature_verification_algorithms;
match verify_tls13_signature(message, cert, dss, &supported_schemes) {
Ok(valid) => {
debug!("* Valid TLSv1.3 signature");
Ok(valid)
}
Err(err) => {
debug!("* Invalid TLSv1.3 signature!");
Err(err)
}
}
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum TlsError {
ClosedWithoutNotify,
InappropriateHandshakeMessage,
InitialConnect(std::io::Error, ServerName<'static>, u16),
NoResponse,
Open(std::io::Error),
Receive(std::io::Error),
Send(std::io::Error),
ServerCertificate(rustls::CertificateError),
ShutdownFailed(std::io::Error),
TimedOut(ServerName<'static>, u16, Duration),
}
impl core::fmt::Display for TlsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ClosedWithoutNotify => write!(
f,
"the server closed the TLS connection without sending close_notify; the response may have been truncated"
),
Self::InappropriateHandshakeMessage => write!(
f,
"received an inappropriate handshake message from the server"
),
Self::InitialConnect(err, server_name, port) => {
write!(
f,
"failed to connect to {} on port {port}: {err}",
server_name.to_str()
)
}
Self::NoResponse => write!(
f,
"the server replied with nothing, or did not reply at all"
),
Self::Open(err) => write!(f, "failed to open a TCP socket: {err}"),
Self::Receive(err) => write!(f, "failed to read data from the server: {err}"),
Self::Send(err) => write!(f, "failed to send data to the server: {err}"),
Self::ServerCertificate(err) => write!(
f,
"could not read or validate the server certificate: {err}"
),
Self::ShutdownFailed(err) => write!(f, "failed to shut down the TCP connection: {err}"),
Self::TimedOut(server_name, port, duration) => write!(
f,
"timed out after {} ms while opening a socket or awaiting a response from {} on port {port}",
duration.as_millis(),
server_name.to_str(),
),
}
}
}
impl core::error::Error for TlsError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
request::Request,
test_server::{
FRIENDLY, new_simple_runtime, start_friendly_server,
start_unfriendly_server_no_close_notify,
},
};
use rustls::pki_types::DnsName;
#[test]
fn test_fails_when_close_notify_is_missing() {
let server = start_unfriendly_server_no_close_notify();
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let err = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.err()
.unwrap();
assert!(matches!(err, TlsError::ClosedWithoutNotify), "{err:?}");
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_validates_cert_with_subject_alt_name() {
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let key = rcgen::generate_simple_self_signed([authority.to_str().into()]).unwrap();
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let expected = FRIENDLY;
let result = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.unwrap();
assert_eq!(result, expected, "{}", String::from_utf8_lossy(&result));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_validates_cert_with_subject_alt_name_with_different_common_name() {
let mut cert_params = rcgen::CertificateParams::new(["localhost".into()]).unwrap();
cert_params
.distinguished_name
.push(rcgen::DnType::CommonName, "localhost.local");
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let expected = FRIENDLY;
let result = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.unwrap();
assert_eq!(result, expected, "{}", String::from_utf8_lossy(&result));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_rejects_cert_with_different_subject_alt_name() {
let key = rcgen::generate_simple_self_signed(["localhost.local".into()]).unwrap();
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(
rustls::CertificateError::NotValidForNameContext {
ref expected,
ref presented,
}) if expected == &authority && presented == &vec![r#"DnsName("localhost.local")"#]),
"{err:?}"
);
}
#[test]
fn test_rejects_cert_with_different_subject_alt_name_and_matching_common_name() {
let mut cert_params = rcgen::CertificateParams::new(["localhost.local".into()]).unwrap();
cert_params
.distinguished_name
.push(rcgen::DnType::CommonName, "localhost");
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(
rustls::CertificateError::NotValidForNameContext {
ref expected,
ref presented,
}) if expected == &authority && presented == &vec![r#"DnsName("localhost.local")"#]),
"{err:?}"
);
}
#[test]
fn test_rejects_cert_with_no_subject_alt_name_and_different_common_name() {
let mut cert_params = rcgen::CertificateParams::new(Vec::new()).unwrap();
cert_params
.distinguished_name
.push(rcgen::DnType::CommonName, "localhost.local");
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(rustls::CertificateError::NotValidForName)
),
"{err:?}"
);
}
#[test]
fn test_validates_cert_with_only_common_name() {
let mut cert_params = rcgen::CertificateParams::new(Vec::new()).unwrap();
cert_params
.distinguished_name
.push(rcgen::DnType::CommonName, "localhost");
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let expected = FRIENDLY;
let result = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.unwrap();
assert_eq!(result, expected, "{}", String::from_utf8_lossy(&result));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_validates_cert_with_only_common_name_with_other_details() {
let mut cert_params = rcgen::CertificateParams::new(Vec::new()).unwrap();
cert_params
.distinguished_name
.push(rcgen::DnType::OrganizationName, "yeahhhhhh");
cert_params
.distinguished_name
.push(rcgen::DnType::CommonName, "localhost");
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let expected = FRIENDLY;
let result = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.unwrap();
assert_eq!(result, expected, "{}", String::from_utf8_lossy(&result));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_rejects_cert_with_no_name() {
let cert_params = rcgen::CertificateParams::new(Vec::new()).unwrap();
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = cert_params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(rustls::CertificateError::NotValidForName)
),
"{err:?}"
);
}
#[test]
fn test_rejects_future_cert() {
let mut params = rcgen::CertificateParams::new(["localhost".into()]).unwrap();
params.not_before = time::OffsetDateTime::now_utc().saturating_add(time::Duration::days(5)); params.not_after = params.not_before.saturating_add(time::Duration::days(1));
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(rustls::CertificateError::NotValidYet)
),
"{err:?}"
);
}
#[test]
fn test_rejects_expired_cert() {
let mut params = rcgen::CertificateParams::new(["localhost".into()]).unwrap();
params.not_before = time::OffsetDateTime::now_utc().saturating_sub(time::Duration::days(2));
params.not_after = params.not_before.saturating_add(time::Duration::days(1)); let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(rustls::CertificateError::Expired)
),
"{err:?}"
);
}
#[test]
fn test_validates_old_cert_still_valid() {
let mut params = rcgen::CertificateParams::new(["localhost".into()]).unwrap();
params.not_before =
time::OffsetDateTime::UNIX_EPOCH.saturating_sub(time::Duration::days(100)); params.not_after = time::OffsetDateTime::now_utc().saturating_add(time::Duration::days(1));
let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let stream = runtime
.block_on(open_stream(&addr, authority.clone()))
.unwrap();
let req = Request::from_uri_string("gemini://localhost").unwrap();
let expected = FRIENDLY;
let result = runtime
.block_on(send(&addr, &authority, stream, req.as_bytes(), false))
.unwrap();
assert_eq!(result, expected, "{}", String::from_utf8_lossy(&result));
assert_eq!(server.request_count(), 1);
}
#[test]
fn test_rejects_ancient_cert() {
let mut params = rcgen::CertificateParams::new(["localhost".into()]).unwrap();
params.not_before =
time::OffsetDateTime::UNIX_EPOCH.saturating_sub(time::Duration::days(100));
params.not_after = params.not_before.saturating_add(time::Duration::days(1)); let signing_key = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&signing_key).unwrap();
let key = rcgen::CertifiedKey { cert, signing_key };
let server = start_friendly_server(key);
let addr = server.addr();
let runtime = new_simple_runtime("test-client-runtime-worker");
let authority = ServerName::DnsName(DnsName::try_from_str("localhost").unwrap());
let err = runtime
.block_on(open_stream(&addr, authority.clone()))
.err()
.unwrap();
assert!(
matches!(
err,
TlsError::ServerCertificate(rustls::CertificateError::Expired)
),
"{err:?}"
);
}
}