use std::io::{self, Read, Write};
use purecrypto::tls::{Config, Connection, CrlStore, HandshakeStatus};
use super::common::ProtocolVersion;
use super::{client_auth, pc_roots};
use crate::error::{Error, Result};
pub use purecrypto::tls::RootCertStore;
const READ_CHUNK: usize = 16 * 1024;
#[derive(Default, Clone)]
pub struct TlsOpts {
pub alpn: Vec<Vec<u8>>,
pub verify: bool,
pub roots: Option<RootCertStore>,
pub min_version: Option<ProtocolVersion>,
pub max_version: Option<ProtocolVersion>,
pub client_cert: Option<Vec<u8>>,
pub client_key: Option<Vec<u8>>,
pub client_key_pass: Option<String>,
pub cert_is_der: bool,
pub key_is_der: bool,
pub pinned_spki_sha256: Vec<[u8; 32]>,
pub crl_pem: Option<Vec<u8>>,
pub cipher_suites: Vec<u16>,
}
impl TlsOpts {
pub fn verifying() -> Self {
TlsOpts {
alpn: Vec::new(),
verify: true,
roots: None,
min_version: None,
max_version: None,
client_cert: None,
client_key: None,
client_key_pass: None,
cert_is_der: false,
key_is_der: false,
pinned_spki_sha256: Vec::new(),
crl_pem: None,
cipher_suites: Vec::new(),
}
}
}
fn to_pc_version(v: ProtocolVersion) -> purecrypto::tls::ProtocolVersion {
match v {
ProtocolVersion::TLSv1_3 => purecrypto::tls::ProtocolVersion::TLSv1_3,
_ => purecrypto::tls::ProtocolVersion::TLSv1_2,
}
}
pub fn load_system_roots() -> Result<RootCertStore> {
pc_roots::load_system_roots()
}
pub fn load_roots_from_file(path: &str) -> Result<RootCertStore> {
pc_roots::load_from_file(path)
}
pub fn load_roots_from_dir(base: Option<RootCertStore>, dir: &str) -> Result<RootCertStore> {
let mut roots = match base {
Some(r) => r,
None => load_system_roots()?,
};
pc_roots::add_from_dir(&mut roots, dir)?;
Ok(roots)
}
pub struct TlsStream<S: Read + Write> {
conn: Connection,
sock: S,
plaintext: Vec<u8>,
pending_wire: Vec<u8>,
seen_eof: bool,
}
pub fn connect_over<S: Read + Write>(transport: S, sni: &str) -> Result<TlsStream<S>> {
connect_over_tls(transport, sni, TlsOpts::verifying())
}
pub fn connect_over_with_alpn<S: Read + Write>(
transport: S,
sni: &str,
alpn: &[&[u8]],
) -> Result<TlsStream<S>> {
let mut opts = TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
connect_over_tls(transport, sni, opts)
}
pub fn connect_over_tls<S: Read + Write>(
transport: S,
sni: &str,
opts: TlsOpts,
) -> Result<TlsStream<S>> {
let roots = match opts.roots {
Some(r) => r,
None => load_system_roots()?,
};
let mut builder = Config::builder()
.tls_only()
.roots(roots)
.server_name(sni.to_string())
.verify_certificates(opts.verify);
if !opts.alpn.is_empty() {
builder = builder.alpn(opts.alpn);
}
if let Some(v) = opts.min_version {
builder = builder.min_version(to_pc_version(v));
}
if let Some(v) = opts.max_version {
builder = builder.max_version(to_pc_version(v));
}
if !opts.cipher_suites.is_empty() {
builder = builder.cipher_suites(&opts.cipher_suites);
}
if let Some(cert_bytes) = &opts.client_cert {
let (chain, key) = build_identity(
cert_bytes,
opts.client_key.as_deref(),
opts.client_key_pass.as_deref(),
opts.cert_is_der,
opts.key_is_der,
)?;
builder = builder.identity(chain, key);
}
if let Some(crl_bytes) = &opts.crl_pem {
let mut store = CrlStore::new();
let pem_ok = std::str::from_utf8(crl_bytes)
.ok()
.map(|s| store.add_pem(s).is_ok())
.unwrap_or(false);
if !pem_ok {
store
.add_der(crl_bytes.clone())
.map_err(|_| Error::BadResponse("--crlfile: not a valid PEM or DER CRL".into()))?;
}
builder = builder.crls(store);
}
let cfg = builder.build();
let conn = Connection::client(&cfg).map_err(tls_err)?;
let mut s = TlsStream {
conn,
sock: transport,
plaintext: Vec::new(),
pending_wire: Vec::new(),
seen_eof: false,
};
s.run_handshake()?;
if !opts.pinned_spki_sha256.is_empty() {
let leaf = s.peer_certificates().first().map(Vec::as_slice);
match leaf {
Some(der) if client_auth::spki_pin_matches(der, &opts.pinned_spki_sha256) => {}
_ => {
return Err(Error::BadResponse(
"pinned public key does not match server certificate".into(),
))
}
}
}
Ok(s)
}
fn build_identity(
cert_bytes: &[u8],
key_bytes: Option<&[u8]>,
pass: Option<&str>,
cert_is_der: bool,
key_is_der: bool,
) -> Result<(Vec<Vec<u8>>, purecrypto::tls::SigningKey)> {
let chain = if cert_is_der {
client_auth::load_cert_chain_der(cert_bytes)?
} else {
let pem = std::str::from_utf8(cert_bytes)
.map_err(|_| Error::BadResponse("client cert: PEM file is not valid UTF-8".into()))?;
client_auth::load_cert_chain(pem)?
};
let key = match key_bytes {
Some(kb) if key_is_der => client_auth::parse_signing_key_der(kb, pass)?,
Some(kb) => {
let pem = std::str::from_utf8(kb).map_err(|_| {
Error::BadResponse("client key: PEM file is not valid UTF-8".into())
})?;
client_auth::parse_signing_key(pem, pass)?
}
None if cert_is_der => {
return Err(Error::BadResponse(
"client cert: a DER cert has no embedded key; pass --key".into(),
))
}
None => {
let pem = std::str::from_utf8(cert_bytes).map_err(|_| {
Error::BadResponse("client cert: PEM file is not valid UTF-8".into())
})?;
client_auth::parse_signing_key(pem, pass)?
}
};
Ok((chain, key))
}
impl<S: Read + Write> TlsStream<S> {
pub fn negotiated_version(&self) -> Option<ProtocolVersion> {
self.conn.negotiated_version().map(map_pc_version)
}
pub fn alpn_selected(&self) -> Option<&[u8]> {
self.conn.alpn_selected()
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
self.conn.peer_certificates()
}
fn run_handshake(&mut self) -> Result<()> {
let mut buf = [0u8; READ_CHUNK];
loop {
self.drain_outgoing().map_err(Error::Io)?;
match self.conn.handshake().map_err(tls_err)? {
HandshakeStatus::Complete => return Ok(()),
HandshakeStatus::WantWrite => continue,
HandshakeStatus::WantRead => {
let n = self.sock.read(&mut buf)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
self.feed_all(&buf[..n]).map_err(Error::Io)?;
}
}
}
}
fn drain_outgoing(&mut self) -> io::Result<()> {
loop {
let out = self.conn.pop().map_err(io_tls)?;
if out.is_empty() {
return Ok(());
}
self.sock.write_all(&out)?;
}
}
fn feed_all(&mut self, wire: &[u8]) -> io::Result<()> {
if !self.pending_wire.is_empty() {
self.pending_wire.extend_from_slice(wire);
let mut taken = 0;
while taken < self.pending_wire.len() {
let n = self
.conn
.feed(&self.pending_wire[taken..])
.map_err(io_tls)?;
if n == 0 {
break;
}
taken += n;
}
self.pending_wire.drain(..taken);
return Ok(());
}
let mut taken = 0;
while taken < wire.len() {
let n = self.conn.feed(&wire[taken..]).map_err(io_tls)?;
if n == 0 {
self.pending_wire.extend_from_slice(&wire[taken..]);
return Ok(());
}
taken += n;
}
Ok(())
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.conn.send(data).map_err(io_tls)?;
self.drain_outgoing()?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.drain_outgoing()?;
self.sock.flush()
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
let mut buf = [0u8; READ_CHUNK];
while self.plaintext.is_empty() {
if self.seen_eof {
return Ok(0);
}
let app = self.conn.recv().map_err(io_tls)?;
if !app.is_empty() {
self.plaintext = app;
break;
}
let n = self.sock.read(&mut buf)?;
if n == 0 {
self.seen_eof = true;
let app = self.conn.recv().map_err(io_tls)?;
if app.is_empty() {
return Ok(0);
}
self.plaintext = app;
break;
}
self.feed_all(&buf[..n])?;
self.drain_outgoing()?;
}
let take = dst.len().min(self.plaintext.len());
dst[..take].copy_from_slice(&self.plaintext[..take]);
self.plaintext.drain(..take);
Ok(take)
}
}
fn map_pc_version(v: purecrypto::tls::ProtocolVersion) -> ProtocolVersion {
use purecrypto::tls::ProtocolVersion as P;
match v {
P::TLSv1_2 => ProtocolVersion::TLSv1_2,
P::TLSv1_3 => ProtocolVersion::TLSv1_3,
other => ProtocolVersion::Other(other.as_u16()),
}
}
fn tls_err(e: purecrypto::tls::Error) -> Error {
Error::BadResponse(format!("tls: {e:?}"))
}
fn io_tls(e: purecrypto::tls::Error) -> io::Error {
io::Error::other(format!("tls: {e:?}"))
}