use async_trait::async_trait;
use crate::dht::types::Chord;
use crate::dht::types::CorrectChord;
use crate::dht::PeerRingAction;
use crate::dht::TopoInfo;
use crate::error::Error;
use crate::error::Result;
use crate::message::types::ConnectNodeReport;
use crate::message::types::ConnectNodeSend;
use crate::message::types::FindSuccessorReport;
use crate::message::types::FindSuccessorSend;
use crate::message::types::Message;
use crate::message::types::QueryForTopoInfoReport;
use crate::message::types::QueryForTopoInfoSend;
use crate::message::types::Then;
use crate::message::FindSuccessorReportHandler;
use crate::message::FindSuccessorThen;
use crate::message::HandleMsg;
use crate::message::MessageHandler;
use crate::message::MessagePayload;
use crate::message::PayloadSender;
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<QueryForTopoInfoSend> for MessageHandler {
async fn handle(&self, ctx: &MessagePayload, msg: &QueryForTopoInfoSend) -> Result<()> {
let info: TopoInfo = TopoInfo::try_from(self.dht.as_ref())?;
if msg.did == self.dht.did {
self.transport
.send_report_message(ctx, Message::QueryForTopoInfoReport(msg.resp(info)))
.await?
}
Ok(())
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<QueryForTopoInfoReport> for MessageHandler {
async fn handle(&self, _ctx: &MessagePayload, msg: &QueryForTopoInfoReport) -> Result<()> {
match msg.then {
<QueryForTopoInfoReport as Then>::Then::SyncSuccessor => {
for peer in msg.info.successors.iter() {
self.join_dht(*peer).await?;
}
}
<QueryForTopoInfoReport as Then>::Then::Stabilization => {
let ev = self.dht.stabilize(msg.info.clone())?;
self.handle_dht_events(&ev).await?;
}
}
Ok(())
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<ConnectNodeSend> for MessageHandler {
async fn handle(&self, ctx: &MessagePayload, msg: &ConnectNodeSend) -> Result<()> {
if msg.network_id != self.transport.network_id {
return Ok(());
}
if self.dht.did != ctx.relay.destination {
self.transport.forward_payload(ctx, None).await
} else {
let answer = self
.transport
.answer_remote_connection(ctx.relay.origin_sender(), self.inner_callback(), msg)
.await?;
self.transport
.send_report_message(ctx, Message::ConnectNodeReport(answer))
.await
}
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<ConnectNodeReport> for MessageHandler {
async fn handle(&self, ctx: &MessagePayload, msg: &ConnectNodeReport) -> Result<()> {
if self.dht.did != ctx.relay.destination {
self.transport.forward_payload(ctx, None).await
} else {
self.transport
.accept_remote_connection(ctx.relay.origin_sender(), msg)
.await
}
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<FindSuccessorSend> for MessageHandler {
async fn handle(&self, ctx: &MessagePayload, msg: &FindSuccessorSend) -> Result<()> {
match self.dht.find_successor(msg.did)? {
PeerRingAction::Some(did) => {
if !msg.strict || self.dht.did == msg.did {
match &msg.then {
FindSuccessorThen::Report(handler) => {
self.transport
.send_report_message(
ctx,
Message::FindSuccessorReport(FindSuccessorReport {
did,
handler: handler.clone(),
}),
)
.await
}
}
} else {
self.transport.forward_payload(ctx, Some(did)).await
}
}
PeerRingAction::RemoteAction(next, _) => {
self.transport.reset_destination(ctx, next).await
}
act => Err(Error::PeerRingUnexpectedAction(act)),
}
}
}
#[cfg_attr(feature = "wasm", async_trait(?Send))]
#[cfg_attr(not(feature = "wasm"), async_trait)]
impl HandleMsg<FindSuccessorReport> for MessageHandler {
async fn handle(&self, ctx: &MessagePayload, msg: &FindSuccessorReport) -> Result<()> {
if self.dht.did != ctx.relay.destination {
return self.transport.forward_payload(ctx, None).await;
}
match &msg.handler {
FindSuccessorReportHandler::FixFingerTable | FindSuccessorReportHandler::Connect => {
if msg.did != self.dht.did {
let offer_msg = self
.transport
.prepare_connection_offer(msg.did, self.inner_callback())
.await?;
self.transport
.send_message(Message::ConnectNodeSend(offer_msg), msg.did)
.await?;
}
}
_ => {}
}
Ok(())
}
}
#[cfg(not(feature = "wasm"))]
#[cfg(test)]
pub mod tests {
use std::matches;
use rings_transport::core::transport::WebrtcConnectionState;
use tokio::time::sleep;
use tokio::time::Duration;
use super::*;
use crate::dht::successor::SuccessorReader;
use crate::ecc::tests::gen_ordered_keys;
use crate::ecc::SecretKey;
use crate::tests::default::assert_no_more_msg;
use crate::tests::default::prepare_node;
use crate::tests::default::wait_for_msgs;
use crate::tests::default::Node;
use crate::tests::manually_establish_connection;
#[tokio::test]
async fn test_triple_nodes_connection_1_2_3() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_ordered_nodes_connection(key1, key2, key3).await?;
Ok(())
}
#[tokio::test]
async fn test_triple_nodes_connection_2_3_1() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_ordered_nodes_connection(key2, key3, key1).await?;
Ok(())
}
#[tokio::test]
async fn test_triple_nodes_connection_3_1_2() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_ordered_nodes_connection(key3, key1, key2).await?;
Ok(())
}
#[tokio::test]
async fn test_triple_nodes_connection_3_2_1() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_desc_ordered_nodes_connection(key3, key2, key1).await?;
Ok(())
}
#[tokio::test]
async fn test_triple_nodes_connection_2_1_3() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_desc_ordered_nodes_connection(key2, key1, key3).await?;
Ok(())
}
#[tokio::test]
async fn test_triple_nodes_connection_1_3_2() -> Result<()> {
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
test_triple_desc_ordered_nodes_connection(key1, key3, key2).await?;
Ok(())
}
async fn test_triple_ordered_nodes_connection(
key1: SecretKey,
key2: SecretKey,
key3: SecretKey,
) -> Result<(Node, Node, Node)> {
let node1 = prepare_node(key1).await;
let node2 = prepare_node(key2).await;
let node3 = prepare_node(key3).await;
println!("========================================");
println!("|| now we connect node1 and node2 ||");
println!("========================================");
manually_establish_connection(&node1.swarm, &node2.swarm).await;
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
node1.assert_transports(vec![node2.did()]);
node2.assert_transports(vec![node1.did()]);
node3.assert_transports(vec![]);
assert_eq!(node1.dht().successors().list()?, vec![node2.did()]);
assert_eq!(node2.dht().successors().list()?, vec![node1.did()]);
assert_eq!(node3.dht().successors().list()?, vec![]);
println!("========================================");
println!("|| now we start join node3 to node2 ||");
println!("========================================");
manually_establish_connection(&node3.swarm, &node2.swarm).await;
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
println!("=== Check state before connect via DHT ===");
node1.assert_transports(vec![node2.did()]);
node2.assert_transports(vec![node1.did(), node3.did()]);
node3.assert_transports(vec![node2.did()]);
assert_eq!(node1.dht().successors().list()?, vec![node2.did(),]);
assert_eq!(node2.dht().successors().list()?, vec![
node3.did(),
node1.did()
]);
assert_eq!(node3.dht().successors().list()?, vec![node2.did()]);
println!("=============================================");
println!("|| now we connect node1 to node3 via DHT ||");
println!("=============================================");
assert!(node1.swarm.transport.get_connection(node3.did()).is_none());
assert_eq!(node1.dht().successors().max()?, node2.did());
node1.swarm.connect(node3.did()).await.unwrap();
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
println!("=== Check state after connect via DHT ===");
node1.assert_transports(vec![node2.did(), node3.did()]);
node2.assert_transports(vec![node1.did(), node3.did()]);
node3.assert_transports(vec![node1.did(), node2.did()]);
assert_eq!(node1.dht().successors().list()?, vec![
node2.did(),
node3.did()
]);
assert_eq!(node2.dht().successors().list()?, vec![
node3.did(),
node1.did()
]);
assert_eq!(node3.dht().successors().list()?, vec![
node1.did(),
node2.did()
]);
Ok((node1, node2, node3))
}
async fn test_triple_desc_ordered_nodes_connection(
key1: SecretKey,
key2: SecretKey,
key3: SecretKey,
) -> Result<(Node, Node, Node)> {
let node1 = prepare_node(key1).await;
let node2 = prepare_node(key2).await;
let node3 = prepare_node(key3).await;
println!("========================================");
println!("|| now we connect node1 and node2 ||");
println!("========================================");
manually_establish_connection(&node1.swarm, &node2.swarm).await;
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
assert_eq!(node1.dht().successors().list()?, vec![node2.did()]);
assert_eq!(node2.dht().successors().list()?, vec![node1.did()]);
assert_eq!(node3.dht().successors().list()?, vec![]);
println!("========================================");
println!("|| now we start join node3 to node2 ||");
println!("========================================");
manually_establish_connection(&node3.swarm, &node2.swarm).await;
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
println!("=== Check state before connect via DHT ===");
node1.assert_transports(vec![node2.did()]);
node2.assert_transports(vec![node1.did(), node3.did()]);
node3.assert_transports(vec![node2.did()]);
assert_eq!(node1.dht().successors().list()?, vec![node2.did()]);
assert_eq!(node2.dht().successors().list()?, vec![
node1.did(),
node3.did()
]);
assert_eq!(node3.dht().successors().list()?, vec![node2.did()]);
println!("=============================================");
println!("|| now we connect node1 to node3 via DHT ||");
println!("=============================================");
assert!(node1.swarm.transport.get_connection(node3.did()).is_none());
assert_eq!(node1.dht().successors().max()?, node2.did());
node1.swarm.connect(node3.did()).await.unwrap();
wait_for_msgs([&node1, &node2, &node3]).await;
assert_no_more_msg([&node1, &node2, &node3]).await;
println!("=== Check state after connect via DHT ===");
node1.assert_transports(vec![node2.did(), node3.did()]);
node2.assert_transports(vec![node1.did(), node3.did()]);
node3.assert_transports(vec![node1.did(), node2.did()]);
assert_eq!(node1.dht().successors().list()?, vec![
node3.did(),
node2.did()
]);
assert_eq!(node2.dht().successors().list()?, vec![
node1.did(),
node3.did()
]);
assert_eq!(node3.dht().successors().list()?, vec![
node2.did(),
node1.did()
]);
Ok((node1, node2, node3))
}
#[tokio::test]
async fn test_fourth_node_connection() -> Result<()> {
let keys = gen_ordered_keys(4);
let (key1, key2, key3, key4) = (keys[0], keys[1], keys[2], keys[3]);
let (node1, node2, node3) = test_triple_ordered_nodes_connection(key1, key2, key3).await?;
let node4 = prepare_node(key4).await;
manually_establish_connection(&node4.swarm, &node2.swarm).await;
tokio::time::sleep(Duration::from_secs(6)).await;
println!("=== Check state before connect via DHT ===");
node1.assert_transports(vec![node2.did(), node3.did(), node4.did()]);
node2.assert_transports(vec![node3.did(), node4.did(), node1.did()]);
node3.assert_transports(vec![node1.did(), node2.did()]);
node4.assert_transports(vec![node1.did(), node2.did()]);
assert_eq!(node1.dht().successors().list().unwrap(), vec![
node2.did(),
node3.did(),
node4.did(),
]);
assert_eq!(node2.dht().successors().list().unwrap(), vec![
node3.did(),
node4.did(),
node1.did(),
]);
assert_eq!(node3.dht().successors().list().unwrap(), vec![
node1.did(),
node2.did(),
]);
assert_eq!(node4.dht().successors().list().unwrap(), vec![
node1.did(),
node2.did(),
]);
println!("========================================");
println!("| test node4 connect node3 via dht |");
println!("========================================");
println!(
"node1.did(): {:?}, node2.did(): {:?}, node3.did(): {:?}, node4.did(): {:?}",
node1.did(),
node2.did(),
node3.did(),
node4.did(),
);
println!("==================================================");
node4.swarm.connect(node3.did()).await.unwrap();
tokio::time::sleep(Duration::from_secs(6)).await;
println!("=== Check state after connect via DHT ===");
node1.assert_transports(vec![node2.did(), node3.did(), node4.did()]);
node2.assert_transports(vec![node3.did(), node4.did(), node1.did()]);
node3.assert_transports(vec![node4.did(), node1.did(), node2.did()]);
node4.assert_transports(vec![node1.did(), node2.did(), node3.did()]);
assert_eq!(node1.dht().successors().list().unwrap(), vec![
node2.did(),
node3.did(),
node4.did()
]);
assert_eq!(node2.dht().successors().list().unwrap(), vec![
node3.did(),
node4.did(),
node1.did(),
]);
assert_eq!(node3.dht().successors().list().unwrap(), vec![
node4.did(),
node1.did(),
node2.did(),
]);
assert_eq!(node4.dht().successors().list().unwrap(), vec![
node1.did(),
node2.did(),
node3.did(),
]);
Ok(())
}
#[tokio::test]
async fn test_finger_when_disconnect() -> Result<()> {
let key1 = SecretKey::random();
let key2 = SecretKey::random();
let node1 = prepare_node(key1).await;
let node2 = prepare_node(key2).await;
{
assert!(node1.dht().lock_finger()?.is_empty());
assert!(node1.dht().lock_finger()?.is_empty());
}
manually_establish_connection(&node1.swarm, &node2.swarm).await;
wait_for_msgs([&node1, &node2]).await;
assert_no_more_msg([&node1, &node2]).await;
node1.assert_transports(vec![node2.did()]);
node2.assert_transports(vec![node1.did()]);
{
let finger1 = node1.dht().lock_finger()?.clone().clone_finger();
let finger2 = node2.dht().lock_finger()?.clone().clone_finger();
assert!(finger1.into_iter().any(|x| x == Some(node2.did())));
assert!(finger2.into_iter().any(|x| x == Some(node1.did())));
}
println!("===================================");
println!("| test disconnect node1 and node2 |");
println!("===================================");
node1.swarm.disconnect(node2.did()).await?;
for _ in 1..10 {
println!("wait 3 seconds for node2's transport 2to1 closing");
sleep(Duration::from_secs(3)).await;
if let Some(t) = node2.swarm.transport.get_connection(node1.did()) {
if matches!(
t.webrtc_connection_state(),
WebrtcConnectionState::Disconnected | WebrtcConnectionState::Closed
) {
println!("transport 2to1 is disconnected!!!!");
break;
}
} else {
println!("transport 2to1 is disappeared!!!!");
break;
}
}
assert_no_more_msg([&node1, &node2]).await;
node1.assert_transports(vec![]);
node2.assert_transports(vec![]);
{
let finger1 = node1.dht().lock_finger()?.clone().clone_finger();
let finger2 = node2.dht().lock_finger()?.clone().clone_finger();
assert!(finger1.into_iter().all(|x| x.is_none()));
assert!(finger2.into_iter().all(|x| x.is_none()));
}
Ok(())
}
}