use std::{
ffi::{CString, c_int},
fmt,
io::{self, Read, Write},
marker::PhantomData,
mem::ManuallyDrop,
panic::resume_unwind,
};
use crate::{
SSL_CTX,
bio::{self, BioMethod},
cvt, cvt_p,
error::ErrorStack,
sys as ffi,
};
use error::InnerError;
pub use error::{Error, ErrorCode, HandshakeError};
pub struct Ssl(*mut ffi::SSL);
impl Drop for Ssl {
fn drop(&mut self) {
let ossl = crate::get();
unsafe {
(ossl.SSL_free)(self.0);
}
}
}
impl Ssl {
pub fn new(ctx: *mut SSL_CTX) -> Result<Ssl, ErrorStack> {
let ossl = crate::get();
cvt_p(unsafe { (ossl.SSL_new)(ctx) }).map(Self)
}
pub fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
where
S: Read + Write,
{
let mut stream = SslStream::new(self, stream)?;
match stream.connect() {
Ok(()) => Ok(stream),
Err(error) => match error.code() {
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
stream,
error,
}))
}
_ => Err(HandshakeError::Failure(MidHandshakeSslStream {
stream,
error,
})),
},
}
}
pub fn accept<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
where
S: Read + Write,
{
let mut stream = SslStream::new(self, stream)?;
match stream.accept() {
Ok(()) => Ok(stream),
Err(error) => match error.code() {
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
stream,
error,
}))
}
_ => Err(HandshakeError::Failure(MidHandshakeSslStream {
stream,
error,
})),
},
}
}
fn get_raw_rbio(&self) -> *mut ffi::BIO {
let ffi = crate::get();
unsafe { (ffi.SSL_get_rbio)(self.0) }
}
fn get_error(&self, ret: c_int) -> ErrorCode {
let ffi = crate::get();
unsafe { ErrorCode::from_raw((ffi.SSL_get_error)(self.0, ret)) }
}
pub fn set_hostname(&mut self, hostname: &str) -> Result<(), ErrorStack> {
let ffi = crate::get();
let cstr = CString::new(hostname).unwrap();
unsafe {
cvt(ffi.SSL_set_tlsext_host_name(self.0, cstr.as_ptr() as *mut _) as c_int).map(|_| ())
}
}
}
pub struct MidHandshakeSslStream<S> {
stream: SslStream<S>,
error: Error,
}
impl<S> MidHandshakeSslStream<S> {
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
pub fn into_error(self) -> Error {
self.error
}
}
impl<S> MidHandshakeSslStream<S>
where
S: Read + Write,
{
pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
match self.stream.do_handshake() {
Ok(()) => Ok(self.stream),
Err(error) => {
self.error = error;
match self.error.code() {
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
Err(HandshakeError::WouldBlock(self))
}
_ => Err(HandshakeError::Failure(self)),
}
}
}
}
}
pub struct SslStream<S> {
ssl: ManuallyDrop<Ssl>,
method: ManuallyDrop<BioMethod>,
_p: PhantomData<S>,
}
impl<S> Drop for SslStream<S> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.ssl);
ManuallyDrop::drop(&mut self.method);
}
}
}
impl<S> fmt::Debug for SslStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("SslStream")
.field("stream", &self.get_ref())
.field("ssl", &self.ssl.0)
.finish()
}
}
impl<S: Read + Write> SslStream<S> {
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
let ffi = crate::get();
let (bio, method) = bio::new(stream)?;
unsafe {
(ffi.SSL_set_bio)(ssl.0, bio, bio);
}
Ok(Self {
ssl: ManuallyDrop::new(ssl),
method: ManuallyDrop::new(method),
_p: PhantomData,
})
}
pub fn connect(&mut self) -> Result<(), Error> {
let ffi = crate::get();
let ret = unsafe { (ffi.SSL_connect)(self.ssl.0) };
if ret > 0 {
Ok(())
} else {
Err(self.make_error(ret))
}
}
pub fn accept(&mut self) -> Result<(), Error> {
let ffi = crate::get();
let ret = unsafe { (ffi.SSL_accept)(self.ssl.0) };
if ret > 0 {
Ok(())
} else {
Err(self.make_error(ret))
}
}
pub fn do_handshake(&mut self) -> Result<(), Error> {
let ffi = crate::get();
let ret = unsafe { (ffi.SSL_do_handshake)(self.ssl.0) };
if ret > 0 {
Ok(())
} else {
Err(self.make_error(ret))
}
}
}
impl<S> SslStream<S> {
fn make_error(&mut self, ret: c_int) -> Error {
self.check_panic();
let code = self.ssl.get_error(ret);
let cause = match code {
ErrorCode::SSL => Some(InnerError::Ssl(ErrorStack::get())),
ErrorCode::SYSCALL => {
let errs = ErrorStack::get();
if errs.errors().is_empty() {
self.get_bio_error().map(InnerError::Io)
} else {
Some(InnerError::Ssl(errs))
}
}
ErrorCode::ZERO_RETURN => None,
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
self.get_bio_error().map(InnerError::Io)
}
_ => None,
};
Error { code, cause }
}
fn check_panic(&mut self) {
if let Some(err) = unsafe { bio::take_panic::<S>(self.ssl.get_raw_rbio()) } {
resume_unwind(err);
}
}
fn get_bio_error(&mut self) -> Option<io::Error> {
unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
}
pub fn get_ref(&self) -> &S {
unsafe {
let bio = self.ssl.get_raw_rbio();
bio::get_ref(bio)
}
}
pub fn get_mut(&mut self) -> &mut S {
unsafe {
let bio = self.ssl.get_raw_rbio();
bio::get_mut(bio)
}
}
}
impl<S: Read + Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let ffi = crate::get();
loop {
let mut readbytes = 0;
let ret = unsafe {
(ffi.SSL_read_ex)(
self.ssl.0,
buf.as_mut_ptr().cast(),
buf.len(),
&mut readbytes,
)
};
if ret > 0 {
return Ok(readbytes);
} else {
let e = self.make_error(ret);
if e.code() == ErrorCode::ZERO_RETURN {
return Ok(0);
} else if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() {
return Ok(0);
} else if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() {
} else {
return Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)));
}
}
}
}
}
impl<S: Read + Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let ffi = crate::get();
loop {
let mut written = 0;
let ret = unsafe {
(ffi.SSL_write_ex)(self.ssl.0, buf.as_ptr().cast(), buf.len(), &mut written)
};
if ret > 0 {
return Ok(written);
} else {
let e = self.make_error(ret);
if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() {
} else {
return Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)));
}
}
}
}
fn flush(&mut self) -> io::Result<()> {
self.get_mut().flush()
}
}
mod error {
use std::{error, ffi::c_int, fmt, io};
use crate::{error::ErrorStack, ssl::MidHandshakeSslStream, sys as ffi};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ErrorCode(c_int);
impl ErrorCode {
pub const ZERO_RETURN: ErrorCode = ErrorCode(ffi::SSL_ERROR_ZERO_RETURN);
pub const WANT_READ: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_READ);
pub const WANT_WRITE: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_WRITE);
pub const SYSCALL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SYSCALL);
pub const SSL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SSL);
pub const WANT_CLIENT_HELLO_CB: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_CLIENT_HELLO_CB);
pub fn from_raw(raw: c_int) -> ErrorCode {
ErrorCode(raw)
}
}
#[derive(Debug)]
pub(crate) enum InnerError {
Io(io::Error),
Ssl(ErrorStack),
}
#[derive(Debug)]
pub struct Error {
pub(crate) code: ErrorCode,
pub(crate) cause: Option<InnerError>,
}
impl Error {
pub fn code(&self) -> ErrorCode {
self.code
}
pub fn io_error(&self) -> Option<&io::Error> {
match self.cause {
Some(InnerError::Io(ref e)) => Some(e),
_ => None,
}
}
pub fn into_io_error(self) -> Result<io::Error, Error> {
match self.cause {
Some(InnerError::Io(e)) => Ok(e),
_ => Err(self),
}
}
pub fn ssl_error(&self) -> Option<&ErrorStack> {
match self.cause {
Some(InnerError::Ssl(ref e)) => Some(e),
_ => None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.code {
ErrorCode::ZERO_RETURN => fmt.write_str("the SSL session has been shut down"),
ErrorCode::WANT_READ => match self.io_error() {
Some(_) => fmt.write_str("a nonblocking read call would have blocked"),
None => fmt.write_str("the operation should be retried"),
},
ErrorCode::WANT_WRITE => match self.io_error() {
Some(_) => fmt.write_str("a nonblocking write call would have blocked"),
None => fmt.write_str("the operation should be retried"),
},
ErrorCode::SYSCALL => match self.io_error() {
Some(err) => write!(fmt, "{}", err),
None => fmt.write_str("unexpected EOF"),
},
ErrorCode::SSL => match self.ssl_error() {
Some(e) => write!(fmt, "{}", e),
None => fmt.write_str("OpenSSL error"),
},
ErrorCode(code) => write!(fmt, "unknown error code {}", code),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self.cause {
Some(InnerError::Io(ref e)) => Some(e),
Some(InnerError::Ssl(ref e)) => Some(e),
None => None,
}
}
}
pub enum HandshakeError<S> {
SetupFailure(ErrorStack),
Failure(MidHandshakeSslStream<S>),
WouldBlock(MidHandshakeSslStream<S>),
}
impl<S> From<ErrorStack> for HandshakeError<S> {
fn from(e: ErrorStack) -> HandshakeError<S> {
HandshakeError::SetupFailure(e)
}
}
}