use std::io::Error as IoError;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use failure::Fail;
use futures::{future, FutureExt, TryFutureExt, SinkExt, StreamExt, TryStreamExt};
use futures::channel::mpsc;
use tokio::net::{TcpStream, TcpListener};
use tokio_util::codec::Framed;
use tokio::time::Error as TimerError;
use tox_crypto::*;
use crate::relay::codec::{DecodeError, EncodeError, Codec};
use crate::relay::handshake::make_server_handshake;
use crate::relay::server::{Client, Server};
use crate::stats::*;
const TCP_PING_INTERVAL: Duration = Duration::from_secs(1);
const TCP_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const SERVER_CHANNEL_SIZE: usize = 64;
#[derive(Debug, Fail)]
pub enum ServerRunError {
#[fail(display = "Incoming IO error: {:?}", error)]
IncomingError {
#[fail(cause)]
error: IoError
},
#[fail(display = "Ping wakeups timer error: {:?}", error)]
PingWakeupsError {
error: TimerError
},
#[fail(display = "Send pings error: {:?}", error)]
SendPingsError {
#[fail(cause)]
error: IoError
},
}
#[derive(Debug, Fail)]
pub enum ConnectionError {
#[fail(display = "Failed to get peer address: {}", error)]
PeerAddrError {
#[fail(cause)]
error: IoError,
},
#[fail(display = "Failed to send TCP packet: {}", error)]
SendPacketError {
error: EncodeError
},
#[fail(display = "Failed to decode incoming packet: {}", error)]
DecodePacketError {
error: DecodeError
},
#[fail(display = "Incoming IO error: {:?}", error)]
IncomingError {
#[fail(cause)]
error: IoError
},
#[fail(display = "Server handshake error: {:?}", error)]
ServerHandshakeTimeoutError {
#[fail(cause)]
error: tokio::time::Elapsed
},
#[fail(display = "Server handshake error: {:?}", error)]
ServerHandshakeIoError {
#[fail(cause)]
error: IoError,
},
#[fail(display = "Packet handling error: {:?}", error)]
PacketHandlingError {
#[fail(cause)]
error: IoError
},
#[fail(display = "Packet handling error: {:?}", error)]
InsertClientError {
#[fail(cause)]
error: IoError
},
#[fail(display = "Packet handling error: {:?}", error)]
ShutdownError {
#[fail(cause)]
error: IoError
},
}
pub async fn tcp_run(server: &Server, mut listener: TcpListener, dht_sk: SecretKey, stats: Stats, connections_limit: usize) -> Result<(), ServerRunError> {
let connections_count = Arc::new(AtomicUsize::new(0));
let connections_future = async {
listener.incoming()
.map_err(|error| ServerRunError::IncomingError { error })
.try_for_each(|stream| {
if connections_count.load(Ordering::SeqCst) < connections_limit {
connections_count.fetch_add(1, Ordering::SeqCst);
let connections_count_c = connections_count.clone();
let dht_sk = dht_sk.clone();
let stats = stats.clone();
let server = server.clone();
tokio::spawn(
async move {
let res = tcp_run_connection(&server, stream, dht_sk, stats).await;
if let Err(ref e) = res {
error!("Error while running tcp connection: {:?}", e)
}
connections_count_c.fetch_sub(1, Ordering::SeqCst);
res
}
);
} else {
trace!("Tcp server has reached the limit of {} connections", connections_limit);
}
future::ok(())
}).await
};
let mut wakeups = tokio::time::interval(TCP_PING_INTERVAL);
let ping_future = async {
while wakeups.next().await.is_some() {
trace!("Tcp server ping sender wake up");
server.send_pings().await
.map_err(|error| ServerRunError::SendPingsError { error })?;
}
Ok(())
};
futures::select! {
res = connections_future.fuse() => res,
res = ping_future.fuse() => res,
}
}
pub async fn tcp_run_connection(server: &Server, stream: TcpStream, dht_sk: SecretKey, stats: Stats) -> Result<(), ConnectionError> {
let addr = match stream.peer_addr() {
Ok(addr) => addr,
Err(error) => return Err(ConnectionError::PeerAddrError {
error
}),
};
debug!("A new TCP client connected from {}", addr);
let fut = tokio::time::timeout(
TCP_HANDSHAKE_TIMEOUT,
make_server_handshake(stream, dht_sk.clone())
);
let (stream, channel, client_pk) = match fut.await {
Err(error) => Err(
ConnectionError::ServerHandshakeTimeoutError { error }
),
Ok(Err(error)) => Err(
ConnectionError::ServerHandshakeIoError { error }
),
Ok(Ok(res)) => Ok(res)
}?;
debug!("Handshake for TCP client {:?} is completed", client_pk);
let secure_socket = Framed::new(stream, Codec::new(channel, stats));
let (mut to_client, from_client) = secure_socket.split();
let (to_client_tx, mut to_client_rx) = mpsc::channel(SERVER_CHANNEL_SIZE);
let processor = from_client
.map_err(|error| ConnectionError::DecodePacketError { error })
.try_for_each(|packet| {
debug!("Handle {:?} => {:?}", client_pk, packet);
server.handle_packet(&client_pk, packet)
.map_err(|error| ConnectionError::PacketHandlingError { error } )
});
let writer = async {
while let Some(packet) = to_client_rx.next().await {
trace!("Sending TCP packet {:?} to {:?}", packet, client_pk);
to_client.send(packet).await
.map_err(|error| ConnectionError::SendPacketError {
error
})?;
}
Ok(())
};
let client = Client::new(
to_client_tx,
&client_pk,
addr.ip(),
addr.port()
);
server.insert(client).await
.map_err(|error| ConnectionError::InsertClientError { error })?;
let r_processing = futures::select! {
res = processor.fuse() => res,
res = writer.fuse() => res
};
debug!("Shutdown a client with PK {:?}", &client_pk);
server.shutdown_client(&client_pk, addr.ip(), addr.port())
.await
.map_err(|error| ConnectionError::ShutdownError { error })?;
r_processing
}
#[cfg(test)]
mod tests {
use super::*;
use tox_binary_io::*;
use failure::Error;
use crate::relay::codec::Codec;
use crate::relay::handshake::make_client_handshake;
use tox_packet::relay::{Packet, PingRequest, PongResponse};
use crate::relay::server::client::*;
#[tokio::test]
async fn run_connection() {
crypto_init().unwrap();
let (client_pk, client_sk) = gen_keypair();
let (server_pk, server_sk) = gen_keypair();
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut listener = TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let stats = Stats::new();
let stats_c = stats.clone();
let server = async {
let connection = listener.incoming().next().await.unwrap().unwrap();
tcp_run_connection(&Server::new(), connection, server_sk, stats.clone())
.map_err(Error::from).await
};
let client = async {
let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
.map_err(Error::from).await?;
let secure_socket = Framed::new(stream, Codec::new(channel, stats_c));
let (mut to_server, mut from_server) = secure_socket.split();
let packet = Packet::PingRequest(PingRequest {
ping_id: 42
});
to_server.send(packet).map_err(Error::from).await.unwrap();
let packet = from_server.next().await.unwrap();
assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
ping_id: 42
}));
Ok(())
};
let result = futures::select!(
res = server.fuse() => res,
res = client.fuse() => res,
);
assert!(result.is_ok());
}
#[tokio::test]
async fn run() {
tokio::time::pause();
crypto_init().unwrap();
let (client_pk, client_sk) = gen_keypair();
let (server_pk, server_sk) = gen_keypair();
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let stats = Stats::new();
let server = async {
tcp_run(&Server::new(), listener, server_sk, stats.clone(), 1).await
.map_err(Error::from)
};
let client = async {
let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
.map_err(Error::from).await?;
let secure_socket = Framed::new(stream, Codec::new(channel, stats.clone()));
let (mut to_server, mut from_server) = secure_socket.split();
let packet = Packet::PingRequest(PingRequest {
ping_id: 42
});
to_server.send(packet).map_err(Error::from).await?;
let packet = from_server.next().await.unwrap();
assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
ping_id: 42
}));
tokio::time::advance(TCP_PING_FREQUENCY + Duration::from_secs(1)).await;
while let Some(packet) = from_server.next().await {
let _ping_packet = unpack!(packet.unwrap(), Packet::PingRequest);
}
Ok(())
};
let result = futures::select!(
res = server.fuse() => res,
res = client.fuse() => res,
);
assert!(result.is_ok());
}
}