use std::io;
use std::io::{IoSlice, IoSliceMut};
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::os::fd::{AsRawFd, RawFd};
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll, Waker};
use mio::{event::Source, unix::SourceFd, Interest, Registry, Token};
use mfio::error::State;
use mfio::io::{Read as RdPerm, Write as WrPerm, *};
use mfio::tarc::BaseArc;
use super::super::unix_extra::set_nonblock;
use super::{BlockTrack, Key, MioState};
use crate::util::{from_io_error, io_err, stream::StreamBuf};
use crate::{Shutdown, TcpStreamHandle};
use mio::net;
pub struct StreamInner {
fd: net::TcpStream,
stream: StreamBuf,
track: BlockTrack,
poll_waker: Option<Waker>,
}
impl AsRawFd for StreamInner {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
impl Source for StreamInner {
fn register(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
self.track.cur_interests = Some(interests);
registry.register(&mut SourceFd(&self.fd.as_raw_fd()), token, interests)
}
fn reregister(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
self.track.cur_interests = Some(interests);
registry.reregister(&mut SourceFd(&self.fd.as_raw_fd()), token, interests)
}
fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
self.track.cur_interests = None;
registry.deregister(&mut SourceFd(&self.fd.as_raw_fd()))
}
}
impl From<net::TcpStream> for StreamInner {
fn from(fd: net::TcpStream) -> Self {
Self {
fd,
stream: StreamBuf::default(),
track: Default::default(),
poll_waker: None,
}
}
}
impl StreamInner {
fn read(mut stream: &net::TcpStream, iov: &mut [IoSliceMut]) -> io::Result<usize> {
use std::io::Read;
stream.read_vectored(iov)
}
fn write(mut stream: &net::TcpStream, iov: &[IoSlice]) -> io::Result<usize> {
use std::io::Write;
stream.write_vectored(iov)
}
pub fn update_interests(&mut self, key: usize, registry: &Registry) -> std::io::Result<()> {
let expected_interests = self.track.expected_interests();
if self.track.cur_interests != expected_interests {
if let Some(i) = expected_interests {
if self.track.cur_interests.is_some() {
self.reregister(registry, Token(key), i)?;
} else {
self.register(registry, Token(key), i)?;
}
} else {
self.deregister(registry)?;
}
}
Ok(())
}
pub fn cancel_all_ops(&mut self) {
self.stream
.on_read(Err(io::ErrorKind::Interrupted.into()), None);
self.stream
.on_write(Err(io::ErrorKind::Interrupted.into()), None);
}
#[tracing::instrument(skip(self))]
pub fn do_ops(&mut self, read: bool, write: bool) {
log::trace!(
"Do ops file={:?} read={read} write={write} (to read={} to write={})",
self.fd.as_raw_fd(),
self.stream.read_ops(),
self.stream.write_ops()
);
if let Some(waker) = self.poll_waker.take() {
waker.wake();
}
if read || !self.track.read_blocked {
while self.stream.read_ops() > 0 {
let rd_span =
tracing::span!(tracing::Level::TRACE, "read", ops = self.stream.read_ops());
let _span = rd_span.enter();
self.track.read_blocked = false;
let queue = self.stream.read_queue();
if !queue.is_empty() {
let res = Self::read(&self.fd, queue);
if res
.as_ref()
.err()
.map(|e| e.kind() != io::ErrorKind::WouldBlock)
.unwrap_or(true)
{
self.stream.on_read(res, None);
} else {
tracing::event!(tracing::Level::INFO, "read blocked");
self.track.read_blocked = true;
break;
}
}
}
}
if write || !self.track.write_blocked {
while self.stream.write_ops() > 0 {
let wr_span = tracing::span!(
tracing::Level::TRACE,
"write",
ops = self.stream.write_ops()
);
let _span = wr_span.enter();
self.track.write_blocked = false;
let queue = self.stream.write_queue();
if !queue.is_empty() {
let res = Self::write(&self.fd, queue);
if res
.as_ref()
.err()
.map(|e| e.kind() != io::ErrorKind::WouldBlock)
.unwrap_or(true)
{
self.stream.on_write(res, None);
} else {
tracing::event!(tracing::Level::INFO, "write blocked");
self.track.write_blocked = true;
break;
}
}
}
}
}
pub fn on_queue(&mut self) {
self.track.update_queued = false;
self.do_ops(true, true);
}
}
trait IntoOp: PacketPerms {
fn push_op(stream: &mut StreamInner, pkt: BoundPacketView<Self>);
}
impl IntoOp for RdPerm {
fn push_op(stream: &mut StreamInner, pkt: BoundPacketView<Self>) {
stream.stream.queue_write(pkt, None);
}
}
impl IntoOp for WrPerm {
fn push_op(stream: &mut StreamInner, pkt: BoundPacketView<Self>) {
stream.stream.queue_read(pkt, None);
}
}
impl<Perms: IntoOp> PacketIo<Perms, NoPos> for TcpStream {
fn send_io(&self, _: NoPos, packet: BoundPacketView<Perms>) {
let streams = self.state.streams.read();
let stream = streams.get(self.idx).unwrap();
let stream = &mut *stream.lock();
Perms::push_op(stream, packet);
if !stream.track.update_queued {
stream.track.update_queued = true;
self.state.opqueue.lock().push(Key::Stream(self.idx));
}
}
}
pub struct TcpStream {
idx: usize,
state: BaseArc<MioState>,
}
impl TcpStream {
pub(super) fn new(idx: usize, state: BaseArc<MioState>) -> Self {
Self { idx, state }
}
pub(super) fn register_stream(state: &BaseArc<MioState>, stream: net::TcpStream) -> Self {
let fd = stream.as_raw_fd();
set_nonblock(fd).unwrap();
let streams = state.streams.read();
let entry = streams.vacant_entry().unwrap();
let key = Key::Stream(entry.key());
let stream = StreamInner::from(stream);
log::trace!(
"Register stream={:?} state={:?}: key={key:?}",
stream.as_raw_fd(),
state.as_ptr()
);
entry.insert(stream.into());
TcpStream::new(key.idx(), state.clone())
}
pub(super) fn tcp_connect<'a, A: ToSocketAddrs + Send + 'a>(
backend: &'a BaseArc<MioState>,
addrs: A,
) -> TcpConnectFuture<'a, A> {
TcpConnectFuture {
backend,
addrs: addrs.to_socket_addrs().ok(),
idx: None,
}
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
log::trace!("Drop {}", self.idx);
let mut stream = self.state.streams.read().take(self.idx).unwrap();
let _ = self
.state
.poll
.lock()
.registry()
.deregister(stream.get_mut());
}
}
impl TcpStreamHandle for TcpStream {
fn local_addr(&self) -> mfio::error::Result<SocketAddr> {
let streams = self.state.streams.read();
let stream = streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
let stream = stream.lock();
stream.fd.local_addr().map_err(from_io_error)
}
fn peer_addr(&self) -> mfio::error::Result<SocketAddr> {
let streams = self.state.streams.read();
let stream = streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
let stream = stream.lock();
stream.fd.peer_addr().map_err(from_io_error)
}
fn shutdown(&self, how: Shutdown) -> mfio::error::Result<()> {
let streams = self.state.streams.read();
let stream = streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
let stream = stream.lock();
stream.fd.shutdown(how.into()).map_err(from_io_error)
}
}
pub struct TcpConnectFuture<'a, A: ToSocketAddrs + 'a> {
backend: &'a BaseArc<MioState>,
addrs: Option<A::Iter>,
idx: Option<usize>,
}
impl<'a, A: ToSocketAddrs + 'a> Future for TcpConnectFuture<'a, A> {
type Output = mfio::error::Result<TcpStream>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
if let Some(idx) = this.idx.take() {
if let Some(stream) = this.backend.streams.read().get(idx) {
let mut stream = stream.lock();
if stream.fd.peer_addr().is_ok() {
let wrapper = TcpStream::new(idx, this.backend.clone());
let ret = match stream.fd.take_error() {
Ok(Some(e)) => Err(e),
Err(e) => Err(e),
Ok(None) => Ok(wrapper),
};
if let Ok(ret) = ret {
break Poll::Ready(Ok(ret));
}
} else {
stream.track.write_blocked = true;
if stream
.update_interests(
Key::Stream(idx).key(),
this.backend.poll.lock().registry(),
)
.is_err()
{
let _ = TcpStream::new(idx, this.backend.clone());
continue;
}
stream.poll_waker = Some(cx.waker().clone());
this.idx = Some(idx);
break Poll::Pending;
}
} else {
break Poll::Ready(Err(io_err(State::NotFound)));
}
} else if let Some(addr) = this.addrs.as_mut().and_then(|v| v.next()) {
let stream = net::TcpStream::connect(addr);
if let Ok(stream) = stream {
let streams = this.backend.streams.read();
let entry = streams.vacant_entry().unwrap();
let key = Key::Stream(entry.key());
let stream = StreamInner::from(stream);
log::trace!(
"Connect stream={:?} state={:?}: key={key:?}",
stream.as_raw_fd(),
this.backend.as_ptr()
);
entry.insert(stream.into());
this.idx = Some(key.idx());
}
} else {
break Poll::Ready(Err(io_err(State::Exhausted)));
}
}
}
}