use crate::audit::tables::iana_groups::named_group_for_code_point;
use crate::probe::handshake::{build_client_hello, parse_server_response, ServerResponse};
use crate::probe::hrr::is_hrr;
use crate::{CipherSuite, PqcHandshakeResult, ProbeError, TlsVersion};
use rustls::{ClientConfig, RootCertStore};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_rustls::TlsConnector;
#[derive(Debug, Clone)]
pub struct ProbeConfig {
pub timeout_ms: u64,
pub sni_override: Option<String>,
}
impl Default for ProbeConfig {
fn default() -> Self {
Self {
timeout_ms: 5000,
sni_override: None,
}
}
}
const DEFAULT_CIPHER_SUITES: &[u16] = &[
0x1302, 0x1301, 0x1303, 0xC02C, 0xC02B, 0xC030, 0xC02F, 0xCCA8, 0xCCA9, ];
const DEFAULT_NAMED_GROUPS: &[u16] = &[
0x11EC, 0x11EB, 0x11ED, 0x0201, 0x0202, 0x001D, 0x0017, ];
const SERVER_HELLO_RANDOM_OFFSET: usize = 11;
const SERVER_HELLO_RANDOM_END: usize = SERVER_HELLO_RANDOM_OFFSET + 32;
async fn probe_raw_group(
host: &str,
port: u16,
sni: &str,
timeout_ms: u64,
) -> Result<Option<(u16, bool)>, ProbeError> {
let hello = build_client_hello(sni, DEFAULT_CIPHER_SUITES, DEFAULT_NAMED_GROUPS, 0x0304);
let mut stream = crate::probe::tcp_connect(host, port, timeout_ms)
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::TimedOut {
ProbeError::Timeout {
after_ms: timeout_ms,
}
} else {
ProbeError::ConnectionRefused {
host: host.into(),
port,
}
}
})?;
stream
.write_all(&hello)
.await
.map_err(|e| ProbeError::TlsHandshakeFailed {
reason: e.to_string(),
})?;
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
let mut buf = Vec::with_capacity(4096);
loop {
let need = if buf.len() >= 5 {
5 + u16::from_be_bytes([buf[3], buf[4]]) as usize
} else {
5
};
if buf.len() >= need {
break;
}
let remaining = deadline
.checked_duration_since(tokio::time::Instant::now())
.ok_or(ProbeError::Timeout {
after_ms: timeout_ms,
})?;
let mut chunk = [0u8; 4096];
match tokio::time::timeout(remaining, stream.read(&mut chunk)).await {
Ok(Ok(0)) => {
return Err(ProbeError::TlsHandshakeFailed {
reason: "connection closed before ServerHello".into(),
})
}
Ok(Ok(n)) => buf.extend_from_slice(&chunk[..n]),
Ok(Err(e)) => {
return Err(ProbeError::TlsHandshakeFailed {
reason: e.to_string(),
})
}
Err(_) => {
return Err(ProbeError::Timeout {
after_ms: timeout_ms,
})
}
}
}
let response =
parse_server_response(&buf).map_err(|e| ProbeError::TlsHandshakeFailed { reason: e })?;
match response {
ServerResponse::ServerHello { selected_group, .. } => {
let hrr = if buf.len() >= SERVER_HELLO_RANDOM_END {
is_hrr(&buf[SERVER_HELLO_RANDOM_OFFSET..SERVER_HELLO_RANDOM_END])
} else {
false
};
match selected_group {
Some(group_code) => Ok(Some((group_code, hrr))),
None => Ok(None),
}
}
ServerResponse::HandshakeFailure => Err(ProbeError::TlsHandshakeFailed {
reason: "server rejected all offered cipher suites".into(),
}),
ServerResponse::ConnectionClose => Err(ProbeError::TlsHandshakeFailed {
reason: "server closed connection during handshake".into(),
}),
ServerResponse::Timeout => Err(ProbeError::Timeout {
after_ms: timeout_ms,
}),
}
}
pub async fn pqc_probe(
host: &str,
port: u16,
sni_override: Option<&str>,
config: &ProbeConfig,
) -> Result<PqcHandshakeResult, ProbeError> {
let sni = sni_override.unwrap_or(host);
let timeout_ms = config.timeout_ms;
let raw = probe_raw_group(host, port, sni, timeout_ms).await?;
let (group_code, hrr_required) = match raw {
Some(pair) => pair,
None => (0u16, false),
};
let root_store = RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let versions: &[&rustls::SupportedProtocolVersion] = if raw.is_none() {
&[&rustls::version::TLS12]
} else {
&[&rustls::version::TLS13, &rustls::version::TLS12]
};
let tls_config =
ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
.with_protocol_versions(versions)
.map_err(|e| ProbeError::TlsHandshakeFailed {
reason: e.to_string(),
})?
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = crate::probe::tcp_connect(host, port, timeout_ms)
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::TimedOut {
ProbeError::Timeout {
after_ms: timeout_ms,
}
} else {
ProbeError::ConnectionRefused {
host: host.into(),
port,
}
}
})?;
let server_name = rustls::pki_types::ServerName::try_from(sni.to_string()).map_err(|e| {
ProbeError::TlsHandshakeFailed {
reason: e.to_string(),
}
})?;
let tls_stream = tokio::time::timeout(
tokio::time::Duration::from_millis(timeout_ms),
connector.connect(server_name, stream),
)
.await
.map_err(|_| ProbeError::Timeout {
after_ms: timeout_ms,
})?
.map_err(|e| ProbeError::TlsHandshakeFailed {
reason: e.to_string(),
})?;
let (_, session) = tls_stream.get_ref();
let suite =
session
.negotiated_cipher_suite()
.ok_or_else(|| ProbeError::TlsHandshakeFailed {
reason: "no cipher suite negotiated".into(),
})?;
let suite_id = u16::from(suite.suite());
let suite_name = format!("{:?}", suite.suite());
let negotiated_suite = CipherSuite {
id: suite_id,
name: suite_name,
};
let negotiated_version = match session.protocol_version() {
Some(rustls::ProtocolVersion::TLSv1_3) => TlsVersion::Tls13,
Some(rustls::ProtocolVersion::TLSv1_2) => TlsVersion::Tls12,
Some(rustls::ProtocolVersion::TLSv1_1) => TlsVersion::Tls11,
Some(rustls::ProtocolVersion::TLSv1_0) => TlsVersion::Tls10,
Some(other) => TlsVersion::Unknown(u16::from(other)),
None => TlsVersion::Unknown(0),
};
let cert_chain_der = session
.peer_certificates()
.unwrap_or_default()
.iter()
.map(|c| c.as_ref().to_vec())
.collect();
let named_group = named_group_for_code_point(group_code);
Ok(PqcHandshakeResult {
negotiated_version,
negotiated_suite,
negotiated_group: named_group,
hrr_required,
cert_chain_der,
})
}