#[cfg(feature = "std")]
use std::{
io::{Error as IoError, ErrorKind as IoErrorKind, Read, Result as IoResult, Write},
net::UdpSocket,
result::Result as StdResult,
};
use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
use mbedtls_sys::types::size_t;
use super::context::Context;
use crate::error::Result;
#[cfg(feature = "std")]
use crate::error::{codes, Error};
pub trait IoCallbackUnsafe<T> {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int
where
Self: Sized;
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int
where
Self: Sized;
fn data_ptr(&mut self) -> *mut c_void;
}
pub trait IoCallback<T> {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
fn send(&mut self, buf: &[u8]) -> Result<usize>;
}
impl<IO: IoCallback<T>, T> IoCallbackUnsafe<T> for IO {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).recv(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(e) => e.to_int(),
}
}
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).send(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(e) => e.to_int(),
}
}
fn data_ptr(&mut self) -> *mut c_void {
self as *mut IO as *mut _
}
}
pub enum AnyIo {}
#[cfg(feature = "std")]
pub enum Stream {}
pub trait Io {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize>;
fn send(&mut self, buf: &[u8]) -> Result<usize>;
}
impl<IO: Io> IoCallback<AnyIo> for IO {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
Io::recv(self, buf)
}
fn send(&mut self, buf: &[u8]) -> Result<usize> {
Io::send(self, buf)
}
}
#[cfg(feature = "std")]
impl<IO: Read + Write> IoCallback<Stream> for IO {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
self.read(buf).map_err(|e| match e {
ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantRead),
_ => Error::from(codes::NetRecvFailed),
})
}
fn send(&mut self, buf: &[u8]) -> Result<usize> {
self.write(buf).map_err(|e| match e {
ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::from(codes::SslWantWrite),
_ => Error::from(codes::NetSendFailed),
})
}
}
#[cfg(feature = "std")]
pub struct ConnectedUdpSocket {
socket: UdpSocket,
}
#[cfg(feature = "std")]
impl ConnectedUdpSocket {
pub fn connect<A: std::net::ToSocketAddrs>(socket: UdpSocket, addr: A) -> StdResult<Self, (IoError, UdpSocket)> {
match socket.connect(addr) {
Ok(_) => Ok(ConnectedUdpSocket { socket }),
Err(e) => Err((e, socket)),
}
}
pub fn into_socket(self) -> UdpSocket {
self.socket
}
}
#[cfg(feature = "std")]
impl Io for ConnectedUdpSocket {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
match self.socket.recv(buf) {
Ok(i) => Ok(i),
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(codes::SslWantRead.into()),
Err(_) => Err(codes::NetRecvFailed.into()),
}
}
fn send(&mut self, buf: &[u8]) -> Result<usize> {
self.socket.send(buf).map_err(|_| codes::NetSendFailed.into())
}
}
impl<T: IoCallbackUnsafe<AnyIo>> Io for Context<T> {
fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
Context::recv(self, buf)
}
fn send(&mut self, buf: &[u8]) -> Result<usize> {
Context::send(self, buf)
}
}
#[cfg(feature = "std")]
impl<T: IoCallbackUnsafe<Stream>> Read for Context<T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self.recv(buf) {
Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
Err(IoErrorKind::WouldBlock.into())
}
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i),
}
}
}
#[cfg(feature = "std")]
impl<T: IoCallbackUnsafe<Stream>> Write for Context<T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self.send(buf) {
Err(e) if e.high_level() == Some(codes::SslPeerCloseNotify) => Ok(0),
Err(e) if matches!(e.high_level(), Some(codes::SslWantRead | codes::SslWantWrite)) => {
Err(IoErrorKind::WouldBlock.into())
}
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i),
}
}
fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}