use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use ckb_logger::{debug, error, trace, warn};
use p2p::{
async_trait, bytes,
context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
multiaddr::Multiaddr,
traits::ServiceProtocol,
utils::{is_reachable, multiaddr_to_socketaddr},
SessionId,
};
use rand::seq::SliceRandom;
pub use self::{
addr::{AddrKnown, AddressManager, MisbehaveResult, Misbehavior},
protocol::{DiscoveryMessage, Node, Nodes},
state::SessionState,
};
use self::{
protocol::{decode, encode},
state::RemoteAddress,
};
use crate::{Flags, NetworkState, ProtocolId};
mod addr;
pub(crate) mod protocol;
mod state;
const ANNOUNCE_CHECK_INTERVAL: Duration = Duration::from_secs(60);
const ANNOUNCE_THRESHOLD: usize = 10;
const MAX_ADDR_TO_SEND: usize = 1000;
const MAX_ADDRS: usize = 3;
const ANNOUNCE_INTERVAL: Duration = Duration::from_secs(3600 * 24);
pub struct DiscoveryProtocol<M> {
sessions: HashMap<SessionId, SessionState>,
announce_check_interval: Option<Duration>,
addr_mgr: M,
}
impl<M: AddressManager + Send> DiscoveryProtocol<M> {
pub fn new(addr_mgr: M, announce_check_interval: Option<Duration>) -> DiscoveryProtocol<M> {
DiscoveryProtocol {
sessions: HashMap::default(),
announce_check_interval,
addr_mgr,
}
}
}
#[async_trait]
impl<M: AddressManager + Send + Sync> ServiceProtocol for DiscoveryProtocol<M> {
async fn init(&mut self, context: &mut ProtocolContext) {
debug!("protocol [discovery({})]: init", context.proto_id);
context
.set_service_notify(
context.proto_id,
self.announce_check_interval
.unwrap_or(ANNOUNCE_CHECK_INTERVAL),
0,
)
.await
.expect("set discovery notify fail")
}
async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
let session = context.session;
debug!(
"DiscoveryProtocol connected, session: {:?}, version: {}",
session, version
);
self.addr_mgr
.register(session.id, context.proto_id, version);
self.sessions
.insert(session.id, SessionState::new(context, &self.addr_mgr).await);
}
async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
let session = context.session;
self.sessions.remove(&session.id);
self.addr_mgr.unregister(session.id, context.proto_id);
debug!("DiscoveryProtocol disconnected, session {:?}", session);
}
async fn received(&mut self, context: ProtocolContextMutRef<'_>, data: bytes::Bytes) {
let session = context.session;
trace!("[received message]: length={}", data.len());
let mgr = &mut self.addr_mgr;
let mut check =
|behavior: Misbehavior| -> bool { mgr.misbehave(session, &behavior).is_disconnect() };
match decode(&data) {
Some(item) => {
match item {
DiscoveryMessage::GetNodes {
listen_port,
count,
version,
required_flags,
} => {
if let Some(state) = self.sessions.get_mut(&session.id) {
if state.received_get_nodes && check(Misbehavior::DuplicateGetNodes) {
if context.disconnect(session.id).await.is_err() {
debug!("disconnect {:?} send fail", session.id)
}
return;
}
state.received_get_nodes = true;
let mut items = self.addr_mgr.get_random(2500, required_flags);
debug!("listen port: {:?}", listen_port);
if let Some(port) = listen_port {
state.remote_addr.update_port(port);
state.addr_known.insert(state.remote_addr.to_inner());
if let RemoteAddress::Listen(ref addr) = state.remote_addr {
let flags = self.addr_mgr.node_flags(session.id);
self.addr_mgr
.add_new_addr(session.id, (addr.clone(), flags));
}
}
if version >= state::REUSE_PORT_VERSION {
state.remote_addr.change_to_listen();
}
let max = ::std::cmp::min(MAX_ADDR_TO_SEND, count as usize);
if items.len() > max {
items = items
.choose_multiple(&mut rand::thread_rng(), max)
.cloned()
.collect();
}
state.addr_known.extend(items.iter());
let items = items
.into_iter()
.map(|addr| Node {
addresses: vec![addr.0],
flags: addr.1,
})
.collect::<Vec<_>>();
let nodes = Nodes {
announce: false,
items,
};
let msg = encode(DiscoveryMessage::Nodes(nodes));
if context.send_message(msg).await.is_err() {
debug!("{:?} send discovery msg Nodes fail", session.id)
}
}
}
DiscoveryMessage::Nodes(nodes) => {
if let Some(misbehavior) = verify_nodes_message(&nodes) {
if check(misbehavior) {
if context.disconnect(session.id).await.is_err() {
debug!("disconnect {:?} send fail", session.id)
}
return;
}
}
if let Some(state) = self.sessions.get_mut(&session.id) {
if !nodes.announce && state.received_nodes {
warn!("already received Nodes(announce=false) message");
if check(Misbehavior::DuplicateFirstNodes)
&& context.disconnect(session.id).await.is_err()
{
debug!("disconnect {:?} send fail", session.id)
}
} else {
let addrs = nodes
.items
.into_iter()
.flat_map(|node| {
node.addresses.into_iter().map(move |a| (a, node.flags))
})
.collect::<Vec<_>>();
state.addr_known.extend(addrs.iter());
if !nodes.announce {
state.received_nodes = true;
}
self.addr_mgr.add_new_addrs(session.id, addrs);
}
}
}
}
}
None => {
if self
.addr_mgr
.misbehave(session, &Misbehavior::InvalidData)
.is_disconnect()
&& context.disconnect(session.id).await.is_err()
{
debug!("disconnect {:?} send fail", session.id)
}
}
}
}
async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
let now = Instant::now();
let mut announce_list = Vec::new();
for (id, state) in self.sessions.iter_mut() {
state.send_messages(context, *id).await;
if let Some(addr) = state
.check_timer(now, ANNOUNCE_INTERVAL)
.filter(|addr| self.addr_mgr.is_valid_addr(addr))
{
announce_list.push((addr.clone(), self.addr_mgr.node_flags(*id)));
}
}
if !announce_list.is_empty() {
let mut rng = rand::thread_rng();
let mut keys = self.sessions.keys().cloned().collect::<Vec<_>>();
for announce_multiaddr in announce_list {
keys.shuffle(&mut rng);
for key in keys.iter().take(3) {
if let Some(value) = self.sessions.get_mut(key) {
trace!(
">> send {:?} to: {:?}, contains: {}",
announce_multiaddr,
value.remote_addr,
value.addr_known.contains(&announce_multiaddr)
);
if value.announce_multiaddrs.len() < ANNOUNCE_THRESHOLD
&& !value.addr_known.contains(&announce_multiaddr)
{
value.announce_multiaddrs.push(announce_multiaddr.clone());
value.addr_known.insert(&announce_multiaddr);
}
}
}
}
}
}
}
fn verify_nodes_message(nodes: &Nodes) -> Option<Misbehavior> {
let mut misbehavior = None;
if nodes.announce {
if nodes.items.len() > ANNOUNCE_THRESHOLD {
warn!("Nodes items more than {}", ANNOUNCE_THRESHOLD);
misbehavior = Some(Misbehavior::TooManyItems {
announce: nodes.announce,
length: nodes.items.len(),
});
}
} else if nodes.items.len() > MAX_ADDR_TO_SEND {
warn!(
"Too many items (announce=false) length={}",
nodes.items.len()
);
misbehavior = Some(Misbehavior::TooManyItems {
announce: nodes.announce,
length: nodes.items.len(),
});
}
if misbehavior.is_none() {
for item in &nodes.items {
if item.addresses.len() > MAX_ADDRS {
misbehavior = Some(Misbehavior::TooManyAddresses(item.addresses.len()));
break;
}
}
}
misbehavior
}
pub struct DiscoveryAddressManager {
pub network_state: Arc<NetworkState>,
pub discovery_local_address: bool,
}
impl AddressManager for DiscoveryAddressManager {
fn register(&self, id: SessionId, pid: ProtocolId, version: &str) {
self.network_state.with_peer_registry_mut(|reg| {
reg.get_peer_mut(id).map(|peer| {
peer.protocols.insert(pid, version.to_owned());
})
});
}
fn unregister(&self, id: SessionId, pid: ProtocolId) {
self.network_state.with_peer_registry_mut(|reg| {
let _ = reg.get_peer_mut(id).map(|peer| {
peer.protocols.remove(&pid);
});
});
}
fn is_valid_addr(&self, addr: &Multiaddr) -> bool {
if !self.discovery_local_address {
let local_or_invalid = multiaddr_to_socketaddr(addr)
.map(|socket_addr| !is_reachable(socket_addr.ip()))
.unwrap_or(true);
!local_or_invalid
} else {
true
}
}
fn add_new_addr(&mut self, session_id: SessionId, addr: (Multiaddr, Flags)) {
self.add_new_addrs(session_id, vec![addr])
}
fn add_new_addrs(&mut self, _session_id: SessionId, addrs: Vec<(Multiaddr, Flags)>) {
if addrs.is_empty() {
return;
}
for (addr, flags) in addrs.into_iter().filter(|addr| self.is_valid_addr(&addr.0)) {
trace!("Add discovered address:{:?}", addr);
self.network_state.with_peer_store_mut(|peer_store| {
if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
debug!(
"Failed to add discoved address to peer_store {:?} {:?}",
err, addr
);
}
});
}
}
fn misbehave(&mut self, session: &SessionContext, behavior: &Misbehavior) -> MisbehaveResult {
error!(
"DiscoveryProtocol detects abnormal behavior, session: {:?}, behavior: {:?}",
session, behavior
);
MisbehaveResult::Disconnect
}
fn get_random(&mut self, n: usize, flags: Flags) -> Vec<(Multiaddr, Flags)> {
let fetch_random_addrs = self
.network_state
.with_peer_store_mut(|peer_store| peer_store.fetch_random_addrs(n, flags));
let addrs = fetch_random_addrs
.into_iter()
.filter_map(|paddr| {
if !self.is_valid_addr(&paddr.addr) {
return None;
}
let f = Flags::from_bits_truncate(paddr.flags);
Some((paddr.addr, f))
})
.collect();
trace!("discovery send random addrs: {:?}", addrs);
addrs
}
fn required_flags(&self) -> super::identify::Flags {
self.network_state.required_flags
}
fn node_flags(&self, id: SessionId) -> Flags {
self.network_state.with_peer_registry(|reg| {
if let Some(peer) = reg.get_peer(id) {
peer.identify_info
.as_ref()
.map(|a| a.flags)
.unwrap_or(Flags::COMPATIBILITY)
} else {
Flags::COMPATIBILITY
}
})
}
}