use openssl::error::ErrorStack;
use openssl::ssl::{self, SslStream};
use std::io::{self, Read, Write};
use std::os::fd::AsRawFd;
use std::time::Duration;
use tarantool::coio::CoIOStream;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TlsHandshakeError {
#[error("setup failure: {0}")]
SetupFailure(ErrorStack),
#[error("handshake error: {0}")]
Failure(ssl::Error),
}
#[derive(Error, Debug)]
pub enum PicoStreamError {
#[error("configuration error: {0}")]
Config(String),
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("tls error: {0}")]
Tls(#[from] TlsHandshakeError),
}
pub struct PicoStream {
inner: PicoStreamImpl,
}
enum PicoStreamImpl {
Plain(CoIOStream),
Tls(SslStream<CoIOStream>),
}
impl PicoStream {
pub fn plain(stream: CoIOStream) -> Self {
Self {
inner: PicoStreamImpl::Plain(stream),
}
}
pub fn tls(stream: SslStream<CoIOStream>) -> Self {
Self {
inner: PicoStreamImpl::Tls(stream),
}
}
pub fn is_tls(&self) -> bool {
matches!(self.inner, PicoStreamImpl::Tls(_))
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
use std::os::fd::BorrowedFd;
let fd = self.as_inner().as_raw_fd();
let fd = unsafe { BorrowedFd::borrow_raw(fd) };
socket2::SockRef::from(&fd).set_nodelay(nodelay)
}
fn as_inner(&self) -> &CoIOStream {
match &self.inner {
PicoStreamImpl::Plain(s) => s,
PicoStreamImpl::Tls(s) => s.get_ref(),
}
}
pub fn read_with_timeout(
&mut self,
buf: &mut [u8],
timeout: Option<Duration>,
) -> io::Result<usize> {
match &mut self.inner {
PicoStreamImpl::Plain(s) => s.read_with_timeout(buf, timeout),
PicoStreamImpl::Tls(s) => {
if let Some(timeout_duration) = timeout {
read_tls_with_timeout(s, buf, timeout_duration)
} else {
s.read(buf)
}
}
}
}
}
fn read_tls_with_timeout(
ssl_stream: &mut SslStream<CoIOStream>,
buf: &mut [u8],
timeout: Duration,
) -> io::Result<usize> {
use tarantool::ffi::tarantool as ffi;
match ssl_stream.read(buf) {
Ok(n) => return Ok(n),
Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
_ => {}
}
let fd = ssl_stream.get_ref().as_raw_fd();
let timeout_secs = timeout.as_secs_f64();
match unsafe { ffi::coio_wait(fd, ffi::CoIOFlags::READ.bits(), timeout_secs) } {
0 => Err(io::Error::new(io::ErrorKind::TimedOut, "read timeout")),
_ => ssl_stream.read(buf),
}
}
impl Read for PicoStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match &mut self.inner {
PicoStreamImpl::Plain(s) => s.read(buf),
PicoStreamImpl::Tls(s) => s.read(buf),
}
}
}
impl Write for PicoStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.inner {
PicoStreamImpl::Plain(s) => s.write(buf),
PicoStreamImpl::Tls(s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.inner {
PicoStreamImpl::Plain(s) => s.flush(),
PicoStreamImpl::Tls(s) => s.flush(),
}
}
}