use std::io::{self, Cursor, Read, Write};
use bytes::{Bytes, BytesMut};
use rustls::ClientConnection;
pub struct RustlsByteAdapter {
session: ClientConnection,
inbox_encrypted: BytesMut,
inbox_plaintext: BytesMut,
outbox_plaintext: BytesMut,
outbox_encrypted: BytesMut,
}
impl std::fmt::Debug for RustlsByteAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RustlsByteAdapter")
.field("inbox_encrypted_len", &self.inbox_encrypted.len())
.field("inbox_plaintext_len", &self.inbox_plaintext.len())
.field("outbox_plaintext_len", &self.outbox_plaintext.len())
.field("outbox_encrypted_len", &self.outbox_encrypted.len())
.field("is_handshaking", &self.session.is_handshaking())
.field("wants_read", &self.session.wants_read())
.field("wants_write", &self.session.wants_write())
.finish()
}
}
impl RustlsByteAdapter {
#[must_use]
pub fn new(session: ClientConnection) -> Self {
Self {
session,
inbox_encrypted: BytesMut::with_capacity(16 * 1024),
inbox_plaintext: BytesMut::with_capacity(16 * 1024),
outbox_plaintext: BytesMut::with_capacity(16 * 1024),
outbox_encrypted: BytesMut::with_capacity(16 * 1024),
}
}
#[must_use]
pub fn is_handshaking(&self) -> bool {
self.session.is_handshaking()
}
pub fn push_encrypted(&mut self, bytes: &[u8]) {
self.inbox_encrypted.extend_from_slice(bytes);
}
pub fn push_plaintext(&mut self, bytes: &[u8]) {
self.outbox_plaintext.extend_from_slice(bytes);
}
pub fn step(&mut self) -> Result<(), rustls::Error> {
if !self.inbox_encrypted.is_empty() {
let mut cursor = Cursor::new(self.inbox_encrypted.as_ref());
let consumed = self
.session
.read_tls(&mut cursor)
.map_err(|err| rustls::Error::General(format!("read_tls: {err}")))?;
let _ = self.inbox_encrypted.split_to(consumed);
let _state = self.session.process_new_packets()?;
}
let mut buf = [0u8; 8192];
loop {
match self.session.reader().read(&mut buf) {
Ok(0) => break,
Ok(n) => self.inbox_plaintext.extend_from_slice(&buf[..n]),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => break,
Err(_) => break,
}
}
if !self.outbox_plaintext.is_empty() {
let written = self
.session
.writer()
.write(self.outbox_plaintext.as_ref())
.unwrap_or(0);
let _ = self.outbox_plaintext.split_to(written);
}
let mut sink = Vec::with_capacity(8192);
let _ = self
.session
.write_tls(&mut sink)
.map_err(|err| rustls::Error::General(format!("write_tls: {err}")))?;
self.outbox_encrypted.extend_from_slice(&sink);
Ok(())
}
#[must_use]
pub fn take_plaintext(&mut self) -> Bytes {
self.inbox_plaintext.split().freeze()
}
#[must_use]
pub fn take_encrypted_outbound(&mut self) -> Bytes {
self.outbox_encrypted.split().freeze()
}
}
#[cfg(test)]
mod tests {
use super::RustlsByteAdapter;
fn make_session() -> rustls::ClientConnection {
crate::tls_crypto::install_default_provider();
let root_store = rustls::RootCertStore::empty();
let config = std::sync::Arc::new(
rustls::ClientConfig::builder_with_provider(crate::tls_crypto::active_provider())
.with_safe_default_protocol_versions()
.expect("rustls default protocol versions are valid")
.with_root_certificates(root_store)
.with_no_client_auth(),
);
let name = rustls::pki_types::ServerName::try_from("example.com").unwrap();
rustls::ClientConnection::new(config, name).expect("rustls client session")
}
#[test]
fn adapter_compiles_and_starts_handshaking() {
let session = make_session();
let mut adapter = RustlsByteAdapter::new(session);
assert!(adapter.is_handshaking());
adapter.push_encrypted(&[]);
adapter.step().unwrap();
let outbound = adapter.take_encrypted_outbound();
assert!(
!outbound.is_empty(),
"client should have produced ClientHello bytes"
);
}
#[test]
fn plaintext_push_and_take_round_trip() {
let session = make_session();
let mut adapter = RustlsByteAdapter::new(session);
adapter.push_plaintext(b"hello");
adapter.step().unwrap();
let taken = adapter.take_plaintext();
assert!(
taken.is_empty(),
"no decrypted plaintext should appear pre-handshake"
);
}
#[test]
fn adapter_step_propagates_decrypt_error() {
let session = make_session();
let mut adapter = RustlsByteAdapter::new(session);
adapter.step().unwrap();
let _ = adapter.take_encrypted_outbound();
let bogus = vec![0x17, 0x03, 0x03, 0x00, 0x05, 0xff, 0xff, 0xff, 0xff, 0xff];
adapter.push_encrypted(&bogus);
let outcome = adapter.step();
assert!(
outcome.is_err(),
"rustls must reject the bogus record, got {outcome:?}"
);
}
}