use tox_crypto::*;
use tox_packet::onion::InnerOnionResponse;
use crate::relay::server::client::Client;
use tox_packet::relay::connection_id::ConnectionId;
use crate::relay::links::*;
use tox_packet::relay::*;
use std::io::{Error, ErrorKind};
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use futures::SinkExt;
use futures::channel::mpsc;
use tokio::sync::RwLock;
#[derive(Default, Clone)]
pub struct Server {
state: Arc<RwLock<ServerState>>,
onion_sink: Option<mpsc::Sender<(OnionRequest, SocketAddr)>>,
}
#[derive(Default)]
struct ServerState {
pub connected_clients: HashMap<PublicKey, Client>,
pub keys_by_addr: HashMap<(IpAddr, /*port*/ u16), PublicKey>,
}
impl Server {
pub fn new() -> Server {
Server::default()
}
pub fn set_udp_onion_sink(&mut self, onion_sink: mpsc::Sender<(OnionRequest, SocketAddr)>) {
self.onion_sink = Some(onion_sink)
}
pub async fn insert(&self, client: Client) -> Result<(), Error> {
let mut state = self.state.write().await;
if state.connected_clients.contains_key(&client.pk()) {
self.shutdown_client_inner(&client.pk(), &mut state).await?;
}
state.keys_by_addr
.insert((client.ip_addr(), client.port()), client.pk());
state.connected_clients
.insert(client.pk(), client);
Ok(())
}
pub async fn handle_packet(&self, pk: &PublicKey, packet: Packet) -> Result<(), Error> {
match packet {
Packet::RouteRequest(packet) => self.handle_route_request(pk, &packet).await,
Packet::RouteResponse(packet) => self.handle_route_response(pk, &packet).await,
Packet::ConnectNotification(packet) => self.handle_connect_notification(pk, &packet).await,
Packet::DisconnectNotification(packet) => self.handle_disconnect_notification(pk, &packet).await,
Packet::PingRequest(packet) => self.handle_ping_request(pk, &packet).await,
Packet::PongResponse(packet) => self.handle_pong_response(pk, &packet).await,
Packet::OobSend(packet) => self.handle_oob_send(pk, packet).await,
Packet::OobReceive(packet) => self.handle_oob_receive(pk, &packet).await,
Packet::OnionRequest(packet) => self.handle_onion_request(pk, packet).await,
Packet::OnionResponse(packet) => self.handle_onion_response(pk, &packet).await,
Packet::Data(packet) => self.handle_data(pk, packet).await,
}
}
pub async fn handle_udp_onion_response(&self, ip_addr: IpAddr, port: u16, payload: InnerOnionResponse) -> Result<(), Error> {
let state = self.state.read().await;
if let Some(client) = state.keys_by_addr.get(&(ip_addr, port)).and_then(|pk| state.connected_clients.get(pk)) {
client.send_onion_response(payload).await
} else {
Err(Error::new(ErrorKind::Other, "Cannot find client by ip_addr to send onion response"))
}
}
pub async fn shutdown_client(&self, pk: &PublicKey, ip_addr: IpAddr, port: u16) -> Result<(), Error> {
let mut state = self.state.write().await;
if let Some(client) = state.connected_clients.get(pk) {
if client.ip_addr() != ip_addr || client.port() != port {
return Err(Error::new(ErrorKind::Other, "Client with pk has different address"))
}
} else {
return Err(Error::new(ErrorKind::Other, "Cannot find client by pk to shutdown it"))
}
self.shutdown_client_inner(pk, &mut state).await
}
async fn shutdown_client_inner(&self, pk: &PublicKey, state: &mut ServerState) -> Result<(), Error> {
let client_a = if let Some(client) = state.connected_clients.remove(pk) {
client
} else {
return Err(Error::new(
ErrorKind::Other,
"Cannot find client by pk to shutdown it"
))
};
state.keys_by_addr.remove(&(client_a.ip_addr(), client_a.port()));
let links = client_a.links();
for link in links.iter_links() {
match link.status {
LinkStatus::Registered => {
},
LinkStatus::Online => {
let client_b_pk = link.pk;
if let Some(client_b) = state.connected_clients.get_mut(&client_b_pk) {
if let Some(a_id_in_client_b) = client_b.links().id_by_pk(pk) {
client_b.links_mut().downgrade(a_id_in_client_b);
client_b.send_disconnect_notification(
ConnectionId::from_index(a_id_in_client_b)
).await;
}
}
}
}
}
Ok(())
}
async fn handle_route_request(&self, pk: &PublicKey, packet: &RouteRequest) -> Result<(), Error> {
let mut state = self.state.write().await;
let client_a =
if let Some(client) = state.connected_clients.get_mut(pk) {
client
} else {
return Err(Error::new(ErrorKind::Other, "RouteRequest: no such PK"))
};
if pk == &packet.pk {
return client_a.send_route_response(pk, ConnectionId::zero()).await
}
if let Some(index) = client_a.links().id_by_pk(&packet.pk) {
return client_a.send_route_response(&packet.pk, ConnectionId::from_index(index)).await
}
let b_id_in_client_a = if let Some(index) = client_a.links_mut().insert(&packet.pk) {
index
} else {
return client_a.send_route_response(&packet.pk, ConnectionId::zero()).await
};
client_a.send_route_response(&packet.pk, ConnectionId::from_index(b_id_in_client_a)).await?;
let client_b = if let Some(client) = state.connected_clients.get(&packet.pk) {
client
} else {
return Ok(())
};
let a_id_in_client_b = if let Some(index) = client_b.links().id_by_pk(pk) {
index
} else {
return Ok(())
};
let client_a = state.connected_clients.get_mut(pk).unwrap();
client_a.links_mut().upgrade(b_id_in_client_a);
client_a.send_connect_notification(ConnectionId::from_index(b_id_in_client_a)).await;
let client_b = state.connected_clients.get_mut(&packet.pk).unwrap();
client_b.links_mut().upgrade(a_id_in_client_b);
client_b.send_connect_notification(ConnectionId::from_index(a_id_in_client_b)).await;
Ok(())
}
async fn handle_route_response(&self, _pk: &PublicKey, _packet: &RouteResponse) -> Result<(), Error> {
Err(Error::new(ErrorKind::Other, "Client must not send RouteResponse to server"))
}
async fn handle_connect_notification(&self, _pk: &PublicKey, _packet: &ConnectNotification) -> Result<(), Error> {
Ok(())
}
async fn handle_disconnect_notification(&self, pk: &PublicKey, packet: &DisconnectNotification) -> Result<(), Error> {
let index = if let Some(index) = packet.connection_id.index() {
index
} else {
return Err(Error::new(ErrorKind::Other, "DisconnectNotification: connection id is zero"))
};
let mut state = self.state.write().await;
let a_link = if let Some(client_a) = state.connected_clients.get_mut(pk) {
if let Some(link) = client_a.links_mut().take(index) {
link
} else {
trace!("DisconnectNotification.connection_id is not linked for the client {:?}", pk);
return Ok(())
}
} else {
return Err(Error::new(ErrorKind::Other, "DisconnectNotification: no such PK"))
};
match a_link.status {
LinkStatus::Registered => {
Ok(())
},
LinkStatus::Online => {
let client_b_pk = a_link.pk;
let client_b = if let Some(client) = state.connected_clients.get_mut(&client_b_pk) {
client
} else {
return Ok(())
};
let a_id_in_client_b = if let Some(id) = client_b.links().id_by_pk(pk) {
id
} else {
return Ok(())
};
client_b.links_mut().downgrade(a_id_in_client_b);
client_b.send_disconnect_notification(ConnectionId::from_index(a_id_in_client_b)).await;
Ok(())
}
}
}
async fn handle_ping_request(&self, pk: &PublicKey, packet: &PingRequest) -> Result<(), Error> {
if packet.ping_id == 0 {
return Err(Error::new(ErrorKind::Other, "PingRequest.ping_id == 0"))
}
let state = self.state.read().await;
if let Some(client_a) = state.connected_clients.get(pk) {
client_a.send_pong_response(packet.ping_id).await
} else {
Err(Error::new(ErrorKind::Other, "PingRequest: no such PK"))
}
}
async fn handle_pong_response(&self, pk: &PublicKey, packet: &PongResponse) -> Result<(), Error> {
if packet.ping_id == 0 {
return Err(
Error::new(ErrorKind::Other,
"PongResponse.ping_id == 0"
))
}
let mut state = self.state.write().await;
if let Some(client_a) = state.connected_clients.get_mut(pk) {
if packet.ping_id == client_a.ping_id() {
client_a.set_last_pong_resp(Instant::now());
Ok(())
} else {
Err(Error::new(ErrorKind::Other, "PongResponse.ping_id does not match"))
}
} else {
Err(Error::new(ErrorKind::Other, "PongResponse: no such PK"))
}
}
async fn handle_oob_send(&self, pk: &PublicKey, packet: OobSend) -> Result<(), Error> {
if packet.data.is_empty() || packet.data.len() > 1024 {
return Err(Error::new(ErrorKind::Other, "OobSend wrong data length"))
}
let state = self.state.read().await;
if let Some(client_b) = state.connected_clients.get(&packet.destination_pk) {
client_b.send_oob(pk, packet.data).await;
}
Ok(())
}
async fn handle_oob_receive(&self, _pk: &PublicKey, _packet: &OobReceive) -> Result<(), Error> {
Err(Error::new(ErrorKind::Other, "Client must not send OobReceive to server"))
}
async fn handle_onion_request(&self, pk: &PublicKey, packet: OnionRequest) -> Result<(), Error> {
if let Some(ref onion_sink) = self.onion_sink {
let state = self.state.read().await;
if let Some(client) = state.connected_clients.get(&pk) {
let saddr = SocketAddr::new(client.ip_addr(), client.port());
let mut tx = onion_sink.clone();
tx .send((packet, saddr)).await
.map_err(|_| {
Error::from(ErrorKind::UnexpectedEof)
})
} else {
Err(Error::new(ErrorKind::Other, "OnionRequest: no such PK"))
}
} else {
Ok(())
}
}
async fn handle_onion_response(&self, _pk: &PublicKey, _packet: &OnionResponse) -> Result<(), Error> {
Err(Error::new(ErrorKind::Other, "Client must not send OnionResponse to server"))
}
async fn handle_data(&self, pk: &PublicKey, packet: Data) -> Result<(), Error> {
let index = if let Some(index) = packet.connection_id.index() {
index
} else {
return Err(Error::new(ErrorKind::Other, "Data: connection id is zero"))
};
let state = self.state.read().await;
let client_a = if let Some(client) = state.connected_clients.get(pk) {
client
} else {
return Err(Error::new(ErrorKind::Other, "Data: no such PK"));
};
let a_link = if let Some(link) = client_a.links().by_id(index) {
*link
} else {
trace!("Data.connection_id is not linked for the client {:?}", pk);
return Ok(())
};
match a_link.status {
LinkStatus::Registered => {
Ok(())
},
LinkStatus::Online => {
let client_b_pk = a_link.pk;
let client_b = if let Some(client) = state.connected_clients.get(&client_b_pk) {
client
} else {
return Ok(())
};
let a_id_in_client_b = if let Some(id) = client_b.links().id_by_pk(pk) {
id
} else {
return Ok(())
};
client_b.send_data(ConnectionId::from_index(a_id_in_client_b), packet.data).await
}
}
}
async fn remove_timedout_clients(&self, state: &mut ServerState) -> Result<(), Error> {
let keys = state.connected_clients.iter()
.filter(|(_key, client)| client.is_pong_timedout())
.map(|(key, _client)| *key)
.collect::<Vec<PublicKey>>();
for key in keys {
self.shutdown_client_inner(&key, state).await.ok();
}
Ok(())
}
pub async fn send_pings(&self) -> Result<(), Error> {
let mut state = self.state.write().await;
self.remove_timedout_clients(&mut state).await?;
for client in state.connected_clients.values_mut() {
if client.is_ping_interval_passed() {
client.send_ping_request().await.ok();
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tox_packet::dht::CryptoData;
use tox_packet::ip_port::*;
use tox_packet::onion::*;
use crate::relay::server::{Client, Server};
use crate::relay::server::client::*;
use futures::channel::mpsc;
use futures::StreamExt;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use crate::time::*;
#[tokio::test]
async fn server_is_clonable() {
crypto_init().unwrap();
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
server.insert(client_1).await.unwrap();
let _cloned = server.clone();
}
fn create_random_client(saddr: SocketAddr) -> (Client, mpsc::Receiver<Packet>) {
crypto_init().unwrap();
let (client_pk, _) = gen_keypair();
let (tx, rx) = mpsc::channel(32);
let client = Client::new(tx, &client_pk, saddr.ip(), saddr.port());
(client, rx)
}
#[tokio::test]
async fn normal_communication_scenario() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
let client_ip_addr_1 = client_1.ip_addr();
let client_port_1 = client_1.port();
server.insert(client_1).await.unwrap();
let (client_2, rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_a = &state.connected_clients[&client_pk_1];
let link_id = client_a.links().id_by_pk(&client_pk_2).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Registered);
}
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_b = &state.connected_clients[&client_pk_2];
assert!(client_b.links().id_by_pk(&client_pk_1).is_none());
}
server.handle_packet(&client_pk_2, Packet::RouteRequest(
RouteRequest { pk: client_pk_1 }
)).await.unwrap();
let (packet, rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_1, connection_id: ConnectionId::from_index(0) }
));
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::ConnectNotification(
ConnectNotification { connection_id: ConnectionId::from_index(0) }
));
let (packet, rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::ConnectNotification(
ConnectNotification { connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_a = &state.connected_clients[&client_pk_1];
let link_id = client_a.links().id_by_pk(&client_pk_2).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Online);
let client_b = &state.connected_clients[&client_pk_2];
let link_id = client_b.links().id_by_pk(&client_pk_1).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Online);
}
server.handle_packet(&client_pk_1, Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await.unwrap();
let (packet, rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
));
server.shutdown_client(&client_pk_1, client_ip_addr_1, client_port_1).await.unwrap();
let (packet, _rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
));
let state = server.state.read().await;
let client_b = &state.connected_clients[&client_pk_2];
assert_eq!(client_b.links().by_id(0).unwrap().status, LinkStatus::Registered);
}
#[tokio::test]
async fn handle_route_request() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_a = &state.connected_clients[&client_pk_1];
let link_id = client_a.links().id_by_pk(&client_pk_2).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Registered);
let client_b = &state.connected_clients[&client_pk_2];
assert!(client_b.links().id_by_pk(&client_pk_1).is_none());
}
}
#[tokio::test]
async fn handle_route_request_to_itself() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_1 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_1, connection_id: ConnectionId::zero() }
));
}
#[tokio::test]
async fn handle_route_request_too_many_connections() {
let server = Server::new();
let (client_1, mut rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
for i in 0..240 {
let saddr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12346 + u16::from(i));
let (other_client, _other_rx) = create_random_client(saddr);
let other_client_pk = other_client.pk();
server.insert(other_client).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: other_client_pk }
)).await.unwrap();
let (packet, rx_1_nested) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: other_client_pk, connection_id: ConnectionId::from_index(i) }
));
rx_1 = rx_1_nested;
}
let (other_client, _other_rx) = create_random_client("1.2.3.5:12345".parse().unwrap());
let other_client_pk = other_client.pk();
server.insert(other_client).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: other_client_pk }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: other_client_pk, connection_id: ConnectionId::zero() }
));
}
#[tokio::test]
async fn handle_connect_notification() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::ConnectNotification(
ConnectNotification { connection_id: ConnectionId::from_index(42) }
)).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_disconnect_notification() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
server.handle_packet(&client_pk_2, Packet::RouteRequest(
RouteRequest { pk: client_pk_1 }
)).await.unwrap();
let (packet, rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_1, connection_id: ConnectionId::from_index(0) }
));
let (packet, rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::ConnectNotification(
ConnectNotification { connection_id: ConnectionId::from_index(0) }
));
let (packet, rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::ConnectNotification(
ConnectNotification { connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_a = &state.connected_clients[&client_pk_1];
let link_id = client_a.links().id_by_pk(&client_pk_2).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Online);
let client_b = &state.connected_clients[&client_pk_2];
let link_id = client_b.links().id_by_pk(&client_pk_1).unwrap();
assert_eq!(client_a.links().by_id(link_id).unwrap().status, LinkStatus::Online);
}
server.handle_packet(&client_pk_1, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
)).await.unwrap();
let (packet, _rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
));
{
let state = server.state.read().await;
let client_a = &state.connected_clients[&client_pk_1];
assert!(client_a.links().id_by_pk(&client_pk_2).is_none());
let client_b = &state.connected_clients[&client_pk_2];
let link_id = client_b.links().id_by_pk(&client_pk_1).unwrap();
assert_eq!(client_b.links().by_id(link_id).unwrap().status, LinkStatus::Registered);
}
server.handle_packet(&client_pk_2, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
)).await.unwrap();
{
let state = server.state.read().await;
let client_b = &state.connected_clients[&client_pk_2];
assert!(client_b.links().id_by_pk(&client_pk_2).is_none());
}
drop(server);
assert!(rx_1.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_disconnect_notification_other_not_linked() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
)).await;
assert!(handle_res.is_ok());
drop(server);
assert!(rx_2.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_disconnect_notification_0() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::zero() }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_ping_request() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
server.handle_packet(&client_pk_1, Packet::PingRequest(
PingRequest { ping_id: 42 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::PongResponse(
PongResponse { ping_id: 42 }
));
}
#[tokio::test]
async fn handle_oob_send() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::OobSend(
OobSend { destination_pk: client_pk_2, data: vec![13; 1024] }
)).await.unwrap();
let (packet, _rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::OobReceive(
OobReceive { sender_pk: client_pk_1, data: vec![13; 1024] }
));
}
#[tokio::test]
async fn handle_onion_request() {
crypto_init().unwrap();
let (udp_onion_sink, udp_onion_stream) = mpsc::channel(1);
let mut server = Server::new();
server.set_udp_onion_sink(udp_onion_sink);
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
let client_addr_1 = client_1.ip_addr();
let client_port_1 = client_1.port();
server.insert(client_1).await.unwrap();
let request = OnionRequest {
nonce: gen_nonce(),
ip_port: IpPort {
protocol: ProtocolType::TCP,
ip_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
port: 12345,
},
temporary_pk: gen_keypair().0,
payload: vec![13; 170]
};
let handle_res = server
.handle_packet(&client_pk_1, Packet::OnionRequest(request.clone()))
.await;
assert!(handle_res.is_ok());
let (packet, _) = udp_onion_stream.into_future().await;
let (packet, saddr) = packet.unwrap();
assert_eq!(saddr.ip(), client_addr_1);
assert_eq!(saddr.port(), client_port_1);
assert_eq!(packet, request);
}
#[tokio::test]
async fn handle_udp_onion_response() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_addr_1 = client_1.ip_addr();
let client_port_1 = client_1.port();
server.insert(client_1).await.unwrap();
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let handle_res = server
.handle_udp_onion_response(client_addr_1, client_port_1, payload.clone())
.await;
assert!(handle_res.is_ok());
let (packet, _) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::OnionResponse(
OnionResponse { payload }
));
}
#[tokio::test]
async fn insert_with_same_pk() {
let server = Server::new();
let (mut client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let (mut client_2, rx_2) = create_random_client("1.2.3.4:12346".parse().unwrap());
let index_1 = client_1.links_mut().insert(&client_2.pk()).unwrap();
assert!(client_1.links_mut().upgrade(index_1));
let index_2 = client_2.links_mut().insert(&client_1.pk()).unwrap();
assert!(client_2.links_mut().upgrade(index_2));
let client_pk_1 = client_1.pk();
let client_addr_3 = "1.2.3.4".parse().unwrap();
let client_port_3 = 12347;
let (tx_3, _rx_3) = mpsc::channel(32);
let client_3 = Client::new(tx_3, &client_pk_1, client_addr_3, client_port_3);
server.insert(client_1).await.unwrap();
server.insert(client_2).await.unwrap();
server.insert(client_3).await.unwrap();
let (packet, _) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(index_2) }
));
let state = server.state.read().await;
let client = &state.connected_clients[&client_pk_1];
assert_eq!(client.ip_addr(), client_addr_3);
assert_eq!(client.port(), client_port_3);
}
#[tokio::test]
async fn shutdown_other_not_linked() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
let client_ip_addr_1 = client_1.ip_addr();
let client_port_1 = client_1.port();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
let handle_res = server.shutdown_client(&client_pk_1, client_ip_addr_1, client_port_1).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_data_other_not_linked() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
let handle_res = server.handle_packet(&client_pk_1, Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_data_0() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk, Packet::Data(
Data {
connection_id: ConnectionId::zero(),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_route_response() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::RouteResponse(
RouteResponse { pk: client_pk_1, connection_id: ConnectionId::from_index(42) }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_disconnect_notification_not_linked() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
)).await;
assert!(handle_res.is_ok());
drop(server);
assert!(rx_1.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_ping_request_0() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::PingRequest(
PingRequest { ping_id: 0 }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_pong_response_0() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::PongResponse(
PongResponse { ping_id: 0 }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_oob_send_empty_data() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::OobSend(
OobSend { destination_pk: client_pk_2, data: vec![] }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_data_self_not_linked() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await;
assert!(handle_res.is_ok());
drop(server);
assert!(rx_1.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_oob_send_to_loooong_data() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::OobSend(
OobSend { destination_pk: client_pk_2, data: vec![42; 1024 + 1] }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_oob_recv() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::OobReceive(
OobReceive { sender_pk: client_pk_2, data: vec![42; 1024] }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_onion_request_disabled_onion_loooong_data() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let request = OnionRequest {
nonce: gen_nonce(),
ip_port: IpPort {
protocol: ProtocolType::TCP,
ip_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
port: 12345,
},
temporary_pk: gen_keypair().0,
payload: vec![13; 1500]
};
let handle_res = server
.handle_packet(&client_pk_1, Packet::OnionRequest(request))
.await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_onion_response() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let handle_res = server.handle_packet(&client_pk_1, Packet::OnionResponse(
OnionResponse { payload }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_udp_onion_response_for_unknown_client() {
crypto_init().unwrap();
let (udp_onion_sink, _) = mpsc::channel(1);
let mut server = Server::new();
server.set_udp_onion_sink(udp_onion_sink);
let client_addr_1 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4));
let client_port_1 = 12345u16;
let (client_pk_1, _) = gen_keypair();
let (tx_1, _rx_1) = mpsc::channel(1);
let client_1 = Client::new(tx_1, &client_pk_1, client_addr_1, client_port_1);
server.insert(client_1).await.unwrap();
let client_addr_2 = IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8));
let client_port_2 = 54321u16;
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let handle_res = server
.handle_udp_onion_response(client_addr_2, client_port_2, payload)
.await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_route_request_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let (client_pk_2, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_disconnect_notification_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(42) }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_disconnect_notification_other_not_connected() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_pk_2, _) = gen_keypair();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let handle_res = server.handle_packet(&client_pk_1, Packet::DisconnectNotification(
DisconnectNotification { connection_id: ConnectionId::from_index(0) }
)).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_ping_request_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::PingRequest(
PingRequest { ping_id: 42 }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_pong_response_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::PongResponse(
PongResponse { ping_id: 42 }
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_oob_send_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let (client_pk_2, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::OobSend(
OobSend { destination_pk: client_pk_2, data: vec![42; 1024] }
)).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn handle_data_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk_1, _) = gen_keypair();
let handle_res = server.handle_packet(&client_pk_1, Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn handle_data_other_not_connected() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_pk_2, _) = gen_keypair();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
let handle_res = server.handle_packet(&client_pk_1, Packet::Data(
Data {
connection_id: ConnectionId::from_index(0),
data: DataPayload::CryptoData(CryptoData {
nonce_last_bytes: 42,
payload: vec![42; 123],
}),
}
)).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn shutdown_different_addr() {
let server = Server::new();
let (client, _rx) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk = client.pk();
server.insert(client).await.unwrap();
let handle_res = server.shutdown_client(&client_pk, "1.2.3.4".parse().unwrap(), 12346).await;
assert!(handle_res.is_err());
let state = server.state.read().await;
assert!(state.connected_clients.contains_key(&client_pk));
}
#[tokio::test]
async fn shutdown_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk, _) = gen_keypair();
let client_ip_addr = "1.2.3.4".parse().unwrap();
let client_port = 12345;
let handle_res = server.shutdown_client(&client_pk, client_ip_addr, client_port).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn shutdown_inner_not_connected() {
crypto_init().unwrap();
let server = Server::new();
let (client_pk, _) = gen_keypair();
let mut state = server.state.write().await;
let handle_res = server.shutdown_client_inner(&client_pk, &mut state).await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn shutdown_other_not_connected() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
let client_ip_addr_1 = client_1.ip_addr();
let client_port_1 = client_1.port();
server.insert(client_1).await.unwrap();
let (client_pk_2, _) = gen_keypair();
server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await.unwrap();
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::RouteResponse(
RouteResponse { pk: client_pk_2, connection_id: ConnectionId::from_index(0) }
));
let handle_res = server.shutdown_client(&client_pk_1, client_ip_addr_1, client_port_1).await;
assert!(handle_res.is_ok());
}
#[tokio::test]
async fn send_anything_to_dropped_client() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let client_pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
drop(rx_1);
let handle_res = server.handle_packet(&client_pk_1, Packet::RouteRequest(
RouteRequest { pk: client_pk_2 }
)).await;
assert!(handle_res.is_err())
}
#[tokio::test]
async fn send_onion_request_to_dropped_stream() {
crypto_init().unwrap();
let (udp_onion_sink, udp_onion_stream) = mpsc::channel(1);
let mut server = Server::new();
server.set_udp_onion_sink(udp_onion_sink);
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let client_pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
drop(udp_onion_stream);
let request = OnionRequest {
nonce: gen_nonce(),
ip_port: IpPort {
protocol: ProtocolType::TCP,
ip_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
port: 12345,
},
temporary_pk: gen_keypair().0,
payload: vec![13; 170]
};
let handle_res = server
.handle_packet(&client_pk_1, Packet::OnionRequest(request))
.await;
assert!(handle_res.is_err());
}
#[tokio::test]
async fn tcp_send_pings_test() {
let server = Server::new();
let (client_1, rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
let (client_3, rx_3) = create_random_client("1.2.3.6:12345".parse().unwrap());
let pk_3 = client_3.pk();
server.insert(client_3).await.unwrap();
tokio::time::pause();
tokio::time::advance(TCP_PING_FREQUENCY + Duration::from_secs(1)).await;
let sender_res = server.send_pings().await;
assert!(sender_res.is_ok());
let (packet, _rx_1) = rx_1.into_future().await;
assert_eq!(packet.unwrap(), Packet::PingRequest(
PingRequest { ping_id: server.state.read().await.connected_clients[&pk_1].ping_id() }
));
let (packet, _rx_2) = rx_2.into_future().await;
assert_eq!(packet.unwrap(), Packet::PingRequest(
PingRequest { ping_id: server.state.read().await.connected_clients[&pk_2].ping_id() }
));
let (packet, _rx_3) = rx_3.into_future().await;
assert_eq!(packet.unwrap(), Packet::PingRequest(
PingRequest { ping_id: server.state.read().await.connected_clients[&pk_3].ping_id() }
));
}
#[tokio::test]
async fn tcp_send_remove_timedouts() {
let server = Server::new();
let (client_1, _rx_1) = create_random_client("1.2.3.4:12345".parse().unwrap());
let pk_1 = client_1.pk();
server.insert(client_1).await.unwrap();
let (client_2, _rx_2) = create_random_client("1.2.3.5:12345".parse().unwrap());
let pk_2 = client_2.pk();
server.insert(client_2).await.unwrap();
let (mut client_3, _rx_3) = create_random_client("1.2.3.6:12345".parse().unwrap());
let pk_3 = client_3.pk();
tokio::time::pause();
tokio::time::advance(TCP_PING_FREQUENCY + TCP_PING_TIMEOUT + Duration::from_secs(1)).await;
client_3.set_last_pong_resp(clock_now());
server.insert(client_3).await.unwrap();
let sender_res = server.send_pings().await;
assert!(sender_res.is_ok());
assert!(!server.state.read().await.connected_clients.contains_key(&pk_1));
assert!(!server.state.read().await.connected_clients.contains_key(&pk_2));
assert!(server.state.read().await.connected_clients.contains_key(&pk_3));
}
}