use std::io::{self, Read, Write};
use purecrypto::tls::{Config, Connection, HandshakeStatus};
use super::common::ProtocolVersion;
use super::pc_roots;
use crate::error::{Error, Result};
pub use purecrypto::tls::RootCertStore;
const READ_CHUNK: usize = 16 * 1024;
#[derive(Default, Clone)]
pub struct TlsOpts {
pub alpn: Vec<Vec<u8>>,
pub verify: bool,
pub roots: Option<RootCertStore>,
}
impl TlsOpts {
pub fn verifying() -> Self {
TlsOpts {
alpn: Vec::new(),
verify: true,
roots: None,
}
}
}
pub fn load_system_roots() -> Result<RootCertStore> {
pc_roots::load_system_roots()
}
pub fn load_roots_from_file(path: &str) -> Result<RootCertStore> {
pc_roots::load_from_file(path)
}
pub struct TlsStream<S: Read + Write> {
conn: Connection,
sock: S,
plaintext: Vec<u8>,
pending_wire: Vec<u8>,
seen_eof: bool,
}
pub fn connect_over<S: Read + Write>(transport: S, sni: &str) -> Result<TlsStream<S>> {
connect_over_tls(transport, sni, TlsOpts::verifying())
}
pub fn connect_over_with_alpn<S: Read + Write>(
transport: S,
sni: &str,
alpn: &[&[u8]],
) -> Result<TlsStream<S>> {
let mut opts = TlsOpts::verifying();
opts.alpn = alpn.iter().map(|p| p.to_vec()).collect();
connect_over_tls(transport, sni, opts)
}
pub fn connect_over_tls<S: Read + Write>(
transport: S,
sni: &str,
opts: TlsOpts,
) -> Result<TlsStream<S>> {
let roots = match opts.roots {
Some(r) => r,
None => load_system_roots()?,
};
let mut builder = Config::builder()
.tls_only()
.roots(roots)
.server_name(sni.to_string())
.verify_certificates(opts.verify);
if !opts.alpn.is_empty() {
builder = builder.alpn(opts.alpn);
}
let cfg = builder.build();
let conn = Connection::client(&cfg).map_err(tls_err)?;
let mut s = TlsStream {
conn,
sock: transport,
plaintext: Vec::new(),
pending_wire: Vec::new(),
seen_eof: false,
};
s.run_handshake()?;
Ok(s)
}
impl<S: Read + Write> TlsStream<S> {
pub fn negotiated_version(&self) -> Option<ProtocolVersion> {
self.conn.negotiated_version().map(map_pc_version)
}
pub fn alpn_selected(&self) -> Option<&[u8]> {
self.conn.alpn_selected()
}
pub fn peer_certificates(&self) -> &[Vec<u8>] {
self.conn.peer_certificates()
}
fn run_handshake(&mut self) -> Result<()> {
let mut buf = [0u8; READ_CHUNK];
loop {
self.drain_outgoing().map_err(Error::Io)?;
match self.conn.handshake().map_err(tls_err)? {
HandshakeStatus::Complete => return Ok(()),
HandshakeStatus::WantWrite => continue,
HandshakeStatus::WantRead => {
let n = self.sock.read(&mut buf)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
self.feed_all(&buf[..n]).map_err(Error::Io)?;
}
}
}
}
fn drain_outgoing(&mut self) -> io::Result<()> {
loop {
let out = self.conn.pop().map_err(io_tls)?;
if out.is_empty() {
return Ok(());
}
self.sock.write_all(&out)?;
}
}
fn feed_all(&mut self, wire: &[u8]) -> io::Result<()> {
if !self.pending_wire.is_empty() {
self.pending_wire.extend_from_slice(wire);
let mut taken = 0;
while taken < self.pending_wire.len() {
let n = self
.conn
.feed(&self.pending_wire[taken..])
.map_err(io_tls)?;
if n == 0 {
break;
}
taken += n;
}
self.pending_wire.drain(..taken);
return Ok(());
}
let mut taken = 0;
while taken < wire.len() {
let n = self.conn.feed(&wire[taken..]).map_err(io_tls)?;
if n == 0 {
self.pending_wire.extend_from_slice(&wire[taken..]);
return Ok(());
}
taken += n;
}
Ok(())
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
self.conn.send(data).map_err(io_tls)?;
self.drain_outgoing()?;
Ok(data.len())
}
fn flush(&mut self) -> io::Result<()> {
self.drain_outgoing()?;
self.sock.flush()
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if dst.is_empty() {
return Ok(0);
}
let mut buf = [0u8; READ_CHUNK];
while self.plaintext.is_empty() {
if self.seen_eof {
return Ok(0);
}
let app = self.conn.recv().map_err(io_tls)?;
if !app.is_empty() {
self.plaintext = app;
break;
}
let n = self.sock.read(&mut buf)?;
if n == 0 {
self.seen_eof = true;
let app = self.conn.recv().map_err(io_tls)?;
if app.is_empty() {
return Ok(0);
}
self.plaintext = app;
break;
}
self.feed_all(&buf[..n])?;
self.drain_outgoing()?;
}
let take = dst.len().min(self.plaintext.len());
dst[..take].copy_from_slice(&self.plaintext[..take]);
self.plaintext.drain(..take);
Ok(take)
}
}
fn map_pc_version(v: purecrypto::tls::ProtocolVersion) -> ProtocolVersion {
use purecrypto::tls::ProtocolVersion as P;
match v {
P::TLSv1_2 => ProtocolVersion::TLSv1_2,
P::TLSv1_3 => ProtocolVersion::TLSv1_3,
other => ProtocolVersion::Other(other.as_u16()),
}
}
fn tls_err(e: purecrypto::tls::Error) -> Error {
Error::BadResponse(format!("tls: {e:?}"))
}
fn io_tls(e: purecrypto::tls::Error) -> io::Error {
io::Error::other(format!("tls: {e:?}"))
}