#![deny(missing_docs)]
extern crate mio;
use std::cell::RefCell;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::io;
use std::io::{Read, Write};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use mio::*;
const STOP_TOKEN: Token = Token(0);
const OBJECT_TOKEN: Token = Token(1);
#[derive(Clone, Debug)]
pub struct Canceller {
set_readiness: SetReadiness,
}
pub struct TcpStream {
stream: mio::net::TcpStream,
poll: Poll,
_stop_registration: Registration,
events: Events,
options: Arc<RwLock<TcpStreamOptions>>,
}
struct TcpStreamOptions {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
nonblocking: bool,
}
pub struct TcpListener {
listener: mio::net::TcpListener,
poll: Poll,
_stop_registration: Registration,
events: RefCell<Events>,
options: Arc<RwLock<TcpListenerOptions>>,
}
struct TcpListenerOptions {
timeout: Option<Duration>,
nonblocking: bool,
}
pub struct Incoming<'a> {
listener: &'a TcpListener,
}
pub struct UdpSocket {
socket: mio::net::UdpSocket,
poll: Poll,
_stop_registration: Registration,
events: RefCell<Events>,
options: Arc<RwLock<UdpSocketOptions>>,
}
struct UdpSocketOptions {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
nonblocking: bool,
}
impl Canceller {
pub fn cancel(&self) -> io::Result<()> {
self.set_readiness.set_readiness(Ready::readable())
}
}
fn cancelled_error() -> io::Error {
io::Error::new(io::ErrorKind::Other, "cancelled")
}
pub fn is_cancelled(e: &io::Error) -> bool {
e.kind() == io::ErrorKind::Other && e.to_string() == "cancelled"
}
impl TcpStream {
fn simple_connect(address: &SocketAddr) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let stream = mio::net::TcpStream::connect(address)?;
poll.register(
&stream,
OBJECT_TOKEN,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
Ok((
TcpStream {
stream,
poll,
_stop_registration: stop_registration,
events,
options: Arc::new(RwLock::new(TcpStreamOptions {
read_timeout: None,
write_timeout: None,
nonblocking: false,
})),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn connect<A: ToSocketAddrs>(address: A) -> io::Result<(Self, Canceller)> {
let mut error = io::Error::from(io::ErrorKind::InvalidInput);
for a in address.to_socket_addrs()? {
match Self::simple_connect(&a) {
Ok(r) => return Ok(r),
Err(e) => error = e,
}
}
Err(error)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.local_addr()
}
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
self.stream.shutdown(how)
}
pub fn try_clone(&self) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let stream = self.stream.try_clone()?;
poll.register(
&stream,
OBJECT_TOKEN,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
Ok((
TcpStream {
stream,
poll,
_stop_registration: stop_registration,
events,
options: self.options.clone(),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.options.write().unwrap().read_timeout = duration;
Ok(())
}
pub fn set_write_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.options.write().unwrap().write_timeout = duration;
Ok(())
}
pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.options.read().unwrap().read_timeout)
}
pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.options.read().unwrap().write_timeout)
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.stream.set_nodelay(nodelay)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.stream.nodelay()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.stream.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.stream.ttl()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.stream.take_error()
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.options.write().unwrap().nonblocking = nonblocking;
Ok(())
}
}
impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.poll.reregister(
&self.stream,
OBJECT_TOKEN,
Ready::readable(),
PollOpt::level(),
)?;
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut self.events, Some(Duration::from_millis(0)))?;
for event in self.events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.stream.read(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let read_timeout = self.options.read().unwrap().read_timeout;
loop {
self.poll.poll(&mut self.events, read_timeout)?;
for event in self.events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.stream.read(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if read_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
}
impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll.reregister(
&self.stream,
OBJECT_TOKEN,
Ready::writable(),
PollOpt::level(),
)?;
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut self.events, Some(Duration::from_millis(0)))?;
for event in self.events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.stream.write(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let write_timeout = self.options.read().unwrap().write_timeout;
loop {
self.poll.poll(&mut self.events, write_timeout)?;
for event in self.events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.stream.write(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if write_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
fn flush(&mut self) -> io::Result<()> {
self.poll.reregister(
&self.stream,
OBJECT_TOKEN,
Ready::writable(),
PollOpt::level(),
)?;
loop {
self.poll.poll(&mut self.events, None)?;
for event in self.events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.stream.flush();
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
}
}
}
impl Debug for TcpStream {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.stream.fmt(f)
}
}
impl TcpListener {
fn simple_bind(address: &SocketAddr) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let listener = mio::net::TcpListener::bind(address)?;
poll.register(&listener, OBJECT_TOKEN, Ready::readable(), PollOpt::level())?;
let events = Events::with_capacity(4);
Ok((
TcpListener {
listener,
poll,
_stop_registration: stop_registration,
events: RefCell::new(events),
options: Arc::new(RwLock::new(TcpListenerOptions {
timeout: None,
nonblocking: false,
})),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn bind<A: ToSocketAddrs>(address: A) -> io::Result<(Self, Canceller)> {
let mut error = io::Error::from(io::ErrorKind::InvalidInput);
for a in address.to_socket_addrs()? {
match Self::simple_bind(&a) {
Ok(r) => return Ok(r),
Err(e) => error = e,
}
}
Err(error)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub fn try_clone(&self) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let listener = self.listener.try_clone()?;
poll.register(&listener, OBJECT_TOKEN, Ready::readable(), PollOpt::level())?;
let events = Events::with_capacity(4);
Ok((
TcpListener {
listener,
poll,
_stop_registration: stop_registration,
events: RefCell::new(events),
options: self.options.clone(),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn accept(&self) -> io::Result<(TcpStream, Canceller, SocketAddr)> {
let mut events = self.events.borrow_mut();
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut events, Some(Duration::from_millis(0)))?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
let (stream, addr) = self.listener.accept()?;
let poll = Poll::new()?;
let stop_token = Token(0);
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
stop_token,
Ready::readable(),
PollOpt::edge(),
)?;
let stream_token = Token(1);
poll.register(
&stream,
stream_token,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
return Ok((
TcpStream {
stream,
poll,
_stop_registration: stop_registration,
events,
options: Arc::new(RwLock::new(TcpStreamOptions {
read_timeout: None,
write_timeout: None,
nonblocking: false,
})),
},
Canceller {
set_readiness: stop_set_readiness,
},
addr,
));
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let timeout = self.options.read().unwrap().timeout;
loop {
self.poll.poll(&mut events, timeout)?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
let (stream, addr) = self.listener.accept()?;
let poll = Poll::new()?;
let stop_token = Token(0);
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
stop_token,
Ready::readable(),
PollOpt::edge(),
)?;
let stream_token = Token(1);
poll.register(
&stream,
stream_token,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
return Ok((
TcpStream {
stream,
poll,
_stop_registration: stop_registration,
events,
options: Arc::new(RwLock::new(TcpStreamOptions {
read_timeout: None,
write_timeout: None,
nonblocking: false,
})),
},
Canceller {
set_readiness: stop_set_readiness,
},
addr,
));
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
pub fn incoming(&self) -> Incoming {
Incoming { listener: self }
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.listener.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.listener.ttl()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.listener.take_error()
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.options.write().unwrap().nonblocking = nonblocking;
Ok(())
}
}
impl Debug for TcpListener {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.listener.fmt(f)
}
}
impl<'a> Iterator for Incoming<'a> {
type Item = io::Result<(TcpStream, Canceller, SocketAddr)>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.listener.accept())
}
}
impl UdpSocket {
fn simple_bind(address: &SocketAddr) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let socket = mio::net::UdpSocket::bind(address)?;
poll.register(
&socket,
OBJECT_TOKEN,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
Ok((
UdpSocket {
socket,
poll,
_stop_registration: stop_registration,
events: RefCell::new(events),
options: Arc::new(RwLock::new(UdpSocketOptions {
read_timeout: None,
write_timeout: None,
nonblocking: false,
})),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn bind<A: ToSocketAddrs>(address: A) -> io::Result<(Self, Canceller)> {
let mut error = io::Error::from(io::ErrorKind::InvalidInput);
for a in address.to_socket_addrs()? {
match Self::simple_bind(&a) {
Ok(r) => return Ok(r),
Err(e) => error = e,
}
}
Err(error)
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.poll.reregister(
&self.socket,
OBJECT_TOKEN,
Ready::readable(),
PollOpt::level(),
)?;
let mut events = self.events.borrow_mut();
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut events, Some(Duration::from_millis(0)))?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.socket.recv_from(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let read_timeout = self.options.read().unwrap().read_timeout;
loop {
self.poll.poll(&mut events, read_timeout)?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.socket.recv_from(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if read_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
pub fn send_to(&self, buf: &[u8], addr: &SocketAddr) -> io::Result<usize> {
self.poll.reregister(
&self.socket,
OBJECT_TOKEN,
Ready::writable(),
PollOpt::level(),
)?;
let mut events = self.events.borrow_mut();
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut events, Some(Duration::from_millis(0)))?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.socket.send_to(buf, addr);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let write_timeout = self.options.read().unwrap().write_timeout;
loop {
self.poll.poll(&mut events, write_timeout)?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.socket.send_to(buf, addr);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if write_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
pub fn try_clone(&self) -> io::Result<(Self, Canceller)> {
let poll = Poll::new()?;
let (stop_registration, stop_set_readiness) = Registration::new2();
poll.register(
&stop_registration,
STOP_TOKEN,
Ready::readable(),
PollOpt::edge(),
)?;
let socket = self.socket.try_clone()?;
poll.register(
&socket,
OBJECT_TOKEN,
Ready::readable() | Ready::writable(),
PollOpt::level(),
)?;
let events = Events::with_capacity(4);
Ok((
UdpSocket {
socket,
poll,
_stop_registration: stop_registration,
events: RefCell::new(events),
options: self.options.clone(),
},
Canceller {
set_readiness: stop_set_readiness,
},
))
}
pub fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.options.write().unwrap().read_timeout = duration;
Ok(())
}
pub fn set_write_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.options.write().unwrap().write_timeout = duration;
Ok(())
}
pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.options.read().unwrap().read_timeout)
}
pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
Ok(self.options.read().unwrap().write_timeout)
}
pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
self.socket.set_broadcast(broadcast)
}
pub fn broadcast(&self) -> io::Result<bool> {
self.socket.broadcast()
}
pub fn set_multicast_loop_v4(&self, multicast_loop: bool) -> io::Result<()> {
self.socket.set_multicast_loop_v4(multicast_loop)
}
pub fn multicast_loop_v4(&self) -> io::Result<bool> {
self.socket.multicast_loop_v4()
}
pub fn set_multicast_ttl_v4(&self, multicast_ttl: u32) -> io::Result<()> {
self.socket.set_multicast_ttl_v4(multicast_ttl)
}
pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
self.socket.multicast_ttl_v4()
}
pub fn set_multicast_loop_v6(&self, multicast_loop: bool) -> io::Result<()> {
self.socket.set_multicast_loop_v6(multicast_loop)
}
pub fn multicast_loop_v6(&self) -> io::Result<bool> {
self.socket.multicast_loop_v6()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.socket.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.socket.ttl()
}
pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.socket.join_multicast_v4(multiaddr, interface)
}
pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.socket.join_multicast_v6(multiaddr, interface)
}
pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.socket.leave_multicast_v4(multiaddr, interface)
}
pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.socket.leave_multicast_v6(multiaddr, interface)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.socket.take_error()
}
pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
self.socket.set_only_v6(only_v6)
}
pub fn only_v6(&self) -> io::Result<bool> {
self.socket.only_v6()
}
fn simple_connect(&self, address: SocketAddr) -> io::Result<()> {
self.socket.connect(address)
}
pub fn connect<A: ToSocketAddrs>(&self, address: A) -> io::Result<()> {
let mut error = io::Error::from(io::ErrorKind::InvalidInput);
for a in address.to_socket_addrs()? {
match self.simple_connect(a) {
Ok(r) => return Ok(r),
Err(e) => error = e,
}
}
Err(error)
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.options.write().unwrap().nonblocking = nonblocking;
Ok(())
}
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.poll.reregister(
&self.socket,
OBJECT_TOKEN,
Ready::readable(),
PollOpt::level(),
)?;
let mut events = self.events.borrow_mut();
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut events, Some(Duration::from_millis(0)))?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.socket.recv(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let read_timeout = self.options.read().unwrap().read_timeout;
loop {
self.poll.poll(&mut events, read_timeout)?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_readable() {
return self.socket.recv(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if read_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
self.poll.reregister(
&self.socket,
OBJECT_TOKEN,
Ready::writable(),
PollOpt::level(),
)?;
let mut events = self.events.borrow_mut();
if self.options.read().unwrap().nonblocking {
self.poll
.poll(&mut events, Some(Duration::from_millis(0)))?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.socket.send(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let write_timeout = self.options.read().unwrap().write_timeout;
loop {
self.poll.poll(&mut events, write_timeout)?;
for event in events.iter() {
let t = event.token();
if t == OBJECT_TOKEN {
if event.readiness().is_writable() {
return self.socket.send(buf);
}
} else if t == STOP_TOKEN {
return Err(cancelled_error());
}
}
if write_timeout.is_some() {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
}
}
}
impl Debug for UdpSocket {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.socket.fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
#[test]
fn test_is_cancelled() {
assert_eq!(
is_cancelled(&io::Error::new(io::ErrorKind::Interrupted, "")),
false
);
assert_eq!(
is_cancelled(&io::Error::new(io::ErrorKind::Other, "")),
false
);
assert_eq!(is_cancelled(&cancelled_error()), true);
}
#[test]
fn test_simple_connection() {
let (listener, listener_canceller) = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = thread::spawn(move || {
for r in listener.incoming() {
match r {
Ok((mut stream, _canceller, _addr)) => {
thread::spawn(move || {
let mut buf = [0; 3];
stream.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"foo");
stream.write_all(b"bar").unwrap();
});
}
Err(ref e) if is_cancelled(e) => break,
Err(ref e) => panic!("{:?}", e),
}
}
});
for _ in 0..3 {
let (mut stream, _stream_canceller) = TcpStream::connect(&addr).unwrap();
stream.write_all(b"foo").unwrap();
stream.flush().unwrap();
let mut buf = Vec::new();
assert_eq!(stream.read_to_end(&mut buf).unwrap(), 3);
assert_eq!(&buf[..], b"bar");
}
listener_canceller.cancel().unwrap();
handle.join().unwrap();
}
#[test]
fn test_cancel_stream() {
let (listener, _listener_canceller) = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server = thread::spawn(move || {
let (mut stream, _canceller, _addr) = listener.accept().unwrap();
let mut buf = [0; 16];
assert_eq!(stream.read(&mut buf).unwrap(), 0);
});
let (mut stream, stream_canceller) = TcpStream::connect(&addr).unwrap();
let client = thread::spawn(move || {
let mut buf = [0; 16];
assert!(is_cancelled(&stream.read(&mut buf).unwrap_err()));
});
thread::sleep(Duration::from_secs(1));
stream_canceller.cancel().unwrap();
client.join().unwrap();
server.join().unwrap();
}
#[test]
fn test_non_blocking() {
let barrier = Arc::new(Barrier::new(3));
let barrier_server = barrier.clone();
let (listener, _listener_canceller) = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server = thread::spawn(move || {
let (mut stream, _canceller, _addr) = listener.accept().unwrap();
barrier_server.wait();
stream.write(b"foo").unwrap();
barrier_server.wait();
barrier_server.wait();
barrier_server.wait();
let mut buf = [0; 16];
assert_eq!(stream.read(&mut buf).unwrap(), 0);
});
let barrier_client = barrier.clone();
let (mut stream, stream_canceller) = TcpStream::connect(&addr).unwrap();
stream.set_nonblocking(true).unwrap();
let client = thread::spawn(move || {
let mut buf = [0; 3];
assert_eq!(
stream.read(&mut buf).unwrap_err().kind(),
io::ErrorKind::WouldBlock
);
barrier_client.wait();
barrier_client.wait();
stream.read_exact(&mut buf).unwrap();
assert_eq!(&buf, b"foo");
barrier_client.wait();
barrier_client.wait();
assert!(is_cancelled(&stream.read(&mut buf).unwrap_err()));
});
barrier.wait();
barrier.wait();
barrier.wait();
stream_canceller.cancel().unwrap();
barrier.wait();
client.join().unwrap();
server.join().unwrap();
}
#[test]
fn test_udp() {
let barrier = Arc::new(Barrier::new(3));
let (socket1, canceller1) = UdpSocket::bind("127.0.0.1:0").unwrap();
let (socket2, canceller2) = UdpSocket::bind("127.0.0.1:0").unwrap();
let address1 = socket1.local_addr().unwrap();
let address2 = socket2.local_addr().unwrap();
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();
let thread1 = thread::spawn(move || {
let mut buf = [0; 16];
assert_eq!(socket1.recv_from(&mut buf).unwrap(), (3, address2));
assert_eq!(socket1.send_to(b"bar", &address2).unwrap(), 3);
barrier1.wait();
assert!(is_cancelled(&socket1.recv_from(&mut buf).unwrap_err()));
});
let thread2 = thread::spawn(move || {
assert_eq!(socket2.send_to(b"foo", &address1).unwrap(), 3);
let mut buf = [0; 16];
assert_eq!(socket2.recv_from(&mut buf).unwrap(), (3, address1));
barrier2.wait();
assert!(is_cancelled(&socket2.recv_from(&mut buf).unwrap_err()));
});
barrier.wait();
canceller1.cancel().unwrap();
canceller2.cancel().unwrap();
thread1.join().unwrap();
thread2.join().unwrap();
}
}