use std::io;
use std::os::fd::AsRawFd as _;
use std::os::raw::c_int;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, Interest, ReadBuf};
use tokio::net::TcpStream;
use crate::error::{last_error, Error, KtlsError, Result};
use crate::ffi::Ssl;
use crate::session::{export_keying_material, KtlsEligibility, NegotiatedSession};
pub struct TlsStream {
ssl: Ssl,
tcp: TcpStream,
ktls_active: bool,
ktls_disabled: bool,
}
impl std::fmt::Debug for TlsStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsStream")
.field("fd", &self.tcp.as_raw_fd())
.field("ktls_active", &self.ktls_active)
.finish_non_exhaustive()
}
}
unsafe impl Send for TlsStream {}
impl TlsStream {
pub(crate) fn from_parts(ssl: Ssl, tcp: TcpStream, ktls_disabled: bool) -> Self {
Self {
ssl,
tcp,
ktls_active: false,
ktls_disabled,
}
}
#[must_use]
pub fn negotiated(&self) -> NegotiatedSession {
unsafe { NegotiatedSession::from_ssl(self.ssl.as_ptr()) }
}
#[must_use]
pub fn ktls_eligibility(&self) -> KtlsEligibility {
unsafe { KtlsEligibility::from_ssl(&self.ssl) }
}
#[must_use]
pub fn ktls_active(&self) -> bool {
self.ktls_active
}
#[must_use]
pub fn ktls_disabled(&self) -> bool {
self.ktls_disabled
}
pub(crate) fn try_auto_install_ktls(&mut self) -> Result<()> {
if self.ktls_active || self.ktls_disabled {
return Ok(());
}
crate::ktls::check_no_buffered_plaintext(&self.ssl)?;
let raw = self.tcp.as_raw_fd();
match crate::ktls::install_ktls(&self.ssl, raw) {
Ok(()) => {
self.ktls_active = true;
Ok(())
}
Err(Error::Ktls(
KtlsError::Unsupported
| KtlsError::IneligibleCipher { .. }
| KtlsError::TlsUlpUnavailable(_)
| KtlsError::SocketUnattachable(_),
)) => Ok(()),
Err(e) => Err(e),
}
}
pub fn export_keying_material(
&self,
out: &mut [u8],
label: &[u8],
context: Option<&[u8]>,
) -> Result<()> {
unsafe { export_keying_material(self.ssl.as_ptr(), out, label, context) }
}
#[must_use]
pub fn has_peer_certificate(&self) -> bool {
let cert = unsafe { aws_lc_sys::SSL_get_peer_certificate(self.ssl.as_ptr()) };
if cert.is_null() {
false
} else {
unsafe {
aws_lc_sys::X509_free(cert);
}
true
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = &mut *self;
if this.ktls_active {
return ktls_poll_read(&this.tcp, cx, buf);
}
loop {
let unfilled = buf.initialize_unfilled();
if unfilled.is_empty() {
return Poll::Ready(Ok(()));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let cap = unfilled.len().min(c_int::MAX as usize) as c_int;
let n = unsafe {
aws_lc_sys::SSL_read(this.ssl.as_ptr(), unfilled.as_mut_ptr().cast(), cap)
};
if n > 0 {
#[allow(clippy::cast_sign_loss)]
buf.advance(n as usize);
return Poll::Ready(Ok(()));
}
let err = unsafe { aws_lc_sys::SSL_get_error(this.ssl.as_ptr(), n) };
match err {
aws_lc_sys::SSL_ERROR_ZERO_RETURN => return Poll::Ready(Ok(())),
aws_lc_sys::SSL_ERROR_WANT_READ => match poll_clear_read_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
aws_lc_sys::SSL_ERROR_WANT_WRITE => match poll_clear_write_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
_ => return Poll::Ready(Err(ssl_io_error("SSL_read", err))),
}
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = &mut *self;
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if this.ktls_active {
return ktls_poll_write(&this.tcp, cx, buf);
}
loop {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let cap = buf.len().min(c_int::MAX as usize) as c_int;
let n = unsafe { aws_lc_sys::SSL_write(this.ssl.as_ptr(), buf.as_ptr().cast(), cap) };
if n > 0 {
#[allow(clippy::cast_sign_loss)]
return Poll::Ready(Ok(n as usize));
}
let err = unsafe { aws_lc_sys::SSL_get_error(this.ssl.as_ptr(), n) };
match err {
aws_lc_sys::SSL_ERROR_WANT_READ => match poll_clear_read_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
aws_lc_sys::SSL_ERROR_WANT_WRITE => match poll_clear_write_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
aws_lc_sys::SSL_ERROR_ZERO_RETURN => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"peer closed TLS session",
)))
}
_ => return Poll::Ready(Err(ssl_io_error("SSL_write", err))),
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = &mut *self;
if this.ktls_active {
return Pin::new(&mut this.tcp).poll_shutdown(cx);
}
loop {
let r = unsafe { aws_lc_sys::SSL_shutdown(this.ssl.as_ptr()) };
if r >= 0 {
return Poll::Ready(Ok(()));
}
let err = unsafe { aws_lc_sys::SSL_get_error(this.ssl.as_ptr(), r) };
match err {
aws_lc_sys::SSL_ERROR_WANT_READ => match poll_clear_read_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
aws_lc_sys::SSL_ERROR_WANT_WRITE => match poll_clear_write_ready(&this.tcp, cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
_ => return Poll::Ready(Err(ssl_io_error("SSL_shutdown", err))),
}
}
}
}
fn ktls_poll_read(
tcp: &TcpStream,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
let unfilled = buf.initialize_unfilled();
if unfilled.is_empty() {
return Poll::Ready(Ok(()));
}
match tcp.try_read(unfilled) {
Ok(0) => return Poll::Ready(Ok(())), Ok(n) => {
buf.advance(n);
return Poll::Ready(Ok(()));
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => match tcp.poll_read_ready(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
fn ktls_poll_write(tcp: &TcpStream, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
loop {
match tcp.try_write(buf) {
Ok(n) => return Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => match tcp.poll_write_ready(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
},
Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
fn poll_clear_read_ready(tcp: &TcpStream, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match tcp.poll_read_ready(cx) {
Poll::Ready(Ok(())) => {
let _: io::Result<()> =
tcp.try_io(Interest::READABLE, || Err(io::ErrorKind::WouldBlock.into()));
Poll::Ready(Ok(()))
}
other => other,
}
}
fn poll_clear_write_ready(tcp: &TcpStream, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match tcp.poll_write_ready(cx) {
Poll::Ready(Ok(())) => {
let _: io::Result<()> =
tcp.try_io(Interest::WRITABLE, || Err(io::ErrorKind::WouldBlock.into()));
Poll::Ready(Ok(()))
}
other => other,
}
}
fn ssl_io_error(op: &'static str, code: c_int) -> io::Error {
let detail = last_error();
io::Error::other(format!("{op}: ssl_error={code} {detail}"))
}
pub(crate) unsafe fn attach_socket_bio(ssl: &mut Ssl, fd: c_int) -> Result<()> {
let bio = unsafe { aws_lc_sys::BIO_new_socket(fd, aws_lc_sys::BIO_NOCLOSE) };
if bio.is_null() {
return Err(Error::Init(format!("BIO_new_socket: {}", last_error())));
}
unsafe {
aws_lc_sys::SSL_set_bio(ssl.as_ptr(), bio, bio);
}
Ok(())
}
pub(crate) async unsafe fn drive_handshake(ssl: &mut Ssl, tcp: &TcpStream) -> Result<()> {
loop {
let r = unsafe { aws_lc_sys::SSL_do_handshake(ssl.as_ptr()) };
if r == 1 {
return Ok(());
}
let err = unsafe { aws_lc_sys::SSL_get_error(ssl.as_ptr(), r) };
match err {
aws_lc_sys::SSL_ERROR_WANT_READ => {
tcp.readable()
.await
.map_err(|e| Error::Handshake(format!("waiting readable: {e}")))?;
let _: io::Result<()> =
tcp.try_io(Interest::READABLE, || Err(io::ErrorKind::WouldBlock.into()));
}
aws_lc_sys::SSL_ERROR_WANT_WRITE => {
tcp.writable()
.await
.map_err(|e| Error::Handshake(format!("waiting writable: {e}")))?;
let _: io::Result<()> =
tcp.try_io(Interest::WRITABLE, || Err(io::ErrorKind::WouldBlock.into()));
}
aws_lc_sys::SSL_ERROR_ZERO_RETURN => {
return Err(Error::Handshake(
"peer closed connection during handshake".into(),
))
}
_ => {
return Err(Error::Handshake(format!(
"SSL_do_handshake: ssl_error={err} {}",
last_error()
)))
}
}
}
}