use std::io::{self, Read, Write};
use purecrypto::tls::{Config, Connection, CrlStore, HandshakeStatus};
use zeroize::Zeroize;
use super::common::ProtocolVersion;
use super::{client_auth, pc_roots};
use crate::error::{Error, Result};
pub use purecrypto::tls::RootCertStore;
const READ_CHUNK: usize = 16 * 1024;
#[derive(Clone)]
pub struct TlsOpts {
pub alpn: Vec<Vec<u8>>,
pub verify: bool,
pub roots: Option<RootCertStore>,
pub min_version: Option<ProtocolVersion>,
pub max_version: Option<ProtocolVersion>,
pub client_cert: Option<Vec<u8>>,
pub client_key: Option<Vec<u8>>,
pub client_key_pass: Option<String>,
pub cert_is_der: bool,
pub key_is_der: bool,
pub pinned_spki_sha256: Vec<[u8; 32]>,
pub crl_pem: Option<Vec<u8>>,
pub cipher_suites: Vec<u16>,
pub verify_callback: Option<super::common::VerifyCallback>,
}
impl TlsOpts {
pub fn verifying() -> Self {
TlsOpts {
alpn: Vec::new(),
verify: true,
roots: None,
min_version: None,
max_version: None,
client_cert: None,
client_key: None,
client_key_pass: None,
cert_is_der: false,
key_is_der: false,
pinned_spki_sha256: Vec::new(),
crl_pem: None,
cipher_suites: Vec::new(),
verify_callback: None,
}
}
}
impl Default for TlsOpts {
fn default() -> Self {
TlsOpts::verifying()
}
}
impl Drop for TlsOpts {
fn drop(&mut self) {
self.client_key.zeroize();
self.client_key_pass.zeroize();
}
}
fn to_pc_version(v: ProtocolVersion) -> purecrypto::tls::ProtocolVersion {
match v {
ProtocolVersion::TLSv1_3 => purecrypto::tls::ProtocolVersion::TLSv1_3,
_ => purecrypto::tls::ProtocolVersion::TLSv1_2,
}
}
pub fn load_system_roots() -> Result<RootCertStore> {
pc_roots::load_system_roots()
}
pub fn load_roots_from_file(path: &str) -> Result<RootCertStore> {
pc_roots::load_from_file(path)
}
pub fn load_roots_from_dir(base: Option<RootCertStore>, dir: &str) -> Result<RootCertStore> {
let mut roots = match base {
Some(r) => r,
None => load_system_roots()?,
};
pc_roots::add_from_dir(&mut roots, dir)?;
Ok(roots)
}
pub struct TlsStream<S: Read + Write> {
conn: Connection,
sock: S,
plaintext: Vec<u8>,
pending_wire: Vec<u8>,
seen_eof: bool,
}
pub fn connect_over<S: Read + Write>(transport: S, sni: &str) -> Result<TlsStream<S>> {
connect_over_tls(transport, sni, TlsOpts::verifying())
}
pub fn connect_over_with_alpn<S: Read + Write>(
transport: S,
sni: &str,
alpn: &[&[u8]],
) -> Result<TlsStream<S>> {
let mut opts = TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
connect_over_tls(transport, sni, opts)
}
pub(crate) fn build_client_conn(sni: &str, opts: &mut TlsOpts) -> Result<Connection> {
let roots = match opts.roots.take() {
Some(r) => r,
None => load_system_roots()?,
};
let effective_verify = opts.verify && opts.verify_callback.is_none();
let mut builder = Config::builder()
.tls_only()
.roots(roots)
.server_name(sni.to_string())
.verify_certificates(effective_verify)
.rng(std::sync::Arc::new(purecrypto::rng::OsRng));
if !opts.alpn.is_empty() {
builder = builder.alpn(std::mem::take(&mut opts.alpn));
}
if let Some(v) = opts.min_version {
builder = builder.min_version(to_pc_version(v));
}
if let Some(v) = opts.max_version {
builder = builder.max_version(to_pc_version(v));
}
if !opts.cipher_suites.is_empty() {
builder = builder.cipher_suites(&opts.cipher_suites);
}
if let Some(cert_bytes) = &opts.client_cert {
let (chain, key) = build_identity(
cert_bytes,
opts.client_key.as_deref(),
opts.client_key_pass.as_deref(),
opts.cert_is_der,
opts.key_is_der,
)?;
builder = builder.identity(chain, key);
}
if let Some(crl_bytes) = &opts.crl_pem {
let mut store = CrlStore::new();
let blocks = std::str::from_utf8(crl_bytes)
.ok()
.map(crl_pem_blocks)
.unwrap_or_default();
if !blocks.is_empty() {
for block in &blocks {
store
.add_pem(block)
.map_err(|_| Error::BadResponse("--crlfile: invalid PEM CRL block".into()))?;
}
} else {
store
.add_der(crl_bytes.clone())
.map_err(|_| Error::BadResponse("--crlfile: not a valid PEM or DER CRL".into()))?;
}
builder = builder.crls(store);
}
let cfg = builder.build();
Connection::client(&cfg).map_err(tls_err)
}
pub fn connect_over_tls<S: Read + Write>(
transport: S,
sni: &str,
mut opts: TlsOpts,
) -> Result<TlsStream<S>> {
let effective_verify = opts.verify && opts.verify_callback.is_none();
let conn = build_client_conn(sni, &mut opts)?;
let mut s = TlsStream {
conn,
sock: transport,
plaintext: Vec::new(),
pending_wire: Vec::new(),
seen_eof: false,
};
s.run_handshake()?;
if let Some(cb) = &opts.verify_callback {
let chain = s.peer_certificates().to_vec();
let verdict = cb.call(&super::common::CertVerify {
server_name: sni,
chain_der: &chain,
});
if verdict == super::common::CertVerdict::Reject {
return Err(Error::BadResponse(
"server certificate rejected by verify callback".into(),
));
}
return Ok(s);
}
if effective_verify {
if let Some(der) = s.peer_certificates().first() {
if !client_auth::leaf_has_san(der) {
return Err(Error::BadResponse(
"server certificate has no Subject Alternative Name \
(CN fallback is not accepted)"
.into(),
));
}
}
}
if !opts.pinned_spki_sha256.is_empty() {
let leaf = s.peer_certificates().first().map(Vec::as_slice);
match leaf {
Some(der) if client_auth::spki_pin_matches(der, &opts.pinned_spki_sha256) => {}
_ => {
return Err(Error::BadResponse(
"pinned public key does not match server certificate".into(),
))
}
}
}
Ok(s)
}
fn crl_pem_blocks(pem: &str) -> Vec<String> {
pc_roots::pem_blocks_labelled(pem, "X509 CRL")
}
fn build_identity(
cert_bytes: &[u8],
key_bytes: Option<&[u8]>,
pass: Option<&str>,
cert_is_der: bool,
key_is_der: bool,
) -> Result<(Vec<Vec<u8>>, purecrypto::tls::SigningKey)> {
let chain = if cert_is_der {
client_auth::load_cert_chain_der(cert_bytes)?
} else {
let pem = std::str::from_utf8(cert_bytes)
.map_err(|_| Error::BadResponse("client cert: PEM file is not valid UTF-8".into()))?;
client_auth::load_cert_chain(pem)?
};
let key = match key_bytes {
Some(kb) if key_is_der => client_auth::parse_signing_key_der(kb, pass)?,
Some(kb) => {
let pem = std::str::from_utf8(kb).map_err(|_| {
Error::BadResponse("client key: PEM file is not valid UTF-8".into())
})?;
client_auth::parse_signing_key(pem, pass)?
}
None if cert_is_der => {
return Err(Error::BadResponse(
"client cert: a DER cert has no embedded key; pass --key".into(),
))
}
None => {
let pem = std::str::from_utf8(cert_bytes).map_err(|_| {
Error::BadResponse("client cert: PEM file is not valid UTF-8".into())
})?;
client_auth::parse_signing_key(pem, pass)?
}
};
Ok((chain, key))
}
impl<S: Read + Write> TlsStream<S> {
pub fn negotiated_version(&self) -> Option<ProtocolVersion> {
self.conn.negotiated_version().map(map_pc_version)
}
pub fn alpn_selected(&self) -> Option<&[u8]> {
self.conn.alpn_selected()
}
pub fn negotiated_cipher_suite(&self) -> Option<u16> {
self.conn.negotiated_cipher_suite()
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
self.conn.peer_certificates()
}
pub fn was_truncated(&self) -> bool {
self.seen_eof && !self.conn.received_close_notify()
}
fn run_handshake(&mut self) -> Result<()> {
let mut buf = [0u8; READ_CHUNK];
loop {
self.drain_outgoing().map_err(Error::Io)?;
match self.conn.handshake().map_err(tls_err)? {
HandshakeStatus::Complete => return Ok(()),
HandshakeStatus::WantWrite => continue,
HandshakeStatus::WantRead => {
let n = self.sock.read(&mut buf)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
self.feed_all(&buf[..n]).map_err(Error::Io)?;
}
}
}
}
fn drain_outgoing(&mut self) -> io::Result<()> {
pc_drain_outgoing(&mut self.conn, &mut self.sock)
}
fn feed_all(&mut self, wire: &[u8]) -> io::Result<()> {
pc_feed_all(&mut self.conn, &mut self.pending_wire, wire)
}
}
fn pc_drain_outgoing(conn: &mut Connection, sock: &mut dyn Write) -> io::Result<()> {
loop {
let out = conn.pop().map_err(io_tls)?;
if out.is_empty() {
return Ok(());
}
sock.write_all(&out)?;
}
}
fn pc_feed_all(conn: &mut Connection, pending: &mut Vec<u8>, wire: &[u8]) -> io::Result<()> {
if !pending.is_empty() {
pending.extend_from_slice(wire);
let mut taken = 0;
while taken < pending.len() {
let n = conn.feed(&pending[taken..]).map_err(io_tls)?;
if n == 0 {
break;
}
taken += n;
}
pending.drain(..taken);
return Ok(());
}
let mut taken = 0;
while taken < wire.len() {
let n = conn.feed(&wire[taken..]).map_err(io_tls)?;
if n == 0 {
pending.extend_from_slice(&wire[taken..]);
return Ok(());
}
taken += n;
}
Ok(())
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.conn.send(data).map_err(io_tls)?;
self.drain_outgoing()?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.drain_outgoing()?;
self.sock.flush()
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
let mut buf = [0u8; READ_CHUNK];
while self.plaintext.is_empty() {
if self.seen_eof {
return Ok(0);
}
let app = self.conn.recv().map_err(io_tls)?;
if !app.is_empty() {
self.plaintext = app;
break;
}
let n = self.sock.read(&mut buf)?;
if n == 0 {
self.seen_eof = true;
let app = self.conn.recv().map_err(io_tls)?;
if app.is_empty() {
return Ok(0);
}
self.plaintext = app;
break;
}
self.feed_all(&buf[..n])?;
self.drain_outgoing()?;
}
let take = dst.len().min(self.plaintext.len());
dst[..take].copy_from_slice(&self.plaintext[..take]);
self.plaintext.drain(..take);
Ok(take)
}
}
pub struct TlsConn {
engine: std::sync::Mutex<PcEngine>,
read_sock: std::sync::Mutex<Box<dyn crate::net::NetStream>>,
}
struct PcEngine {
conn: Connection,
write_sock: Box<dyn crate::net::NetStream>,
pending_wire: Vec<u8>,
plaintext: Vec<u8>,
seen_eof: bool,
}
impl TlsStream<Box<dyn crate::net::NetStream>> {
pub fn into_concurrent(self, read_sock: Box<dyn crate::net::NetStream>) -> TlsConn {
TlsConn {
engine: std::sync::Mutex::new(PcEngine {
conn: self.conn,
write_sock: self.sock,
pending_wire: self.pending_wire,
plaintext: self.plaintext,
seen_eof: self.seen_eof,
}),
read_sock: std::sync::Mutex::new(read_sock),
}
}
}
impl TlsConn {
pub fn read(&self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
loop {
{
let mut e = self.engine.lock().unwrap();
if let Some(n) = e.take_plaintext(dst) {
return Ok(n);
}
let app = e.conn.recv().map_err(io_tls)?;
if !app.is_empty() {
e.plaintext = app;
return Ok(e.take_plaintext(dst).unwrap());
}
if e.seen_eof {
return Ok(0);
}
}
let mut buf = [0u8; READ_CHUNK];
let n = self.read_sock.lock().unwrap().read(&mut buf)?;
let mut e = self.engine.lock().unwrap();
if n == 0 {
e.seen_eof = true;
let app = e.conn.recv().map_err(io_tls)?;
if app.is_empty() {
return Ok(0);
}
e.plaintext = app;
return Ok(e.take_plaintext(dst).unwrap());
}
let PcEngine {
conn,
write_sock,
pending_wire,
..
} = &mut *e;
pc_feed_all(conn, pending_wire, &buf[..n])?;
pc_drain_outgoing(conn, write_sock)?;
}
}
pub fn write(&self, data: &[u8]) -> io::Result<()> {
let mut e = self.engine.lock().unwrap();
e.conn.send(data).map_err(io_tls)?;
let PcEngine {
conn, write_sock, ..
} = &mut *e;
pc_drain_outgoing(conn, write_sock)
}
pub fn flush(&self) -> io::Result<()> {
let mut e = self.engine.lock().unwrap();
let PcEngine {
conn, write_sock, ..
} = &mut *e;
pc_drain_outgoing(conn, write_sock)?;
write_sock.flush()
}
pub fn set_read_timeout(&self, dur: Option<std::time::Duration>) -> io::Result<()> {
self.read_sock.lock().unwrap().set_read_timeout(dur)
}
}
impl PcEngine {
fn take_plaintext(&mut self, dst: &mut [u8]) -> Option<usize> {
if self.plaintext.is_empty() {
return None;
}
let n = dst.len().min(self.plaintext.len());
dst[..n].copy_from_slice(&self.plaintext[..n]);
self.plaintext.drain(..n);
Some(n)
}
}
fn map_pc_version(v: purecrypto::tls::ProtocolVersion) -> ProtocolVersion {
use purecrypto::tls::ProtocolVersion as P;
match v {
P::TLSv1_2 => ProtocolVersion::TLSv1_2,
P::TLSv1_3 => ProtocolVersion::TLSv1_3,
other => ProtocolVersion::Other(other.as_u16()),
}
}
fn tls_err(e: purecrypto::tls::Error) -> Error {
Error::BadResponse(format!("tls: {e:?}"))
}
fn io_tls(e: purecrypto::tls::Error) -> io::Error {
io::Error::other(format!("tls: {e:?}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_enables_verification() {
assert!(TlsOpts::default().verify);
assert!(TlsOpts::verifying().verify);
}
#[test]
fn crl_pem_blocks_splits_concatenated_crls() {
let pem = "-----BEGIN X509 CRL-----\nAAA\n-----END X509 CRL-----\n\
noise between blocks\n\
-----BEGIN X509 CRL-----\nBBB\n-----END X509 CRL-----\n";
let blocks = crl_pem_blocks(pem);
assert_eq!(blocks.len(), 2);
assert!(blocks[0].contains("AAA"));
assert!(blocks[1].contains("BBB"));
}
#[test]
fn crl_pem_blocks_single_unchanged() {
let pem = "-----BEGIN X509 CRL-----\nAAA\n-----END X509 CRL-----\n";
let blocks = crl_pem_blocks(pem);
assert_eq!(blocks.len(), 1);
assert!(blocks[0].contains("AAA"));
}
#[test]
fn crl_pem_blocks_empty_for_der() {
assert!(crl_pem_blocks("not a pem armored crl at all").is_empty());
}
}