1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::{atomic::Ordering, Arc};
4
5use ckb_logger::{debug, error, trace, warn};
6use ckb_systemtime::{Duration, Instant};
7use p2p::{
8 async_trait,
9 bytes::Bytes,
10 context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
11 multiaddr::{Multiaddr, Protocol},
12 service::{SessionType, TargetProtocol},
13 traits::ServiceProtocol,
14 utils::{extract_peer_id, is_reachable, multiaddr_to_socketaddr},
15 SessionId,
16};
17
18mod protocol;
19
20use crate::{peer_store::required_flags_filter, NetworkState, PeerIdentifyInfo, SupportProtocols};
21use ckb_types::{packed, prelude::*};
22
23use protocol::IdentifyMessage;
24
25const MAX_RETURN_LISTEN_ADDRS: usize = 10;
26const BAN_ON_NOT_SAME_NET: Duration = Duration::from_secs(5 * 60);
27const CHECK_TIMEOUT_TOKEN: u64 = 100;
28const CHECK_TIMEOUT_INTERVAL: u64 = 1;
30const DEFAULT_TIMEOUT: u64 = 8;
31const MAX_ADDRS: usize = 10;
32
33#[allow(dead_code)]
35#[derive(Clone, Debug)]
36pub enum Misbehavior {
37 DuplicateReceived,
39 Timeout,
41 InvalidData,
43 TooManyAddresses(usize),
45}
46
47pub enum MisbehaveResult {
49 Continue,
51 Disconnect,
53}
54
55impl MisbehaveResult {
56 pub fn is_disconnect(&self) -> bool {
57 matches!(self, MisbehaveResult::Disconnect)
58 }
59}
60
61#[async_trait]
63pub trait Callback: Clone + Send {
64 fn register(&self, context: &ProtocolContextMutRef, version: &str);
66 fn unregister(&self, context: &ProtocolContextMutRef);
68 async fn received_identify(
70 &mut self,
71 context: &mut ProtocolContextMutRef<'_>,
72 identify: &[u8],
73 ) -> MisbehaveResult;
74 fn identify(&mut self) -> &[u8];
76 fn local_listen_addrs(&mut self) -> Vec<Multiaddr>;
78 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>);
80 fn add_observed_addr(&mut self, addr: Multiaddr, ty: SessionType) -> MisbehaveResult;
82 fn misbehave(&mut self, session: &SessionContext, kind: Misbehavior) -> MisbehaveResult;
84}
85
86pub struct IdentifyProtocol<T> {
88 callback: T,
89 remote_infos: HashMap<SessionId, RemoteInfo>,
90 global_ip_only: bool,
91}
92
93impl<T: Callback> IdentifyProtocol<T> {
94 pub fn new(callback: T) -> IdentifyProtocol<T> {
95 IdentifyProtocol {
96 callback,
97 remote_infos: HashMap::default(),
98 global_ip_only: true,
99 }
100 }
101
102 #[cfg(test)]
103 pub fn global_ip_only(mut self, only: bool) -> Self {
104 self.global_ip_only = only;
105 self
106 }
107
108 fn check_duplicate(&mut self, context: &mut ProtocolContextMutRef) -> MisbehaveResult {
109 let session = context.session;
110 let info = self
111 .remote_infos
112 .get_mut(&session.id)
113 .expect("RemoteInfo must exists");
114
115 if info.has_received {
116 self.callback
117 .misbehave(&info.session, Misbehavior::DuplicateReceived)
118 } else {
119 info.has_received = true;
120 MisbehaveResult::Continue
121 }
122 }
123
124 fn process_listens(
125 &mut self,
126 context: &mut ProtocolContextMutRef,
127 listens: Vec<Multiaddr>,
128 ) -> MisbehaveResult {
129 let session = context.session;
130 let info = self
131 .remote_infos
132 .get_mut(&session.id)
133 .expect("RemoteInfo must exists");
134
135 if listens.len() > MAX_ADDRS {
136 self.callback
137 .misbehave(&info.session, Misbehavior::TooManyAddresses(listens.len()))
138 } else {
139 let global_ip_only = self.global_ip_only;
140 let reachable_addrs = listens
141 .into_iter()
142 .filter(|addr| match multiaddr_to_socketaddr(addr) {
143 Some(socket_addr) => !global_ip_only || is_reachable(socket_addr.ip()),
144 None => true,
145 })
146 .collect::<Vec<_>>();
147 self.callback
148 .add_remote_listen_addrs(session, reachable_addrs);
149 MisbehaveResult::Continue
150 }
151 }
152
153 fn process_observed(
154 &mut self,
155 context: &mut ProtocolContextMutRef,
156 observed: Multiaddr,
157 ) -> MisbehaveResult {
158 debug!(
159 "IdentifyProtocol process observed address, session: {:?}, observed: {}",
160 context.session, observed,
161 );
162
163 let session = context.session;
164 let info = self
165 .remote_infos
166 .get_mut(&session.id)
167 .expect("RemoteInfo must exists");
168 let global_ip_only = self.global_ip_only;
169 if multiaddr_to_socketaddr(&observed)
170 .map(|socket_addr| socket_addr.ip())
171 .filter(|ip_addr| !global_ip_only || is_reachable(*ip_addr))
172 .is_none()
173 {
174 return MisbehaveResult::Continue;
175 }
176
177 self.callback.add_observed_addr(observed, info.session.ty)
178 }
179}
180
181pub(crate) struct RemoteInfo {
182 session: SessionContext,
183 connected_at: Instant,
184 timeout: Duration,
185 has_received: bool,
186}
187
188impl RemoteInfo {
189 fn new(session: SessionContext, timeout: Duration) -> RemoteInfo {
190 RemoteInfo {
191 session,
192 connected_at: Instant::now(),
193 timeout,
194 has_received: false,
195 }
196 }
197}
198
199#[async_trait]
200impl<T: Callback> ServiceProtocol for IdentifyProtocol<T> {
201 async fn init(&mut self, context: &mut ProtocolContext) {
202 let proto_id = context.proto_id;
203 if let Err(err) = context
204 .set_service_notify(
205 proto_id,
206 Duration::from_secs(CHECK_TIMEOUT_INTERVAL),
207 CHECK_TIMEOUT_TOKEN,
208 )
209 .await
210 {
211 error!("IdentifyProtocol init error: {:?}", err)
212 }
213 }
214
215 async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
216 let session = context.session;
217 debug!("IdentifyProtocol connected, session: {:?}", session);
218
219 self.callback.register(&context, version);
220
221 let remote_info = RemoteInfo::new(session.clone(), Duration::from_secs(DEFAULT_TIMEOUT));
222 self.remote_infos.insert(session.id, remote_info);
223
224 let listen_addrs: Vec<Multiaddr> = self
225 .callback
226 .local_listen_addrs()
227 .iter()
228 .filter(|addr| {
229 multiaddr_to_socketaddr(addr)
230 .map(|socket_addr| !self.global_ip_only || is_reachable(socket_addr.ip()))
231 .unwrap_or(false)
232 })
233 .take(MAX_ADDRS)
234 .cloned()
235 .collect();
236
237 let identify = self.callback.identify();
238 let data = IdentifyMessage::new(listen_addrs, session.address.clone(), identify).encode();
239 let _ = context
240 .quick_send_message(data)
241 .await
242 .map_err(|err| error!("IdentifyProtocol quick_send_message, error: {:?}", err));
243 }
244
245 async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
246 self.remote_infos
247 .remove(&context.session.id)
248 .expect("RemoteInfo must exists");
249 debug!(
250 "IdentifyProtocol disconnected, session: {:?}",
251 context.session
252 );
253 self.callback.unregister(&context);
254 }
255
256 async fn received(&mut self, mut context: ProtocolContextMutRef<'_>, data: Bytes) {
257 let session = context.session;
258 match IdentifyMessage::decode(&data) {
259 Some(message) => {
260 trace!(
261 "IdentifyProtocol received, session: {:?}, listen_addrs: {:?}, observed_addr: {}",
262 context.session, message.listen_addrs, message.observed_addr
263 );
264
265 if let MisbehaveResult::Disconnect = self.check_duplicate(&mut context) {
267 error!(
268 "Disconnect IdentifyProtocol session {:?} due to duplication.",
269 session
270 );
271 let _ = context.disconnect(session.id).await;
272 return;
273 }
274 if let MisbehaveResult::Disconnect = self
275 .callback
276 .received_identify(&mut context, message.identify)
277 .await
278 {
279 error!(
280 "Disconnect IdentifyProtocol session {:?} due to invalid identify message.",
281 session,
282 );
283 let _ = context.disconnect(session.id).await;
284 return;
285 }
286 if let MisbehaveResult::Disconnect =
287 self.process_listens(&mut context, message.listen_addrs.clone())
288 {
289 error!(
290 "Disconnect IdentifyProtocol session {:?} due to invalid listen addrs: {:?}.",
291 session, message.listen_addrs,
292 );
293 let _ = context.disconnect(session.id).await;
294 return;
295 }
296 if let MisbehaveResult::Disconnect =
297 self.process_observed(&mut context, message.observed_addr.clone())
298 {
299 error!(
300 "Disconnect IdentifyProtocol session {:?} due to invalid observed addr: {}.",
301 session, message.observed_addr,
302 );
303 let _ = context.disconnect(session.id).await;
304 }
305 }
306 None => {
307 let info = self
308 .remote_infos
309 .get(&session.id)
310 .expect("RemoteInfo must exists");
311 if self
312 .callback
313 .misbehave(&info.session, Misbehavior::InvalidData)
314 .is_disconnect()
315 {
316 let _ = context.disconnect(session.id).await;
317 }
318 }
319 }
320 }
321
322 async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
323 for (session_id, info) in &self.remote_infos {
324 if !info.has_received && (info.connected_at + info.timeout) <= Instant::now() {
325 let misbehave_result = self.callback.misbehave(&info.session, Misbehavior::Timeout);
326 if misbehave_result.is_disconnect() {
327 let _ = context.disconnect(*session_id).await;
328 }
329 }
330 }
331 }
332}
333
334#[derive(Clone)]
335pub struct IdentifyCallback {
336 network_state: Arc<NetworkState>,
337 identify: Identify,
338}
339
340impl IdentifyCallback {
341 pub(crate) fn new(
342 network_state: Arc<NetworkState>,
343 name: String,
344 client_version: String,
345 flags: Flags,
346 ) -> IdentifyCallback {
347 IdentifyCallback {
348 network_state,
349 identify: Identify::new(name, flags, client_version),
350 }
351 }
352
353 fn listen_addrs(&self) -> Vec<Multiaddr> {
354 let addrs = self.network_state.public_addrs(MAX_RETURN_LISTEN_ADDRS * 2);
355 addrs
356 .into_iter()
357 .take(MAX_RETURN_LISTEN_ADDRS)
358 .collect::<Vec<_>>()
359 }
360}
361
362#[async_trait]
363impl Callback for IdentifyCallback {
364 fn register(&self, context: &ProtocolContextMutRef, version: &str) {
365 self.network_state.with_peer_registry_mut(|reg| {
366 reg.get_peer_mut(context.session.id).map(|peer| {
367 peer.protocols.insert(context.proto_id, version.to_owned());
368 })
369 });
370 }
371
372 fn unregister(&self, context: &ProtocolContextMutRef) {
373 let protocol_version_match = self
374 .network_state
375 .with_peer_registry(|reg| {
376 reg.get_peer(context.session.id)
377 .map(|p| p.protocol_version(context.proto_id))
378 })
379 .flatten()
380 .map(|version| version != "3")
381 .unwrap_or_default();
382
383 if self.network_state.ckb2023.load(Ordering::SeqCst) && protocol_version_match {
384 } else if context.session.ty.is_outbound() {
385 self.network_state.with_peer_store_mut(|peer_store| {
390 peer_store.update_outbound_addr_last_connected_ms(context.session.address.clone());
391 });
392 }
393 }
394
395 fn identify(&mut self) -> &[u8] {
396 self.identify.encode()
397 }
398
399 async fn received_identify(
400 &mut self,
401 context: &mut ProtocolContextMutRef<'_>,
402 identify: &[u8],
403 ) -> MisbehaveResult {
404 match self.identify.verify(identify) {
405 None => {
406 self.network_state.ban_session(
407 &context.control().clone().into(),
408 context.session.id,
409 BAN_ON_NOT_SAME_NET,
410 "The nodes are not on the same network".to_string(),
411 );
412 MisbehaveResult::Disconnect
413 }
414 Some((flags, client_version)) => {
415 let registry_client_version = |version: String| {
416 self.network_state.with_peer_registry_mut(|registry| {
417 if let Some(peer) = registry.get_peer_mut(context.session.id) {
418 peer.identify_info = Some(PeerIdentifyInfo {
419 client_version: version,
420 flags,
421 })
422 }
423 });
424 };
425
426 registry_client_version(client_version);
427
428 let required_flags = self.network_state.required_flags;
429
430 let protocol_version_match = self
431 .network_state
432 .with_peer_registry(|reg| {
433 reg.get_peer(context.session.id)
434 .map(|p| p.protocol_version(context.proto_id))
435 })
436 .flatten()
437 .map(|version| version != "3")
438 .unwrap_or_default();
439 let ckb2023 = self.network_state.ckb2023.load(Ordering::SeqCst);
440
441 let renew = if ckb2023 && protocol_version_match {
442 if context.session.ty.is_outbound() {
443 self.network_state
444 .peer_store
445 .lock()
446 .mut_addr_manager()
447 .remove(&context.session.address);
448 }
449 false
450 } else {
451 true
452 };
453
454 if context.session.ty.is_outbound() {
455 if renew {
460 self.network_state.with_peer_store_mut(|peer_store| {
461 peer_store.add_outbound_addr(context.session.address.clone(), flags);
462 });
463 }
464
465 if self.network_state.with_peer_registry_mut(|reg| {
466 reg.change_feeler_flags(&context.session.address, flags)
467 }) {
468 let _ = context
469 .open_protocols(
470 context.session.id,
471 TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
472 )
473 .await;
474 } else if required_flags_filter(required_flags, flags) {
475 let _ = context
477 .open_protocols(
478 context.session.id,
479 TargetProtocol::Filter(Box::new(move |id| {
480 if ckb2023 {
481 id != &SupportProtocols::Feeler.protocol_id()
482 && id != &SupportProtocols::RelayV2.protocol_id()
483 } else {
484 id != &SupportProtocols::Feeler.protocol_id()
485 }
486 })),
487 )
488 .await;
489 } else {
490 warn!("Session closed from IdentifyProtocol due to peer's flag not meeting the requirements");
492 return MisbehaveResult::Disconnect;
493 }
494 }
495 MisbehaveResult::Continue
496 }
497 }
498 }
499
500 fn local_listen_addrs(&mut self) -> Vec<Multiaddr> {
502 self.listen_addrs()
503 }
504
505 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>) {
506 trace!(
507 "IdentifyProtocol add remote listening addresses, session: {:?}, addresses : {:?}",
508 session,
509 addrs,
510 );
511 let flags = self.network_state.with_peer_registry_mut(|reg| {
512 if let Some(peer) = reg.get_peer_mut(session.id) {
513 peer.listened_addrs = addrs.clone();
514 peer.identify_info
515 .as_ref()
516 .map(|a| a.flags)
517 .unwrap_or(Flags::COMPATIBILITY)
518 } else {
519 Flags::COMPATIBILITY
520 }
521 });
522 self.network_state.with_peer_store_mut(|peer_store| {
523 for addr in addrs {
524 if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
525 error!("IdentifyProtocol failed to add address to peer store, address: {}, error: {:?}", addr, err);
526 }
527 }
528 })
529 }
530
531 fn add_observed_addr(&mut self, mut addr: Multiaddr, ty: SessionType) -> MisbehaveResult {
532 if ty.is_inbound() {
533 return MisbehaveResult::Continue;
535 }
536
537 if !multiaddr_to_socketaddr(&addr)
539 .map(|socket_addr| is_reachable(socket_addr.ip()))
540 .unwrap_or(false)
541 {
542 return MisbehaveResult::Continue;
543 }
544
545 if extract_peer_id(&addr).is_none() {
546 addr.push(Protocol::P2P(Cow::Borrowed(
547 self.network_state.local_peer_id().as_bytes(),
548 )))
549 }
550
551 let source_addr = addr.clone();
552 let observed_addrs_iter = self
553 .listen_addrs()
554 .into_iter()
555 .filter_map(|listen_addr| multiaddr_to_socketaddr(&listen_addr))
556 .map(|socket_addr| {
557 addr.iter()
558 .map(|proto| match proto {
559 Protocol::Tcp(_) => Protocol::Tcp(socket_addr.port()),
560 value => value,
561 })
562 .collect::<Multiaddr>()
563 })
564 .chain(::std::iter::once(source_addr));
565
566 self.network_state.add_observed_addrs(observed_addrs_iter);
567 MisbehaveResult::Continue
569 }
570
571 fn misbehave(&mut self, session: &SessionContext, reason: Misbehavior) -> MisbehaveResult {
572 error!(
573 "IdentifyProtocol detects abnormal behavior, session: {:?}, reason: {:?}",
574 session, reason
575 );
576 MisbehaveResult::Disconnect
577 }
578}
579
580#[derive(Clone)]
581struct Identify {
582 name: String,
583 encode_data: ckb_types::bytes::Bytes,
584}
585
586impl Identify {
587 fn new(name: String, flags: Flags, client_version: String) -> Self {
588 Identify {
589 encode_data: packed::Identify::new_builder()
590 .name(name.as_str().pack())
591 .flag(flags.bits().pack())
592 .client_version(client_version.as_str().pack())
593 .build()
594 .as_bytes(),
595 name,
596 }
597 }
598
599 fn encode(&mut self) -> &[u8] {
600 &self.encode_data
601 }
602
603 fn verify(&self, data: &[u8]) -> Option<(Flags, String)> {
604 let reader = packed::IdentifyReader::from_slice(data).ok()?;
605
606 let name = reader.name().as_utf8().ok()?.to_owned();
607 if self.name != name {
608 warn!(
609 "IdentifyProtocol detects peer has different network identifiers, local network id: {}, remote network id: {}",
610 self.name, name,
611 );
612 return None;
613 }
614
615 let flag: u64 = reader.flag().unpack();
616 if flag == 0 {
617 return None;
618 }
619
620 let raw_client_version = reader.client_version().as_utf8().ok()?.to_owned();
621
622 Some((Flags::from_bits_truncate(flag), raw_client_version))
623 }
624}
625
626bitflags::bitflags! {
627 pub struct Flags: u64 {
629 const COMPATIBILITY = 0b1;
631 const DISCOVERY = 0b10;
633 const SYNC = 0b100;
635 const RELAY = 0b1000;
637 const LIGHT_CLIENT = 0b10000;
639 const BLOCK_FILTER = 0b100000;
641 }
642}