use std::cell::UnsafeCell;
use std::io::{Error, ErrorKind, IoSlice, Read, Result, Write};
use std::net::SocketAddr;
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::SocketAddrIterator;
use crate::error::LunaticError;
use crate::host;
const TIMEOUT: u32 = 9027;
#[derive(Debug)]
pub struct TcpStream {
id: u64,
consumed: UnsafeCell<bool>,
}
impl Drop for TcpStream {
fn drop(&mut self) {
if unsafe { !*self.consumed.get() } {
unsafe { host::api::networking::drop_tcp_stream(self.id) };
}
}
}
impl Clone for TcpStream {
fn clone(&self) -> Self {
let id = unsafe { host::api::networking::clone_tcp_stream(self.id) };
Self {
id,
consumed: UnsafeCell::new(false),
}
}
}
impl Serialize for TcpStream {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
unsafe { *self.consumed.get() = true };
let index = unsafe { host::api::message::push_tcp_stream(self.id) };
serializer.serialize_u64(index)
}
}
impl<'de> Deserialize<'de> for TcpStream {
fn deserialize<D>(deserializer: D) -> std::result::Result<TcpStream, D::Error>
where
D: Deserializer<'de>,
{
let index = Deserialize::deserialize(deserializer)?;
let id = unsafe { host::api::message::take_tcp_stream(index) };
Ok(TcpStream::from(id))
}
}
impl TcpStream {
pub(crate) fn from(id: u64) -> Self {
TcpStream {
id,
consumed: UnsafeCell::new(false),
}
}
pub fn connect<A>(addr: A) -> Result<Self>
where
A: super::ToSocketAddrs,
{
TcpStream::connect_timeout_(addr, None)
}
pub fn connect_timeout<A>(addr: A, timeout: Duration) -> Result<Self>
where
A: super::ToSocketAddrs,
{
TcpStream::connect_timeout_(addr, Some(timeout))
}
fn connect_timeout_<A>(addr: A, timeout: Option<Duration>) -> Result<Self>
where
A: super::ToSocketAddrs,
{
let mut id = 0;
for addr in addr.to_socket_addrs()? {
let timeout_ms = match timeout {
Some(timeout) => timeout.as_millis() as u64,
None => u64::MAX,
};
let result = match addr {
SocketAddr::V4(v4_addr) => {
let ip = v4_addr.ip().octets();
let port = v4_addr.port() as u32;
unsafe {
host::api::networking::tcp_connect(
4,
ip.as_ptr(),
port,
0,
0,
timeout_ms,
&mut id as *mut u64,
)
}
}
SocketAddr::V6(v6_addr) => {
let ip = v6_addr.ip().octets();
let port = v6_addr.port() as u32;
let flow_info = v6_addr.flowinfo();
let scope_id = v6_addr.scope_id();
unsafe {
host::api::networking::tcp_connect(
6,
ip.as_ptr(),
port,
flow_info,
scope_id,
timeout_ms,
&mut id as *mut u64,
)
}
}
};
if result == 0 {
return Ok(TcpStream::from(id));
}
}
let lunatic_error = LunaticError::Error(id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
pub fn peer_addr(&self) -> Result<SocketAddr> {
let mut dns_iter_or_error_id = 0;
let result = unsafe {
host::api::networking::tcp_peer_addr(self.id, &mut dns_iter_or_error_id as *mut u64)
};
if result == 0 {
let mut dns_iter = SocketAddrIterator::from(dns_iter_or_error_id);
let addr = dns_iter.next().expect("must contain one element");
Ok(addr)
} else {
let lunatic_error = LunaticError::Error(dns_iter_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
pub fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<()> {
unsafe {
host::api::networking::set_write_timeout(
self.id,
duration.map_or(u64::MAX, |d| d.as_millis() as u64),
);
}
Ok(())
}
pub fn write_timeout(&self) -> Option<Duration> {
unsafe {
match host::api::networking::get_write_timeout(self.id) {
u64::MAX => None,
millis => Some(Duration::from_millis(millis)),
}
}
}
pub fn set_read_timeout(&mut self, duration: Option<Duration>) -> Result<()> {
unsafe {
host::api::networking::set_read_timeout(
self.id,
duration.map_or(u64::MAX, |d| d.as_millis() as u64),
);
}
Ok(())
}
pub fn read_timeout(&self) -> Option<Duration> {
unsafe {
match host::api::networking::get_read_timeout(self.id) {
u64::MAX => None,
millis => Some(Duration::from_millis(millis)),
}
}
}
pub fn set_peek_timeout(&mut self, duration: Option<Duration>) -> Result<()> {
unsafe {
host::api::networking::set_peek_timeout(
self.id,
duration.map_or(u64::MAX, |d| d.as_millis() as u64),
);
}
Ok(())
}
pub fn peek_timeout(&self) -> Option<Duration> {
unsafe {
match host::api::networking::get_peek_timeout(self.id) {
u64::MAX => None,
millis => Some(Duration::from_millis(millis)),
}
}
}
pub fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut nread_or_error_id: u64 = 0;
let result = unsafe {
host::api::networking::tcp_peek(
self.id,
buf.as_mut_ptr(),
buf.len(),
&mut nread_or_error_id as *mut u64,
)
};
if result == 0 {
Ok(nread_or_error_id as usize)
} else if result == TIMEOUT {
Err(Error::new(ErrorKind::TimedOut, "TcpStream peek timed out"))
} else {
let lunatic_error = LunaticError::Error(nread_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
}
impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let io_slice = IoSlice::new(buf);
self.write_vectored(&[io_slice])
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
let mut nwritten_or_error_id: u64 = 0;
let result = unsafe {
host::api::networking::tcp_write_vectored(
self.id,
bufs.as_ptr() as *const u32,
bufs.len(),
&mut nwritten_or_error_id as *mut u64,
)
};
if result == 0 {
Ok(nwritten_or_error_id as usize)
} else if result == TIMEOUT {
Err(Error::new(ErrorKind::TimedOut, "TcpStream write timed out"))
} else {
let lunatic_error = LunaticError::Error(nwritten_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
fn flush(&mut self) -> Result<()> {
let mut error_id = 0;
match unsafe { host::api::networking::tcp_flush(self.id, &mut error_id as *mut u64) } {
0 => Ok(()),
_ => {
let lunatic_error = LunaticError::Error(error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
}
}
impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut nread_or_error_id: u64 = 0;
let result = unsafe {
host::api::networking::tcp_read(
self.id,
buf.as_mut_ptr(),
buf.len(),
&mut nread_or_error_id as *mut u64,
)
};
if result == 0 {
Ok(nread_or_error_id as usize)
} else if result == TIMEOUT {
Err(Error::new(ErrorKind::TimedOut, "TcpStream read timed out"))
} else {
let lunatic_error = LunaticError::Error(nread_or_error_id);
Err(Error::new(ErrorKind::Other, lunatic_error))
}
}
}