use std::io;
use std::io::{Read, Write};
use std::mem::ManuallyDrop;
use std::net::SocketAddr;
use std::os::unix::io::{AsRawFd, RawFd};
use std::rc::Rc;
use std::time::Duration;
use ignore_result::Ignore;
use mio::{net, Token};
use static_assertions::{assert_impl_all, assert_not_impl_any};
use crate::channel::parallel;
use crate::channel::prelude::*;
use crate::runtime::Scheduler;
pub struct TcpListener {
listener: ManuallyDrop<net::TcpListener>,
readable: parallel::Receiver<()>,
token: Token,
}
assert_impl_all!(TcpListener: Send, Sync);
impl Drop for TcpListener {
fn drop(&mut self) {
let registry = unsafe { Scheduler::registry() };
let listener = unsafe { ManuallyDrop::take(&mut self.listener) };
registry.deregister_event_source(self.token, listener);
}
}
impl TcpListener {
pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
let mut listener = net::TcpListener::bind(addr)?;
let registry = unsafe { Scheduler::registry() };
let (token, readable) = registry.register_tcp_listener(&mut listener)?;
Ok(TcpListener { listener: ManuallyDrop::new(listener), readable, token })
}
pub fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
loop {
match self.listener.accept() {
Ok((stream, addr)) => {
let stream = TcpStream::new(stream)?;
return Ok((stream, addr));
},
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
self.readable.recv().expect("runtime closing");
},
Err(err) => return Err(err),
}
}
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub fn set_ttl(&self, ttl: u8) -> io::Result<()> {
self.listener.set_ttl(ttl.into())
}
pub fn ttl(&self) -> io::Result<u8> {
self.listener.ttl().map(|ttl| ttl as u8)
}
}
pub struct TcpStream {
stream: ManuallyDrop<net::TcpStream>,
readable: parallel::Receiver<()>,
writable: parallel::Receiver<()>,
token: Token,
}
assert_impl_all!(TcpStream: Send, Sync);
impl Drop for TcpStream {
fn drop(&mut self) {
let registry = unsafe { Scheduler::registry() };
let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
registry.deregister_event_source(self.token, stream);
}
}
impl TcpStream {
fn new(mut stream: net::TcpStream) -> io::Result<Self> {
let registry = unsafe { Scheduler::registry() };
let (token, readable, mut writable) = registry.register_tcp_stream(&mut stream)?;
writable.recv().expect("runtime closing");
Ok(TcpStream { stream: ManuallyDrop::new(stream), readable, writable, token })
}
pub fn set_ttl(&self, ttl: u8) -> io::Result<()> {
self.stream.set_ttl(ttl.into())
}
pub fn ttl(&self) -> io::Result<u8> {
self.stream.ttl().map(|ttl| ttl as u8)
}
pub fn connect(addr: SocketAddr) -> io::Result<Self> {
let stream = Self::new(net::TcpStream::connect(addr)?)?;
if let Some(err) = stream.stream.take_error()? {
return Err(err);
}
Ok(stream)
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.stream.set_nodelay(nodelay)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.stream.nodelay()
}
pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
let fd = self.stream.as_raw_fd();
let linger = libc::linger {
l_onoff: if linger.is_some() { 1 } else { 0 },
l_linger: linger.map(|d| d.as_secs() as libc::c_int).unwrap_or_default(),
};
let rc = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_LINGER,
&linger as *const _ as *const libc::c_void,
std::mem::size_of::<libc::linger>() as libc::socklen_t,
)
};
if rc != 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
let fd = self.stream.as_raw_fd();
let mut linger: libc::linger = unsafe { std::mem::zeroed() };
let mut optlen = std::mem::size_of::<libc::linger>() as libc::socklen_t;
let rc = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_LINGER,
&mut linger as *mut _ as *mut libc::c_void,
&mut optlen,
)
};
if rc != 0 {
return Err(io::Error::last_os_error());
}
Ok((linger.l_onoff != 0).then(|| Duration::from_secs(linger.l_linger as u64)))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
pub fn into_split(mut self) -> (TcpReader, TcpWriter) {
let stream = Rc::new(unsafe { ManuallyDrop::take(&mut self.stream) });
let reader = TcpReader {
stream: ManuallyDrop::new(stream.clone()),
readable: unsafe { std::ptr::read(&self.readable) },
token: self.token,
};
let writer = TcpWriter {
stream: ManuallyDrop::new(stream),
writable: unsafe { std::ptr::read(&self.writable) },
token: self.token,
};
std::mem::forget(self);
(reader, writer)
}
pub fn shutdown_read(&self) -> io::Result<()> {
self.stream.shutdown(std::net::Shutdown::Read)
}
pub fn shutdown_write(&self) -> io::Result<()> {
self.stream.shutdown(std::net::Shutdown::Write)
}
fn read(stream: &mut net::TcpStream, readable: &mut parallel::Receiver<()>, buf: &mut [u8]) -> io::Result<usize> {
loop {
match stream.read(buf) {
Ok(n) => return Ok(n),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
readable.recv().expect("runtime closing");
},
Err(err) => return Err(err),
}
}
}
fn write(stream: &mut net::TcpStream, writable: &mut parallel::Receiver<()>, buf: &[u8]) -> io::Result<usize> {
loop {
match stream.write(buf) {
Ok(n) => return Ok(n),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
writable.recv().expect("runtime closing");
},
Err(err) => return Err(err),
}
}
}
}
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}
pub struct TcpReader {
stream: ManuallyDrop<Rc<net::TcpStream>>,
readable: parallel::Receiver<()>,
token: Token,
}
assert_not_impl_any!(TcpReader: Send, Sync);
impl Drop for TcpReader {
fn drop(&mut self) {
let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
stream.shutdown(std::net::Shutdown::Read).ignore();
if let Some(stream) = Rc::into_inner(stream) {
let registry = unsafe { Scheduler::registry() };
registry.deregister_event_source(self.token, stream);
}
}
}
impl io::Read for TcpReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let stream = Rc::as_ptr(&self.stream) as *mut _;
TcpStream::read(unsafe { &mut *stream }, &mut self.readable, buf)
}
}
pub struct TcpWriter {
stream: ManuallyDrop<Rc<net::TcpStream>>,
writable: parallel::Receiver<()>,
token: Token,
}
assert_not_impl_any!(TcpReader: Send, Sync);
impl Drop for TcpWriter {
fn drop(&mut self) {
let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
stream.shutdown(std::net::Shutdown::Write).ignore();
if let Some(stream) = Rc::into_inner(stream) {
let registry = unsafe { Scheduler::registry() };
registry.deregister_event_source(self.token, stream);
}
}
}
impl io::Write for TcpWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let stream = Rc::as_ptr(&self.stream) as *mut _;
TcpStream::write(unsafe { &mut *stream }, &mut self.writable, buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl io::Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
TcpStream::read(&mut self.stream, &mut self.readable, buf)
}
}
impl io::Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
TcpStream::write(&mut self.stream, &mut self.writable, buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}