use super::{
disconnect_message::DisconnectMessageProtocol,
discovery::{DiscoveryAddressManager, DiscoveryProtocol},
feeler::Feeler,
identify::{Flags, IdentifyCallback, IdentifyProtocol},
ping::PingHandler,
};
use crate::{
network::{DefaultExitHandler, EventHandler},
services::protocol_type_checker::ProtocolTypeCheckerService,
NetworkState, PeerIdentifyInfo, SupportProtocols,
};
use std::{
borrow::Cow,
sync::Arc,
thread,
time::{Duration, Instant},
};
use ckb_app_config::NetworkConfig;
use p2p::{
builder::ServiceBuilder,
multiaddr::{Multiaddr, Protocol},
service::{ProtocolHandle, ServiceControl, TargetProtocol},
utils::multiaddr_to_socketaddr,
ProtocolId, SessionId,
};
use tempfile::tempdir;
mod discovery;
struct Node {
listen_addr: Multiaddr,
control: ServiceControl,
network_state: Arc<NetworkState>,
}
impl Node {
fn dial(&self, node: &Node, protocol: TargetProtocol) {
self.control
.dial(node.listen_addr.clone(), protocol)
.unwrap();
}
fn dial_addr(&self, addr: Multiaddr, protocol: TargetProtocol) {
self.control.dial(addr, protocol).unwrap();
}
fn disconnect_all(&self) {
for id in self.connected_sessions() {
self.control.disconnect(id).unwrap();
}
}
fn session_num(&self) -> usize {
self.connected_sessions().len()
}
fn connected_sessions(&self) -> Vec<SessionId> {
self.network_state
.peer_registry
.read()
.peers()
.keys()
.cloned()
.collect()
}
fn connected_protocols(&self, id: SessionId) -> Vec<ProtocolId> {
self.network_state
.peer_registry
.read()
.peers()
.get(&id)
.map(|peer| peer.protocols.keys().cloned().collect())
.unwrap_or_default()
}
fn session_version(&self, id: SessionId) -> Option<PeerIdentifyInfo> {
self.network_state
.peer_registry
.read()
.peers()
.get(&id)
.map(|peer| peer.identify_info.clone())
.unwrap_or_default()
}
fn open_protocols(&self, id: SessionId, protocol: TargetProtocol) {
self.control.open_protocols(id, protocol).unwrap();
}
fn ban_all(&self) {
for id in self.connected_sessions() {
self.network_state.ban_session(
&self.control,
id,
Duration::from_secs(20),
Default::default(),
);
}
}
}
fn net_service_start(
name: String,
enable_discovery_push: bool,
required_flags: Flags,
self_flags: Flags,
) -> Node {
let config = NetworkConfig {
max_peers: 19,
max_outbound_peers: 5,
path: tempdir()
.expect("create tempdir failed")
.path()
.to_path_buf(),
ping_interval_secs: 15,
ping_timeout_secs: 20,
connect_outbound_interval_secs: 1,
discovery_local_address: true,
bootnode_mode: true,
reuse_port_on_linux: true,
public_addresses: vec![format!(
"/ip4/225.0.0.1/tcp/42/p2p/{}",
crate::PeerId::random().to_base58()
)
.parse()
.unwrap()],
..Default::default()
};
let network_state = Arc::new(
NetworkState::from_config(config.clone())
.expect("Init network state failed")
.required_flags(required_flags),
);
network_state.protocols.write().push((
SupportProtocols::Ping.protocol_id(),
SupportProtocols::Ping.name(),
SupportProtocols::Ping.support_versions(),
));
network_state.protocols.write().push((
SupportProtocols::Discovery.protocol_id(),
SupportProtocols::Discovery.name(),
SupportProtocols::Discovery.support_versions(),
));
network_state.protocols.write().push((
SupportProtocols::Identify.protocol_id(),
SupportProtocols::Identify.name(),
SupportProtocols::Identify.support_versions(),
));
network_state.protocols.write().push((
SupportProtocols::Feeler.protocol_id(),
SupportProtocols::Feeler.name(),
SupportProtocols::Feeler.support_versions(),
));
let ping_interval = Duration::from_secs(5);
let ping_timeout = Duration::from_secs(10);
let ping_network_state = Arc::clone(&network_state);
let (ping_handler, _ping_controller) =
PingHandler::new(ping_interval, ping_timeout, ping_network_state);
let ping_meta = SupportProtocols::Ping
.build_meta_with_service_handle(move || ProtocolHandle::Callback(Box::new(ping_handler)));
let addr_mgr = DiscoveryAddressManager {
network_state: Arc::clone(&network_state),
discovery_local_address: config.discovery_local_address,
};
let disc_meta = SupportProtocols::Discovery.build_meta_with_service_handle(move || {
ProtocolHandle::Callback(Box::new(DiscoveryProtocol::new(
addr_mgr,
if enable_discovery_push {
Some(Duration::from_secs(1))
} else {
None
},
)))
});
let identify_callback = IdentifyCallback::new(
Arc::clone(&network_state),
name,
"0.1.0".to_string(),
self_flags,
);
let identify_meta = SupportProtocols::Identify.build_meta_with_service_handle(move || {
ProtocolHandle::Callback(Box::new(
IdentifyProtocol::new(identify_callback).global_ip_only(false),
))
});
let disconnect_message_state = Arc::clone(&network_state);
let disconnect_message_meta = SupportProtocols::DisconnectMessage
.build_meta_with_service_handle(move || {
ProtocolHandle::Callback(Box::new(DisconnectMessageProtocol::new(
disconnect_message_state,
)))
});
let feeler_meta = SupportProtocols::Feeler.build_meta_with_service_handle({
let network_state = Arc::clone(&network_state);
move || ProtocolHandle::Callback(Box::new(Feeler::new(Arc::clone(&network_state))))
});
let service_builder = ServiceBuilder::default()
.insert_protocol(ping_meta)
.insert_protocol(disc_meta)
.insert_protocol(identify_meta)
.insert_protocol(disconnect_message_meta)
.insert_protocol(feeler_meta);
let mut p2p_service = service_builder
.key_pair(network_state.local_private_key().clone())
.upnp(config.upnp)
.forever(true)
.build(EventHandler {
network_state: Arc::clone(&network_state),
exit_handler: DefaultExitHandler::default(),
});
let peer_id = network_state.local_peer_id().clone();
let control = p2p_service.control().clone().into();
let (addr_sender, addr_receiver) = ::std::sync::mpsc::channel();
static RT: once_cell::sync::OnceCell<tokio::runtime::Runtime> =
once_cell::sync::OnceCell::new();
let rt = RT.get_or_init(|| {
let num_threads = ::std::cmp::max(num_cpus::get(), 4);
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.unwrap()
});
rt.spawn(async move {
let mut listen_addr = p2p_service
.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.await
.unwrap();
listen_addr.push(Protocol::P2P(Cow::Owned(peer_id.into_bytes())));
addr_sender.send(listen_addr).unwrap();
p2p_service.run().await
});
let listen_addr = addr_receiver.recv().unwrap();
Node {
control,
listen_addr,
network_state,
}
}
pub fn wait_until<F>(secs: u64, f: F) -> bool
where
F: Fn() -> bool,
{
let start = Instant::now();
let timeout = Duration::new(secs, 0);
while Instant::now().saturating_duration_since(start) <= timeout {
if f() {
return true;
}
thread::sleep(Duration::new(1, 0));
}
false
}
fn wait_connect_state(node: &Node, expect_num: usize) {
if !wait_until(10, || node.session_num() == expect_num) {
panic!(
"node session number is {}, not {}",
node.session_num(),
expect_num
)
}
}
#[allow(clippy::blocks_in_if_conditions)]
fn wait_discovery(node: &Node, assert: impl Fn(usize) -> bool) {
if !wait_until(100, || {
assert(
node.network_state
.peer_store
.lock()
.mut_addr_manager()
.count(),
)
}) {
panic!("discovery can't find other node")
}
}
#[test]
fn test_identify_behavior() {
let node1 = net_service_start(
"/test/1".to_string(),
false,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/2".to_string(),
false,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node3 = net_service_start(
"/test/1".to_string(),
false,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node4 = net_service_start(
"/test/1".to_string(),
false,
Flags::SYNC | Flags::RELAY | Flags::DISCOVERY | Flags::BLOCK_FILTER,
Flags::SYNC | Flags::RELAY | Flags::DISCOVERY,
);
node4.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
thread::sleep(Duration::from_secs(1));
wait_connect_state(&node4, 0);
node1.dial(
&node3,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node3, 1);
node2.dial(
&node3,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node2, 0);
wait_connect_state(&node3, 1);
let check_nodes_ban_count = |node_a: &Node, node_b: &Node| {
let node_a_ban_count = node_a
.network_state
.peer_store
.lock()
.ban_list()
.get_banned_addrs()
.len();
let node_b_ban_count = node_b
.network_state
.peer_store
.lock()
.ban_list()
.get_banned_addrs()
.len();
node_a_ban_count != 0 || node_b_ban_count != 0
};
if !wait_until(10, || check_nodes_ban_count(&node2, &node3)) {
panic!("identify can't ban not same net")
}
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node2, 0);
if !wait_until(10, || check_nodes_ban_count(&node1, &node2)) {
panic!("identify can't ban not same net")
}
let sessions = node3.connected_sessions();
if !wait_until(10, || node3.connected_protocols(sessions[0]).len() == 4) {
panic!("identify can't open other protocols")
}
assert_eq!(
node3.session_version(sessions[0]).unwrap().client_version,
"0.1.0"
);
let mut protocols = node3.connected_protocols(sessions[0]);
protocols.sort();
assert_eq!(
protocols,
vec![
SupportProtocols::Ping.protocol_id(),
SupportProtocols::Discovery.protocol_id(),
SupportProtocols::Identify.protocol_id(),
SupportProtocols::DisconnectMessage.protocol_id()
]
);
}
#[test]
fn test_feeler_behavior() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node2, 1);
node2.open_protocols(
node2.connected_sessions()[0],
TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
);
wait_connect_state(&node1, 0);
wait_connect_state(&node2, 0);
}
#[test]
fn test_discovery_behavior() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node3 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
node3.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node3, 1);
wait_connect_state(&node2, 2);
wait_discovery(&node3, |num| num >= 2);
let addrs = {
let listen_addr = &node3.listen_addr;
let mut locked = node3.network_state.peer_store.lock();
locked
.fetch_addrs_to_feeler(6)
.into_iter()
.map(|peer| peer.addr)
.flat_map(|addr| {
match (
multiaddr_to_socketaddr(&addr),
multiaddr_to_socketaddr(listen_addr),
) {
(Some(dis), Some(listen)) => {
if dis.port() != listen.port() {
Some(addr)
} else {
None
}
}
_ => None,
}
})
.collect::<Vec<_>>()
};
for addr in addrs {
node3.dial_addr(
addr,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
}
wait_connect_state(&node1, 2);
wait_connect_state(&node2, 2);
wait_connect_state(&node3, 2);
thread::sleep(Duration::from_secs(10));
let checker = ProtocolTypeCheckerService::new(
node1.network_state,
node1.control,
vec![SupportProtocols::Identify.protocol_id()],
);
checker.check_protocol_type();
let checker = ProtocolTypeCheckerService::new(
node2.network_state,
node2.control,
vec![SupportProtocols::Sync.protocol_id()],
);
checker.check_protocol_type();
let checker = ProtocolTypeCheckerService::new(
node3.network_state,
node3.control,
vec![SupportProtocols::Identify.protocol_id()],
);
checker.check_protocol_type();
}
#[test]
fn test_dial_all() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
node1.dial(&node2, TargetProtocol::All);
wait_connect_state(&node1, 0);
wait_connect_state(&node1, 0);
}
#[test]
fn test_ban() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node2, 1);
node1.ban_all();
wait_connect_state(&node1, 0);
wait_connect_state(&node2, 0);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 0);
wait_connect_state(&node2, 0);
}
#[test]
fn test_bootnode_mode_inbound_eviction() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node3 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node4 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node5 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
let node6 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::COMPATIBILITY,
);
node2.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
node3.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
node4.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 3);
node5.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 4);
node6.dial(
&node1,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 3);
}
#[test]
fn test_dont_reset_peer_flags_on_disconnect() {
let node1 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::all(),
);
let node2 = net_service_start(
"/test/1".to_string(),
true,
Flags::COMPATIBILITY,
Flags::all(),
);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node1, 1);
let check_flags = |node: &Node| {
for info in node
.network_state
.peer_store
.lock()
.addr_manager()
.addrs_iter()
{
assert_eq!(info.flags, Flags::all().bits())
}
};
check_flags(&node1);
check_flags(&node2);
node1.disconnect_all();
check_flags(&node1);
check_flags(&node2);
node1.dial(
&node2,
TargetProtocol::Single(SupportProtocols::Identify.protocol_id()),
);
wait_connect_state(&node1, 1);
wait_connect_state(&node1, 1);
check_flags(&node1);
check_flags(&node2);
}