use std::collections::VecDeque;
use std::io;
use std::net::TcpStream;
use std::sync::{Mutex, Condvar, Arc, mpsc};
use std::sync::atomic::{Ordering, AtomicBool};
use std::fmt;
use std::error;
use std::time::{Instant, Duration};
use std::thread;
use std::io::{Read, Write};
use rand;
use maidsafe_utilities::thread::RaiiThreadJoiner;
use socket_addr::SocketAddr;
use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian};
use nat_traversal::{MappingContext, tcp_punch_hole};
use nat_traversal;
use w_result::{WResult, WOk, WErr};
use crossbeam;
use void::Void;
use utils::DisplaySlice;
use endpoint::{Endpoint, ToEndpoints};
use rendezvous_info::{PubRendezvousInfo, PrivRendezvousInfo, PrivTcpInfo, PrivUdpInfo,
RENDEZVOUS_INFO_EXPIRY_DURATION_SECS};
pub enum StreamProtocolInfo {
Tcp {
local_addr: SocketAddr,
peer_addr: SocketAddr,
},
Utp {
local_addr: SocketAddr,
peer_addr: SocketAddr,
},
}
pub struct StreamInfo {
pub protocol: StreamProtocolInfo,
pub connection_id: u64,
}
impl fmt::Display for StreamInfo {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "#{:016x} [{}]", self.connection_id, self.protocol)
}
}
impl fmt::Display for StreamProtocolInfo {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StreamProtocolInfo::Tcp { local_addr, peer_addr }
=> write!(f, "[tcp {:>015} -> {:<015}]", local_addr, peer_addr),
StreamProtocolInfo::Utp { local_addr, peer_addr }
=> write!(f, "[utp {:>015} -> {:<015}]", local_addr, peer_addr),
}
}
}
struct BufferInner {
buf: VecDeque<Vec<u8>>,
running: bool,
error: Option<io::Error>,
}
struct Buffer {
inner: Mutex<BufferInner>,
condvar: Condvar,
}
enum StreamInner {
Tcp {
stream: TcpStream,
local_addr: SocketAddr,
peer_addr: SocketAddr,
},
}
pub struct Stream {
protocol_inner: StreamInner,
buffer: Arc<Buffer>,
_writer_thread: RaiiThreadJoiner,
connection_id: u64,
}
quick_error! {
#[derive(Debug)]
pub enum StreamFromTcpStreamError {
LocalAddr(err: io::Error) {
description("Error getting local address of tcp stream")
display("Error getting local address of tcp stream: {}", err)
cause(err)
}
PeerAddr(err: io::Error) {
description("Error getting peer address of tcp stream")
display("Error getting peer address of tcp stream: {}", err)
cause(err)
}
CloneStream(err: io::Error) {
description("Error cloning tcp stream")
display("Error cloning tcp stream: {}", err)
cause(err)
}
}
}
#[derive(Debug)]
pub enum StreamDirectConnectError<E: error::Error + Send + 'static> {
AllConnectionsFailed(Vec<StreamDirectConnectEndpointError<E>>),
TimedOut,
}
impl<E: error::Error + Send + 'static> error::Error for StreamDirectConnectError<E> {
fn description(&self) -> &str {
match *self {
StreamDirectConnectError::AllConnectionsFailed(..)
=> "All connection attempts failed.",
StreamDirectConnectError::TimedOut
=> "Direct connect timed out.",
}
}
fn cause(&self) -> Option<&error::Error> {
match *self {
StreamDirectConnectError::AllConnectionsFailed(ref es) => match es.first() {
Some(e) => Some(e),
None => None,
},
StreamDirectConnectError::TimedOut => None,
}
}
}
impl<E: error::Error + Send + 'static> fmt::Display for StreamDirectConnectError<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StreamDirectConnectError::AllConnectionsFailed(ref es)
=> write!(f, "All connection attempts failed: {}", DisplaySlice("errors", es)),
StreamDirectConnectError::TimedOut
=> write!(f, "Direct connect timed out"),
}
}
}
#[derive(Debug)]
pub enum StreamDirectConnectEndpointError<E: error::Error + Send + 'static> {
ParseEndpoint {
err: E,
connection_id: u64,
},
TcpConnect {
err: io::Error,
addr: SocketAddr,
connection_id: u64,
},
TcpWrite {
err: io::Error,
addr: SocketAddr,
connection_id: u64,
},
TcpRead {
err: io::Error,
addr: SocketAddr,
connection_id: u64,
},
TcpSetTimeout {
err: io::Error,
addr: SocketAddr,
connection_id: u64,
},
TcpCreateStream {
err: StreamFromTcpStreamError,
connection_id: u64,
},
HandshakeError {
connection_id: u64
},
}
impl<E: error::Error + Send + 'static> error::Error for StreamDirectConnectEndpointError<E> {
fn cause(&self) -> Option<&error::Error> {
match *self {
StreamDirectConnectEndpointError::ParseEndpoint { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::TcpConnect { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::TcpWrite { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::TcpCreateStream { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::TcpRead { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::TcpSetTimeout { ref err, .. } => Some(err),
StreamDirectConnectEndpointError::HandshakeError { .. } => None,
}
}
fn description(&self) -> &str {
match *self {
StreamDirectConnectEndpointError::ParseEndpoint { .. }
=> "Error parsing endpoint",
StreamDirectConnectEndpointError::TcpConnect { .. }
=> "Error connecting to tcp endpoint",
StreamDirectConnectEndpointError::TcpWrite { .. }
=> "Error writing to tcp stream",
StreamDirectConnectEndpointError::TcpCreateStream { .. }
=> "Error creating Stream from TcpStream",
StreamDirectConnectEndpointError::TcpRead { .. }
=> "Error reading from tcp stream",
StreamDirectConnectEndpointError::TcpSetTimeout { .. }
=> "Error setting timeout option on tcp stream",
StreamDirectConnectEndpointError::HandshakeError { .. }
=> "Handshake failed.",
}
}
}
impl<E: error::Error + Send + 'static> fmt::Display for StreamDirectConnectEndpointError<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StreamDirectConnectEndpointError::ParseEndpoint { ref err, connection_id }
=> write!(f, "Error parsing endpoint: #{:016x} {}", connection_id, err),
StreamDirectConnectEndpointError::TcpConnect { ref err, addr, connection_id }
=> write!(f, "Error connecting to tcp endpoint: {} #{:016x}: {}", addr, connection_id, err),
StreamDirectConnectEndpointError::TcpWrite { ref err, addr, connection_id }
=> write!(f, "Error writing to tcp stream: {} #{:016x}: {}", addr, connection_id, err),
StreamDirectConnectEndpointError::TcpCreateStream { ref err, connection_id }
=> write!(f, "Error creating Stream from TcpStream: #{:016x} {}", connection_id, err),
StreamDirectConnectEndpointError::TcpRead { ref err, addr, connection_id }
=> write!(f, "Error reading from tcp stream: {} #{:016x}: {}", addr, connection_id, err),
StreamDirectConnectEndpointError::TcpSetTimeout { ref err, addr, connection_id }
=> write!(f, "Error setting timeout option on tcp stream: {} #{:016x}: {}", addr, connection_id, err),
StreamDirectConnectEndpointError::HandshakeError { connection_id }
=> write!(f, "Handshake failed: #{:016x}", connection_id),
}
}
}
quick_error! {
#[derive(Debug)]
pub enum StreamRendezvousConnectWarning {
TcpPunchHole(w: nat_traversal::TcpPunchHoleWarning) {
description("Warning raised when doing tcp hole punching.")
display("Warning raised when doing tcp hole punching: {}", w)
cause(w)
}
}
}
quick_error! {
#[derive(Debug)]
pub enum StreamRendezvousConnectTcpError {
CreateStream(err: StreamFromTcpStreamError) {
description("Error creating tcp stream")
display("Error creating tcp stream: {}", err)
cause(err)
}
PunchHole(err: nat_traversal::TcpPunchHoleError) {
description("Error doing tcp hole punching")
display("Error doing tcp hole punching")
cause(err)
}
}
}
#[derive(Debug)]
pub struct StreamGenRendezvousInfoDiagnostics {
pub tcp_diags: WResult<(), nat_traversal::MappedTcpSocketMapWarning,
nat_traversal::MappedTcpSocketNewError>,
pub udp_diags: WResult<(), nat_traversal::MappedUdpSocketMapWarning,
nat_traversal::MappedUdpSocketNewError>,
}
impl fmt::Display for StreamGenRendezvousInfoDiagnostics {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
try!(write!(f, "gen_rendezvous_info diagnostic info:"));
match self.tcp_diags {
WOk((), ref ws) => try!(write!(f, " tcp {}.", DisplaySlice("warning", &ws[..]))),
WErr(ref e) => try!(write!(f, " tcp error: {}.", e)),
};
match self.udp_diags {
WOk((), ref ws) => try!(write!(f, " udp {}.", DisplaySlice("warning", &ws[..]))),
WErr(ref e) => try!(write!(f, " udp error: {}.", e)),
};
Ok(())
}
}
#[derive(Debug)]
pub enum StreamRendezvousConnectError {
Expired,
AllProtocolsFailed {
tcp_err: Option<StreamRendezvousConnectTcpError>,
direct_err: Option<StreamDirectConnectError<Void>>,
},
}
impl error::Error for StreamRendezvousConnectError {
fn cause(&self) -> Option<&error::Error> {
match *self {
StreamRendezvousConnectError::Expired => None,
StreamRendezvousConnectError::AllProtocolsFailed { ref tcp_err, ref direct_err } => {
match *direct_err {
Some(ref err) => return Some(err),
None => (),
}
match *tcp_err {
Some(ref err) => return Some(err),
None => (),
};
None
},
}
}
fn description(&self) -> &str {
match *self {
StreamRendezvousConnectError::Expired => "The supplied rendezvous info has expired",
StreamRendezvousConnectError::AllProtocolsFailed { ref tcp_err, ref direct_err } => {
match (direct_err, tcp_err) {
(&Some(..), &Some(..))
=> "Error making direct connection and tcp rendezvous connection",
(&Some(..), &None)
=> "Error making direct connection",
(&None, &Some(..))
=> "Error making tcp rendezvous connection",
(&None, &None)
=> "Could not attempt a tcp rendezvous connection due to incompatible \
rendezvous infos",
}
},
}
}
}
impl fmt::Display for StreamRendezvousConnectError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
StreamRendezvousConnectError::Expired => write!(f, "The supplied rendezvous info has expired"),
StreamRendezvousConnectError::AllProtocolsFailed { ref tcp_err, ref direct_err } => {
match (direct_err, tcp_err) {
(&Some(ref direct_err), &Some(ref tcp_err))
=> write!(f, "Error making direct connection and tcp rendezvous \
connection. Direct connect error: {}; tcp rendezvous \
connect error: {}", direct_err, tcp_err),
(&Some(ref direct_err), &None)
=> write!(f, "Error making direct connection: {}", direct_err),
(&None, &Some(ref tcp_err))
=> write!(f, "Error making tcp rendezvous connection: {}", tcp_err),
(&None, &None)
=> write!(f, "Could not attempt a tcp rendezvous connection due to \
incompatible rendezvous infos"),
}
},
}
}
}
impl Stream {
pub fn info(&self) -> StreamInfo {
match self.protocol_inner {
StreamInner::Tcp {
local_addr,
peer_addr,
..
} => StreamInfo {
protocol: StreamProtocolInfo::Tcp {
local_addr: local_addr,
peer_addr: peer_addr,
},
connection_id: self.connection_id,
}
}
}
pub fn gen_rendezvous_info(mc: &MappingContext, deadline: Instant)
-> (PrivRendezvousInfo, PubRendezvousInfo, StreamGenRendezvousInfoDiagnostics)
{
crossbeam::scope(|scope| {
let (tcp_result_tx, tcp_result_rx) = mpsc::channel();
let _ = scope.spawn(move || {
let res = nat_traversal::MappedTcpSocket::new(mc, deadline);
let _ = tcp_result_tx.send(res);
});
let (udp_result_tx, udp_result_rx) = mpsc::channel();
let _ = scope.spawn(move || {
let res = nat_traversal::MappedUdpSocket::new(mc, deadline);
let _ = udp_result_tx.send(res);
});
let mut priv_tcp_info_opt = None;
let mut pub_tcp_info_opt = None;
let tcp_diags = unwrap_result!(tcp_result_rx.recv()).map(|mapped_tcp_socket| {
let tcp_endpoints = mapped_tcp_socket.endpoints;
let (priv_tcp, pub_tcp) = nat_traversal::gen_rendezvous_info(tcp_endpoints);
priv_tcp_info_opt = Some(PrivTcpInfo {
socket: mapped_tcp_socket.socket,
info: priv_tcp,
});
pub_tcp_info_opt = Some(pub_tcp);
});
let mut priv_udp_info_opt = None;
let mut pub_udp_info_opt = None;
let udp_diags = unwrap_result!(udp_result_rx.recv()).map(|mapped_udp_socket| {
let udp_endpoints = mapped_udp_socket.endpoints;
let (priv_udp, pub_udp) = nat_traversal::gen_rendezvous_info(udp_endpoints);
priv_udp_info_opt = Some(PrivUdpInfo {
socket: mapped_udp_socket.socket,
info: priv_udp,
});
pub_udp_info_opt = Some(pub_udp);
});
let connection_id_half = rand::random();
let priv_info = PrivRendezvousInfo {
priv_tcp_info: priv_tcp_info_opt,
priv_udp_info: priv_udp_info_opt,
connection_id_half: connection_id_half,
creation_time: Instant::now(),
};
let pub_info = PubRendezvousInfo {
pub_tcp_info: pub_tcp_info_opt,
pub_udp_info: pub_udp_info_opt,
connection_id_half: connection_id_half,
static_endpoints: Vec::new(),
};
let diagnostics = StreamGenRendezvousInfoDiagnostics {
tcp_diags: tcp_diags,
udp_diags: udp_diags,
};
(priv_info, pub_info, diagnostics)
})
}
pub fn rendezvous_connect(our_priv_info: PrivRendezvousInfo,
their_pub_info: PubRendezvousInfo,
deadline: Instant)
-> WResult<Stream, StreamRendezvousConnectWarning, StreamRendezvousConnectError>
{
let conn_id_part_a = our_priv_info.connection_id_half;
let conn_id_part_b = their_pub_info.connection_id_half;
let connection_id = conn_id_part_a.wrapping_add(conn_id_part_b);
let static_endpoints = their_pub_info.static_endpoints;
if (Instant::now() - our_priv_info.creation_time) >
Duration::from_secs(RENDEZVOUS_INFO_EXPIRY_DURATION_SECS)
{
return WErr(StreamRendezvousConnectError::Expired);
}
let (direct_result_tx, direct_result_rx) = mpsc::channel();
if static_endpoints.len() > 0 {
let _ = thread!("Stream::rendezvous_connect direct connect", move || {
let result = Stream::direct_connect_inner(conn_id_part_b,
&static_endpoints[..],
deadline);
let _ = direct_result_tx.send(result);
});
}
else {
drop(direct_result_tx);
}
let (tcp_result_tx, tcp_result_rx) = mpsc::channel();
match (our_priv_info.priv_tcp_info, their_pub_info.pub_tcp_info) {
(Some(our_info), Some(their_info)) => {
let _ = thread!("Stream::rendezvous_connect tcp connect", move || {
match tcp_punch_hole(our_info.socket, our_info.info, their_info, deadline) {
WOk(tcp_stream, ws) => {
let ws = ws.into_iter()
.map(|w| StreamRendezvousConnectWarning::TcpPunchHole(w))
.collect();
let stream = match Stream::from_tcp_stream(tcp_stream, connection_id) {
Ok(stream) => stream,
Err(e) => {
let _ = tcp_result_tx.send(WErr(StreamRendezvousConnectTcpError::CreateStream(e)));
return;
},
};
let _ = tcp_result_tx.send(WOk(stream, ws));
},
WErr(e) => {
let _ = tcp_result_tx.send(WErr(StreamRendezvousConnectTcpError::PunchHole(e)));
},
};
});
},
_ => drop(tcp_result_tx),
};
trace!("Stream::rendezvous_connect waiting for direct connect result");
let direct_err = match direct_result_rx.recv() {
Ok(Ok(stream)) => {
trace!("Stream::rendezvous_connect got direct connection");
return WOk(stream, Vec::new());
},
Ok(Err(e)) => Some(e),
Err(mpsc::RecvError) => None,
};
trace!("Stream::rendezvous_connect waiting for tcp result");
let tcp_err = match tcp_result_rx.recv() {
Ok(WOk(stream, ws)) => {
trace!("Stream::rendezvous_connect got tcp connection");
return WOk(stream, ws);
},
Ok(WErr(e)) => Some(e),
Err(mpsc::RecvError) => None,
};
trace!("Stream::rendezvous_connect failed");
WErr(StreamRendezvousConnectError::AllProtocolsFailed {
tcp_err: tcp_err,
direct_err: direct_err,
})
}
pub fn direct_connect<E>(endpoints: E, deadline: Instant)
-> Result<Stream, StreamDirectConnectError<E::Err>>
where E: ToEndpoints
{
let connection_id = rand::random();
Stream::direct_connect_inner(connection_id, endpoints, deadline)
}
fn direct_connect_inner<E>(connection_id: u64, endpoints: E, deadline: Instant)
-> Result<Stream, StreamDirectConnectError<E::Err>>
where E: ToEndpoints
{
let endpoints_iter = endpoints.to_endpoints();
let stop = Arc::new(AtomicBool::new(false));
let (result_tx, result_rx) = mpsc::channel();
let mut num_endpoints = 0;
for endpoint_res in endpoints_iter {
num_endpoints += 1;
let stop = stop.clone();
let result_tx = result_tx.clone();
let _ = thread!("Stream::direct_connect connect", move || {
if stop.load(Ordering::SeqCst) {
return;
}
let endpoint = match endpoint_res {
Ok(endpoint) => endpoint,
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::ParseEndpoint {
err: e,
connection_id: connection_id,
})));
return;
},
};
match endpoint {
Endpoint::Tcp(addr) => {
if stop.load(Ordering::SeqCst) {
return;
}
let mut tcp_stream = match TcpStream::connect(&*addr) {
Ok(tcp_stream) => tcp_stream,
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpConnect {
addr: addr,
err: e,
connection_id: connection_id,
})));
return;
},
};
if stop.load(Ordering::SeqCst) {
return;
}
match tcp_stream.write_u64::<BigEndian>(connection_id) {
Ok(()) => (),
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpWrite {
addr: addr,
err: e,
connection_id: connection_id,
})));
return;
},
};
if stop.load(Ordering::SeqCst) {
return;
}
match tcp_stream.set_read_timeout(Some(Duration::from_millis(400))) {
Ok(()) => (),
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpSetTimeout {
err: e,
addr: addr,
connection_id: connection_id,
})));
return;
},
};
let recv_connection_id = match tcp_stream.read_u64::<BigEndian>() {
Ok(recv_connection_id) => recv_connection_id,
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpWrite {
addr: addr,
err: e,
connection_id: connection_id,
})));
return;
},
};
if connection_id != recv_connection_id {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::HandshakeError {
connection_id: connection_id,
})));
return;
}
match tcp_stream.set_read_timeout(None) {
Ok(()) => (),
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpSetTimeout {
err: e,
addr: addr,
connection_id: connection_id,
})));
return;
},
};
if stop.load(Ordering::SeqCst) {
return;
}
let stream = match Stream::from_tcp_stream(tcp_stream, connection_id) {
Ok(stream) => stream,
Err(e) => {
let _ = result_tx.send(Some(Err(StreamDirectConnectEndpointError::TcpCreateStream {
err: e,
connection_id: connection_id,
})));
return;
},
};
let _ = result_tx.send(Some(Ok(stream)));
},
Endpoint::Utp(..) => unimplemented!(),
}
});
}
let timeout_thread = thread!("Stream::direct_connect timeout", move || {
let now = Instant::now();
if deadline > now {
let timeout = deadline - now;
thread::park_timeout(timeout);
}
let _ = result_tx.send(None);
});
let mut errors = Vec::new();
loop {
if errors.len() == num_endpoints {
stop.store(true, Ordering::SeqCst);
timeout_thread.thread().unpark();
return Err(StreamDirectConnectError::AllConnectionsFailed(errors))
}
let result = result_rx.recv();
match result {
Ok(Some(Ok(stream))) => {
stop.store(true, Ordering::SeqCst);
timeout_thread.thread().unpark();
return Ok(stream);
},
Ok(Some(Err(e))) => {
errors.push(e);
},
Ok(None) => {
stop.store(true, Ordering::SeqCst);
timeout_thread.thread().unpark();
return Err(StreamDirectConnectError::TimedOut);
},
Err(mpsc::RecvError) => {
panic!("Connecting threads panicked!");
},
};
}
}
pub fn from_tcp_stream(stream: TcpStream, connection_id: u64) -> Result<Stream, StreamFromTcpStreamError> {
trace!("Stream::from_tcp_stream(connection_id = #{:016x})", connection_id);
let local_addr = match stream.local_addr() {
Ok(local_addr) => local_addr,
Err(e) => {
debug!("Error getting local address of tcp stream: {}", e);
return Err(StreamFromTcpStreamError::LocalAddr(e))
},
};
let peer_addr = match stream.peer_addr() {
Ok(peer_addr) => peer_addr,
Err(e) => {
debug!("Error getting peer address of tcp stream: {}", e);
return Err(StreamFromTcpStreamError::PeerAddr(e))
},
};
let mut writer_stream = match stream.try_clone() {
Ok(writer_stream) => writer_stream,
Err(e) => {
debug!("Error cloning tcp stream: {}", e);
return Err(StreamFromTcpStreamError::CloneStream(e))
},
};
let buffer = Arc::new(Buffer {
inner: Mutex::new(BufferInner {
buf: VecDeque::new(),
running: true,
error: None,
}),
condvar: Condvar::new(),
});
let buffer_cloned = buffer.clone();
let writer_thread = RaiiThreadJoiner::new(thread!("Stream tcp writer", move || {
let buffer = buffer_cloned;
loop {
trace!("tcp writer thread checking for fresh data (connection_id == #{:016x})", connection_id);
let buf;
{
let mut inner = unwrap_result!(buffer.inner.lock());
loop {
match inner.buf.pop_front() {
Some(b) => {
buf = b;
break;
},
None => {
if !inner.running {
trace!("tcp writer thread exiting normally (connection_id == #{:016x})", connection_id);
return;
}
trace!("tcp writer thread going to sleep (connection_id == #{:016x})", connection_id);
inner = unwrap_result!(buffer.condvar.wait(inner));
trace!("tcp writer thread waking up (connection_id == #{:016x})", connection_id);
},
}
};
};
let len = buf.len();
trace!("tcp writer thread writing {} bytes (connection_id == #{:016x})", len, connection_id);
match writer_stream.write_all(&buf[..]) {
Ok(()) => (),
Err(e) => {
debug!("tcp writer thread exiting due to error (connection_id == #{:016x}): {}", connection_id, e);
let mut inner = unwrap_result!(buffer.inner.lock());
inner.error = Some(e);
return;
},
}
};
}));
Ok(Stream {
protocol_inner: StreamInner::Tcp {
stream: stream,
local_addr: SocketAddr(local_addr),
peer_addr: SocketAddr(peer_addr),
},
buffer: buffer,
_writer_thread: writer_thread,
connection_id: connection_id,
})
}
}
impl Drop for Stream {
fn drop(&mut self) {
let mut inner = unwrap_result!(self.buffer.inner.lock());
inner.running = false;
self.buffer.condvar.notify_all();
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.protocol_inner {
StreamInner::Tcp { ref mut stream, .. } => {
stream.read(buf)
},
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut inner = unwrap_result!(self.buffer.inner.lock());
if let Some(e) = inner.error.take() {
inner.error = Some(io::Error::new(io::ErrorKind::BrokenPipe, "Stream has closed"));
return Err(e);
}
let len = buf.len();
inner.buf.push_back(buf.to_owned());
self.buffer.condvar.notify_all();
Ok(len)
}
fn flush(&mut self) -> io::Result<()> {
let mut inner = unwrap_result!(self.buffer.inner.lock());
if let Some(e) = inner.error.take() {
inner.error = Some(io::Error::new(io::ErrorKind::BrokenPipe, "Stream has closed"));
return Err(e);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::{Instant, Duration};
use std::net::{TcpListener, TcpStream};
use maidsafe_utilities;
use nat_traversal::MappingContext;
use stream::Stream;
use test_utils::{check_stream, bounce_stream, timebomb};
#[test]
pub fn rendezvous_connect() {
let _ = maidsafe_utilities::log::init(true);
timebomb(Duration::from_secs(12), || {
let mc = unwrap_result!(MappingContext::new().result_log());
let deadline = Instant::now() + Duration::from_secs(2);
let (priv_info_0, pub_info_0, diags) = Stream::gen_rendezvous_info(&mc, deadline);
info!("info_0: {}", diags);
let deadline = Instant::now() + Duration::from_secs(2);
let (priv_info_1, pub_info_1, diags) = Stream::gen_rendezvous_info(&mc, deadline);
info!("info_1: {}", diags);
let deadline = Instant::now() + Duration::from_secs(5);
let thread_0 = thread!("rendezvous_connect 0", move || {
let mut stream = unwrap_result!(Stream::rendezvous_connect(priv_info_0,
pub_info_1,
deadline).result_log());
bounce_stream(&mut stream);
});
let thread_1 = thread!("rendezvous_connect 1", move || {
let mut stream = unwrap_result!(Stream::rendezvous_connect(priv_info_1,
pub_info_0,
deadline).result_log());
check_stream(&mut stream);
});
unwrap_result!(thread_0.join());
unwrap_result!(thread_1.join());
})
}
#[test]
pub fn read_write_tcp() {
let _ = maidsafe_utilities::log::init(true);
timebomb(Duration::from_secs(3), || {
let listener = unwrap_result!(TcpListener::bind("127.0.0.1:0"));
let addr = unwrap_result!(listener.local_addr());
let accept_thread = thread!("accept thread", move || {
let (tcp_stream, _) = unwrap_result!(listener.accept());
let mut stream = unwrap_result!(Stream::from_tcp_stream(tcp_stream, 1234));
bounce_stream(&mut stream);
drop(stream);
trace!("acceptor thread dropped stream");
});
let tcp_stream = unwrap_result!(TcpStream::connect(&addr));
let mut stream = unwrap_result!(Stream::from_tcp_stream(tcp_stream, 5678));
check_stream(&mut stream);
trace!("connector thread dropping stream");
drop(stream);
trace!("connector thread dropped stream");
unwrap_result!(accept_thread.join());
})
}
}