use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use mfio::error::{Error, State};
use mfio::io::*;
use mfio::mferr;
use mfio::tarc::BaseArc;
use parking_lot::Mutex;
use std::io;
use std::net::{self, SocketAddr, ToSocketAddrs};
use std::os::windows::io::{AsRawHandle, AsRawSocket, RawHandle};
use super::{CSockAddr, IocpState, Operation, OperationHeader, OperationMode, TcpGetSock, WsaOp};
use crate::util::{from_io_error, io_err, stream::StreamBuf, DeferredPackets};
use crate::{Shutdown, TcpStreamHandle};
use ::windows::Win32::Foundation::HANDLE;
use ::windows::Win32::Networking::WinSock::{
shutdown, WSAGetLastError, SD_BOTH, SD_RECEIVE, SD_SEND, SOCKET, WSABUF, WSAECONNRESET,
};
use ::windows::Win32::System::IO::CancelIoEx;
use ::windows::Win32::System::IO::OVERLAPPED;
pub struct TcpStream {
idx: usize,
state: BaseArc<Mutex<IocpState>>,
}
impl Drop for TcpStream {
fn drop(&mut self) {
let mut state = self.state.lock();
let v = state.streams.remove(self.idx);
log::trace!("Dropping {} {}", self.idx, v.socket.as_raw_socket());
}
}
impl TcpStream {
pub(super) fn new(idx: usize, state: BaseArc<Mutex<IocpState>>) -> Self {
Self { idx, state }
}
pub(super) fn tcp_connect<'a, A: ToSocketAddrs + Send + 'a>(
state: &'a BaseArc<Mutex<IocpState>>,
addrs: A,
) -> TcpConnectFuture<'a, A> {
TcpConnectFuture {
state,
addrs: addrs.to_socket_addrs().ok(),
idx: None,
}
}
}
pub struct StreamInner {
pub(super) socket: net::TcpStream,
stream: StreamBuf,
in_read: bool,
in_write: usize,
read_queue: Vec<BoundPacketView<Write>>,
write_queue: Vec<BoundPacketView<Read>>,
}
impl Drop for StreamInner {
fn drop(&mut self) {
let _ = self.shutdown(Shutdown::Both);
}
}
impl StreamInner {
fn shutdown(&self, how: Shutdown) -> Result<(), Error> {
let ret = unsafe {
shutdown(
SOCKET(self.socket.as_raw_socket() as _),
match how {
Shutdown::Read => SD_RECEIVE,
Shutdown::Write => SD_SEND,
Shutdown::Both => SD_BOTH,
},
)
};
if ret != 0 {
match unsafe { WSAGetLastError() } {
WSAECONNRESET => Ok(()),
v => {
log::error!("Unable to shutdown stream: {ret} {v:?}");
Err(mferr!(500, Io, Other, Network))
}
}
} else {
Ok(())
}
}
}
impl From<net::TcpStream> for StreamInner {
fn from(socket: net::TcpStream) -> Self {
Self {
socket,
stream: StreamBuf::default(),
in_read: false,
in_write: 0,
read_queue: Default::default(),
write_queue: Default::default(),
}
}
}
unsafe impl Send for StreamInner {}
unsafe impl Sync for StreamInner {}
impl TcpStreamHandle for TcpStream {
fn local_addr(&self) -> Result<SocketAddr, Error> {
let state = self.state.lock();
let stream = state
.streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
stream.socket.local_addr().map_err(from_io_error)
}
fn peer_addr(&self) -> Result<SocketAddr, Error> {
let state = self.state.lock();
let stream = state
.streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
stream.socket.peer_addr().map_err(from_io_error)
}
fn shutdown(&self, how: Shutdown) -> Result<(), Error> {
let state = self.state.lock();
let stream = state
.streams
.get(self.idx)
.ok_or_else(|| io_err(State::NotFound))?;
stream.shutdown(how)
}
}
impl StreamInner {
pub fn on_read(&mut self, res: io::Result<usize>, deferred_pkts: &mut DeferredPackets) {
log::debug!("On read {res:?} {}", self.stream.read_ops());
self.in_read = false;
self.stream.on_read(res, Some(deferred_pkts))
}
pub fn on_write(&mut self, res: io::Result<usize>, deferred_pkts: &mut DeferredPackets) {
log::debug!("On write {res:?} {}", self.stream.read_ops());
self.in_write -= 1;
self.stream.on_write(res, Some(deferred_pkts))
}
#[tracing::instrument(skip(self, ops, deferred_pkts, event))]
pub(super) fn on_queue(
&mut self,
idx: usize,
ops: &mut Vec<Operation>,
deferred_pkts: &mut DeferredPackets,
event: RawHandle,
) {
log::trace!(
"Do ops file={:?} (to read={} to write={})",
self.socket.as_raw_socket(),
self.stream.read_ops(),
self.stream.write_ops()
);
if (!self.read_queue.is_empty() || self.stream.read_ops() > 0) && !self.in_read {
let rd_span =
tracing::span!(tracing::Level::TRACE, "read", ops = self.stream.read_ops());
let _span = rd_span.enter();
for op in self.read_queue.drain(..) {
self.stream.queue_read(op, Some(deferred_pkts));
}
let queue = self.stream.read_queue();
if !queue.is_empty() {
self.in_read = true;
let hdr = OperationHeader {
overlapped: OVERLAPPED {
hEvent: HANDLE(event as _),
..Default::default()
},
idx: !0,
handle: HANDLE(self.socket.as_raw_socket() as _),
};
let operation = Operation {
header: hdr.into(),
mode: OperationMode::StreamRead(WsaOp {
bufs: queue as *const [_] as *const [WSABUF],
transferred: 0,
flags: 0,
sock_idx: idx,
}),
};
ops.push(operation);
}
}
if (!self.write_queue.is_empty() || self.stream.write_ops() > 0) && self.in_write == 0 {
let wr_span = tracing::span!(
tracing::Level::TRACE,
"write",
ops = self.stream.write_ops()
);
let _span = wr_span.enter();
for op in self.write_queue.drain(..) {
self.stream.queue_write(op, Some(deferred_pkts));
}
let queue = self.stream.write_queue();
if !queue.is_empty() {
let hdr = OperationHeader {
overlapped: OVERLAPPED {
hEvent: HANDLE(event as _),
..Default::default()
},
idx: !0,
handle: HANDLE(self.socket.as_raw_socket() as _),
};
let operation = Operation {
header: hdr.into(),
mode: OperationMode::StreamWrite(WsaOp {
bufs: queue as *const [_] as *const [WSABUF],
transferred: 0,
flags: 0,
sock_idx: idx,
}),
};
ops.push(operation);
self.in_write += 1;
}
}
}
pub fn cancel_all_ops(&mut self) {
let _ = unsafe { CancelIoEx(HANDLE(self.socket.as_raw_socket() as _), None) };
self.stream
.on_read(Err(io::ErrorKind::Interrupted.into()), None)
}
}
trait IntoOp: PacketPerms {
fn push_op(
stream: &mut StreamInner,
pkt: BoundPacketView<Self>,
deferred_pkts: &mut DeferredPackets,
);
}
impl IntoOp for Read {
fn push_op(
stream: &mut StreamInner,
pkt: BoundPacketView<Self>,
deferred_pkts: &mut DeferredPackets,
) {
if stream.in_write == 0 {
stream.stream.queue_write(pkt, Some(deferred_pkts));
} else {
stream.write_queue.push(pkt);
}
}
}
impl IntoOp for Write {
fn push_op(
stream: &mut StreamInner,
pkt: BoundPacketView<Self>,
deferred_pkts: &mut DeferredPackets,
) {
if !stream.in_read {
stream.stream.queue_read(pkt, Some(deferred_pkts));
} else {
stream.read_queue.push(pkt);
}
}
}
impl<Perms: IntoOp> PacketIo<Perms, NoPos> for TcpStream {
fn send_io(&self, _: NoPos, packet: BoundPacketView<Perms>) {
log::debug!("Send io {}", packet.len());
let mut state = self.state.lock();
let state = &mut *state;
let stream = state.streams.get_mut(self.idx).unwrap();
Perms::push_op(stream, packet, &mut state.deferred_pkts);
}
}
pub struct TcpConnectFuture<'a, A: ToSocketAddrs + 'a> {
state: &'a BaseArc<Mutex<IocpState>>,
addrs: Option<A::Iter>,
idx: Option<usize>,
}
impl<'a, A: ToSocketAddrs + 'a> Future for TcpConnectFuture<'a, A> {
type Output = mfio::error::Result<TcpStream>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let state = &mut *this.state.lock();
loop {
if let Some(idx) = this.idx {
if let Some(conn) = state.connections.get_mut(idx) {
match conn.res.take() {
Some(Ok(socket_idx)) => {
let _ = state.connections.remove(idx);
let stream = TcpStream::new(socket_idx, this.state.clone());
return Poll::Ready(Ok(stream));
}
Some(Err(_)) => {
conn.waker = Some(cx.waker().clone());
}
None => {
conn.waker = Some(cx.waker().clone());
return Poll::Pending;
}
}
} else {
return Poll::Ready(Err(io_err(State::NotFound)));
}
}
if let Some(idx) = this.idx.take() {
state.connections.remove(idx);
}
if let Some(addr) = this.addrs.as_mut().and_then(|v| v.next()) {
let Ok(connection) = TcpGetSock::new_for_connect(addr, state, cx.waker().clone())
else {
continue;
};
let handle = HANDLE(
state
.streams
.get(connection.socket_idx.unwrap())
.as_ref()
.unwrap()
.socket
.as_raw_socket() as _,
);
let &mut idx = this
.idx
.get_or_insert_with(|| state.connections.insert(connection));
let hdr = OperationHeader {
overlapped: OVERLAPPED {
hEvent: HANDLE(state.event.as_raw_handle() as _),
..Default::default()
},
idx: !0,
handle,
};
let operation = Operation {
header: hdr.into(),
mode: OperationMode::TcpConnect(TcpConnect {
conn_id: idx,
addr: addr.into(),
}),
};
match unsafe { state.try_submit_op(operation) } {
Ok(()) => {
continue;
}
Err(_) => return Poll::Pending,
}
} else {
return Poll::Ready(Err(io_err(State::Exhausted)));
}
}
}
}
pub(crate) struct TcpConnect {
pub addr: CSockAddr,
pub conn_id: usize,
}