use std::cell::UnsafeCell;
use std::io::{Error, ErrorKind, IoSlice, Read, Result, Write};
use std::time::Duration;
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::error::LunaticError;
use crate::host;
const TIMEOUT: u32 = 9027;
#[derive(Debug)]
pub struct TlsStream {
id: u64,
consumed: UnsafeCell<bool>,
}
impl Drop for TlsStream {
fn drop(&mut self) {
if unsafe { !*self.consumed.get() } {
unsafe { host::api::networking::drop_tls_stream(self.id) };
}
}
}
impl Clone for TlsStream {
fn clone(&self) -> Self {
let id = unsafe { host::api::networking::clone_tls_stream(self.id) };
Self {
id,
consumed: UnsafeCell::new(false),
}
}
}
impl Serialize for TlsStream {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
unsafe { *self.consumed.get() = true };
let index = unsafe { host::api::message::push_tls_stream(self.id) };
serializer.serialize_u64(index)
}
}
struct TlsStreamVisitor;
impl<'de> Visitor<'de> for TlsStreamVisitor {
type Value = TlsStream;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an u64 index")
}
fn visit_u64<E>(self, index: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
let id = unsafe { host::api::message::take_tls_stream(index) };
Ok(TlsStream::from(id))
}
}
impl<'de> Deserialize<'de> for TlsStream {
fn deserialize<D>(deserializer: D) -> std::result::Result<TlsStream, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_u64(TlsStreamVisitor)
}
}
impl TlsStream {
pub(crate) fn from(id: u64) -> Self {
TlsStream {
id,
consumed: UnsafeCell::new(false),
}
}
pub fn connect(addr: &str, port: u32) -> Result<Self> {
TlsStream::connect_timeout_(addr, None, port, vec![])
}
pub fn connect_with_certs(addr: &str, port: u32, certs: Vec<Vec<u8>>) -> Result<Self> {
TlsStream::connect_timeout_(addr, None, port, certs)
}
pub fn connect_timeout(
addr: &str,
timeout: Duration,
port: u32,
certs: Vec<Vec<u8>>,
) -> Result<Self> {
TlsStream::connect_timeout_(addr, Some(timeout), port, certs)
}
fn connect_timeout_(
addr: &str,
timeout: Option<Duration>,
port: u32,
certs: Vec<Vec<u8>>,
) -> Result<Self> {
let mut id = 0;
let timeout_ms = match timeout {
Some(timeout) => timeout.as_millis() as u64,
None => u64::MAX,
};
let certs_mapped: Vec<(u32, *mut u32)> = certs
.iter()
.map(|cert| (cert.len() as u32, cert.as_ptr() as *mut u32))
.collect();
let result = unsafe {
host::api::networking::tls_connect(
addr.as_ptr(),
addr.len() as u32,
port,
timeout_ms,
&mut id as *mut u64,
certs_mapped.as_ptr() as *mut u8,
certs_mapped.len() as u32,
)
};
if result == 0 {
return Ok(TlsStream::from(id));
}
let lunatic_error = LunaticError::Error(id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
pub fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<()> {
unsafe {
host::api::networking::set_tls_write_timeout(
self.id,
duration.map_or(u64::MAX, |d| d.as_millis() as u64),
);
}
Ok(())
}
pub fn write_timeout(&self) -> Option<Duration> {
unsafe {
match host::api::networking::get_tls_write_timeout(self.id) {
u64::MAX => None,
millis => Some(Duration::from_millis(millis)),
}
}
}
pub fn set_read_timeout(&mut self, duration: Option<Duration>) -> Result<()> {
unsafe {
host::api::networking::set_tls_read_timeout(
self.id,
duration.map_or(u64::MAX, |d| d.as_millis() as u64),
);
}
Ok(())
}
pub fn read_timeout(&self) -> Option<Duration> {
unsafe {
match host::api::networking::get_tls_read_timeout(self.id) {
u64::MAX => None,
millis => Some(Duration::from_millis(millis)),
}
}
}
}
impl Write for TlsStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let io_slice = IoSlice::new(buf);
self.write_vectored(&[io_slice])
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
let mut nwritten_or_error_id: u64 = 0;
let result = unsafe {
host::api::networking::tls_write_vectored(
self.id,
bufs.as_ptr() as *const u32,
bufs.len(),
&mut nwritten_or_error_id as *mut u64,
)
};
if result == 0 {
Ok(nwritten_or_error_id as usize)
} else if result == TIMEOUT {
Err(Error::new(ErrorKind::TimedOut, "TlsStream write timed out"))
} else {
let lunatic_error = LunaticError::Error(nwritten_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
fn flush(&mut self) -> Result<()> {
let mut error_id = 0;
match unsafe { host::api::networking::tls_flush(self.id, &mut error_id as *mut u64) } {
0 => Ok(()),
_ => {
let lunatic_error = LunaticError::Error(error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
}
}
impl Read for TlsStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut nread_or_error_id: u64 = 0;
let result = unsafe {
host::api::networking::tls_read(
self.id,
buf.as_mut_ptr(),
buf.len(),
&mut nread_or_error_id as *mut u64,
)
};
if result == 0 {
Ok(nread_or_error_id as usize)
} else if result == TIMEOUT {
Err(Error::new(ErrorKind::TimedOut, "TlsStream read timed out"))
} else {
let lunatic_error = LunaticError::Error(nread_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
}