use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{atomic::Ordering, Arc};
use std::time::{Duration, Instant};
use ckb_logger::{debug, error, trace, warn};
use p2p::{
async_trait,
bytes::Bytes,
context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
multiaddr::{Multiaddr, Protocol},
service::{SessionType, TargetProtocol},
traits::ServiceProtocol,
utils::{extract_peer_id, is_reachable, multiaddr_to_socketaddr},
SessionId,
};
mod protocol;
use crate::{peer_store::required_flags_filter, NetworkState, PeerIdentifyInfo, SupportProtocols};
use ckb_types::{packed, prelude::*};
use protocol::IdentifyMessage;
const MAX_RETURN_LISTEN_ADDRS: usize = 10;
const BAN_ON_NOT_SAME_NET: Duration = Duration::from_secs(5 * 60);
const CHECK_TIMEOUT_TOKEN: u64 = 100;
const CHECK_TIMEOUT_INTERVAL: u64 = 1;
const DEFAULT_TIMEOUT: u64 = 8;
const MAX_ADDRS: usize = 10;
#[derive(Clone, Debug)]
pub enum Misbehavior {
DuplicateReceived,
Timeout,
InvalidData,
TooManyAddresses(usize),
}
pub enum MisbehaveResult {
Continue,
Disconnect,
}
impl MisbehaveResult {
pub fn is_disconnect(&self) -> bool {
matches!(self, MisbehaveResult::Disconnect)
}
}
#[async_trait]
pub trait Callback: Clone + Send {
fn register(&self, context: &ProtocolContextMutRef, version: &str);
fn unregister(&self, context: &ProtocolContextMutRef);
async fn received_identify(
&mut self,
context: &mut ProtocolContextMutRef<'_>,
identify: &[u8],
) -> MisbehaveResult;
fn identify(&mut self) -> &[u8];
fn local_listen_addrs(&mut self) -> Vec<Multiaddr>;
fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>);
fn add_observed_addr(&mut self, addr: Multiaddr, ty: SessionType) -> MisbehaveResult;
fn misbehave(&mut self, session: &SessionContext, kind: Misbehavior) -> MisbehaveResult;
}
pub struct IdentifyProtocol<T> {
callback: T,
remote_infos: HashMap<SessionId, RemoteInfo>,
global_ip_only: bool,
}
impl<T: Callback> IdentifyProtocol<T> {
pub fn new(callback: T) -> IdentifyProtocol<T> {
IdentifyProtocol {
callback,
remote_infos: HashMap::default(),
global_ip_only: true,
}
}
#[cfg(test)]
pub fn global_ip_only(mut self, only: bool) -> Self {
self.global_ip_only = only;
self
}
fn check_duplicate(&mut self, context: &mut ProtocolContextMutRef) -> MisbehaveResult {
let session = context.session;
let info = self
.remote_infos
.get_mut(&session.id)
.expect("RemoteInfo must exists");
if info.has_received {
self.callback
.misbehave(&info.session, Misbehavior::DuplicateReceived)
} else {
info.has_received = true;
MisbehaveResult::Continue
}
}
fn process_listens(
&mut self,
context: &mut ProtocolContextMutRef,
listens: Vec<Multiaddr>,
) -> MisbehaveResult {
let session = context.session;
let info = self
.remote_infos
.get_mut(&session.id)
.expect("RemoteInfo must exists");
if listens.len() > MAX_ADDRS {
self.callback
.misbehave(&info.session, Misbehavior::TooManyAddresses(listens.len()))
} else {
let global_ip_only = self.global_ip_only;
let reachable_addrs = listens
.into_iter()
.filter(|addr| {
multiaddr_to_socketaddr(addr)
.map(|socket_addr| !global_ip_only || is_reachable(socket_addr.ip()))
.unwrap_or(false)
})
.collect::<Vec<_>>();
self.callback
.add_remote_listen_addrs(session, reachable_addrs);
MisbehaveResult::Continue
}
}
fn process_observed(
&mut self,
context: &mut ProtocolContextMutRef,
observed: Multiaddr,
) -> MisbehaveResult {
debug!(
"IdentifyProtocol process observed address, session: {:?}, observed: {}",
context.session, observed,
);
let session = context.session;
let info = self
.remote_infos
.get_mut(&session.id)
.expect("RemoteInfo must exists");
let global_ip_only = self.global_ip_only;
if multiaddr_to_socketaddr(&observed)
.map(|socket_addr| socket_addr.ip())
.filter(|ip_addr| !global_ip_only || is_reachable(*ip_addr))
.is_none()
{
return MisbehaveResult::Continue;
}
self.callback.add_observed_addr(observed, info.session.ty)
}
}
pub(crate) struct RemoteInfo {
session: SessionContext,
connected_at: Instant,
timeout: Duration,
has_received: bool,
}
impl RemoteInfo {
fn new(session: SessionContext, timeout: Duration) -> RemoteInfo {
RemoteInfo {
session,
connected_at: Instant::now(),
timeout,
has_received: false,
}
}
}
#[async_trait]
impl<T: Callback> ServiceProtocol for IdentifyProtocol<T> {
async fn init(&mut self, context: &mut ProtocolContext) {
let proto_id = context.proto_id;
if let Err(err) = context
.set_service_notify(
proto_id,
Duration::from_secs(CHECK_TIMEOUT_INTERVAL),
CHECK_TIMEOUT_TOKEN,
)
.await
{
error!("IdentifyProtocol init error: {:?}", err)
}
}
async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
let session = context.session;
debug!("IdentifyProtocol connected, session: {:?}", session);
self.callback.register(&context, version);
let remote_info = RemoteInfo::new(session.clone(), Duration::from_secs(DEFAULT_TIMEOUT));
self.remote_infos.insert(session.id, remote_info);
let listen_addrs: Vec<Multiaddr> = self
.callback
.local_listen_addrs()
.iter()
.filter(|addr| {
multiaddr_to_socketaddr(addr)
.map(|socket_addr| !self.global_ip_only || is_reachable(socket_addr.ip()))
.unwrap_or(false)
})
.take(MAX_ADDRS)
.cloned()
.collect();
let identify = self.callback.identify();
let data = IdentifyMessage::new(listen_addrs, session.address.clone(), identify).encode();
let _ = context
.quick_send_message(data)
.await
.map_err(|err| error!("IdentifyProtocol quick_send_message, error: {:?}", err));
}
async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
self.remote_infos
.remove(&context.session.id)
.expect("RemoteInfo must exists");
debug!(
"IdentifyProtocol disconnected, session: {:?}",
context.session
);
self.callback.unregister(&context);
}
async fn received(&mut self, mut context: ProtocolContextMutRef<'_>, data: Bytes) {
let session = context.session;
match IdentifyMessage::decode(&data) {
Some(message) => {
trace!(
"IdentifyProtocol received, session: {:?}, listen_addrs: {:?}, observed_addr: {}",
context.session, message.listen_addrs, message.observed_addr
);
if let MisbehaveResult::Disconnect = self.check_duplicate(&mut context) {
error!(
"Disconnect IdentifyProtocol session {:?} due to duplication.",
session
);
let _ = context.disconnect(session.id).await;
return;
}
if let MisbehaveResult::Disconnect = self
.callback
.received_identify(&mut context, message.identify)
.await
{
error!(
"Disconnect IdentifyProtocol session {:?} due to invalid identify message.",
session,
);
let _ = context.disconnect(session.id).await;
return;
}
if let MisbehaveResult::Disconnect =
self.process_listens(&mut context, message.listen_addrs.clone())
{
error!(
"Disconnect IdentifyProtocol session {:?} due to invalid listen addrs: {:?}.",
session, message.listen_addrs,
);
let _ = context.disconnect(session.id).await;
return;
}
if let MisbehaveResult::Disconnect =
self.process_observed(&mut context, message.observed_addr.clone())
{
error!(
"Disconnect IdentifyProtocol session {:?} due to invalid observed addr: {}.",
session, message.observed_addr,
);
let _ = context.disconnect(session.id).await;
}
}
None => {
let info = self
.remote_infos
.get(&session.id)
.expect("RemoteInfo must exists");
if self
.callback
.misbehave(&info.session, Misbehavior::InvalidData)
.is_disconnect()
{
let _ = context.disconnect(session.id).await;
}
}
}
}
async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
for (session_id, info) in &self.remote_infos {
if !info.has_received && (info.connected_at + info.timeout) <= Instant::now() {
let misbehave_result = self.callback.misbehave(&info.session, Misbehavior::Timeout);
if misbehave_result.is_disconnect() {
let _ = context.disconnect(*session_id).await;
}
}
}
}
}
#[derive(Clone)]
pub struct IdentifyCallback {
network_state: Arc<NetworkState>,
identify: Identify,
}
impl IdentifyCallback {
pub(crate) fn new(
network_state: Arc<NetworkState>,
name: String,
client_version: String,
flags: Flags,
) -> IdentifyCallback {
IdentifyCallback {
network_state,
identify: Identify::new(name, flags, client_version),
}
}
fn listen_addrs(&self) -> Vec<Multiaddr> {
let addrs = self.network_state.public_addrs(MAX_RETURN_LISTEN_ADDRS * 2);
addrs
.into_iter()
.take(MAX_RETURN_LISTEN_ADDRS)
.collect::<Vec<_>>()
}
}
#[async_trait]
impl Callback for IdentifyCallback {
fn register(&self, context: &ProtocolContextMutRef, version: &str) {
self.network_state.with_peer_registry_mut(|reg| {
reg.get_peer_mut(context.session.id).map(|peer| {
peer.protocols.insert(context.proto_id, version.to_owned());
})
});
}
fn unregister(&self, context: &ProtocolContextMutRef) {
let protocol_version_match = self
.network_state
.with_peer_registry(|reg| {
reg.get_peer(context.session.id)
.map(|p| p.protocol_version(context.proto_id))
})
.flatten()
.map(|version| version != "3")
.unwrap_or_default();
if self.network_state.ckb2023.load(Ordering::SeqCst) && protocol_version_match {
} else if context.session.ty.is_outbound() {
self.network_state.with_peer_store_mut(|peer_store| {
peer_store.update_outbound_addr_last_connected_ms(context.session.address.clone());
});
}
}
fn identify(&mut self) -> &[u8] {
self.identify.encode()
}
async fn received_identify(
&mut self,
context: &mut ProtocolContextMutRef<'_>,
identify: &[u8],
) -> MisbehaveResult {
match self.identify.verify(identify) {
None => {
self.network_state.ban_session(
&context.control().clone().into(),
context.session.id,
BAN_ON_NOT_SAME_NET,
"The nodes are not on the same network".to_string(),
);
MisbehaveResult::Disconnect
}
Some((flags, client_version)) => {
let registry_client_version = |version: String| {
self.network_state.with_peer_registry_mut(|registry| {
if let Some(peer) = registry.get_peer_mut(context.session.id) {
peer.identify_info = Some(PeerIdentifyInfo {
client_version: version,
flags,
})
}
});
};
registry_client_version(client_version);
let required_flags = self.network_state.required_flags;
let protocol_version_match = self
.network_state
.with_peer_registry(|reg| {
reg.get_peer(context.session.id)
.map(|p| p.protocol_version(context.proto_id))
})
.flatten()
.map(|version| version != "3")
.unwrap_or_default();
let ckb2023 = self.network_state.ckb2023.load(Ordering::SeqCst);
let renew = if ckb2023 && protocol_version_match {
if context.session.ty.is_outbound() {
self.network_state
.peer_store
.lock()
.mut_addr_manager()
.remove(&context.session.address);
}
false
} else {
true
};
if context.session.ty.is_outbound() {
if renew {
self.network_state.with_peer_store_mut(|peer_store| {
peer_store.add_outbound_addr(context.session.address.clone(), flags);
});
}
if self
.network_state
.with_peer_registry(|reg| reg.is_feeler(&context.session.address))
{
let _ = context
.open_protocols(
context.session.id,
TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
)
.await;
} else if required_flags_filter(required_flags, flags) {
let _ = context
.open_protocols(
context.session.id,
TargetProtocol::Filter(Box::new(move |id| {
if ckb2023 {
id != &SupportProtocols::Feeler.protocol_id()
&& id != &SupportProtocols::RelayV2.protocol_id()
} else {
id != &SupportProtocols::Feeler.protocol_id()
}
})),
)
.await;
} else {
warn!("Session closed from IdentifyProtocol due to peer's flag not meeting the requirements");
return MisbehaveResult::Disconnect;
}
}
MisbehaveResult::Continue
}
}
}
fn local_listen_addrs(&mut self) -> Vec<Multiaddr> {
self.listen_addrs()
}
fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>) {
trace!(
"IdentifyProtocol add remote listening addresses, session: {:?}, addresses : {:?}",
session,
addrs,
);
let flags = self.network_state.with_peer_registry_mut(|reg| {
if let Some(peer) = reg.get_peer_mut(session.id) {
peer.listened_addrs = addrs.clone();
peer.identify_info
.as_ref()
.map(|a| a.flags)
.unwrap_or(Flags::COMPATIBILITY)
} else {
Flags::COMPATIBILITY
}
});
self.network_state.with_peer_store_mut(|peer_store| {
for addr in addrs {
if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
error!("IdentifyProtocol failed to add address to peer store, address: {}, error: {:?}", addr, err);
}
}
})
}
fn add_observed_addr(&mut self, mut addr: Multiaddr, ty: SessionType) -> MisbehaveResult {
if ty.is_inbound() {
return MisbehaveResult::Continue;
}
if !multiaddr_to_socketaddr(&addr)
.map(|socket_addr| is_reachable(socket_addr.ip()))
.unwrap_or(false)
{
return MisbehaveResult::Continue;
}
if extract_peer_id(&addr).is_none() {
addr.push(Protocol::P2P(Cow::Borrowed(
self.network_state.local_peer_id().as_bytes(),
)))
}
let source_addr = addr.clone();
let observed_addrs_iter = self
.listen_addrs()
.into_iter()
.filter_map(|listen_addr| multiaddr_to_socketaddr(&listen_addr))
.map(|socket_addr| {
addr.iter()
.map(|proto| match proto {
Protocol::Tcp(_) => Protocol::Tcp(socket_addr.port()),
value => value,
})
.collect::<Multiaddr>()
})
.chain(::std::iter::once(source_addr));
self.network_state.add_observed_addrs(observed_addrs_iter);
MisbehaveResult::Continue
}
fn misbehave(&mut self, session: &SessionContext, reason: Misbehavior) -> MisbehaveResult {
error!(
"IdentifyProtocol detects abnormal behavior, session: {:?}, reason: {:?}",
session, reason
);
MisbehaveResult::Disconnect
}
}
#[derive(Clone)]
struct Identify {
name: String,
encode_data: ckb_types::bytes::Bytes,
}
impl Identify {
fn new(name: String, flags: Flags, client_version: String) -> Self {
Identify {
encode_data: packed::Identify::new_builder()
.name(name.as_str().pack())
.flag(flags.bits().pack())
.client_version(client_version.as_str().pack())
.build()
.as_bytes(),
name,
}
}
fn encode(&mut self) -> &[u8] {
&self.encode_data
}
fn verify(&self, data: &[u8]) -> Option<(Flags, String)> {
let reader = packed::IdentifyReader::from_slice(data).ok()?;
let name = reader.name().as_utf8().ok()?.to_owned();
if self.name != name {
warn!(
"IdentifyProtocol detects peer has different network identifiers, local network id: {}, remote network id: {}",
self.name, name,
);
return None;
}
let flag: u64 = reader.flag().unpack();
if flag == 0 {
return None;
}
let raw_client_version = reader.client_version().as_utf8().ok()?.to_owned();
Some((Flags::from_bits_truncate(flag), raw_client_version))
}
}
bitflags::bitflags! {
pub struct Flags: u64 {
const COMPATIBILITY = 0b1;
const DISCOVERY = 0b10;
const SYNC = 0b100;
const RELAY = 0b1000;
const LIGHT_CLIENT = 0b10000;
const BLOCK_FILTER = 0b100000;
}
}