use std;
use std::result::Result;
use std::io;
use std::net::SocketAddr;
use std::mem::size_of;
use bytes::{BytesMut, BigEndian, ByteOrder};
use futures::{self, Future};
use tokio_core::reactor::{Remote, Handle};
use tokio_core::net::{TcpListener, TcpStream};
use tokio_io::codec::{Encoder, Decoder};
use slog::Logger;
use serde_json;
use super::{
GameMessage,
WireMessage,
RecvWireMessage,
NewPeer,
};
type MessageLengthPrefix = u16;
struct Codec<G> {
peer_addr: SocketAddr,
log: Logger,
_phantom_game_message: std::marker::PhantomData<G>,
}
impl<G: GameMessage> Encoder for Codec<G> {
type Item = WireMessage<G>;
type Error = io::Error;
fn encode(&mut self, message: WireMessage<G>, buf: &mut BytesMut) -> Result<(), io::Error> {
use bytes::BufMut;
buf.reserve(1024 * 1024);
let length_header_index = buf.len();
buf.put_u16::<BigEndian>(0);
{
let reference = buf.by_ref();
let writer = reference.writer();
serde_json::to_writer(writer, &message).expect("Error encoding message");
}
let message_length = (buf.len() - length_header_index - size_of::<MessageLengthPrefix>()) as u16;
BigEndian::write_u16(&mut buf[length_header_index..], message_length);
Ok(())
}
}
impl<G: GameMessage> Decoder for Codec<G> {
type Item = RecvWireMessage<G>;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<RecvWireMessage<G>>, io::Error> {
if buf.len() < size_of::<MessageLengthPrefix>() {
return Ok(None);
}
let message_length = BigEndian::read_u16(buf) as usize;
if buf.len() < size_of::<MessageLengthPrefix>() + message_length {
return Ok(None);
}
buf.split_to(size_of::<MessageLengthPrefix>());
serde_json::from_slice::<WireMessage<G>>(&buf[0..message_length])
.map(|message| {
buf.split_to(message_length);
Some(RecvWireMessage {
src: self.peer_addr,
message: Result::Ok(message)
})
})
.map_err(|error| {
warn!(
self.log,
"Got a bad message from peer";
"peer_addr" => format!("{:?}", self.peer_addr),
"message_length" => message_length,
"buffer" => format!("{:?}", buf),
"error" => format!("{:?}", error)
);
io::Error::new(io::ErrorKind::Other, "Couldn't parse message")
})
}
}
pub fn start_tcp_server<G: GameMessage, MaybePort>(
parent_log: &Logger,
recv_system_sender: std::sync::mpsc::Sender<RecvWireMessage<G>>,
send_system_new_peer_sender:
std::sync::mpsc::Sender<NewPeer<G>>,
remote: Remote,
port: MaybePort
) -> u16
where MaybePort: Into<Option<u16>>
{
use futures::Stream;
let (actual_port_tx, actual_port_rx) = std::sync::mpsc::channel::<u16>();
let addr = format!("0.0.0.0:{}", port.into().unwrap_or(0));
let addr = addr.parse::<SocketAddr>().unwrap();
let server_log = parent_log.new(o!());
let server_error_log = server_log.new(o!());
remote.spawn(move |handle| {
let socket = TcpListener::bind(&addr, &handle).expect("Failed to bind server socket");
let actual_addr = socket.local_addr().expect("Socket isn't bound");
info!(server_log, "TCP server listening"; "addr" => format!("{}", actual_addr));
actual_port_tx.send(actual_addr.port()).expect("Receiver hung up");
let cloned_handle = handle.clone();
let f = socket.incoming().for_each(move |(socket, peer_addr)| {
info!(server_log, "New client connected"; "addr" => format!("{}", peer_addr));
handle_tcp_stream(
&cloned_handle,
socket,
peer_addr,
&server_log,
recv_system_sender.clone(),
send_system_new_peer_sender.clone(),
)
}).or_else(move |error| {
info!(server_error_log, "Something broke in listening for connections"; "error" => format!("{}", error));
futures::future::ok(())
});
f
});
actual_port_rx.recv().expect("Sender hung up")
}
pub fn connect_to_server<G: GameMessage>(
parent_log: &Logger,
recv_system_sender: std::sync::mpsc::Sender<RecvWireMessage<G>>,
send_system_new_peer_sender:
std::sync::mpsc::Sender<NewPeer<G>>,
remote: Remote,
addr: SocketAddr,
) -> u16 {
let (local_port_tx, local_port_rx) = std::sync::mpsc::channel::<u16>();
let client_log = parent_log.new(o!());
let client_error_log = client_log.new(o!());
remote.spawn(move |handle| {
info!(client_log, "Connecting to server"; "addr" => format!("{}", addr));
let socket_future = TcpStream::connect(&addr, &handle);
let cloned_handle = handle.clone();
let f = socket_future.and_then(move |socket| {
info!(client_log, "Connected!");
local_port_tx.send(
socket
.local_addr()
.expect("Somehow we didn't actually bind a local port?")
.port()
).expect("Receiver hung up?");
handle_tcp_stream(
&cloned_handle,
socket,
addr,
&client_log,
recv_system_sender,
send_system_new_peer_sender,
)
}).or_else(move |error| {
info!(client_error_log, "Something broke in connecting to server, or handling connection"; "error" => format!("{}", error));
futures::future::ok(())
});
f
});
local_port_rx.recv().expect("Sender hung up")
}
fn handle_tcp_stream<G: GameMessage>(
handle: &Handle,
socket: TcpStream,
peer_addr: SocketAddr,
parent_log: &Logger,
recv_system_sender: std::sync::mpsc::Sender<RecvWireMessage<G>>,
send_system_new_peer_sender: std::sync::mpsc::Sender<NewPeer<G>>,
) -> Box<Future<Item=(), Error=std::io::Error>> {
use futures::Stream;
use futures::Sink;
use tokio_io::AsyncRead;
let codec = Codec::<G>{
peer_addr: peer_addr,
log: parent_log.new(o!()),
_phantom_game_message: std::marker::PhantomData,
};
let (sink, stream) = socket.framed(codec).split();
let sink_error_log = parent_log.new(o!("peer_addr" => format!("{}", peer_addr)));
let sink = sink.sink_map_err(move |err| {
error!(sink_error_log, "Unexpected error in sending to sink"; "err" => format!("{}", err));
()
});
let (tcp_tx, tcp_rx) = futures::sync::mpsc::channel::<WireMessage<G>>(1000);
let (rtr_tx, rtr_rx) = futures::sync::oneshot::channel::<()>();
let new_peer = NewPeer {
tcp_sender: tcp_tx,
socket_addr: peer_addr,
ready_to_receive_tx: rtr_tx,
};
send_system_new_peer_sender.send(new_peer).expect("Receiver hung up?");
let tx_f = sink.send_all(tcp_rx).map(|_| ());
handle.spawn(tx_f);
let peer_server_log = parent_log.new(o!("peer_addr" => format!("{}", peer_addr)));
let peer_server_error_log = peer_server_log.clone();
let f = rtr_rx.then(|_| {
stream.filter(|recv_wire_message| {
match recv_wire_message.message {
Result::Err(_) => {
println!("Got a bad message from peer");
false
}
_ => true,
}
})
.for_each(move |recv_wire_message| {
trace!(peer_server_log, "Got recv_wire_message"; "recv_wire_message" => format!("{:?}", recv_wire_message));
recv_system_sender.send(recv_wire_message).expect("Receiver hung up?");
futures::future::ok(())
}).or_else(move |error| {
info!(peer_server_error_log, "Peer broke pipe"; "error" => format!("{}", error));
futures::future::ok(())
})
});
Box::new(f)
}
#[cfg(test)]
mod tests {
use super::*;
use std;
use std::thread;
use futures::{self, Future};
use tokio_core::reactor::Core;
use tokio_core::net::TcpStream;
use tokio_io::io::write_all;
use slog;
use bytes::BufMut;
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
struct TestMessage {}
impl GameMessage for TestMessage{}
#[test]
fn receive_corrupt_message() {
let (remote_tx, remote_rx) = std::sync::mpsc::channel::<Remote>();
thread::Builder::new()
.name("tcp_server".to_string())
.spawn(move || {
let mut reactor = Core::new().expect("Failed to create reactor for network server");
remote_tx.send(reactor.remote()).expect("Receiver hung up");
reactor.run(futures::future::empty::<(), ()>()).expect("Network server reactor failed");
}).expect("Failed to spawn server thread");
let remote = remote_rx.recv().expect("Sender hung up");
let drain = slog::Discard;
let log = slog::Logger::root(drain, o!("pk_version" => env!("CARGO_PKG_VERSION")));
let (tx, rx) = std::sync::mpsc::channel::<RecvWireMessage<TestMessage>>();
let (new_peer_tx, _new_peer_rx) = std::sync::mpsc::channel::<NewPeer<TestMessage>>();
let server_port = start_tcp_server(&log, tx, new_peer_tx, remote, None);
let connect_addr = format!("127.0.0.1:{}", server_port);
let connect_addr: SocketAddr = connect_addr.parse().unwrap();
let mut reactor = Core::new().expect("Failed to create reactor");
let handle = reactor.handle();
let socket_future = TcpStream::connect(&connect_addr, &handle);
let mut buf = BytesMut::with_capacity(1000);
let mut buf2 = BytesMut::with_capacity(1000);
let f = socket_future.and_then(|tcp_stream| {
let message = b"\"hello\"";
buf.put_u16::<BigEndian>(message.len() as u16);
buf.put_slice(message);
write_all(tcp_stream, &mut buf)
}).and_then(|stream_and_buffer| {
let tcp_stream = stream_and_buffer.0;
let message = b"{\"Game\":{}}";
buf2.put_u16::<BigEndian>(message.len() as u16);
buf2.put_slice(message);
write_all(tcp_stream, &mut buf2)
});
reactor.run(f).expect("Test reactor failed");
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(rx.try_recv(), Err(std::sync::mpsc::TryRecvError::Empty));
}
#[test]
fn receive_two_messages_in_one_segment() {
let (remote_tx, remote_rx) = std::sync::mpsc::channel::<Remote>();
thread::Builder::new()
.name("tcp_server".to_string())
.spawn(move || {
let mut reactor = Core::new().expect("Failed to create reactor for network server");
remote_tx.send(reactor.remote()).expect("Receiver hung up");
reactor.run(futures::future::empty::<(), ()>()).expect("Network server reactor failed");
}).expect("Failed to spawn server thread");
let remote = remote_rx.recv().expect("Sender hung up");
let drain = slog::Discard;
let log = slog::Logger::root(drain, o!("pk_version" => env!("CARGO_PKG_VERSION")));
let (tx, rx) = std::sync::mpsc::channel::<RecvWireMessage<TestMessage>>();
let (new_peer_tx, new_peer_rx) = std::sync::mpsc::channel::<NewPeer<TestMessage>>();
let server_port = start_tcp_server(&log, tx, new_peer_tx, remote, None);
let connect_addr = format!("127.0.0.1:{}", server_port);
let connect_addr: SocketAddr = connect_addr.parse().unwrap();
let mut reactor = Core::new().expect("Failed to create reactor");
let handle = reactor.handle();
let socket_future = TcpStream::connect(&connect_addr, &handle);
std::thread::sleep(std::time::Duration::from_millis(10));
let new_peer = new_peer_rx.try_recv().expect("Should've been a new peer connected");
new_peer.ready_to_receive_tx.send(()).expect("Receiver hung up?");
let mut buf = BytesMut::with_capacity(1000);
let f = socket_future.and_then(|tcp_stream| {
let message = b"{\"Game\":{}}";
buf.put_u16::<BigEndian>(message.len() as u16);
buf.put_slice(message);
buf.put_u16::<BigEndian>(message.len() as u16);
buf.put_slice(message);
write_all(tcp_stream, &mut buf)
});
reactor.run(f).expect("Test reactor failed");
let blink = std::time::Duration::from_millis(100);
std::thread::sleep(blink);
let recv_wire_message = rx.recv_timeout(blink).expect("Should have found our first message on the channel");
assert_eq!(recv_wire_message.message, Ok(WireMessage::Game(TestMessage{})));
let recv_wire_message = rx.recv_timeout(blink).expect("Should have found our second message on the channel");
assert_eq!(recv_wire_message.message, Ok(WireMessage::Game(TestMessage{})));
assert_eq!(rx.try_recv(), Err(std::sync::mpsc::TryRecvError::Empty));
}
}