use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use futures::{Async, Future, Poll, Stream};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_core::net::TcpStream as TokioTcpStream;
use tokio_core::reactor::Handle;
use BufClientStreamHandle;
use tcp::TcpStream;
use client::ClientStreamHandle;
#[must_use = "futures do nothing unless polled"]
pub struct TcpClientStream<S> {
tcp_stream: TcpStream<S>,
}
impl TcpClientStream<TokioTcpStream> {
pub fn new(name_server: SocketAddr,
loop_handle: &Handle)
-> (Box<Future<Item = TcpClientStream<TokioTcpStream>, Error = io::Error>>,
Box<ClientStreamHandle>) {
Self::with_timeout(name_server, loop_handle, Duration::from_secs(5))
}
pub fn with_timeout(name_server: SocketAddr,
loop_handle: &Handle,
timeout: Duration)
-> (Box<Future<Item = TcpClientStream<TokioTcpStream>, Error = io::Error>>,
Box<ClientStreamHandle>) {
let (stream_future, sender) = TcpStream::with_timeout(name_server, loop_handle, timeout);
let new_future: Box<Future<Item = TcpClientStream<TokioTcpStream>,
Error = io::Error>> =
Box::new(stream_future.map(move |tcp_stream| {
TcpClientStream { tcp_stream: tcp_stream }
}));
let sender = Box::new(BufClientStreamHandle {
name_server: name_server,
sender: sender,
});
(new_future, sender)
}
}
impl<S> TcpClientStream<S> {
pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
TcpClientStream { tcp_stream: tcp_stream }
}
}
impl<S: AsyncRead + AsyncWrite> Stream for TcpClientStream<S> {
type Item = Vec<u8>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match try_ready!(self.tcp_stream.poll()) {
Some((buffer, src_addr)) => {
let peer = self.tcp_stream.peer_addr();
if src_addr != peer {
warn!("{} does not match name_server: {}", src_addr, peer)
}
Ok(Async::Ready(Some(buffer)))
}
None => Ok(Async::Ready(None)),
}
}
}
#[cfg(test)]
use std::net::{IpAddr, Ipv4Addr};
#[cfg(not(target_os = "linux"))]
#[cfg(test)]
use std::net::Ipv6Addr;
#[test]
fn test_tcp_client_stream_ipv4() {
tcp_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
}
#[test]
#[cfg(not(target_os = "linux"))] fn test_tcp_client_stream_ipv6() {
tcp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
}
#[cfg(test)]
const TEST_BYTES: &'static [u8; 8] = b"DEADBEEF";
#[cfg(test)]
const TEST_BYTES_LEN: usize = 8;
#[cfg(test)]
fn tcp_client_stream_test(server_addr: IpAddr) {
use std::io::{Read, Write};
use tokio_core::reactor::Core;
use std;
let succeeded = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
std::thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone.clone();
for _ in 0..15 {
std::thread::sleep(std::time::Duration::from_secs(1));
if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server = std::net::TcpListener::bind(SocketAddr::new(server_addr, 0)).unwrap();
let server_addr = server.local_addr().unwrap();
let send_recv_times = 4;
let server_handle = std::thread::Builder::new()
.name("test_tcp_client_stream_ipv4:server".to_string())
.spawn(move || {
let (mut socket, _) = server.accept().expect("accept failed");
socket
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap(); socket
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
for _ in 0..send_recv_times {
let mut len_bytes = [0_u8; 2];
socket
.read_exact(&mut len_bytes)
.expect("SERVER: receive failed");
let length = (len_bytes[0] as u16) << 8 & 0xFF00 | len_bytes[1] as u16 & 0x00FF;
assert_eq!(length as usize, TEST_BYTES_LEN);
let mut buffer = [0_u8; TEST_BYTES_LEN];
socket.read_exact(&mut buffer).unwrap();
assert_eq!(&buffer, TEST_BYTES);
socket
.write_all(&len_bytes)
.expect("SERVER: send length failed");
socket
.write_all(&buffer)
.expect("SERVER: send buffer failed");
std::thread::yield_now();
}
})
.unwrap();
let mut io_loop = Core::new().unwrap();
let (stream, mut sender) = TcpClientStream::new(server_addr, &io_loop.handle());
let mut stream = io_loop
.run(stream)
.ok()
.expect("run failed to get stream");
for _ in 0..send_recv_times {
sender.send(TEST_BYTES.to_vec()).expect("send failed");
let (buffer, stream_tmp) = io_loop
.run(stream.into_future())
.ok()
.expect("future iteration run failed");
stream = stream_tmp;
let buffer = buffer.expect("no buffer received");
assert_eq!(&buffer, TEST_BYTES);
}
succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
server_handle.join().expect("server thread failed");
}