use std::time::Instant;
use crate::error::Result;
use crate::io::Machine;
pub(crate) trait TlsEngine {
fn is_handshaking(&self) -> bool;
fn feed_incoming(&mut self, ciphertext: &[u8]) -> Result<()>;
fn drain_outgoing(&mut self, out: &mut Vec<u8>);
fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize>;
fn write_plaintext(&mut self, plaintext: &[u8]);
}
pub(crate) struct TlsClient<E, M> {
tls: E,
inner: M,
scratch: Vec<u8>,
}
impl<E: TlsEngine, M: Machine> TlsClient<E, M> {
pub(crate) fn new(tls: E, inner: M) -> TlsClient<E, M> {
TlsClient {
tls,
inner,
scratch: vec![0u8; 16 * 1024],
}
}
fn pump_inner_input(&mut self) -> Result<()> {
loop {
let n = self.tls.read_plaintext(&mut self.scratch)?;
if n == 0 {
return Ok(());
}
let chunk = self.scratch[..n].to_vec();
self.inner.handle_input(&chunk)?;
}
}
}
impl<E: TlsEngine, M: Machine> Machine for TlsClient<E, M> {
type Event = M::Event;
fn handle_input(&mut self, wire: &[u8]) -> Result<usize> {
self.tls.feed_incoming(wire)?;
if !self.tls.is_handshaking() {
self.pump_inner_input()?;
}
Ok(wire.len())
}
fn handle_eof(&mut self) -> Result<()> {
if !self.tls.is_handshaking() {
self.pump_inner_input()?;
}
self.inner.handle_eof()
}
fn poll_transmit(&mut self, out: &mut Vec<u8>) -> bool {
if !self.tls.is_handshaking() {
let mut plaintext = Vec::new();
while self.inner.poll_transmit(&mut plaintext) {}
if !plaintext.is_empty() {
self.tls.write_plaintext(&plaintext);
}
}
let before = out.len();
self.tls.drain_outgoing(out);
out.len() > before
}
fn poll_event(&mut self) -> Option<M::Event> {
self.inner.poll_event()
}
fn handle_timeout(&mut self, now: Instant) {
self.inner.handle_timeout(now);
}
fn next_timeout(&self) -> Option<Instant> {
self.inner.next_timeout()
}
fn is_finished(&self) -> bool {
self.inner.is_finished()
}
}
#[cfg(feature = "rustls-tls")]
pub(crate) struct RustlsEngine(pub(crate) rustls::ClientConnection);
#[cfg(feature = "rustls-tls")]
impl TlsEngine for RustlsEngine {
fn is_handshaking(&self) -> bool {
self.0.is_handshaking()
}
fn feed_incoming(&mut self, mut ciphertext: &[u8]) -> Result<()> {
while !ciphertext.is_empty() {
let used = self
.0
.read_tls(&mut ciphertext)
.map_err(crate::error::Error::Io)?;
if used == 0 {
break;
}
self.0
.process_new_packets()
.map_err(|e| crate::error::Error::Io(std::io::Error::other(format!("tls: {e}"))))?;
}
Ok(())
}
fn drain_outgoing(&mut self, out: &mut Vec<u8>) {
while self.0.wants_write() {
let _ = self.0.write_tls(out);
}
}
fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize> {
use std::io::Read;
match self.0.reader().read(dst) {
Ok(n) => Ok(n),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(crate::error::Error::Io(e)),
}
}
fn write_plaintext(&mut self, plaintext: &[u8]) {
use std::io::Write;
let _ = self.0.writer().write_all(plaintext);
}
}
#[cfg(feature = "purecrypto-tls")]
pub(crate) struct PurecryptoEngine {
conn: purecrypto::tls::Connection,
pending_wire: Vec<u8>,
plaintext: Vec<u8>,
done: bool,
}
#[cfg(feature = "purecrypto-tls")]
fn pc_err(e: impl std::fmt::Debug) -> crate::error::Error {
crate::error::Error::Io(std::io::Error::other(format!("tls: {e:?}")))
}
#[cfg(feature = "purecrypto-tls")]
impl PurecryptoEngine {
pub(crate) fn new(conn: purecrypto::tls::Connection) -> Result<PurecryptoEngine> {
let mut e = PurecryptoEngine {
conn,
pending_wire: Vec::new(),
plaintext: Vec::new(),
done: false,
};
e.refresh_done()?;
Ok(e)
}
fn refresh_done(&mut self) -> Result<()> {
if matches!(
self.conn.handshake().map_err(pc_err)?,
purecrypto::tls::HandshakeStatus::Complete
) {
self.done = true;
}
Ok(())
}
}
#[cfg(feature = "purecrypto-tls")]
impl TlsEngine for PurecryptoEngine {
fn is_handshaking(&self) -> bool {
!self.done
}
fn feed_incoming(&mut self, ciphertext: &[u8]) -> Result<()> {
self.pending_wire.extend_from_slice(ciphertext);
let mut taken = 0;
while taken < self.pending_wire.len() {
let n = self
.conn
.feed(&self.pending_wire[taken..])
.map_err(pc_err)?;
if n == 0 {
break;
}
taken += n;
}
self.pending_wire.drain(..taken);
self.refresh_done()?;
Ok(())
}
fn drain_outgoing(&mut self, out: &mut Vec<u8>) {
loop {
match self.conn.pop() {
Ok(rec) if rec.is_empty() => break,
Ok(rec) => out.extend_from_slice(&rec),
Err(_) => break,
}
}
}
fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize> {
if self.plaintext.is_empty() {
let app = self.conn.recv().map_err(pc_err)?;
self.plaintext.extend_from_slice(&app);
}
let n = dst.len().min(self.plaintext.len());
dst[..n].copy_from_slice(&self.plaintext[..n]);
self.plaintext.drain(..n);
Ok(n)
}
fn write_plaintext(&mut self, plaintext: &[u8]) {
let _ = self.conn.send(plaintext);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::http1::{ClientExchange, Event};
#[derive(Default)]
struct MockTls {
sent_hello: bool,
established: bool,
outbox: Vec<u8>,
plaintext_in: Vec<u8>,
}
impl TlsEngine for MockTls {
fn is_handshaking(&self) -> bool {
!self.established
}
fn feed_incoming(&mut self, ciphertext: &[u8]) -> Result<()> {
if !self.established {
if let Some(rest) = strip_prefix(ciphertext, b"SHLO") {
self.established = true;
self.plaintext_in.extend_from_slice(rest);
}
} else {
self.plaintext_in.extend_from_slice(ciphertext);
}
Ok(())
}
fn drain_outgoing(&mut self, out: &mut Vec<u8>) {
if !self.sent_hello {
out.extend_from_slice(b"CHLO");
self.sent_hello = true;
}
out.append(&mut self.outbox);
}
fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize> {
let n = dst.len().min(self.plaintext_in.len());
dst[..n].copy_from_slice(&self.plaintext_in[..n]);
self.plaintext_in.drain(..n);
Ok(n)
}
fn write_plaintext(&mut self, plaintext: &[u8]) {
self.outbox.extend_from_slice(plaintext); }
}
fn strip_prefix<'a>(buf: &'a [u8], prefix: &[u8]) -> Option<&'a [u8]> {
buf.starts_with(prefix).then(|| &buf[prefix.len()..])
}
fn request() -> Vec<u8> {
ClientExchange::encode_request("GET", "/", &[("Host".into(), "x".into())], b"")
}
#[test]
fn layered_handshake_then_request_then_response() {
let mut tls = TlsClient::new(MockTls::default(), ClientExchange::new("GET", request()));
let mut out = Vec::new();
assert!(tls.poll_transmit(&mut out));
assert_eq!(out, b"CHLO");
tls.handle_input(b"SHLO").unwrap();
out.clear();
assert!(tls.poll_transmit(&mut out));
assert_eq!(out, request());
tls.handle_input(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nhi")
.unwrap();
let Event::Response { head, body } = tls.poll_event().expect("response");
assert_eq!(head.status, 200);
assert_eq!(body, b"hi");
assert!(tls.is_finished());
}
#[test]
fn handshake_reply_may_carry_app_data() {
let mut tls = TlsClient::new(MockTls::default(), ClientExchange::new("GET", request()));
let mut out = Vec::new();
tls.poll_transmit(&mut out); tls.handle_input(b"SHLOHTTP/1.1 204 No Content\r\n\r\n")
.unwrap();
let Event::Response { head, .. } = tls.poll_event().expect("response");
assert_eq!(head.status, 204);
}
#[test]
fn drives_through_the_real_blocking_driver() {
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let Ok((mut sock, _)) = listener.accept() else {
return;
};
let mut hello = [0u8; 4];
if sock.read_exact(&mut hello).is_err() || &hello != b"CHLO" {
return;
}
let _ = sock.write_all(b"SHLO");
let mut buf = Vec::new();
let mut byte = [0u8; 1];
while sock.read(&mut byte).map(|n| n == 1).unwrap_or(false) {
buf.push(byte[0]);
if buf.ends_with(b"\r\n\r\n") {
break;
}
}
let _ = sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello");
});
let mut sock = TcpStream::connect(("127.0.0.1", port)).unwrap();
let mut tls = TlsClient::new(MockTls::default(), ClientExchange::new("GET", request()));
let events = crate::io::blocking::drive(&mut tls, &mut sock).unwrap();
assert_eq!(events.len(), 1);
let Event::Response { head, body } = &events[0];
assert_eq!(head.status, 200);
assert_eq!(body, b"hello");
}
}
#[cfg(all(test, feature = "rustls-tls"))]
mod rustls_tests {
use std::io::{Read, Write};
use std::sync::Arc;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection};
use super::*;
use crate::proto::http1::{ClientExchange, Event};
pub(super) const CA_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----
MIIBhzCCAS2gAwIBAgIUEJAJGguFhUu6Wi64F9FYb6oJ9bkwCgYIKoZIzj0EAwIw
GDEWMBQGA1UEAwwNcnN1cmwtdGVzdC1jYTAgFw0yNjA2MjEyMzI2MjFaGA8yMTI2
MDUyODIzMjYyMVowGDEWMBQGA1UEAwwNcnN1cmwtdGVzdC1jYTBZMBMGByqGSM49
AgEGCCqGSM49AwEHA0IABGvezLhNMu/DJw3ClBkhcK571eQz/QctqGAf1whkMiXf
Sj46b9bBymWIV706DP/x2nXzSJgiXTv9rnTli35el0CjUzBRMB0GA1UdDgQWBBQU
AOFhWcYfxuM+R86kRFZWr/KATzAfBgNVHSMEGDAWgBQUAOFhWcYfxuM+R86kRFZW
r/KATzAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBWUfubWKWST
arQvZPn0jqXOwKG0x+xYs5UtcjVf3vOiAiEAlxoTAAh0nVLMrmTsnJXD131iPHz7
Uk3Wt1xw1blCE/8=
-----END CERTIFICATE-----
";
const LEAF_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----
MIIBuDCCAV2gAwIBAgIUcMudt8JBWAsDX8h+3CC46SiY14EwCgYIKoZIzj0EAwIw
GDEWMBQGA1UEAwwNcnN1cmwtdGVzdC1jYTAgFw0yNjA2MjEyMzI2MjFaGA8yMTI2
MDUyODIzMjYyMVowFDESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAEuBVdUYNtZqpWDO9h4nw0HF9sTKT3R7p/WJYsNgIfeO4hi/AM
9x+n7MP1tYi6zPlfR6qG/ZbEJLFDzZShfHPc/KOBhjCBgzAUBgNVHREEDTALggls
b2NhbGhvc3QwCQYDVR0TBAIwADALBgNVHQ8EBAMCB4AwEwYDVR0lBAwwCgYIKwYB
BQUHAwEwHQYDVR0OBBYEFAAZvjmK2EXoiEDqFV3wFGMS8GBJMB8GA1UdIwQYMBaA
FBQA4WFZxh/G4z5HzqREVlav8oBPMAoGCCqGSM49BAMCA0kAMEYCIQCPQPF3G07F
EhDmMDPLFGbF/ZdfuDFfBN6Sjs3DuIgSXAIhAMGqymq6vFwXRbvrhbGljFfJQjtz
98VOQz3xfzdRnPC2
-----END CERTIFICATE-----
";
const LEAF_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg8mp/gpytQtzNMwlE
fXfhylHGgcKzHtmkPeil9MKfoSyhRANCAAS4FV1Rg21mqlYM72HifDQcX2xMpPdH
un9Yliw2Ah947iGL8Az3H6fsw/W1iLrM+V9Hqob9lsQksUPNlKF8c9z8
-----END PRIVATE KEY-----
";
pub(super) fn server_config() -> Arc<ServerConfig> {
let certs = rustls_pemfile::certs(&mut LEAF_CERT_PEM.as_bytes())
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap();
let key = rustls_pemfile::private_key(&mut LEAF_KEY_PEM.as_bytes())
.unwrap()
.unwrap();
Arc::new(
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap(),
)
}
fn client_conn() -> ClientConnection {
let mut roots = RootCertStore::empty();
for c in rustls_pemfile::certs(&mut CA_CERT_PEM.as_bytes()) {
roots.add(c.unwrap()).unwrap();
}
let config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let name = ServerName::try_from("localhost").unwrap();
ClientConnection::new(Arc::new(config), name).unwrap()
}
#[test]
fn real_rustls_handshake_carries_http_exchange() {
let req =
ClientExchange::encode_request("GET", "/", &[("Host".into(), "localhost".into())], b"");
let mut client =
TlsClient::new(RustlsEngine(client_conn()), ClientExchange::new("GET", req));
let mut server = ServerConnection::new(server_config()).unwrap();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nhello rustls";
let mut server_req = Vec::new();
let mut replied = false;
for _ in 0..64 {
let mut c2s = Vec::new();
while client.poll_transmit(&mut c2s) {}
let mut cur = &c2s[..];
while !cur.is_empty() {
let used = server.read_tls(&mut cur).unwrap();
if used == 0 {
break;
}
server.process_new_packets().unwrap();
}
if !server.is_handshaking() {
let mut tmp = [0u8; 4096];
loop {
match server.reader().read(&mut tmp) {
Ok(0) => break,
Ok(n) => server_req.extend_from_slice(&tmp[..n]),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => panic!("server read: {e}"),
}
}
if !replied && server_req.windows(4).any(|w| w == b"\r\n\r\n") {
server.writer().write_all(response).unwrap();
replied = true;
}
}
let mut s2c = Vec::new();
while server.wants_write() {
server.write_tls(&mut s2c).unwrap();
}
if !s2c.is_empty() {
client.handle_input(&s2c).unwrap();
}
if let Some(Event::Response { head, body }) = client.poll_event() {
assert_eq!(head.status, 200);
assert_eq!(body, b"hello rustls");
return;
}
}
panic!("TLS exchange did not complete within the iteration budget");
}
}
#[cfg(all(test, feature = "purecrypto-tls"))]
mod purecrypto_tests {
use super::*;
#[test]
fn purecrypto_adapter_emits_client_hello() {
let cfg = purecrypto::tls::Config::builder()
.roots(purecrypto::tls::RootCertStore::new())
.server_name("localhost")
.rng(std::sync::Arc::new(purecrypto::rng::OsRng))
.build();
let conn = purecrypto::tls::Connection::client(&cfg).unwrap();
let mut eng = PurecryptoEngine::new(conn).unwrap();
assert!(eng.is_handshaking());
let mut out = Vec::new();
eng.drain_outgoing(&mut out);
assert!(out.len() > 5, "expected a ClientHello record");
assert_eq!(out[0], 0x16, "record content type should be handshake");
assert_eq!(out[5], 0x01, "handshake message type should be ClientHello");
}
}
#[cfg(all(test, feature = "purecrypto-tls", feature = "rustls-tls"))]
mod cross_backend_tests {
use std::io::{Read, Write};
use rustls::ServerConnection;
use super::*;
use crate::proto::http1::{ClientExchange, Event};
#[test]
fn purecrypto_client_against_rustls_server() {
let mut roots = purecrypto::tls::RootCertStore::new();
roots.add_pem(super::rustls_tests::CA_CERT_PEM).unwrap();
let cfg = purecrypto::tls::Config::builder()
.roots(roots)
.server_name("localhost")
.rng(std::sync::Arc::new(purecrypto::rng::OsRng))
.build();
let client_conn = purecrypto::tls::Connection::client(&cfg).unwrap();
let req =
ClientExchange::encode_request("GET", "/", &[("Host".into(), "localhost".into())], b"");
let mut client = TlsClient::new(
PurecryptoEngine::new(client_conn).unwrap(),
ClientExchange::new("GET", req),
);
let mut server = ServerConnection::new(super::rustls_tests::server_config()).unwrap();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\ninterop ok";
let mut server_req = Vec::new();
let mut replied = false;
for _ in 0..64 {
let mut c2s = Vec::new();
while client.poll_transmit(&mut c2s) {}
let mut cur = &c2s[..];
while !cur.is_empty() {
let used = server.read_tls(&mut cur).unwrap();
if used == 0 {
break;
}
server.process_new_packets().unwrap();
}
if !server.is_handshaking() {
let mut tmp = [0u8; 4096];
loop {
match server.reader().read(&mut tmp) {
Ok(0) => break,
Ok(n) => server_req.extend_from_slice(&tmp[..n]),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => panic!("server read: {e}"),
}
}
if !replied && server_req.windows(4).any(|w| w == b"\r\n\r\n") {
server.writer().write_all(response).unwrap();
replied = true;
}
}
let mut s2c = Vec::new();
while server.wants_write() {
server.write_tls(&mut s2c).unwrap();
}
if !s2c.is_empty() {
client.handle_input(&s2c).unwrap();
}
if let Some(Event::Response { head, body }) = client.poll_event() {
assert_eq!(head.status, 200);
assert_eq!(body, b"interop ok");
return;
}
}
panic!("cross-backend TLS exchange did not complete");
}
}
#[cfg(all(test, feature = "rustls-tls"))]
mod connect_wiring_tests {
use std::io::{Read, Write};
use rustls::ServerConnection;
use super::*;
use crate::proto::http1::{ClientExchange, Event};
#[test]
fn engine_from_build_client_engine_completes_handshake() {
let mut opts = crate::tls::TlsOpts::verifying();
opts.verify = false;
opts.roots = Some(rustls::RootCertStore::empty());
let engine = crate::tls::build_client_engine("localhost", &mut opts).unwrap();
let req =
ClientExchange::encode_request("GET", "/", &[("Host".into(), "localhost".into())], b"");
let mut client = TlsClient::new(engine, ClientExchange::new("GET", req));
let mut server = ServerConnection::new(super::rustls_tests::server_config()).unwrap();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nwired";
let mut server_req = Vec::new();
let mut replied = false;
for _ in 0..64 {
let mut c2s = Vec::new();
while client.poll_transmit(&mut c2s) {}
let mut cur = &c2s[..];
while !cur.is_empty() {
let used = server.read_tls(&mut cur).unwrap();
if used == 0 {
break;
}
server.process_new_packets().unwrap();
}
if !server.is_handshaking() {
let mut tmp = [0u8; 4096];
loop {
match server.reader().read(&mut tmp) {
Ok(0) => break,
Ok(n) => server_req.extend_from_slice(&tmp[..n]),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(e) => panic!("server read: {e}"),
}
}
if !replied && server_req.windows(4).any(|w| w == b"\r\n\r\n") {
server.writer().write_all(response).unwrap();
replied = true;
}
}
let mut s2c = Vec::new();
while server.wants_write() {
server.write_tls(&mut s2c).unwrap();
}
if !s2c.is_empty() {
client.handle_input(&s2c).unwrap();
}
if let Some(Event::Response { head, body }) = client.poll_event() {
assert_eq!(head.status, 200);
assert_eq!(body, b"wired");
return;
}
}
panic!("handshake/exchange did not complete");
}
}