use std::fmt;
use std::io::{self, Read, Write, Cursor};
use std::mem;
use std::net::{self, SocketAddr};
use std::os::windows::prelude::*;
use std::sync::{Mutex, MutexGuard};
use net2::{self, TcpBuilder};
use net::tcp::Shutdown;
use miow::iocp::CompletionStatus;
use miow::net::*;
use winapi::*;
use {Evented, EventSet, PollOpt, Selector, Token};
use event::IoEvent;
use sys::windows::selector::{Overlapped, Registration};
use sys::windows::{wouldblock, Family};
use sys::windows::from_raw_arc::FromRawArc;
pub struct TcpStream {
imp: StreamImp,
}
pub struct TcpListener {
imp: ListenerImp,
}
#[derive(Clone)]
struct StreamImp {
inner: FromRawArc<StreamIo>,
}
#[derive(Clone)]
struct ListenerImp {
inner: FromRawArc<ListenerIo>,
}
struct StreamIo {
inner: Mutex<StreamInner>,
read: Overlapped, write: Overlapped,
}
struct ListenerIo {
inner: Mutex<ListenerInner>,
accept: Overlapped,
}
struct StreamInner {
socket: net::TcpStream,
iocp: Registration,
deferred_connect: Option<SocketAddr>,
read: State<Vec<u8>, Cursor<Vec<u8>>>,
write: State<(Vec<u8>, usize), (Vec<u8>, usize)>,
}
struct ListenerInner {
socket: net::TcpListener,
family: Family,
iocp: Registration,
accept: State<net::TcpStream, (net::TcpStream, SocketAddr)>,
accept_buf: AcceptAddrsBuf,
}
enum State<T, U> {
Empty, Pending(T), Ready(U), Error(io::Error), }
impl TcpStream {
fn new(socket: net::TcpStream,
deferred_connect: Option<SocketAddr>) -> TcpStream {
TcpStream {
imp: StreamImp {
inner: FromRawArc::new(StreamIo {
read: Overlapped::new(read_done),
write: Overlapped::new(write_done),
inner: Mutex::new(StreamInner {
socket: socket,
iocp: Registration::new(),
deferred_connect: deferred_connect,
read: State::Empty,
write: State::Empty,
}),
}),
},
}
}
pub fn connect(socket: net::TcpStream, addr: &SocketAddr)
-> io::Result<TcpStream> {
Ok(TcpStream::new(socket, Some(*addr)))
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner().socket.peer_addr()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner().socket.local_addr()
}
pub fn try_clone(&self) -> io::Result<TcpStream> {
self.inner().socket.try_clone().map(|s| TcpStream::new(s, None))
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner().socket.shutdown(how)
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
net2::TcpStreamExt::set_nodelay(&self.inner().socket, nodelay)
}
pub fn set_keepalive(&self, seconds: Option<u32>) -> io::Result<()> {
let dur = seconds.map(|s| s * 1000);
net2::TcpStreamExt::set_keepalive_ms(&self.inner().socket, dur)
}
pub fn take_socket_error(&self) -> io::Result<()> {
net2::TcpStreamExt::take_error(&self.inner().socket).and_then(|e| {
match e {
Some(e) => Err(e),
None => Ok(())
}
})
}
fn inner(&self) -> MutexGuard<StreamInner> {
self.imp.inner()
}
fn post_register(&self, interest: EventSet, me: &mut StreamInner) {
if interest.is_readable() {
self.imp.schedule_read(me);
}
if interest.is_writable() {
if let State::Empty = me.write {
me.iocp.defer(EventSet::writable());
}
}
}
}
impl StreamImp {
fn inner(&self) -> MutexGuard<StreamInner> {
self.inner.inner.lock().unwrap()
}
fn schedule_connect(&self, addr: &SocketAddr, me: &mut StreamInner)
-> io::Result<()> {
unsafe {
trace!("scheduling a connect");
try!(me.socket.connect_overlapped(addr, self.inner.read.get_mut()));
}
mem::forget(self.clone());
Ok(())
}
fn schedule_read(&self, me: &mut StreamInner) {
match me.read {
State::Empty => {}
_ => return,
}
me.iocp.unset_readiness(EventSet::readable());
let mut buf = me.iocp.get_buffer(64 * 1024);
let res = unsafe {
trace!("scheduling a read");
let cap = buf.capacity();
buf.set_len(cap);
me.socket.read_overlapped(&mut buf, self.inner.read.get_mut())
};
match res {
Ok(_) => {
me.read = State::Pending(buf);
mem::forget(self.clone());
}
Err(e) => {
let mut set = EventSet::readable();
if e.raw_os_error() == Some(WSAECONNRESET as i32) {
set = set | EventSet::hup();
}
me.read = State::Error(e);
me.iocp.defer(set);
me.iocp.put_buffer(buf);
}
}
}
fn schedule_write(&self, buf: Vec<u8>, pos: usize,
me: &mut StreamInner) {
me.iocp.unset_readiness(EventSet::writable());
trace!("scheduling a write");
let err = unsafe {
me.socket.write_overlapped(&buf[pos..], self.inner.write.get_mut())
};
match err {
Ok(_) => {
me.write = State::Pending((buf, pos));
mem::forget(self.clone());
}
Err(e) => {
me.write = State::Error(e);
me.iocp.defer(EventSet::writable());
me.iocp.put_buffer(buf);
}
}
}
fn push(&self, me: &mut StreamInner, set: EventSet,
into: &mut Vec<IoEvent>) {
if me.socket.as_raw_socket() != INVALID_SOCKET {
me.iocp.push_event(set, into);
}
}
}
fn read_done(status: &CompletionStatus, dst: &mut Vec<IoEvent>) {
let me2 = StreamImp {
inner: unsafe { overlapped2arc!(status.overlapped(), StreamIo, read) },
};
let mut me = me2.inner();
match mem::replace(&mut me.read, State::Empty) {
State::Pending(mut buf) => {
trace!("finished a read: {}", status.bytes_transferred());
unsafe {
buf.set_len(status.bytes_transferred() as usize);
}
me.read = State::Ready(Cursor::new(buf));
let mut e = EventSet::readable();
if status.bytes_transferred() == 0 {
e = e | EventSet::hup();
}
return me2.push(&mut me, e, dst)
}
s => me.read = s,
}
trace!("finished a connect");
me2.push(&mut me, EventSet::writable(), dst);
me2.schedule_read(&mut me);
}
fn write_done(status: &CompletionStatus, dst: &mut Vec<IoEvent>) {
trace!("finished a write {}", status.bytes_transferred());
let me2 = StreamImp {
inner: unsafe { overlapped2arc!(status.overlapped(), StreamIo, write) },
};
let mut me = me2.inner();
let (buf, pos) = match mem::replace(&mut me.write, State::Empty) {
State::Pending(pair) => pair,
_ => unreachable!(),
};
let new_pos = pos + (status.bytes_transferred() as usize);
if new_pos == buf.len() {
me2.push(&mut me, EventSet::writable(), dst);
} else {
me2.schedule_write(buf, new_pos, &mut me);
}
}
impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut me = self.inner();
match mem::replace(&mut me.read, State::Empty) {
State::Empty => Err(wouldblock()),
State::Pending(buf) => {
me.read = State::Pending(buf);
Err(wouldblock())
}
State::Ready(mut cursor) => {
let amt = try!(cursor.read(buf));
if cursor.position() as usize == cursor.get_ref().len() {
me.iocp.put_buffer(cursor.into_inner());
self.imp.schedule_read(&mut me);
} else {
me.read = State::Ready(cursor);
}
Ok(amt)
}
State::Error(e) => {
self.imp.schedule_read(&mut me);
Err(e)
}
}
}
}
impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut me = self.inner();
let me = &mut *me;
match me.write {
State::Empty => {}
_ => return Err(wouldblock())
}
if me.iocp.port().is_none() {
return Err(wouldblock())
}
let mut intermediate = me.iocp.get_buffer(64 * 1024);
let amt = try!(intermediate.write(buf));
self.imp.schedule_write(intermediate, 0, me);
Ok(amt)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Evented for TcpStream {
fn register(&self, selector: &mut Selector, token: Token,
interest: EventSet, opts: PollOpt) -> io::Result<()> {
let mut me = self.inner();
let me = &mut *me;
try!(me.iocp.register_socket(&me.socket, selector, token, interest,
opts));
if let Some(addr) = me.deferred_connect.take() {
return self.imp.schedule_connect(&addr, me).map(|_| ())
}
self.post_register(interest, me);
Ok(())
}
fn reregister(&self, selector: &mut Selector, token: Token,
interest: EventSet, opts: PollOpt) -> io::Result<()> {
let mut me = self.inner();
{
let me = &mut *me;
try!(me.iocp.reregister_socket(&me.socket, selector, token,
interest, opts));
}
self.post_register(interest, &mut me);
Ok(())
}
fn deregister(&self, selector: &mut Selector) -> io::Result<()> {
self.inner().iocp.checked_deregister(selector)
}
}
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
"TcpStream { ... }".fmt(f)
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
let mut inner = self.inner();
inner.socket = unsafe {
net::TcpStream::from_raw_socket(INVALID_SOCKET)
};
inner.iocp.deregister();
}
}
impl TcpListener {
pub fn new(socket: net::TcpListener, addr: &SocketAddr)
-> io::Result<TcpListener> {
Ok(TcpListener::new_family(socket, match *addr {
SocketAddr::V4(..) => Family::V4,
SocketAddr::V6(..) => Family::V6,
}))
}
fn new_family(socket: net::TcpListener, family: Family) -> TcpListener {
TcpListener {
imp: ListenerImp {
inner: FromRawArc::new(ListenerIo {
accept: Overlapped::new(accept_done),
inner: Mutex::new(ListenerInner {
socket: socket,
iocp: Registration::new(),
accept: State::Empty,
accept_buf: AcceptAddrsBuf::new(),
family: family,
}),
}),
},
}
}
pub fn accept(&self) -> io::Result<Option<(TcpStream, SocketAddr)>> {
let mut me = self.inner();
let ret = match mem::replace(&mut me.accept, State::Empty) {
State::Empty => return Ok(None),
State::Pending(t) => {
me.accept = State::Pending(t);
return Ok(None)
}
State::Ready((s, a)) => {
Ok(Some((TcpStream::new(s, None), a)))
}
State::Error(e) => Err(e),
};
self.imp.schedule_accept(&mut me);
return ret
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner().socket.local_addr()
}
pub fn try_clone(&self) -> io::Result<TcpListener> {
let inner = self.inner();
inner.socket.try_clone().map(|s| {
TcpListener::new_family(s, inner.family)
})
}
pub fn take_socket_error(&self) -> io::Result<()> {
net2::TcpListenerExt::take_error(&self.inner().socket).and_then(|e| {
match e {
Some(e) => Err(e),
None => Ok(())
}
})
}
fn inner(&self) -> MutexGuard<ListenerInner> {
self.imp.inner()
}
}
impl ListenerImp {
fn inner(&self) -> MutexGuard<ListenerInner> {
self.inner.inner.lock().unwrap()
}
fn schedule_accept(&self, me: &mut ListenerInner) {
match me.accept {
State::Empty => {}
_ => return
}
me.iocp.unset_readiness(EventSet::readable());
let res = match me.family {
Family::V4 => TcpBuilder::new_v4(),
Family::V6 => TcpBuilder::new_v6(),
}.and_then(|builder| unsafe {
trace!("scheduling an accept");
me.socket.accept_overlapped(&builder, &mut me.accept_buf,
self.inner.accept.get_mut())
});
match res {
Ok((socket, _)) => {
me.accept = State::Pending(socket);
mem::forget(self.clone());
}
Err(e) => {
me.accept = State::Error(e);
me.iocp.defer(EventSet::readable());
}
}
}
fn push(&self, me: &mut ListenerInner, set: EventSet,
into: &mut Vec<IoEvent>) {
if me.socket.as_raw_socket() != INVALID_SOCKET {
me.iocp.push_event(set, into);
}
}
}
fn accept_done(status: &CompletionStatus, dst: &mut Vec<IoEvent>) {
let me2 = ListenerImp {
inner: unsafe { overlapped2arc!(status.overlapped(), ListenerIo, accept) },
};
let mut me = me2.inner();
let socket = match mem::replace(&mut me.accept, State::Empty) {
State::Pending(s) => s,
_ => unreachable!(),
};
trace!("finished an accept");
me.accept = match me.accept_buf.parse(&me.socket) {
Ok(buf) => State::Ready((socket, buf.remote().unwrap())),
Err(e) => State::Error(e),
};
me2.push(&mut me, EventSet::readable(), dst);
}
impl Evented for TcpListener {
fn register(&self, selector: &mut Selector, token: Token,
interest: EventSet, opts: PollOpt) -> io::Result<()> {
let mut me = self.inner();
let me = &mut *me;
try!(me.iocp.register_socket(&me.socket, selector, token, interest,
opts));
self.imp.schedule_accept(me);
Ok(())
}
fn reregister(&self, selector: &mut Selector, token: Token,
interest: EventSet, opts: PollOpt) -> io::Result<()> {
let mut me = self.inner();
let me = &mut *me;
try!(me.iocp.reregister_socket(&me.socket, selector, token,
interest, opts));
self.imp.schedule_accept(me);
Ok(())
}
fn deregister(&self, selector: &mut Selector) -> io::Result<()> {
self.inner().iocp.checked_deregister(selector)
}
}
impl fmt::Debug for TcpListener {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
"TcpListener { ... }".fmt(f)
}
}
impl Drop for TcpListener {
fn drop(&mut self) {
let mut inner = self.inner();
inner.socket = unsafe {
net::TcpListener::from_raw_socket(INVALID_SOCKET)
};
inner.iocp.deregister();
}
}