use std::cell::Cell;
use std::ffi::{CString, NulError};
use std::future::Future;
use std::mem::{self, MaybeUninit};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::RawFd;
use std::os::unix::prelude::IntoRawFd;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{io, ptr};
use futures::{AsyncRead, AsyncWrite};
use crate::ffi::tarantool as ffi;
use crate::fiber::r#async::context::ContextExt;
use crate::fiber::{self, r#async};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("failed to resolve host address by domain name")]
ResolveAddress,
#[error("input parameters contain ffi incompatible strings: {0}")]
ConstructCString(NulError),
#[error("failed to connect to supplied address: {0}")]
Connect(io::Error),
#[error("failed to set socket to nonblocking mode: {0}")]
SetNonBlock(io::Error),
#[error("unknown address family: {0}")]
UnknownAddressFamily(u16),
#[error("write half of the stream is closed")]
WriteClosed,
}
#[derive(Debug)]
pub struct TcpStream {
fd: RawFd,
}
impl TcpStream {
pub async fn connect(url: &str, port: u16) -> Result<TcpStream, Error> {
let addrs = unsafe {
let addr_info = get_address_info(url).await?;
let addrs = get_addrs_from_info(addr_info, port);
libc::freeaddrinfo(addr_info);
addrs
};
let addrs = addrs?;
let stream = std::net::TcpStream::connect(addrs.as_slice()).map_err(Error::Connect)?;
stream.set_nonblocking(true).map_err(Error::SetNonBlock)?;
Ok(Self {
fd: stream.into_raw_fd(),
})
}
pub fn close_token(&self) -> CloseToken {
CloseToken(self.fd)
}
}
#[derive(Debug)]
pub struct CloseToken(RawFd);
impl CloseToken {
pub fn close(&self) -> io::Result<()> {
let (res, err) = (
unsafe { ffi::coio_close(self.0) },
io::Error::last_os_error(),
);
if res != 0 {
Err(err)
} else {
Ok(())
}
}
}
unsafe fn get_addrs_from_info(
addrs: *const libc::addrinfo,
port: u16,
) -> Result<Vec<SocketAddr>, Error> {
let mut addr = addrs;
let mut out_addrs = Vec::new();
while !addr.is_null() {
out_addrs.push(to_rs_sockaddr((*addr).ai_addr, port)?);
addr = (*addr).ai_next;
}
Ok(out_addrs)
}
async unsafe fn get_address_info(url: &str) -> Result<*mut libc::addrinfo, Error> {
struct GetAddrInfo(r#async::coio::GetAddrInfo);
impl Future for GetAddrInfo {
type Output = Result<*mut libc::addrinfo, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
if self.0.err.get() {
return Poll::Ready(Err(Error::ResolveAddress));
}
if self.0.res.get().is_null() {
ContextExt::set_coio_getaddrinfo(cx, self.0.clone());
Poll::Pending
} else {
Poll::Ready(Ok(self.0.res.get()))
}
}
}
}
let host = CString::new(url).map_err(Error::ConstructCString)?;
let mut hints = MaybeUninit::<libc::addrinfo>::zeroed().assume_init();
hints.ai_family = libc::AF_UNSPEC;
hints.ai_socktype = libc::SOCK_STREAM;
GetAddrInfo(r#async::coio::GetAddrInfo {
host,
hints,
res: Rc::new(Cell::new(ptr::null_mut())),
err: Rc::new(Cell::new(false)),
})
.await
}
unsafe fn to_rs_sockaddr(addr: *const libc::sockaddr, port: u16) -> Result<SocketAddr, Error> {
match (*addr).sa_family as libc::c_int {
libc::AF_INET => {
let addr: *mut libc::sockaddr_in = mem::transmute(addr);
(*addr).sin_port = port;
let octets: [u8; 4] = (*addr).sin_addr.s_addr.to_ne_bytes();
Ok(SocketAddr::V4(SocketAddrV4::new(octets.into(), port)))
}
libc::AF_INET6 => {
let addr: *mut libc::sockaddr_in6 = mem::transmute(addr);
(*addr).sin6_port = port;
let octets = (*addr).sin6_addr.s6_addr;
let flow_info = (*addr).sin6_flowinfo;
let scope_id = (*addr).sin6_scope_id;
Ok(SocketAddr::V6(SocketAddrV6::new(
octets.into(),
port,
flow_info,
scope_id,
)))
}
af => Err(Error::UnknownAddressFamily(af as u16)),
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let (result, err) = (
unsafe { libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len()) },
io::Error::last_os_error(),
);
if result >= 0 {
return Poll::Ready(Ok(result as usize));
}
match err.kind() {
io::ErrorKind::WouldBlock => {
unsafe { ContextExt::set_coio_wait(cx, self.fd, ffi::CoIOFlags::WRITE) }
Poll::Pending
}
io::ErrorKind::Interrupted => {
unsafe { ContextExt::set_deadline(cx, fiber::clock()) }
Poll::Pending
}
_ => Poll::Ready(Err(err)),
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.close_token().close())
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let (result, err) = (
unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) },
io::Error::last_os_error(),
);
if result >= 0 {
return Poll::Ready(Ok(result as usize));
}
match err.kind() {
io::ErrorKind::WouldBlock => {
unsafe { ContextExt::set_coio_wait(cx, self.fd, ffi::CoIOFlags::READ) }
Poll::Pending
}
io::ErrorKind::Interrupted => {
unsafe { ContextExt::set_deadline(cx, fiber::clock()) }
Poll::Pending
}
_ => Poll::Ready(Err(err)),
}
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
let _ = self.close_token().close();
}
}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
use crate::fiber;
use crate::fiber::r#async::timeout::{self, IntoTimeout};
use crate::test::util::always_pending;
use crate::test::util::TARANTOOL_LISTEN;
use std::net::TcpListener;
use std::thread;
use std::time::Duration;
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
const _10_SEC: Duration = Duration::from_secs(10);
const _0_SEC: Duration = Duration::from_secs(0);
#[crate::test(tarantool = "crate")]
fn resolve_address() {
unsafe {
let _ = fiber::block_on(get_address_info("localhost").timeout(_10_SEC)).unwrap();
}
}
#[crate::test(tarantool = "crate")]
fn resolve_address_error() {
unsafe {
let err = fiber::block_on(get_address_info("invalid domain name").timeout(_10_SEC))
.unwrap_err()
.to_string();
assert_eq!(err, "failed to resolve host address by domain name")
}
}
#[crate::test(tarantool = "crate")]
fn connect() {
let _ = fiber::block_on(TcpStream::connect("localhost", TARANTOOL_LISTEN).timeout(_10_SEC))
.unwrap();
}
#[crate::test(tarantool = "crate")]
async fn read() {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
stream.read_exact(&mut buf).timeout(_10_SEC).await.unwrap();
}
#[crate::test(tarantool = "crate")]
async fn read_timeout() {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
assert_eq!(
stream
.read_exact(&mut buf)
.timeout(_0_SEC)
.await
.unwrap_err()
.to_string(),
"deadline expired"
);
}
#[crate::test(tarantool = "crate")]
fn write() {
let (sender, receiver) = std::sync::mpsc::channel();
let listener = TcpListener::bind("127.0.0.1:3302").unwrap();
thread::spawn(move || {
for stream in listener.incoming() {
let mut stream = stream.unwrap();
let mut buf = vec![];
<std::net::TcpStream as std::io::Read>::read_to_end(&mut stream, &mut buf).unwrap();
sender.send(buf).unwrap();
}
});
{
fiber::block_on(async {
let mut stream = TcpStream::connect("localhost", 3302)
.timeout(_10_SEC)
.await
.unwrap();
timeout::timeout(_10_SEC, stream.write_all(&[1, 2, 3]))
.await
.unwrap();
timeout::timeout(_10_SEC, stream.write_all(&[4, 5]))
.await
.unwrap();
});
}
let buf = receiver.recv_timeout(Duration::from_secs(5)).unwrap();
assert_eq!(buf, vec![1, 2, 3, 4, 5])
}
#[crate::test(tarantool = "crate")]
fn split() {
let (sender, receiver) = std::sync::mpsc::channel();
let listener = TcpListener::bind("127.0.0.1:3303").unwrap();
thread::spawn(move || {
for stream in listener.incoming() {
let mut stream = stream.unwrap();
let mut buf = vec![0; 5];
<std::net::TcpStream as std::io::Read>::read_exact(&mut stream, &mut buf).unwrap();
<std::net::TcpStream as std::io::Write>::write_all(&mut stream, &buf.clone())
.unwrap();
sender.send(buf).unwrap();
}
});
{
let stream =
fiber::block_on(TcpStream::connect("localhost", 3303).timeout(_10_SEC)).unwrap();
let (mut reader, mut writer) = stream.split();
let reader_handle = fiber::start_async(async move {
let mut buf = vec![0; 5];
timeout::timeout(_10_SEC, reader.read_exact(&mut buf))
.await
.unwrap();
assert_eq!(buf, vec![1, 2, 3, 4, 5])
});
let writer_handle = fiber::start_async(async move {
timeout::timeout(_10_SEC, writer.write_all(&[1, 2, 3]))
.await
.unwrap();
timeout::timeout(_10_SEC, writer.write_all(&[4, 5]))
.await
.unwrap();
});
writer_handle.join();
reader_handle.join();
}
let buf = receiver.recv_timeout(Duration::from_secs(5)).unwrap();
assert_eq!(buf, vec![1, 2, 3, 4, 5])
}
#[crate::test(tarantool = "crate")]
fn join_correct_timeout() {
{
fiber::block_on(async {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
let (is_err, is_ok) = futures::join!(
timeout::timeout(_0_SEC, always_pending()),
timeout::timeout(_10_SEC, stream.read_exact(&mut buf))
);
assert_eq!(is_err.unwrap_err().to_string(), "deadline expired");
is_ok.unwrap();
});
}
{
fiber::block_on(async {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
let (is_ok, is_err) = futures::join!(
timeout::timeout(_10_SEC, stream.read_exact(&mut buf)),
timeout::timeout(_0_SEC, always_pending())
);
assert_eq!(is_err.unwrap_err().to_string(), "deadline expired");
is_ok.unwrap();
});
}
}
#[crate::test(tarantool = "crate")]
fn select_correct_timeout() {
{
fiber::block_on(async {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
let f1 = timeout::timeout(_0_SEC, always_pending()).fuse();
let f2 = timeout::timeout(_10_SEC, stream.read_exact(&mut buf)).fuse();
futures::pin_mut!(f1);
futures::pin_mut!(f2);
let is_err = futures::select!(
res = f1 => res.is_err(),
res = f2 => res.is_err()
);
assert!(is_err);
});
}
{
fiber::block_on(async {
let mut stream = TcpStream::connect("localhost", TARANTOOL_LISTEN)
.timeout(_10_SEC)
.await
.unwrap();
let mut buf = vec![0; 128];
let f1 = timeout::timeout(Duration::from_secs(15), always_pending()).fuse();
let f2 = timeout::timeout(_10_SEC, stream.read_exact(&mut buf)).fuse();
futures::pin_mut!(f1);
futures::pin_mut!(f2);
let is_ok = futures::select!(
res = f1 => res.is_ok(),
res = f2 => res.is_ok()
);
assert!(is_ok);
});
}
}
}