pub mod hole_punching;
mod errors;
use failure::Fail;
use futures::{TryFutureExt, StreamExt, SinkExt, future};
use futures::channel::mpsc;
use tokio::sync::RwLock;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{iter, mem};
use crate::time::*;
use tox_crypto::*;
use tox_packet::dht::*;
use tox_packet::dht::packed_node::*;
use crate::dht::kbucket::*;
use crate::dht::ktree::*;
use crate::dht::forced_ktree::*;
use crate::dht::precomputed_cache::*;
use crate::onion::client::*;
use tox_packet::onion::*;
use crate::onion::onion_announce::*;
use crate::dht::request_queue::*;
use tox_packet::ip_port::*;
use crate::dht::dht_friend::*;
use crate::dht::dht_node::*;
use crate::dht::server::hole_punching::*;
use tox_packet::relay::OnionRequest;
use crate::net_crypto::*;
use crate::dht::ip_port::IsGlobal;
use crate::utils::*;
use crate::dht::server::errors::*;
use crate::io_tokio::*;
type Tx = mpsc::Sender<(Packet, SocketAddr)>;
type TcpOnionTx = mpsc::Sender<(InnerOnionResponse, SocketAddr)>;
pub const MAX_BOOTSTRAP_TIMES: u32 = 5;
pub const ONION_REFRESH_KEY_INTERVAL: Duration = Duration::from_secs(7200);
pub const NODES_REQ_INTERVAL: Duration = Duration::from_secs(20);
pub const PING_TIMEOUT: Duration = Duration::from_secs(5);
pub const MAX_TO_PING: u8 = 32;
pub const MAX_TO_BOOTSTRAP: u8 = 8;
pub const TIME_TO_PING: Duration = Duration::from_secs(2);
pub const BOOTSTRAP_INTERVAL: Duration = Duration::from_secs(1);
pub const FAKE_FRIENDS_NUMBER: usize = 2;
pub const PRECOMPUTED_LRU_CACHE_SIZE: usize = KBUCKET_DEFAULT_SIZE as usize * KBUCKET_MAX_ENTRIES as usize + KBUCKET_DEFAULT_SIZE as usize * (2 + 10); const MAIN_LOOP_INTERVAL: u64 = 1;
#[derive(Clone)]
struct ServerBootstrapInfo {
version: u32,
motd_cb: Arc<dyn Fn(&Server) -> Vec<u8> + Send + Sync>,
}
#[derive(Clone)]
pub struct Server {
pub sk: SecretKey,
pub pk: PublicKey,
pub tx: Tx,
friend_saddr_sink: Arc<RwLock<Option<mpsc::UnboundedSender<PackedNode>>>>,
request_queue: Arc<RwLock<RequestQueue<PublicKey>>>,
pub close_nodes: Arc<RwLock<ForcedKtree>>,
onion_symmetric_key: Arc<RwLock<secretbox::Key>>,
onion_announce: Arc<RwLock<OnionAnnounce>>,
fake_friends_keys: Vec<PublicKey>,
friends: Arc<RwLock<HashMap<PublicKey, DhtFriend>>>,
nodes_to_bootstrap: Arc<RwLock<Kbucket<PackedNode>>>,
random_requests_count: Arc<RwLock<u32>>,
last_nodes_req_time: Arc<RwLock<Instant>>,
nodes_to_ping: Arc<RwLock<Kbucket<PackedNode>>>,
bootstrap_info: Option<ServerBootstrapInfo>,
tcp_onion_sink: Option<TcpOnionTx>,
net_crypto: Option<NetCrypto>,
onion_client: Option<Box<OnionClient>>,
lan_discovery_enabled: bool,
is_ipv6_enabled: bool,
initial_bootstrap: Vec<PackedNode>,
precomputed_keys: PrecomputedCache,
}
impl Server {
pub fn new(tx: Tx, pk: PublicKey, sk: SecretKey) -> Server {
debug!("Created new Server instance");
let fake_friends_keys = iter::repeat_with(|| gen_keypair().0)
.take(FAKE_FRIENDS_NUMBER)
.collect::<Vec<_>>();
let friends = fake_friends_keys.iter()
.map(|&pk| (pk, DhtFriend::new(pk)))
.collect();
let precomputed_keys = PrecomputedCache::new(sk.clone(), PRECOMPUTED_LRU_CACHE_SIZE);
Server {
sk,
pk,
tx,
friend_saddr_sink: Default::default(),
request_queue: Arc::new(RwLock::new(RequestQueue::new(PING_TIMEOUT))),
close_nodes: Arc::new(RwLock::new(ForcedKtree::new(&pk))),
onion_symmetric_key: Arc::new(RwLock::new(secretbox::gen_key())),
onion_announce: Arc::new(RwLock::new(OnionAnnounce::new(pk))),
fake_friends_keys,
friends: Arc::new(RwLock::new(friends)),
nodes_to_bootstrap: Arc::new(RwLock::new(Kbucket::new(MAX_TO_BOOTSTRAP))),
random_requests_count: Arc::new(RwLock::new(0)),
last_nodes_req_time: Arc::new(RwLock::new(clock_now())),
nodes_to_ping: Arc::new(RwLock::new(Kbucket::new(MAX_TO_PING))),
bootstrap_info: None,
tcp_onion_sink: None,
net_crypto: None,
onion_client: None,
lan_discovery_enabled: true,
is_ipv6_enabled: false,
initial_bootstrap: Vec::new(),
precomputed_keys,
}
}
pub fn enable_ipv6_mode(&mut self, enable: bool) {
self.is_ipv6_enabled = enable;
}
pub fn is_ipv6_enabled(&self) -> bool {
self.is_ipv6_enabled
}
pub fn enable_lan_discovery(&mut self, enable: bool) {
self.lan_discovery_enabled = enable;
}
pub async fn is_connected(&self) -> bool {
self.close_nodes.read()
.await
.iter()
.any(|node| !node.is_bad())
}
fn get_closest_inner(
close_nodes: &ForcedKtree,
friends: &HashMap<PublicKey, DhtFriend>,
base_pk: &PublicKey,
count: u8,
only_global: bool
) -> Kbucket<PackedNode> {
let mut kbucket = close_nodes.get_closest(base_pk, count, only_global);
for node in friends.values().flat_map(|friend| friend.close_nodes.iter()) {
if let Some(pn) = node.to_packed_node() {
if !only_global || IsGlobal::is_global(&pn.saddr.ip()) {
kbucket.try_add(base_pk, pn, true);
}
}
}
kbucket
}
pub async fn get_closest(&self, base_pk: &PublicKey, count: u8, only_global: bool) -> Kbucket<PackedNode> {
let close_nodes = self.close_nodes.read().await;
let friends = self.friends.read().await;
Server::get_closest_inner(&close_nodes, &friends, base_pk, count, only_global)
}
pub async fn add_friend(&self, friend_pk: PublicKey) {
let mut friends = self.friends.write().await;
if friends.contains_key(&friend_pk) {
return;
}
let close_nodes = self.close_nodes.read().await;
let mut friend = DhtFriend::new(friend_pk);
let close_nodes = Server::get_closest_inner(&close_nodes, &friends, &friend.pk, 4, true);
for &node in close_nodes.iter() {
friend.nodes_to_bootstrap.try_add(&friend.pk, node, true);
}
friends.insert(friend_pk, friend);
}
pub async fn remove_friend(&self, friend_pk: PublicKey) {
self.friends.write().await.remove(&friend_pk);
}
async fn dht_main_loop(&self) -> Result<(), RunError> {
fn send_random_request(last_nodes_req_time: &mut Instant, random_requests_count: &mut u32) -> bool {
if clock_elapsed(*last_nodes_req_time) > NODES_REQ_INTERVAL || *random_requests_count < MAX_BOOTSTRAP_TIMES {
*random_requests_count = random_requests_count.saturating_add(1);
*last_nodes_req_time = clock_now();
true
} else {
false
}
}
let mut request_queue = self.request_queue.write().await;
let mut nodes_to_bootstrap = self.nodes_to_bootstrap.write().await;
let mut close_nodes = self.close_nodes.write().await;
let mut friends = self.friends.write().await;
request_queue.clear_timed_out();
self.ping_nodes_to_bootstrap(&mut request_queue, &mut nodes_to_bootstrap, self.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
self.ping_close_nodes(&mut request_queue, close_nodes.iter_mut(), self.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
if send_random_request(&mut *self.last_nodes_req_time.write().await, &mut *self.random_requests_count.write().await) {
self.send_nodes_req_random(&mut request_queue, close_nodes.iter(), self.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
}
for friend in friends.values_mut() {
self.ping_nodes_to_bootstrap(&mut request_queue, &mut friend.nodes_to_bootstrap, friend.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
self.ping_close_nodes(&mut request_queue, friend.close_nodes.nodes.iter_mut(), friend.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
if send_random_request(&mut friend.last_nodes_req_time, &mut friend.random_requests_count) {
self.send_nodes_req_random(&mut request_queue, friend.close_nodes.nodes.iter(), friend.pk).await
.map_err(|e| e.context(RunErrorKind::SendTo))?
}
}
self.send_nat_ping_req(&mut request_queue, &mut friends).await
.map_err(|e| RunError::from(e.context(RunErrorKind::SendTo)))
}
pub async fn run(&self) -> Result<(), RunError> {
let (r1, r2, r3, r4) = futures::join!(
self.run_pings_sending(),
self.run_onion_key_refreshing(),
self.run_main_loop(),
self.run_bootstrap_requests_sending(),
);
r1?; r2?; r3?; r4?;
Ok(())
}
pub fn add_initial_bootstrap(&mut self, pn: PackedNode) {
self.initial_bootstrap.push(pn);
}
async fn run_bootstrap_requests_sending(&self) -> Result<(), RunError> {
let interval = BOOTSTRAP_INTERVAL;
let mut wakeups = tokio::time::interval(interval);
while wakeups.next().await.is_some() {
trace!("Bootstrap wake up");
let send_res = tokio::time::timeout(
interval,
self.send_bootstrap_requests(),
).await;
let res =
match send_res {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) =>
Err(e.context(RunErrorKind::SendTo).into()),
Err(e) =>
Err(e.context(RunErrorKind::SendTo).into()),
};
if let Err(ref e) = res {
warn!("Failed to send initial bootstrap packets: {}", e);
return res
}
}
Ok(())
}
async fn send_bootstrap_requests(&self) -> Result<(), mpsc::SendError> {
let mut request_queue = self.request_queue.write().await;
let close_nodes = self.close_nodes.read().await;
if !close_nodes.is_all_discarded() {
return Ok(());
}
let nodes = close_nodes
.iter()
.flat_map(|node| node.to_all_packed_nodes())
.chain(self.initial_bootstrap.iter().cloned());
for node in nodes {
self.send_nodes_req(&node, &mut request_queue, self.pk).await?;
}
Ok(())
}
async fn run_main_loop(&self) -> Result<(), RunError> {
let interval = Duration::from_secs(MAIN_LOOP_INTERVAL);
let mut wakeups = tokio::time::interval(interval);
while wakeups.next().await.is_some() {
trace!("DHT server wake up");
let loop_res =
tokio::time::timeout(interval, self.dht_main_loop()).await;
let res = match loop_res {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) =>
Err(e.context(RunErrorKind::SendTo).into()),
Err(e) =>
Err(e.context(RunErrorKind::SendTo).into()),
};
if let Err(ref e) = res {
warn!("Failed to send DHT periodical packets: {}", e);
return res
}
}
Ok(())
}
async fn run_onion_key_refreshing(&self) -> Result<(), RunError> {
let interval = ONION_REFRESH_KEY_INTERVAL;
let mut wakeups = tokio::time::interval_at(tokio::time::Instant::now() + interval, interval);
while wakeups.next().await.is_some() {
trace!("Refreshing onion key");
self.refresh_onion_key().await;
}
Ok(())
}
async fn run_pings_sending(&self) -> Result<(), RunError> {
let interval = TIME_TO_PING;
let mut wakeups = tokio::time::interval_at(tokio::time::Instant::now() + interval, interval);
while wakeups.next().await.is_some() {
self.send_pings().await
.map_err(|e| e.context(RunErrorKind::SendTo))?;
}
Ok(())
}
async fn send_pings(&self) -> Result<(), mpsc::SendError> {
let nodes_to_ping = mem::replace(
&mut *self.nodes_to_ping.write().await,
Kbucket::<PackedNode>::new(MAX_TO_PING)
);
if nodes_to_ping.is_empty() {
return Ok(());
}
let mut request_queue = self.request_queue.write().await;
for node in nodes_to_ping.iter() {
self.send_ping_req(node, &mut request_queue).await?;
}
Ok(())
}
async fn ping_add(&self, node: &PackedNode) -> Result<(), mpsc::SendError> {
let close_nodes = self.close_nodes.read().await;
if !close_nodes.can_add(&node) {
return Ok(());
}
let friends = self.friends.read().await;
if friends.get(&node.pk).map_or(false, |friend| !friend.is_addr_known()) {
return self.send_ping_req(&node, &mut *self.request_queue.write().await).await;
}
self.nodes_to_ping.write().await.try_add(&self.pk, *node, true);
Ok(())
}
async fn ping_nodes_to_bootstrap(&self, request_queue: &mut RequestQueue<PublicKey>, nodes_to_bootstrap: &mut Kbucket<PackedNode>, pk: PublicKey)
-> Result<(), mpsc::SendError> {
let capacity = nodes_to_bootstrap.capacity() as u8;
let nodes_to_bootstrap = mem::replace(nodes_to_bootstrap, Kbucket::new(capacity));
for node in nodes_to_bootstrap.iter() {
self.send_nodes_req(&node, request_queue, pk).await?;
}
Ok(())
}
async fn ping_close_nodes<'a, T>(&self, request_queue: &mut RequestQueue<PublicKey>, nodes: T, pk: PublicKey)
-> Result<(), mpsc::SendError>
where T: Iterator<Item = &'a mut DhtNode>
{
let nodes = nodes
.flat_map(|node| {
let ping_addr_v4 = node.assoc4
.ping_addr()
.map(|addr| PackedNode::new(addr.into(), &node.pk));
let ping_addr_v6 = node.assoc6
.ping_addr()
.map(|addr| PackedNode::new(addr.into(), &node.pk));
ping_addr_v4.into_iter().chain(ping_addr_v6.into_iter())
});
for node in nodes {
self.send_nodes_req(&node, request_queue, pk).await?;
}
Ok(())
}
async fn send_nodes_req_random<'a, T>(&self, request_queue: &mut RequestQueue<PublicKey>, nodes: T, pk: PublicKey)
-> Result<(), mpsc::SendError>
where T: Iterator<Item = &'a DhtNode>
{
let good_nodes = nodes
.filter(|&node| !node.is_bad())
.flat_map(|node| node.to_all_packed_nodes())
.collect::<Vec<_>>();
if good_nodes.is_empty() {
return Ok(());
}
let mut random_node_idx = random_limit_usize(good_nodes.len());
if random_node_idx != 0 {
random_node_idx -= random_limit_usize(random_node_idx + 1);
}
let random_node = &good_nodes[random_node_idx];
self.send_nodes_req(&random_node, request_queue, pk).await
}
pub async fn ping_node(&self, node: &PackedNode) -> Result<(), PingError> {
let mut request_queue = self.request_queue.write().await;
self.send_nodes_req(node, &mut request_queue, self.pk)
.await
.map_err(|e| e.context(PingErrorKind::SendTo).into())
}
async fn send_ping_req(&self, node: &PackedNode, request_queue: &mut RequestQueue<PublicKey>)
-> Result<(), mpsc::SendError> {
let payload = PingRequestPayload {
id: request_queue.new_ping_id(node.pk),
};
let ping_req = Packet::PingRequest(PingRequest::new(
&self.precomputed_keys.get(node.pk).await,
&self.pk,
&payload,
));
self.send_to(node.saddr, ping_req).await
}
async fn send_nodes_req(&self, node: &PackedNode, request_queue: &mut RequestQueue<PublicKey>, search_pk: PublicKey)
-> Result<(), mpsc::SendError> {
if self.pk == node.pk {
trace!("Attempt to send NodesRequest to ourselves.");
return Ok(());
}
let payload = NodesRequestPayload {
pk: search_pk,
id: request_queue.new_ping_id(node.pk),
};
let nodes_req = Packet::NodesRequest(NodesRequest::new(
&self.precomputed_keys.get(node.pk).await,
&self.pk,
&payload,
));
self.send_to(node.saddr, nodes_req).await
}
async fn send_nat_ping_req(&self, request_queue: &mut RequestQueue<PublicKey>, friends: &mut HashMap<PublicKey, DhtFriend>)
-> Result<(), mpsc::SendError> {
for friend in friends.values_mut() {
if friend.is_addr_known() {
continue;
}
let addrs = friend.get_returned_addrs();
if addrs.len() < FRIEND_CLOSE_NODES_COUNT as usize / 2 {
continue;
}
self.punch_holes(request_queue, friend, &addrs).await?;
if friend.hole_punch.last_send_ping_time.map_or(true, |time| clock_elapsed(time) >= PUNCH_INTERVAL) {
friend.hole_punch.last_send_ping_time = Some(clock_now());
let payload = DhtRequestPayload::NatPingRequest(NatPingRequest {
id: friend.hole_punch.ping_id,
});
let nat_ping_req_packet = DhtRequest::new(
&self.precomputed_keys.get(friend.pk).await,
&friend.pk,
&self.pk,
&payload,
);
self.send_nat_ping_req_inner(&friend, nat_ping_req_packet).await?;
}
}
Ok(())
}
async fn punch_holes(&self, request_queue: &mut RequestQueue<PublicKey>, friend: &mut DhtFriend, returned_addrs: &[SocketAddr])
-> Result<(), mpsc::SendError> {
let punch_addrs = friend.hole_punch.next_punch_addrs(returned_addrs);
let mut tx = self.tx.clone();
let payload = PingRequestPayload {
id: request_queue.new_ping_id(friend.pk),
};
let packet = Packet::PingRequest(PingRequest::new(
&self.precomputed_keys.get(friend.pk).await,
&self.pk,
&payload,
));
let packets = punch_addrs.into_iter().map(|addr| {
(packet.clone(), addr)
}).collect::<Vec<_>>();
let mut stream = futures::stream::iter(packets).map(Ok);
tx.send_all(&mut stream).await
}
async fn send_nat_ping_req_inner(&self, friend: &DhtFriend, nat_ping_req_packet: DhtRequest)
-> Result<(), mpsc::SendError> {
let packet = Packet::DhtRequest(nat_ping_req_packet);
let nodes = friend.close_nodes.nodes
.iter()
.flat_map(|node| node.to_packed_node().into_iter());
for node in nodes {
self.send_to(node.saddr, packet.clone()).await?;
}
Ok(())
}
pub async fn handle_packet(&self, packet: Packet, addr: SocketAddr) -> Result<(), HandlePacketError> {
match packet {
Packet::PingRequest(packet) =>
self.handle_ping_req(packet, addr).await,
Packet::PingResponse(packet) =>
self.handle_ping_resp(packet, addr).await,
Packet::NodesRequest(packet) =>
self.handle_nodes_req(packet, addr).await,
Packet::NodesResponse(packet) =>
self.handle_nodes_resp(packet, addr).await,
Packet::CookieRequest(packet) =>
self.handle_cookie_request(&packet, addr).await,
Packet::CookieResponse(packet) =>
self.handle_cookie_response(&packet, addr).await,
Packet::CryptoHandshake(packet) =>
self.handle_crypto_handshake(&packet, addr).await,
Packet::DhtRequest(packet) =>
self.handle_dht_req(packet, addr).await,
Packet::LanDiscovery(packet) =>
self.handle_lan_discovery(&packet, addr).await,
Packet::OnionRequest0(packet) =>
self.handle_onion_request_0(packet, addr).await,
Packet::OnionRequest1(packet) =>
self.handle_onion_request_1(packet, addr).await,
Packet::OnionRequest2(packet) =>
self.handle_onion_request_2(packet, addr).await,
Packet::OnionAnnounceRequest(packet) =>
self.handle_onion_announce_request(packet, addr).await,
Packet::OnionDataRequest(packet) =>
self.handle_onion_data_request(packet).await,
Packet::OnionResponse3(packet) =>
self.handle_onion_response_3(packet).await,
Packet::OnionResponse2(packet) =>
self.handle_onion_response_2(packet).await,
Packet::OnionResponse1(packet) =>
self.handle_onion_response_1(packet).await,
Packet::BootstrapInfo(packet) =>
self.handle_bootstrap_info(&packet, addr).await,
Packet::CryptoData(packet) =>
self.handle_crypto_data(&packet, addr).await,
Packet::OnionDataResponse(packet) =>
self.handle_onion_data_response(&packet).await,
Packet::OnionAnnounceResponse(packet) =>
self.handle_onion_announce_response(&packet, addr).await,
}
}
async fn send_to(&self, addr: SocketAddr, packet: Packet) -> Result<(), mpsc::SendError> {
self.tx.clone().send((packet, addr)).await
}
async fn handle_ping_req(&self, packet: PingRequest, addr: SocketAddr)
-> Result<(), HandlePacketError> {
let precomputed_key = self.precomputed_keys.get(packet.pk).await;
let payload = match packet.get_payload(&precomputed_key) {
Err(e) => return future::err(e.context(HandlePacketErrorKind::GetPayload).into()).await,
Ok(payload) => payload,
};
let resp_payload = PingResponsePayload {
id: payload.id,
};
let ping_resp = Packet::PingResponse(PingResponse::new(
&precomputed_key,
&self.pk,
&resp_payload,
));
future::try_join(
self.ping_add(&PackedNode::new(addr, &packet.pk)),
self.send_to(addr, ping_resp),
)
.map_ok(drop)
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
.await
}
async fn check_ping_id(&self, ping_id: u64, packet_pk: &PublicKey) -> bool {
let mut request_queue = self.request_queue.write().await;
request_queue.check_ping_id(ping_id, |pk| packet_pk.eq(pk)).is_some()
}
async fn try_add_to_close(&self, payload_id: u64, node: PackedNode, check_ping_id: bool) -> Result<(), HandlePacketError> {
if check_ping_id && !self.check_ping_id(payload_id, &node.pk).await {
return Err(HandlePacketError::from(HandlePacketErrorKind::PingIdMismatch));
}
let mut close_nodes = self.close_nodes.write().await;
let mut friends = self.friends.write().await;
close_nodes.try_add(node);
for friend in friends.values_mut() {
friend.try_add_to_close(node);
}
if friends.contains_key(&node.pk) {
let sink = self.friend_saddr_sink.read().await.clone();
maybe_send_unbounded(sink, node).await
.map_err(|e| e.context(HandlePacketErrorKind::FriendSaddr).into())
} else {
Ok(())
}
}
async fn handle_ping_resp(&self, packet: PingResponse, addr: SocketAddr) -> Result<(), HandlePacketError> {
let precomputed_key = self.precomputed_keys.get(packet.pk).await;
let payload = match packet.get_payload(&precomputed_key) {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
if payload.id == 0u64 {
return Err(
HandlePacketError::from(
HandlePacketErrorKind::ZeroPingId));
}
self.try_add_to_close(payload.id, PackedNode::new(addr, &packet.pk), true).await
}
async fn handle_nodes_req(&self, packet: NodesRequest, addr: SocketAddr)
-> Result<(), HandlePacketError> {
let precomputed_key = self.precomputed_keys.get(packet.pk).await;
let payload = match packet.get_payload(&precomputed_key) {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
let close_nodes = self.get_closest(&payload.pk, 4, IsGlobal::is_global(&addr.ip())).await;
let resp_payload = NodesResponsePayload {
nodes: close_nodes.into(),
id: payload.id,
};
let nodes_resp = Packet::NodesResponse(NodesResponse::new(
&precomputed_key,
&self.pk,
&resp_payload,
));
future::try_join(
self.ping_add(&PackedNode::new(addr, &packet.pk)),
self.send_to(addr, nodes_resp),
)
.map_ok(drop)
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
.await
}
async fn add_bootstrap_nodes(&self, nodes: &[PackedNode], packet_pk: &PublicKey) {
let mut close_nodes = self.close_nodes.write().await;
let mut friends = self.friends.write().await;
let mut nodes_to_bootstrap = self.nodes_to_bootstrap.write().await;
for &node in nodes {
if !self.is_ipv6_enabled && node.saddr.is_ipv6() {
continue;
}
if close_nodes.can_add(&node) {
nodes_to_bootstrap.try_add(&self.pk, node, true);
}
for friend in friends.values_mut() {
if friend.can_add_to_close(&node) {
friend.nodes_to_bootstrap.try_add(&friend.pk, node, true);
}
}
self.update_returned_addr(&node, packet_pk, &mut close_nodes, &mut friends);
}
}
async fn handle_nodes_resp(&self, packet: NodesResponse, addr: SocketAddr)
-> Result<(), HandlePacketError> {
let precomputed_key = self.precomputed_keys.get(packet.pk).await;
let payload = match packet.get_payload(&precomputed_key) {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
if self.check_ping_id(payload.id, &packet.pk).await {
trace!("Received nodes with NodesResponse from {}: {:?}", addr, payload.nodes);
self.try_add_to_close(payload.id, PackedNode::new(addr, &packet.pk), false).await?;
self.add_bootstrap_nodes(&payload.nodes, &packet.pk).await;
} else {
trace!("NodesResponse.ping_id does not match");
}
Ok(())
}
fn update_returned_addr(&self, node: &PackedNode, packet_pk: &PublicKey, close_nodes: &mut ForcedKtree, friends: &mut HashMap<PublicKey, DhtFriend>) {
if self.pk == node.pk {
if let Some(node_to_update) = close_nodes.get_node_mut(packet_pk) {
node_to_update.update_returned_addr(node.saddr);
}
}
if let Some(friend) = friends.get_mut(&node.pk) {
if let Some(node_to_update) = friend.close_nodes.get_node_mut(&friend.pk, packet_pk) {
node_to_update.update_returned_addr(node.saddr);
}
}
}
async fn handle_cookie_request(&self, packet: &CookieRequest, addr: SocketAddr)
-> Result<(), HandlePacketError> {
if let Some(ref net_crypto) = self.net_crypto {
net_crypto.handle_udp_cookie_request(packet, addr).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleNetCrypto).into())
} else {
Err(
HandlePacketError::from(HandlePacketErrorKind::NetCrypto)
)
}
}
async fn handle_cookie_response(&self, packet: &CookieResponse, addr: SocketAddr)
-> Result<(), HandlePacketError> {
if let Some(ref net_crypto) = self.net_crypto {
net_crypto.handle_udp_cookie_response(packet, addr).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleNetCrypto).into())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::NetCrypto))
}
}
async fn handle_crypto_handshake(&self, packet: &CryptoHandshake, addr: SocketAddr)
-> Result<(), HandlePacketError> {
if let Some(ref net_crypto) = self.net_crypto {
net_crypto.handle_udp_crypto_handshake(packet, addr).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleNetCrypto).into())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::NetCrypto))
}
}
async fn handle_crypto_data(&self, packet: &CryptoData, addr: SocketAddr) -> Result<(), HandlePacketError> {
if let Some(ref net_crypto) = self.net_crypto {
net_crypto.handle_udp_crypto_data(packet, addr).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleNetCrypto).into())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::NetCrypto))
}
}
async fn handle_onion_data_response(&self, packet: &OnionDataResponse) -> Result<(), HandlePacketError> {
if let Some(ref onion_client) = self.onion_client {
onion_client.handle_data_response(packet).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleOnionClient).into())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::OnionClient))
}
}
async fn handle_onion_announce_response(&self, packet: &OnionAnnounceResponse, addr: SocketAddr) -> Result<(), HandlePacketError> {
if let Some(ref onion_client) = self.onion_client {
onion_client.handle_announce_response(packet, IsGlobal::is_global(&addr.ip())).await
.map_err(|e| e.context(HandlePacketErrorKind::HandleOnionClient).into())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::OnionClient))
}
}
async fn handle_dht_req(&self, packet: DhtRequest, addr: SocketAddr)
-> Result<(), HandlePacketError> { if packet.rpk == self.pk { self.handle_dht_req_for_us(packet, addr).await
} else {
self.handle_dht_req_for_others(packet).await
}
}
async fn handle_dht_req_for_us(&self, packet: DhtRequest, addr: SocketAddr) -> Result<(), HandlePacketError> {
let precomputed_key = self.precomputed_keys.get(packet.spk).await;
let payload = packet.get_payload(&precomputed_key);
let payload = match payload {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
match payload {
DhtRequestPayload::NatPingRequest(nat_payload) => {
debug!("Received nat ping request");
self.handle_nat_ping_req(nat_payload, packet.spk, addr).await
}
DhtRequestPayload::NatPingResponse(nat_payload) => {
debug!("Received nat ping response");
self.handle_nat_ping_resp(nat_payload, &packet.spk).await
}
DhtRequestPayload::DhtPkAnnounce(_dht_pk_payload) => {
debug!("Received DHT PublicKey Announce");
Ok(())
}
DhtRequestPayload::HardeningRequest(_dht_pk_payload) => {
debug!("Received Hardening request");
Ok(())
}
DhtRequestPayload::HardeningResponse(_dht_pk_payload) => {
debug!("Received Hardening response");
Ok(())
}
}
}
async fn handle_dht_req_for_others(&self, packet: DhtRequest) -> Result<(), HandlePacketError> {
let close_nodes = self.close_nodes.read().await;
if let Some(node) = close_nodes.get_node(&packet.rpk).and_then(|node| node.to_packed_node()) {
let packet = Packet::DhtRequest(packet);
self.send_to(node.saddr, packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo))?;
}
Ok(())
}
async fn set_friend_hole_punch_last_recv_ping_time(&self, spk: &PublicKey, ping_time: Instant)
-> Result<(), HandlePacketError> {
let mut friends = self.friends.write().await;
match friends.get_mut(spk) {
None => Err(HandlePacketError::from(HandlePacketErrorKind::NoFriend)),
Some(friend) => {
friend.hole_punch.last_recv_ping_time = ping_time;
Ok(())
}
}
}
async fn handle_nat_ping_req(&self, payload: NatPingRequest, spk: PublicKey, addr: SocketAddr) -> Result<(), HandlePacketError> {
self.set_friend_hole_punch_last_recv_ping_time(&spk, clock_now()).await?;
let resp_payload = DhtRequestPayload::NatPingResponse(NatPingResponse {
id: payload.id,
});
let nat_ping_resp = Packet::DhtRequest(DhtRequest::new(
&self.precomputed_keys.get(spk).await,
&spk,
&self.pk,
&resp_payload,
));
self.send_to(addr, nat_ping_resp).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn handle_nat_ping_resp(&self, payload: NatPingResponse, spk: &PublicKey) -> Result<(), HandlePacketError> {
if payload.id == 0 {
return Err(HandlePacketError::from(HandlePacketErrorKind::ZeroPingId))
}
let mut friends = self.friends.write().await;
let friend = match friends.get_mut(spk) {
None => return Err(HandlePacketError::from(HandlePacketErrorKind::NoFriend)),
Some(friend) => friend,
};
if friend.hole_punch.ping_id == payload.id {
friend.hole_punch.ping_id = gen_ping_id();
friend.hole_punch.is_punching_done = false;
Ok(())
} else {
Err(HandlePacketError::from(HandlePacketErrorKind::PingIdMismatch))
}
}
async fn handle_lan_discovery(&self, packet: &LanDiscovery, addr: SocketAddr)
-> Result<(), HandlePacketError> {
if !self.lan_discovery_enabled {
return Ok(());
}
if packet.pk == self.pk {
return Ok(());
}
self.send_nodes_req(&PackedNode::new(addr, &packet.pk), &mut *self.request_queue.write().await, self.pk)
.await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn handle_onion_request_0(&self, packet: OnionRequest0, addr: SocketAddr) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let onion_return = OnionReturn::new(
&onion_symmetric_key,
&IpPort::from_udp_saddr(addr),
None, );
let shared_secret = self.precomputed_keys.get(packet.temporary_pk).await;
let payload = packet.get_payload(&shared_secret);
let payload = match payload {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
let next_packet = Packet::OnionRequest1(OnionRequest1 {
nonce: packet.nonce,
temporary_pk: payload.temporary_pk,
payload: payload.inner,
onion_return,
});
self.send_to(payload.ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn handle_onion_request_1(&self, packet: OnionRequest1, addr: SocketAddr) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let onion_return = OnionReturn::new(
&onion_symmetric_key,
&IpPort::from_udp_saddr(addr),
Some(&packet.onion_return)
);
let shared_secret = self.precomputed_keys.get(packet.temporary_pk).await;
let payload = packet.get_payload(&shared_secret);
let payload = match payload {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
let next_packet = Packet::OnionRequest2(OnionRequest2 {
nonce: packet.nonce,
temporary_pk: payload.temporary_pk,
payload: payload.inner,
onion_return,
});
self.send_to(payload.ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn handle_onion_request_2(&self, packet: OnionRequest2, addr: SocketAddr) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let onion_return = OnionReturn::new(
&onion_symmetric_key,
&IpPort::from_udp_saddr(addr),
Some(&packet.onion_return),
);
let shared_secret = self.precomputed_keys.get(packet.temporary_pk).await;
let payload = packet.get_payload(&shared_secret);
let payload = match payload {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
let next_packet = match payload.inner {
InnerOnionRequest::InnerOnionAnnounceRequest(inner) => Packet::OnionAnnounceRequest(OnionAnnounceRequest {
inner,
onion_return,
}),
InnerOnionRequest::InnerOnionDataRequest(inner) => Packet::OnionDataRequest(OnionDataRequest {
inner,
onion_return,
}),
};
self.send_to(payload.ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn get_onion_announce_ping_id_or_pk(
&self,
payload: &OnionAnnounceRequestPayload,
packet: &OnionAnnounceRequest,
addr: SocketAddr
) -> (AnnounceStatus, sha256::Digest) {
let mut onion_announce = self.onion_announce.write().await;
onion_announce.handle_onion_announce_request(
&payload,
packet.inner.pk,
packet.onion_return.clone(),
addr
)
}
async fn handle_onion_announce_request(&self, packet: OnionAnnounceRequest, addr: SocketAddr) -> Result<(), HandlePacketError> {
let shared_secret = self.precomputed_keys.get(packet.inner.pk).await;
let payload = match packet.inner.get_payload(&shared_secret) {
Err(e) => return Err(e.context(HandlePacketErrorKind::GetPayload).into()),
Ok(payload) => payload,
};
let (announce_status, ping_id_or_pk) = self.get_onion_announce_ping_id_or_pk(
&payload,
&packet,
addr,
).await;
let close_nodes = self.get_closest(&payload.search_pk, 4, IsGlobal::is_global(&addr.ip())).await;
let response_payload = OnionAnnounceResponsePayload {
announce_status,
ping_id_or_pk,
nodes: close_nodes.into(),
};
let response = OnionAnnounceResponse::new(&shared_secret, payload.sendback_data, &response_payload);
self.send_to(addr, Packet::OnionResponse3(OnionResponse3 {
onion_return: packet.onion_return,
payload: InnerOnionResponse::OnionAnnounceResponse(response),
}))
.await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
}
async fn handle_onion_data_request(&self, packet: OnionDataRequest)
-> Result<(), HandlePacketError> {
let onion_announce = self.onion_announce.read().await;
match onion_announce.handle_data_request(packet) {
Ok((response, addr)) => self.send_to(addr, Packet::OnionResponse3(response)).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into()),
Err(e) => Err(e.context(HandlePacketErrorKind::OnionOrNetCrypto).into())
}
}
async fn handle_onion_response_3(&self, packet: OnionResponse3) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let payload = packet.onion_return.get_payload(&onion_symmetric_key);
let payload = match payload {
Err(e) => {
trace!("Failed to decrypt onion_return from OnionResponse3: {}", e);
return Ok(());
},
Ok(payload) => payload,
};
if let (ip_port, Some(next_onion_return)) = payload {
let next_packet = Packet::OnionResponse2(OnionResponse2 {
onion_return: next_onion_return,
payload: packet.payload
});
self.send_to(ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
} else {
Err(HandlePacketErrorKind::OnionResponseNext.into())
}
}
async fn handle_onion_response_2(&self, packet: OnionResponse2) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let payload = packet.onion_return.get_payload(&onion_symmetric_key);
let payload = match payload {
Err(e) => {
trace!("Failed to decrypt onion_return from OnionResponse2: {}", e);
return Ok(());
},
Ok(payload) => payload,
};
if let (ip_port, Some(next_onion_return)) = payload {
let next_packet = Packet::OnionResponse1(OnionResponse1 {
onion_return: next_onion_return,
payload: packet.payload
});
self.send_to(ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
} else {
Err(HandlePacketErrorKind::OnionResponseNext.into())
}
}
async fn handle_onion_response_1(&self, packet: OnionResponse1) -> Result<(), HandlePacketError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let payload = packet.onion_return.get_payload(&onion_symmetric_key);
let payload = match payload {
Err(e) => {
trace!("Failed to decrypt onion_return from OnionResponse1: {}", e);
return Ok(())
},
Ok(payload) => payload,
};
if let (ip_port, None) = payload {
match ip_port.protocol {
ProtocolType::UDP => {
let next_packet = match packet.payload {
InnerOnionResponse::OnionAnnounceResponse(inner) => Packet::OnionAnnounceResponse(inner),
InnerOnionResponse::OnionDataResponse(inner) => Packet::OnionDataResponse(inner),
};
self.send_to(ip_port.to_saddr(), next_packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo).into())
},
ProtocolType::TCP => {
if let Some(ref tcp_onion_sink) = self.tcp_onion_sink {
tcp_onion_sink.clone().send((packet.payload, ip_port.to_saddr())).await
.map_err(|e| e.context(HandlePacketErrorKind::OnionResponseRedirect).into())
} else {
Err(HandlePacketErrorKind::OnionResponseRedirect.into())
}
},
}
} else {
Err(HandlePacketErrorKind::OnionResponseNext.into())
}
}
async fn refresh_onion_key(&self) {
*self.onion_symmetric_key.write().await = secretbox::gen_key();
}
pub async fn handle_tcp_onion_request(&self, packet: OnionRequest, addr: SocketAddr)
-> Result<(), mpsc::SendError> {
let onion_symmetric_key = self.onion_symmetric_key.read().await;
let onion_return = OnionReturn::new(
&onion_symmetric_key,
&IpPort::from_tcp_saddr(addr),
None );
let next_packet = Packet::OnionRequest1(OnionRequest1 {
nonce: packet.nonce,
temporary_pk: packet.temporary_pk,
payload: packet.payload,
onion_return
});
self.send_to(packet.ip_port.to_saddr(), next_packet).await
}
async fn handle_bootstrap_info(&self, packet: &BootstrapInfo, addr: SocketAddr) -> Result<(), HandlePacketError> {
if packet.motd.len() != BOOSTRAP_CLIENT_MAX_MOTD_LENGTH {
return Err(HandlePacketError::from(HandlePacketErrorKind::BootstrapInfoLength))
}
if let Some(ref bootstrap_info) = self.bootstrap_info {
let mut motd = (bootstrap_info.motd_cb)(&self);
if motd.len() > BOOSTRAP_SERVER_MAX_MOTD_LENGTH {
warn!(
"Too long MOTD: {} bytes. Truncating to {} bytes",
motd.len(),
BOOSTRAP_SERVER_MAX_MOTD_LENGTH
);
motd.truncate(BOOSTRAP_SERVER_MAX_MOTD_LENGTH);
}
let packet = Packet::BootstrapInfo(BootstrapInfo {
version: bootstrap_info.version,
motd,
});
self.send_to(addr, packet).await
.map_err(|e| e.context(HandlePacketErrorKind::SendTo))?;
}
Ok(())
}
pub async fn random_friend_nodes(&self, count: u8) -> Vec<PackedNode> {
let friends = self.friends.read().await;
let mut nodes = Vec::new();
let skip = random_limit_usize(FAKE_FRIENDS_NUMBER as usize);
for pk in self.fake_friends_keys.iter().cycle().skip(skip).take(FAKE_FRIENDS_NUMBER) {
let friend = &friends[pk];
let skip = random_limit_usize(FRIEND_CLOSE_NODES_COUNT as usize);
let take = (count as usize - nodes.len()).min(friend.close_nodes.len());
nodes.extend(
friend.close_nodes
.iter()
.flat_map(|node| node.to_packed_node())
.cycle()
.skip(skip)
.take(take)
);
if nodes.len() == count as usize {
break;
}
}
nodes
}
pub fn set_bootstrap_info(&mut self, version: u32, motd_cb: Box<dyn Fn(&Server) -> Vec<u8> + Send + Sync>) {
self.bootstrap_info = Some(ServerBootstrapInfo {
version,
motd_cb: motd_cb.into(),
});
}
pub fn set_tcp_onion_sink(&mut self, tcp_onion_sink: TcpOnionTx) {
self.tcp_onion_sink = Some(tcp_onion_sink)
}
pub fn set_net_crypto(&mut self, net_crypto: NetCrypto) {
self.net_crypto = Some(net_crypto);
}
pub fn set_onion_client(&mut self, onion_client: OnionClient) {
self.onion_client = Some(Box::new(onion_client));
}
pub async fn set_friend_saddr_sink(&self, friend_saddr_sink: mpsc::UnboundedSender<PackedNode>) {
*self.friend_saddr_sink.write().await = Some(friend_saddr_sink);
}
pub fn get_precomputed_keys(&self) -> PrecomputedCache {
self.precomputed_keys.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tox_binary_io::*;
use std::net::SocketAddr;
const ONION_RETURN_1_PAYLOAD_SIZE: usize = ONION_RETURN_1_SIZE - secretbox::NONCEBYTES;
const ONION_RETURN_2_PAYLOAD_SIZE: usize = ONION_RETURN_2_SIZE - secretbox::NONCEBYTES;
const ONION_RETURN_3_PAYLOAD_SIZE: usize = ONION_RETURN_3_SIZE - secretbox::NONCEBYTES;
impl Server {
pub async fn has_friend(&self, pk: &PublicKey) -> bool {
self.friends.read().await.contains_key(pk)
}
pub async fn add_node(&self, node: PackedNode) {
assert!(self.close_nodes.write().await.try_add(node));
}
}
fn create_node() -> (Server, PrecomputedKey, PublicKey, SecretKey,
mpsc::Receiver<(Packet, SocketAddr)>, SocketAddr) {
crypto_init().unwrap();
let (pk, sk) = gen_keypair();
let (tx, rx) = mpsc::channel(32);
let alice = Server::new(tx, pk, sk);
let (bob_pk, bob_sk) = gen_keypair();
let precomp = precompute(&alice.pk, &bob_sk);
let addr: SocketAddr = "127.0.0.1:12346".parse().unwrap();
(alice, precomp, bob_pk, bob_sk, rx, addr)
}
#[tokio::test]
async fn add_friend() {
let (alice, _precomp, bob_pk, _bob_sk, _rx, _addr) = create_node();
let packed_node = PackedNode::new("211.192.153.67:33445".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
let inserted_friend = &alice.friends.read().await[&friend_pk];
assert!(inserted_friend.nodes_to_bootstrap.contains(&friend_pk, &bob_pk));
}
#[tokio::test]
async fn readd_friend() {
let (alice, _precomp, bob_pk, _bob_sk, _rx, _addr) = create_node();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
let packed_node = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(alice.friends.write().await.get_mut(&friend_pk).unwrap().try_add_to_close(packed_node));
alice.add_friend(friend_pk).await;
assert!(alice.friends.read().await[&friend_pk].close_nodes.contains(&friend_pk, &bob_pk));
}
#[tokio::test]
async fn remove_friend() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
assert!(alice.friends.read().await.contains_key(&friend_pk));
alice.remove_friend(friend_pk).await;
assert!(!alice.friends.read().await.contains_key(&friend_pk));
}
#[tokio::test]
async fn handle_bootstrap_info() {
let (mut alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let version = 42;
let motd = b"motd".to_vec();
let motd_c = motd.clone();
alice.set_bootstrap_info(version, Box::new(move |_| motd_c.clone()));
let packet = Packet::BootstrapInfo(BootstrapInfo {
version: 00,
motd: vec![0; BOOSTRAP_CLIENT_MAX_MOTD_LENGTH],
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let bootstrap_info = unpack!(packet, Packet::BootstrapInfo);
assert_eq!(bootstrap_info.version, version);
assert_eq!(bootstrap_info.motd, motd);
}
#[tokio::test]
async fn handle_bootstrap_info_wrong_length() {
let (mut alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let version = 42;
let motd = b"motd".to_vec();
alice.set_bootstrap_info(version, Box::new(move |_| motd.clone()));
let packet = Packet::BootstrapInfo(BootstrapInfo {
version: 00,
motd: Vec::new(),
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::BootstrapInfoLength);
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_ping_req() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
let req_payload = PingRequestPayload { id: 42 };
let ping_req = Packet::PingRequest(PingRequest::new(&precomp, &bob_pk, &req_payload));
alice.handle_packet(ping_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let ping_resp = unpack!(packet, Packet::PingResponse);
let precomputed_key = precompute(&ping_resp.pk, &bob_sk);
let ping_resp_payload = ping_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(ping_resp_payload.id, req_payload.id);
assert!(alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_ping_req_from_friend_with_unknown_addr() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let req_payload = PingRequestPayload { id: 42 };
let ping_req = Packet::PingRequest(PingRequest::new(&precomp, &bob_pk, &req_payload));
alice.handle_packet(ping_req, addr).await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr_to_send)| {
assert_eq!(addr_to_send, addr);
if let Packet::PingResponse(ping_resp) = packet {
let precomputed_key = precompute(&ping_resp.pk, &bob_sk);
let ping_resp_payload = ping_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(ping_resp_payload.id, req_payload.id);
} else {
let ping_req = unpack!(packet, Packet::PingRequest);
let precomputed_key = precompute(&ping_req.pk, &bob_sk);
let ping_req_payload = ping_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(ping_req_payload.id, |&pk| pk == bob_pk).is_some());
}
}).collect::<Vec<_>>().await;
assert!(!alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_ping_req_invalid_payload() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let req_payload = PingRequestPayload { id: 42 };
let ping_req = Packet::PingRequest(PingRequest::new(&precomp, &alice.pk, &req_payload));
let res = alice.handle_packet(ping_req, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_ping_resp() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = PingResponsePayload { id: ping_id };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &bob_pk, &resp_payload));
tokio::time::pause();
tokio::time::advance(Duration::from_secs(1)).await;
alice.handle_packet(ping_resp, addr).await.unwrap();
let friends = alice.friends.read().await;
let friend = friends.values().next().unwrap();
assert!(friend.close_nodes.contains(&bob_pk, &bob_pk));
let close_nodes = alice.close_nodes.read().await;
let node = close_nodes.get_node(&bob_pk).unwrap();
let time = clock_now();
assert_eq!(node.assoc4.last_resp_time.unwrap(), time);
}
#[tokio::test]
async fn handle_ping_resp_not_a_friend() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let (friend_saddr_tx, friend_saddr_rx) = mpsc::unbounded();
alice.set_friend_saddr_sink(friend_saddr_tx).await;
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = PingResponsePayload { id: ping_id };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &bob_pk, &resp_payload));
alice.handle_packet(ping_resp, addr).await.unwrap();
drop(alice);
assert!(friend_saddr_rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_ping_resp_friend_saddr() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let (friend_saddr_tx, friend_saddr_rx) = mpsc::unbounded();
alice.set_friend_saddr_sink(friend_saddr_tx).await;
alice.add_friend(bob_pk).await;
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = PingResponsePayload { id: ping_id };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &bob_pk, &resp_payload));
alice.handle_packet(ping_resp, addr).await.unwrap();
let (received_node, _friend_saddr_rx) = friend_saddr_rx.into_future().await;
let received_node = received_node.unwrap();
assert_eq!(received_node.pk, bob_pk);
assert_eq!(received_node.saddr, addr);
}
#[tokio::test]
async fn handle_ping_resp_invalid_payload() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let payload = PingResponsePayload { id: ping_id };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &alice.pk, &payload));
let res = alice.handle_packet(ping_resp, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_ping_resp_ping_id_is_0() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let payload = PingResponsePayload { id: 0 };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &bob_pk, &payload));
let res = alice.handle_packet(ping_resp, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::ZeroPingId);
}
#[tokio::test]
async fn handle_ping_resp_invalid_ping_id() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let payload = PingResponsePayload { id: ping_id + 1 };
let ping_resp = Packet::PingResponse(PingResponse::new(&precomp, &bob_pk, &payload));
let res = alice.handle_packet(ping_resp, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::PingIdMismatch);
}
#[tokio::test]
async fn handle_nodes_req() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
let packed_node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let req_payload = NodesRequestPayload { pk: bob_pk, id: 42 };
let nodes_req = Packet::NodesRequest(NodesRequest::new(&precomp, &bob_pk, &req_payload));
alice.handle_packet(nodes_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_resp = unpack!(packet, Packet::NodesResponse);
let precomputed_key = precompute(&nodes_resp.pk, &bob_sk);
let nodes_resp_payload = nodes_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(nodes_resp_payload.id, req_payload.id);
assert_eq!(nodes_resp_payload.nodes, vec!(packed_node));
assert!(alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_nodes_req_should_return_nodes_from_friends() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let packed_node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &bob_pk);
assert!(alice.friends.write().await.get_mut(&bob_pk).unwrap().try_add_to_close(packed_node));
let req_payload = NodesRequestPayload { pk: bob_pk, id: 42 };
let nodes_req = Packet::NodesRequest(NodesRequest::new(&precomp, &bob_pk, &req_payload));
alice.handle_packet(nodes_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_resp = unpack!(packet, Packet::NodesResponse);
let precomputed_key = precompute(&nodes_resp.pk, &bob_sk);
let nodes_resp_payload = nodes_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(nodes_resp_payload.id, req_payload.id);
assert_eq!(nodes_resp_payload.nodes, vec!(packed_node));
assert!(alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_nodes_req_should_not_return_bad_nodes() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
let packed_node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let req_payload = NodesRequestPayload { pk: bob_pk, id: 42 };
let nodes_req = Packet::NodesRequest(NodesRequest::new(&precomp, &bob_pk, &req_payload));
let delay = BAD_NODE_TIMEOUT + Duration::from_secs(1);
tokio::time::pause();
tokio::time::advance(delay).await;
alice.handle_packet(nodes_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_resp = unpack!(packet, Packet::NodesResponse);
let precomputed_key = precompute(&nodes_resp.pk, &bob_sk);
let nodes_resp_payload = nodes_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(nodes_resp_payload.id, req_payload.id);
assert!(nodes_resp_payload.nodes.is_empty());
assert!(alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_nodes_req_should_not_return_lan_nodes_when_address_is_global() {
let (alice, precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let addr = "8.10.8.10:12345".parse().unwrap();
let packed_node = PackedNode::new("192.168.42.42:12345".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let req_payload = NodesRequestPayload { pk: bob_pk, id: 42 };
let nodes_req = Packet::NodesRequest(NodesRequest::new(&precomp, &bob_pk, &req_payload));
alice.handle_packet(nodes_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_resp = unpack!(packet, Packet::NodesResponse);
let precomputed_key = precompute(&nodes_resp.pk, &bob_sk);
let nodes_resp_payload = nodes_resp.get_payload(&precomputed_key).unwrap();
assert_eq!(nodes_resp_payload.id, req_payload.id);
assert!(nodes_resp_payload.nodes.is_empty());
assert!(alice.nodes_to_ping.read().await.contains(&alice.pk, &bob_pk));
}
#[tokio::test]
async fn handle_nodes_req_invalid_payload() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let req_payload = NodesRequestPayload { pk: bob_pk, id: 42 };
let nodes_req = Packet::NodesRequest(NodesRequest::new(&precomp, &alice.pk, &req_payload));
let res = alice.handle_packet(nodes_req, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_nodes_resp() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0);
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = NodesResponsePayload { nodes: vec![node], id: ping_id };
let nodes_resp = Packet::NodesResponse(NodesResponse::new(&precomp, &bob_pk, &resp_payload));
tokio::time::pause();
tokio::time::advance(Duration::from_secs(1)).await;
alice.handle_packet(nodes_resp, addr).await.unwrap();
assert!(alice.nodes_to_bootstrap.read().await.contains(&alice.pk, &node.pk));
let friends = alice.friends.read().await;
let friend = friends.values().next().unwrap();
assert!(friend.nodes_to_bootstrap.contains(&bob_pk, &node.pk));
assert!(friend.close_nodes.contains(&bob_pk, &bob_pk));
let close_nodes = alice.close_nodes.read().await;
let node = close_nodes.get_node(&bob_pk).unwrap();
let time = clock_now();
assert_eq!(node.assoc4.last_resp_time.unwrap(), time);
}
#[tokio::test]
async fn handle_nodes_resp_friend_saddr() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let (friend_saddr_tx, friend_saddr_rx) = mpsc::unbounded();
alice.set_friend_saddr_sink(friend_saddr_tx).await;
alice.add_friend(bob_pk).await;
let packed_node = PackedNode::new(addr, &bob_pk);
assert!(alice.close_nodes.write().await.try_add(packed_node));
let node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0);
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = NodesResponsePayload { nodes: vec![node], id: ping_id };
let nodes_resp = Packet::NodesResponse(NodesResponse::new(&precomp, &bob_pk, &resp_payload));
alice.handle_packet(nodes_resp, addr).await.unwrap();
let (received_node, _friend_saddr_rx) = friend_saddr_rx.into_future().await;
let received_node = received_node.unwrap();
assert_eq!(received_node.pk, bob_pk);
assert_eq!(received_node.saddr, addr);
}
#[tokio::test]
async fn handle_nodes_resp_invalid_payload() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let resp_payload = NodesResponsePayload { nodes: vec![
PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0)
], id: 38 };
let nodes_resp = Packet::NodesResponse(NodesResponse::new(&precomp, &alice.pk, &resp_payload));
let res = alice.handle_packet(nodes_resp, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_nodes_resp_ping_id_is_0() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let resp_payload = NodesResponsePayload { nodes: vec![
PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0)
], id: 0 };
let nodes_resp = Packet::NodesResponse(NodesResponse::new(&precomp, &bob_pk, &resp_payload));
alice.handle_packet(nodes_resp, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_nodes_resp_invalid_ping_id() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let resp_payload = NodesResponsePayload {
nodes: vec![
PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0)
],
id: ping_id + 1
};
let nodes_resp = Packet::NodesResponse(NodesResponse::new(&precomp, &bob_pk, &resp_payload));
alice.handle_packet(nodes_resp, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_cookie_request() {
crypto_init().unwrap();
let (udp_tx, udp_rx) = mpsc::channel(1);
let (dht_pk, dht_sk) = gen_keypair();
let mut alice = Server::new(udp_tx.clone(), dht_pk, dht_sk.clone());
let (lossless_tx, _lossless_rx) = mpsc::unbounded();
let (lossy_tx, _lossy_rx) = mpsc::unbounded();
let (real_pk, real_sk) = gen_keypair();
let (bob_pk, bob_sk) = gen_keypair();
let (bob_real_pk, _bob_real_sk) = gen_keypair();
let precomp = precompute(&alice.pk, &bob_sk);
let net_crypto = NetCrypto::new(NetCryptoNewArgs {
udp_tx,
lossless_tx,
lossy_tx,
dht_pk,
dht_sk,
real_pk,
real_sk,
precomputed_keys: alice.get_precomputed_keys(),
});
alice.set_net_crypto(net_crypto);
let addr = "127.0.0.1:12346".parse().unwrap();
let cookie_request_id = 12345;
let cookie_request_payload = CookieRequestPayload {
pk: bob_real_pk,
id: cookie_request_id,
};
let cookie_request = Packet::CookieRequest(CookieRequest::new(&precomp, &bob_pk, &cookie_request_payload));
alice.handle_packet(cookie_request, addr).await.unwrap();
let (received, _udp_rx) = udp_rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let packet = unpack!(packet, Packet::CookieResponse);
let payload = packet.get_payload(&precomp).unwrap();
assert_eq!(payload.id, cookie_request_id);
}
#[tokio::test]
async fn handle_cookie_request_uninitialized() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let (bob_real_pk, _bob_real_sk) = gen_keypair();
let cookie_request_payload = CookieRequestPayload {
pk: bob_real_pk,
id: 12345,
};
let cookie_request = Packet::CookieRequest(CookieRequest::new(&precomp, &bob_pk, &cookie_request_payload));
let res = alice.handle_packet(cookie_request, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::NetCrypto);
}
#[tokio::test]
async fn handle_cookie_response_uninitialized() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let cookie = EncryptedCookie {
nonce: secretbox::gen_nonce(),
payload: vec![43; 88]
};
let cookie_response_payload = CookieResponsePayload {
cookie: cookie.clone(),
id: 12345
};
let cookie_response = Packet::CookieResponse(CookieResponse::new(&precomp, &cookie_response_payload));
let res = alice.handle_packet(cookie_response, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::NetCrypto);
}
#[tokio::test]
async fn handle_crypto_handshake_uninitialized() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let cookie = EncryptedCookie {
nonce: secretbox::gen_nonce(),
payload: vec![43; 88]
};
let crypto_handshake_payload = CryptoHandshakePayload {
base_nonce: gen_nonce(),
session_pk: gen_keypair().0,
cookie_hash: cookie.hash(),
cookie: cookie.clone()
};
let crypto_handshake = Packet::CryptoHandshake(CryptoHandshake::new(&precomp, &crypto_handshake_payload, cookie));
let res = alice.handle_packet(crypto_handshake, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::NetCrypto);
}
#[tokio::test]
async fn handle_dht_req_for_unknown_node() {
let (alice, _precomp, bob_pk, bob_sk, rx, addr) = create_node();
let (charlie_pk, _charlie_sk) = gen_keypair();
let precomp = precompute(&charlie_pk, &bob_sk);
let nat_req = NatPingRequest { id: 42 };
let nat_payload = DhtRequestPayload::NatPingRequest(nat_req);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &charlie_pk, &bob_pk, &nat_payload));
alice.handle_packet(dht_req, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_dht_req_for_known_node() {
let (alice, _precomp, bob_pk, bob_sk, rx, addr) = create_node();
let charlie_addr = "1.2.3.4:12345".parse().unwrap();
let (charlie_pk, _charlie_sk) = gen_keypair();
let precomp = precompute(&charlie_pk, &bob_sk);
let pn = PackedNode::new(charlie_addr, &charlie_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
let nat_req = NatPingRequest { id: 42 };
let nat_payload = DhtRequestPayload::NatPingRequest(nat_req);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &charlie_pk, &bob_pk, &nat_payload));
alice.handle_packet(dht_req.clone(), addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, charlie_addr);
assert_eq!(packet, dht_req);
}
#[tokio::test]
async fn handle_dht_req_invalid_payload() {
let (alice, _precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let dht_req = Packet::DhtRequest(DhtRequest {
rpk: alice.pk,
spk: bob_pk,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let res = alice.handle_packet(dht_req, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_nat_ping_req() {
let (alice, precomp, bob_pk, bob_sk, rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let nat_req = NatPingRequest { id: 42 };
let nat_payload = DhtRequestPayload::NatPingRequest(nat_req);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &alice.pk, &bob_pk, &nat_payload));
tokio::time::pause();
tokio::time::advance(Duration::from_secs(1)).await;
alice.handle_packet(dht_req, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let dht_req = unpack!(packet, Packet::DhtRequest);
let precomputed_key = precompute(&dht_req.spk, &bob_sk);
let dht_payload = dht_req.get_payload(&precomputed_key).unwrap();
let nat_ping_resp_payload = unpack!(dht_payload, DhtRequestPayload::NatPingResponse);
assert_eq!(nat_ping_resp_payload.id, nat_req.id);
let friends = alice.friends.read().await;
let time = clock_now();
assert_eq!(friends[&bob_pk].hole_punch.last_recv_ping_time, time);
}
#[tokio::test]
async fn handle_nat_ping_resp() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
alice.add_friend(bob_pk).await;
let ping_id = alice.friends.read().await[&bob_pk].hole_punch.ping_id;
let nat_res = NatPingResponse { id: ping_id };
let nat_payload = DhtRequestPayload::NatPingResponse(nat_res);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &alice.pk, &bob_pk, &nat_payload));
alice.handle_packet(dht_req, addr).await.unwrap();
let friends = alice.friends.read().await;
assert!(!friends[&bob_pk].hole_punch.is_punching_done);
}
#[tokio::test]
async fn handle_nat_ping_resp_ping_id_is_0() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let nat_res = NatPingResponse { id: 0 };
let nat_payload = DhtRequestPayload::NatPingResponse(nat_res);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &alice.pk, &bob_pk, &nat_payload));
let res = alice.handle_packet(dht_req, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::ZeroPingId);
}
#[tokio::test]
async fn handle_nat_ping_resp_invalid_ping_id() {
let (alice, precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let ping_id = alice.request_queue.write().await.new_ping_id(bob_pk);
let nat_res = NatPingResponse { id: ping_id + 1 };
let nat_payload = DhtRequestPayload::NatPingResponse(nat_res);
let dht_req = Packet::DhtRequest(DhtRequest::new(&precomp, &alice.pk, &bob_pk, &nat_payload));
let res = alice.handle_packet(dht_req, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::NoFriend);
}
#[tokio::test]
async fn handle_onion_request_0() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let temporary_pk = gen_keypair().0;
let inner = vec![42; 123];
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let payload = OnionRequest0Payload {
ip_port: ip_port.clone(),
temporary_pk,
inner: inner.clone()
};
let packet = Packet::OnionRequest0(OnionRequest0::new(&precomp, &bob_pk, &payload));
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionRequest1);
assert_eq!(next_packet.temporary_pk, temporary_pk);
assert_eq!(next_packet.payload, inner);
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let onion_return_payload = next_packet.onion_return.get_payload(&onion_symmetric_key).unwrap();
assert_eq!(onion_return_payload.0, IpPort::from_udp_saddr(addr));
}
#[tokio::test]
async fn handle_onion_request_0_invalid_payload() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let packet = Packet::OnionRequest0(OnionRequest0 {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123] });
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_onion_request_1() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let temporary_pk = gen_keypair().0;
let inner = vec![42; 123];
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let payload = OnionRequest1Payload {
ip_port: ip_port.clone(),
temporary_pk,
inner: inner.clone()
};
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_1_PAYLOAD_SIZE]
};
let packet = Packet::OnionRequest1(OnionRequest1::new(&precomp, &bob_pk, &payload, onion_return));
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionRequest2);
assert_eq!(next_packet.temporary_pk, temporary_pk);
assert_eq!(next_packet.payload, inner);
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let onion_return_payload = next_packet.onion_return.get_payload(&onion_symmetric_key).unwrap();
assert_eq!(onion_return_payload.0, IpPort::from_udp_saddr(addr));
}
#[tokio::test]
async fn handle_onion_request_1_invalid_payload() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let packet = Packet::OnionRequest1(OnionRequest1 {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123], onion_return: OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_1_PAYLOAD_SIZE]
}
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_onion_request_2_with_onion_announce_request() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let inner = InnerOnionAnnounceRequest {
nonce: gen_nonce(),
pk: gen_keypair().0,
payload: vec![42; 123]
};
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let payload = OnionRequest2Payload {
ip_port: ip_port.clone(),
inner: InnerOnionRequest::InnerOnionAnnounceRequest(inner.clone())
};
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_2_PAYLOAD_SIZE]
};
let packet = Packet::OnionRequest2(OnionRequest2::new(&precomp, &bob_pk, &payload, onion_return));
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionAnnounceRequest);
assert_eq!(next_packet.inner, inner);
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let onion_return_payload = next_packet.onion_return.get_payload(&onion_symmetric_key).unwrap();
assert_eq!(onion_return_payload.0, IpPort::from_udp_saddr(addr));
}
#[tokio::test]
async fn handle_onion_request_2_with_onion_data_request() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let inner = InnerOnionDataRequest {
destination_pk: gen_keypair().0,
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
};
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let payload = OnionRequest2Payload {
ip_port: ip_port.clone(),
inner: InnerOnionRequest::InnerOnionDataRequest(inner.clone())
};
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_2_PAYLOAD_SIZE]
};
let packet = Packet::OnionRequest2(OnionRequest2::new(&precomp, &bob_pk, &payload, onion_return));
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionDataRequest);
assert_eq!(next_packet.inner, inner);
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let onion_return_payload = next_packet.onion_return.get_payload(&onion_symmetric_key).unwrap();
assert_eq!(onion_return_payload.0, IpPort::from_udp_saddr(addr));
}
#[tokio::test]
async fn handle_onion_request_2_invalid_payload() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let packet = Packet::OnionRequest2(OnionRequest2 {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123], onion_return: OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_2_PAYLOAD_SIZE]
}
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_onion_announce_request() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let sendback_data = 42;
let payload = OnionAnnounceRequestPayload {
ping_id: initial_ping_id(),
search_pk: gen_keypair().0,
data_pk: gen_keypair().0,
sendback_data
};
let inner = InnerOnionAnnounceRequest::new(&precomp, &bob_pk, &payload);
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_3_PAYLOAD_SIZE]
};
let packet = Packet::OnionAnnounceRequest(OnionAnnounceRequest {
inner,
onion_return: onion_return.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let response = unpack!(packet, Packet::OnionResponse3);
assert_eq!(response.onion_return, onion_return);
let response = unpack!(response.payload, InnerOnionResponse::OnionAnnounceResponse);
assert_eq!(response.sendback_data, sendback_data);
let payload = response.get_payload(&precomp).unwrap();
assert_eq!(payload.announce_status, AnnounceStatus::Failed);
}
#[tokio::test]
async fn handle_onion_announce_request_invalid_payload() {
let (alice, _precomp, bob_pk, _bob_sk, _rx, addr) = create_node();
let inner = InnerOnionAnnounceRequest {
nonce: gen_nonce(),
pk: bob_pk,
payload: vec![42; 123]
};
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_3_PAYLOAD_SIZE]
};
let packet = Packet::OnionAnnounceRequest(OnionAnnounceRequest {
inner,
onion_return: onion_return.clone()
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::GetPayload);
}
#[tokio::test]
async fn handle_onion_data_request() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let payload = OnionAnnounceRequestPayload {
ping_id: initial_ping_id(),
search_pk: gen_keypair().0,
data_pk: gen_keypair().0,
sendback_data: 42
};
let inner = InnerOnionAnnounceRequest::new(&precomp, &bob_pk, &payload);
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_3_PAYLOAD_SIZE]
};
let packet = Packet::OnionAnnounceRequest(OnionAnnounceRequest {
inner,
onion_return: onion_return.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, rx) = rx.into_future().await;
let (packet, _addr_to_send) = received.unwrap();
let response = unpack!(packet, Packet::OnionResponse3);
let response = unpack!(response.payload, InnerOnionResponse::OnionAnnounceResponse);
let payload = response.get_payload(&precomp).unwrap();
let ping_id = payload.ping_id_or_pk;
let payload = OnionAnnounceRequestPayload {
ping_id,
search_pk: gen_keypair().0,
data_pk: gen_keypair().0,
sendback_data: 42
};
let inner = InnerOnionAnnounceRequest::new(&precomp, &bob_pk, &payload);
let packet = Packet::OnionAnnounceRequest(OnionAnnounceRequest {
inner,
onion_return: onion_return.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let nonce = gen_nonce();
let temporary_pk = gen_keypair().0;
let payload = vec![42; 123];
let inner = InnerOnionDataRequest {
destination_pk: bob_pk,
nonce,
temporary_pk,
payload: payload.clone()
};
let packet = Packet::OnionDataRequest(OnionDataRequest {
inner,
onion_return: onion_return.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.skip(1).into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let response = unpack!(packet, Packet::OnionResponse3);
assert_eq!(response.onion_return, onion_return);
let response = unpack!(response.payload, InnerOnionResponse::OnionDataResponse);
assert_eq!(response.nonce, nonce);
assert_eq!(response.temporary_pk, temporary_pk);
assert_eq!(response.payload, payload);
}
#[tokio::test]
async fn handle_onion_response_3() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let next_onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_2_PAYLOAD_SIZE]
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, Some(&next_onion_return));
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse3(OnionResponse3 {
onion_return,
payload: payload.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionResponse2);
assert_eq!(next_packet.payload, payload);
assert_eq!(next_packet.onion_return, next_onion_return);
}
#[tokio::test]
async fn handle_onion_response_3_invalid_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_3_PAYLOAD_SIZE] };
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse3(OnionResponse3 {
onion_return,
payload
});
alice.handle_packet(packet, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_onion_response_3_invalid_next_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = OnionDataResponse {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
};
let packet = Packet::OnionResponse3(OnionResponse3 {
onion_return,
payload: InnerOnionResponse::OnionDataResponse(inner.clone())
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionResponseNext);
}
#[tokio::test]
async fn handle_onion_response_2() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let next_onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_1_PAYLOAD_SIZE]
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, Some(&next_onion_return));
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse2(OnionResponse2 {
onion_return,
payload: payload.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionResponse1);
assert_eq!(next_packet.payload, payload);
assert_eq!(next_packet.onion_return, next_onion_return);
}
#[tokio::test]
async fn handle_onion_response_2_invalid_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_2_PAYLOAD_SIZE] };
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse2(OnionResponse2 {
onion_return,
payload
});
alice.handle_packet(packet, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_onion_response_2_invalid_next_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = OnionDataResponse {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
};
let packet = Packet::OnionResponse2(OnionResponse2 {
onion_return,
payload: InnerOnionResponse::OnionDataResponse(inner.clone())
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionResponseNext);
}
#[tokio::test]
async fn handle_onion_response_1_with_onion_announce_response() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
};
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload: InnerOnionResponse::OnionAnnounceResponse(inner.clone())
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionAnnounceResponse);
assert_eq!(next_packet, inner);
}
#[tokio::test]
async fn server_handle_onion_response_1_with_onion_data_response_test() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = OnionDataResponse {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
};
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload: InnerOnionResponse::OnionDataResponse(inner.clone())
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionDataResponse);
assert_eq!(next_packet, inner);
}
#[tokio::test]
async fn handle_onion_response_1_redirect_to_tcp() {
let (mut alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
let (tcp_onion_tx, tcp_onion_rx) = mpsc::channel(1);
alice.set_tcp_onion_sink(tcp_onion_tx);
let addr: SocketAddr = "127.0.0.1:12346".parse().unwrap();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::TCP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload: inner.clone()
});
alice.handle_packet(packet, addr).await.unwrap();
let (received, _tcp_onion_rx) = tcp_onion_rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
assert_eq!(packet, inner);
}
#[tokio::test]
async fn handle_onion_response_1_can_not_redirect_to_tcp() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::TCP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, None);
let inner = OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
};
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload: InnerOnionResponse::OnionAnnounceResponse(inner.clone())
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionResponseRedirect);
}
#[tokio::test]
async fn handle_onion_response_1_invalid_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_1_PAYLOAD_SIZE] };
let payload = InnerOnionResponse::OnionAnnounceResponse(OnionAnnounceResponse {
sendback_data: 12345,
nonce: gen_nonce(),
payload: vec![42; 123]
});
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload
});
alice.handle_packet(packet, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_onion_response_1_invalid_next_onion_return() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let next_onion_return = OnionReturn {
nonce: secretbox::gen_nonce(),
payload: vec![42; ONION_RETURN_1_PAYLOAD_SIZE]
};
let onion_return = OnionReturn::new(&onion_symmetric_key, &ip_port, Some(&next_onion_return));
let inner = OnionDataResponse {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
};
let packet = Packet::OnionResponse1(OnionResponse1 {
onion_return,
payload: InnerOnionResponse::OnionDataResponse(inner.clone())
});
let res = alice.handle_packet(packet, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionResponseNext);
}
#[tokio::test]
async fn send_nat_ping_req() {
let (alice, _precomp, _bob_pk, _bob_sk, mut rx, _addr) = create_node();
let (friend_pk, friend_sk) = gen_keypair();
let nodes = [
PackedNode::new("127.1.1.1:12345".parse().unwrap(), &gen_keypair().0),
PackedNode::new("127.1.1.2:12345".parse().unwrap(), &gen_keypair().0),
PackedNode::new("127.1.1.3:12345".parse().unwrap(), &gen_keypair().0),
PackedNode::new("127.1.1.4:12345".parse().unwrap(), &gen_keypair().0),
];
alice.add_friend(friend_pk).await;
let mut friends = alice.friends.write().await;
for &node in &nodes {
let friend = friends.get_mut(&friend_pk).unwrap();
friend.try_add_to_close(node);
let dht_node = friend.close_nodes.get_node_mut(&friend_pk, &node.pk).unwrap();
dht_node.update_returned_addr(node.saddr);
}
drop(friends);
alice.dht_main_loop().await.unwrap();
loop {
let (received, rx1) = rx.into_future().await;
let (packet, _addr_to_send) = received.unwrap();
if let Packet::DhtRequest(nat_ping_req) = packet {
let precomputed_key = precompute(&nat_ping_req.spk, &friend_sk);
let nat_ping_req_payload = nat_ping_req.get_payload(&precomputed_key).unwrap();
let nat_ping_req_payload = unpack!(nat_ping_req_payload, DhtRequestPayload::NatPingRequest);
assert_eq!(alice.friends.read().await[&friend_pk].hole_punch.ping_id, nat_ping_req_payload.id);
break;
}
rx = rx1;
}
}
#[tokio::test]
async fn handle_lan_discovery() {
let (alice, _precomp, bob_pk, bob_sk, rx, addr) = create_node();
let lan = Packet::LanDiscovery(LanDiscovery { pk: bob_pk });
alice.handle_packet(lan, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_req = unpack!(packet, Packet::NodesRequest);
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert_eq!(nodes_req_payload.pk, alice.pk);
}
#[tokio::test]
async fn handle_lan_discovery_for_ourselves() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let lan = Packet::LanDiscovery(LanDiscovery { pk: alice.pk });
alice.handle_packet(lan, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn handle_lan_discovery_when_disabled() {
let (mut alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
alice.enable_lan_discovery(false);
assert_eq!(alice.lan_discovery_enabled, false);
let lan = Packet::LanDiscovery(LanDiscovery { pk: alice.pk });
alice.handle_packet(lan, addr).await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn refresh_onion_key() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
let onion_symmetric_key_1 = alice.onion_symmetric_key.read().await.clone();
alice.refresh_onion_key().await;
let onion_symmetric_key_2 = alice.onion_symmetric_key.read().await.clone();
assert_ne!(onion_symmetric_key_1, onion_symmetric_key_2)
}
#[tokio::test]
async fn handle_tcp_onion_request() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, addr) = create_node();
let temporary_pk = gen_keypair().0;
let payload = vec![42; 123];
let ip_port = IpPort {
protocol: ProtocolType::UDP,
ip_addr: "5.6.7.8".parse().unwrap(),
port: 12345
};
let packet = OnionRequest {
nonce: gen_nonce(),
ip_port: ip_port.clone(),
temporary_pk,
payload: payload.clone()
};
alice.handle_tcp_onion_request(packet, addr).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, ip_port.to_saddr());
let next_packet = unpack!(packet, Packet::OnionRequest1);
assert_eq!(next_packet.temporary_pk, temporary_pk);
assert_eq!(next_packet.payload, payload);
let onion_symmetric_key = alice.onion_symmetric_key.read().await;
let onion_return_payload = next_packet.onion_return.get_payload(&onion_symmetric_key).unwrap();
assert_eq!(onion_return_payload.0, IpPort::from_tcp_saddr(addr));
}
#[tokio::test]
async fn ping_nodes_to_bootstrap() {
let (alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(alice.nodes_to_bootstrap.write().await.try_add(&alice.pk, pn, true));
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(alice.nodes_to_bootstrap.write().await.try_add(&alice.pk, pn, true));
alice.dht_main_loop().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "127.0.0.1:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
assert_eq!(nodes_req_payload.pk, alice.pk);
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
assert_eq!(nodes_req_payload.pk, alice.pk);
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn ping_nodes_from_nodes_to_ping_list() {
let (alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(alice.nodes_to_ping.write().await.try_add(&alice.pk, pn, true));
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(alice.nodes_to_ping.write().await.try_add(&alice.pk, pn, true));
alice.send_pings().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::PingRequest);
if addr == "127.0.0.1:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let ping_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(ping_req_payload.id, |&pk| pk == bob_pk).is_some());
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let ping_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(ping_req_payload.id, |&pk| pk == node_pk).is_some());
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn ping_nodes_when_nodes_to_ping_list_is_empty() {
let (alice, _precomp, _bob_pk, _bob_sk, rx, _addr) = create_node();
alice.send_pings().await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn ping_close_nodes() {
let (alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
alice.dht_main_loop().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(3).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "127.0.0.1:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
assert_eq!(nodes_req_payload.pk, alice.pk);
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
assert_eq!(nodes_req_payload.pk, alice.pk);
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn send_nodes_req_random_periodicity() {
let (alice, _precomp, bob_pk, _bob_sk, mut rx, _addr) = create_node();
{
let mut close_nodes = alice.close_nodes.write().await;
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &bob_pk);
assert!(close_nodes.try_add(pn));
let node = close_nodes.get_node_mut(&bob_pk).unwrap();
node.assoc4.last_ping_req_time = Some(clock_now());
node.assoc6.last_ping_req_time = Some(clock_now());
}
tokio::time::pause();
for _ in 0 .. MAX_BOOTSTRAP_TIMES {
alice.dht_main_loop().await.unwrap();
let (received, rx1) = rx.into_future().await;
let (packet, _) = received.unwrap();
unpack!(packet, Packet::NodesRequest);
tokio::time::advance(Duration::from_secs(1)).await;
rx = rx1;
}
alice.dht_main_loop().await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn ping_nodes_to_bootstrap_of_friend() {
let (alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
let mut friends = alice.friends.write().await;
let friend = friends.get_mut(&friend_pk).unwrap();
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(friend.nodes_to_bootstrap.try_add(&alice.pk, pn, true));
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(friend.nodes_to_bootstrap.try_add(&alice.pk, pn, true));
drop(friends);
alice.dht_main_loop().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "127.0.0.1:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
assert_eq!(nodes_req_payload.pk, friend_pk);
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
assert_eq!(nodes_req_payload.pk, friend_pk);
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn ping_close_nodes_of_friend() {
let (alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
{
let mut friends = alice.friends.write().await;
let friend = friends.get_mut(&friend_pk).unwrap();
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(friend.try_add_to_close(pn));
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(friend.try_add_to_close(pn));
}
alice.dht_main_loop().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(3).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "127.0.0.1:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
assert_eq!(nodes_req_payload.pk, friend_pk);
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
assert_eq!(nodes_req_payload.pk, friend_pk);
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn send_nodes_req_random_friend_periodicity() {
let (alice, _precomp, bob_pk, _bob_sk, mut rx, _addr) = create_node();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
let mut friends = alice.friends.write().await;
let friend = friends.get_mut(&friend_pk).unwrap();
let pn = PackedNode::new("127.0.0.1:33445".parse().unwrap(), &bob_pk);
assert!(friend.try_add_to_close(pn));
friend.close_nodes.nodes[0].assoc4.last_ping_req_time = Some(clock_now());
friend.close_nodes.nodes[0].assoc6.last_ping_req_time = Some(clock_now());
drop(friends);
tokio::time::pause();
for _ in 0 .. MAX_BOOTSTRAP_TIMES {
alice.friends.write().await.get_mut(&friend_pk).unwrap().hole_punch.last_send_ping_time = Some(clock_now());
alice.dht_main_loop().await.unwrap();
let (received, rx1) = rx.into_future().await;
let (packet, _) = received.unwrap();
unpack!(packet, Packet::NodesRequest);
tokio::time::advance(Duration::from_secs(1)).await;
rx = rx1;
}
alice.dht_main_loop().await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn enable_ipv6_mode() {
let (mut alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
alice.enable_ipv6_mode(true);
assert_eq!(alice.is_ipv6_enabled, true);
}
#[tokio::test]
async fn send_to() {
let (mut alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let pn = PackedNode::new("[FF::01]:33445".parse().unwrap(), &bob_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
alice.enable_ipv6_mode(true);
alice.dht_main_loop().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "[FF::01]:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn send_bootstrap_requests() {
let (mut alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let pn = PackedNode::new("[FF::01]:33445".parse().unwrap(), &bob_pk);
alice.add_initial_bootstrap(pn);
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
alice.add_initial_bootstrap(pn);
alice.enable_ipv6_mode(true);
alice.send_bootstrap_requests().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "[FF::01]:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn send_bootstrap_requests_when_ktree_has_good_node() {
let (mut alice, _precomp, bob_pk, _bob_sk, rx, _addr) = create_node();
let (node_pk, _node_sk) = gen_keypair();
let pn = PackedNode::new("[FF::01]:33445".parse().unwrap(), &bob_pk);
alice.add_initial_bootstrap(pn);
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(alice.close_nodes.write().await.try_add(pn));
alice.enable_ipv6_mode(true);
alice.send_bootstrap_requests().await.unwrap();
drop(alice);
assert!(rx.collect::<Vec<_>>().await.is_empty());
}
#[tokio::test]
async fn send_bootstrap_requests_with_discarded() {
let (mut alice, _precomp, bob_pk, bob_sk, rx, _addr) = create_node();
let (node_pk, node_sk) = gen_keypair();
let mut close_nodes = alice.close_nodes.write().await;
let pn = PackedNode::new("[FF::01]:33445".parse().unwrap(), &bob_pk);
assert!(close_nodes.try_add(pn));
let pn = PackedNode::new("127.1.1.1:12345".parse().unwrap(), &node_pk);
assert!(close_nodes.try_add(pn));
drop(close_nodes);
alice.enable_ipv6_mode(true);
tokio::time::pause();
tokio::time::advance(KILL_NODE_TIMEOUT + Duration::from_secs(1)).await;
alice.send_bootstrap_requests().await.unwrap();
let mut request_queue = alice.request_queue.write().await;
rx.take(2).map(|(packet, addr)| {
let nodes_req = unpack!(packet, Packet::NodesRequest);
if addr == "[FF::01]:33445".parse().unwrap() {
let precomputed_key = precompute(&nodes_req.pk, &bob_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == bob_pk).is_some());
} else {
let precomputed_key = precompute(&nodes_req.pk, &node_sk);
let nodes_req_payload = nodes_req.get_payload(&precomputed_key).unwrap();
assert!(request_queue.check_ping_id(nodes_req_payload.id, |&pk| pk == node_pk).is_some());
}
}).collect::<Vec<_>>().await;
}
#[tokio::test]
async fn handle_crypto_data_uninitialized() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let data_payload = CryptoDataPayload {
buffer_start: 1,
packet_number: 0,
data: vec![1, 2, 3, 4]
};
let data = Packet::CryptoData(CryptoData::new(&precomp, gen_nonce(), &data_payload));
let res = alice.handle_packet(data, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::NetCrypto);
}
#[tokio::test]
async fn handle_onion_data_response_uninitialized() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let data = Packet::OnionDataResponse(OnionDataResponse {
nonce: gen_nonce(),
temporary_pk: gen_keypair().0,
payload: vec![42; 123]
});
let res = alice.handle_packet(data, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionClient);
}
#[tokio::test]
async fn handle_onion_announce_response_uninitialized() {
let (alice, precomp, _bob_pk, _bob_sk, _rx, addr) = create_node();
let payload = OnionAnnounceResponsePayload {
announce_status: AnnounceStatus::Found,
ping_id_or_pk: sha256::hash(&[1, 2, 3]),
nodes: vec![
PackedNode::new(SocketAddr::V4("5.6.7.8:12345".parse().unwrap()), &gen_keypair().0)
]
};
let data = Packet::OnionAnnounceResponse(OnionAnnounceResponse::new(&precomp, 12345, &payload));
let res = alice.handle_packet(data, addr).await;
assert!(res.is_err());
assert_eq!(*res.err().unwrap().kind(), HandlePacketErrorKind::OnionClient);
}
#[tokio::test]
async fn ping_node() {
let (alice, precomp, bob_pk, _bob_sk, rx, addr) = create_node();
let node = PackedNode::new(addr, &bob_pk);
alice.ping_node(&node).await.unwrap();
let (received, _rx) = rx.into_future().await;
let (packet, addr_to_send) = received.unwrap();
assert_eq!(addr_to_send, addr);
let nodes_req = unpack!(packet, Packet::NodesRequest);
let nodes_req_payload = nodes_req.get_payload(&precomp).unwrap();
assert_eq!(nodes_req_payload.pk, alice.pk);
}
#[tokio::test]
async fn random_friend_nodes() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
let friend_pk = gen_keypair().0;
alice.add_friend(friend_pk).await;
let mut friends = alice.friends.write().await;
for pk in &alice.fake_friends_keys {
let node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0);
assert!(friends.get_mut(pk).unwrap().close_nodes.try_add(pk, node, true));
}
let node = PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0);
assert!(friends.get_mut(&friend_pk).unwrap().close_nodes.try_add(&friend_pk, node, true));
drop(friends);
let nodes = alice.random_friend_nodes(FAKE_FRIENDS_NUMBER as u8 + 1).await;
assert_eq!(nodes.len(), FAKE_FRIENDS_NUMBER);
assert!(!nodes.contains(&node));
let nodes = alice.random_friend_nodes(FAKE_FRIENDS_NUMBER as u8 - 1).await;
assert_eq!(nodes.len(), FAKE_FRIENDS_NUMBER - 1);
assert!(!nodes.contains(&node));
}
#[tokio::test]
async fn is_connected() {
let (alice, _precomp, _bob_pk, _bob_sk, _rx, _addr) = create_node();
assert!(!alice.is_connected().await);
alice.add_node(PackedNode::new("127.0.0.1:12345".parse().unwrap(), &gen_keypair().0)).await;
assert!(alice.is_connected().await);
}
}