use std::{
future::Future,
io::{ErrorKind, Read, Write},
net::ToSocketAddrs,
sync::{Arc, Mutex},
task::Poll,
};
use mio::Interest;
use crate::{pool::thread_pool::HoochPool, reactor::Reactor};
pub struct HoochTcpStream {
pub stream: mio::net::TcpStream,
pub token: mio::Token,
}
impl HoochTcpStream {
pub async fn connect(addr: impl ToSocketAddrs) -> std::io::Result<Self> {
let reactor_tag = Reactor::generate_reactor_tag();
let reactor = Reactor::get();
reactor.register_reactor_tag(reactor_tag);
let stream_handle: Arc<Mutex<Option<Result<mio::net::TcpStream, std::io::Error>>>> =
Arc::new(Mutex::default());
let mut async_hooch_tcp_stream = Box::pin(AsyncHoochTcpStream {
addr,
state: Arc::clone(&stream_handle),
has_polled: false,
});
let mut stream =
std::future::poll_fn(|cx| async_hooch_tcp_stream.as_mut().poll(cx)).await?;
let reactor = Reactor::get();
let token = reactor.unique_token();
Reactor::get().registry().register(
&mut stream,
token,
Interest::READABLE | Interest::WRITABLE,
)?;
Ok(Self { stream, token })
}
pub async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
loop {
match self.stream.read(buf) {
Ok(num) => return Ok(num),
Err(e) if e.kind() == ErrorKind::WouldBlock => {
std::future::poll_fn(|cx| Reactor::get().poll(self.token, cx)).await?
}
Err(e) => return Err(e),
}
}
}
pub async fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
loop {
match self.stream.write(buf) {
Ok(num) => return Ok(num),
Err(e) if e.kind() == ErrorKind::WouldBlock => {
std::future::poll_fn(|cx| Reactor::get().poll(self.token, cx)).await?
}
Err(e) => return Err(e),
}
}
}
}
struct AsyncHoochTcpStream<T: ToSocketAddrs> {
addr: T,
state: Arc<Mutex<Option<std::io::Result<mio::net::TcpStream>>>>,
has_polled: bool,
}
impl<T: ToSocketAddrs> Future for AsyncHoochTcpStream<T> {
type Output = Result<mio::net::TcpStream, std::io::Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if !self.has_polled {
let this = unsafe { self.as_mut().get_unchecked_mut() };
this.has_polled = true;
let listener_handle_clone = Arc::clone(&self.state);
let socket_addr = self.addr.to_socket_addrs().unwrap().next().unwrap();
let waker = cx.waker().clone();
let connect_fn = move || {
let result = move || {
let stream = std::net::TcpStream::connect(socket_addr)?;
stream.set_nonblocking(true)?;
Ok(mio::net::TcpStream::from_std(stream))
};
let stream_result = result();
*listener_handle_clone.lock().unwrap() = Some(stream_result);
waker.wake();
};
let pool = HoochPool::get();
pool.execute(Box::new(connect_fn));
return Poll::Pending;
}
if self.state.lock().unwrap().is_none() {
return Poll::Pending;
}
let listener_result = self.state.lock().unwrap().take().unwrap();
Poll::Ready(listener_result)
}
}