use std::future::poll_fn;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::kernel::{Addr, Domain, Fd, SocketOption, SocketOptionKind, Type};
use crate::shim::tokio::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::shim::tokio::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::shim::tokio::net::ToSocketAddrs;
use crate::sys;
#[derive(Debug)]
pub struct TcpStream {
fd: Fd,
}
impl TcpStream {
pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let peer = addr.to_socket_addr()?;
let guard = FdGuard::new(sys(|k| k.open(domain_of(&peer), Type::Stream)));
poll_fn(|cx| sys(|k| k.poll_connect(guard.fd, cx, &Addr::Inet(peer)))).await?;
Ok(Self { fd: guard.disarm() })
}
pub(super) fn from_fd(fd: Fd) -> Self {
Self { fd }
}
pub(super) fn fd(&self) -> Fd {
self.fd
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
match sys(|k| k.local_addr(self.fd))? {
Addr::Inet(sa) => Ok(sa),
Addr::Unix(_) => panic!("TcpStream is Addr::Inet"),
}
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
match sys(|k| k.peer_addr(self.fd))? {
Addr::Inet(sa) => Ok(sa),
Addr::Unix(_) => panic!("TcpStream is Addr::Inet"),
}
}
pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
match sys(|k| k.poll_recv(self.fd, &mut noop_cx(), buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
match sys(|k| k.poll_send(self.fd, &mut noop_cx(), buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| sys(|k| k.poll_peek(self.fd, cx, buf))).await
}
pub fn poll_peek(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
let fd = self.fd;
let unfilled = buf.initialize_unfilled();
match sys(|k| k.poll_peek(fd, cx, unfilled)) {
Poll::Ready(Ok(n)) => {
buf.advance(n);
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
let ttl: u8 = ttl
.try_into()
.map_err(|_| io::Error::from(io::ErrorKind::InvalidInput))?;
sys(|k| k.set_option(self.fd, SocketOption::IpTtl(ttl)))
}
pub fn ttl(&self) -> io::Result<u32> {
match sys(|k| k.get_option(self.fd, SocketOptionKind::IpTtl))? {
SocketOption::IpTtl(v) => Ok(v as u32),
_ => unreachable!(),
}
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
sys(|k| k.set_option(self.fd, SocketOption::TcpNoDelay(nodelay)))
}
pub fn nodelay(&self) -> io::Result<bool> {
match sys(|k| k.get_option(self.fd, SocketOptionKind::TcpNoDelay))? {
SocketOption::TcpNoDelay(v) => Ok(v),
_ => unreachable!(),
}
}
#[allow(clippy::needless_lifetimes)]
pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) {
split(self)
}
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
split_owned(self)
}
}
pub(super) fn noop_cx() -> Context<'static> {
Context::from_waker(Waker::noop())
}
impl Drop for TcpStream {
fn drop(&mut self) {
sys(|k| k.close(self.fd));
}
}
impl TcpStream {
pub(super) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let fd = self.fd;
let unfilled = buf.initialize_unfilled();
match sys(|k| k.poll_recv(fd, cx, unfilled)) {
Poll::Ready(Ok(n)) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
pub(super) fn poll_write_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let fd = self.fd;
sys(|k| k.poll_send(fd, cx, buf))
}
pub(super) fn poll_shutdown_priv(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let fd = self.fd;
sys(|k| k.poll_shutdown_write(fd, cx))
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.poll_shutdown_priv(cx)
}
}
fn domain_of(peer: &SocketAddr) -> Domain {
match peer {
SocketAddr::V4(_) => Domain::Inet,
SocketAddr::V6(_) => Domain::Inet6,
}
}
struct FdGuard {
fd: Fd,
armed: bool,
}
impl FdGuard {
fn new(fd: Fd) -> Self {
Self { fd, armed: true }
}
fn disarm(mut self) -> Fd {
self.armed = false;
self.fd
}
}
impl Drop for FdGuard {
fn drop(&mut self) {
if self.armed {
sys(|k| k.close(self.fd));
}
}
}