1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::{Arc, atomic::Ordering};
4
5use ckb_logger::{debug, error, trace, warn};
6use ckb_systemtime::{Duration, Instant};
7use p2p::{
8 SessionId, async_trait,
9 bytes::Bytes,
10 context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
11 multiaddr::{Multiaddr, Protocol},
12 service::{SessionType, TargetProtocol},
13 traits::ServiceProtocol,
14 utils::{extract_peer_id, is_reachable, multiaddr_to_socketaddr},
15};
16
17mod protocol;
18
19use crate::{NetworkState, PeerIdentifyInfo, SupportProtocols, peer_store::required_flags_filter};
20use ckb_types::{packed, prelude::*};
21
22use protocol::IdentifyMessage;
23
24const MAX_RETURN_LISTEN_ADDRS: usize = 10;
25const BAN_ON_NOT_SAME_NET: Duration = Duration::from_secs(5 * 60);
26const CHECK_TIMEOUT_TOKEN: u64 = 100;
27const CHECK_TIMEOUT_INTERVAL: u64 = 1;
29const DEFAULT_TIMEOUT: u64 = 8;
30const MAX_ADDRS: usize = 10;
31
32#[allow(dead_code)]
34#[derive(Clone, Debug)]
35pub enum Misbehavior {
36 DuplicateReceived,
38 Timeout,
40 InvalidData,
42 TooManyAddresses(usize),
44}
45
46pub enum MisbehaveResult {
48 Continue,
50 Disconnect,
52}
53
54impl MisbehaveResult {
55 pub fn is_disconnect(&self) -> bool {
56 matches!(self, MisbehaveResult::Disconnect)
57 }
58}
59
60#[async_trait]
62pub trait Callback: Clone + Send {
63 fn register(&self, context: &ProtocolContextMutRef, version: &str);
65 fn unregister(&self, context: &ProtocolContextMutRef);
67 async fn received_identify(
69 &mut self,
70 context: &mut ProtocolContextMutRef<'_>,
71 identify: &[u8],
72 ) -> MisbehaveResult;
73 fn identify(&mut self) -> &[u8];
75 fn local_listen_addrs(&mut self) -> Vec<Multiaddr>;
77 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>);
79 fn add_observed_addr(&mut self, addr: Multiaddr, ty: SessionType) -> MisbehaveResult;
81 fn misbehave(&mut self, session: &SessionContext, kind: Misbehavior) -> MisbehaveResult;
83}
84
85pub struct IdentifyProtocol<T> {
87 callback: T,
88 remote_infos: HashMap<SessionId, RemoteInfo>,
89 global_ip_only: bool,
90}
91
92impl<T: Callback> IdentifyProtocol<T> {
93 pub fn new(callback: T) -> IdentifyProtocol<T> {
94 IdentifyProtocol {
95 callback,
96 remote_infos: HashMap::default(),
97 global_ip_only: true,
98 }
99 }
100
101 #[cfg(test)]
102 pub fn global_ip_only(mut self, only: bool) -> Self {
103 self.global_ip_only = only;
104 self
105 }
106
107 fn check_duplicate(&mut self, context: &mut ProtocolContextMutRef) -> MisbehaveResult {
108 let session = context.session;
109 let info = self
110 .remote_infos
111 .get_mut(&session.id)
112 .expect("RemoteInfo must exists");
113
114 if info.has_received {
115 self.callback
116 .misbehave(&info.session, Misbehavior::DuplicateReceived)
117 } else {
118 info.has_received = true;
119 MisbehaveResult::Continue
120 }
121 }
122
123 fn process_listens(
124 &mut self,
125 context: &mut ProtocolContextMutRef,
126 listens: Vec<Multiaddr>,
127 ) -> MisbehaveResult {
128 let session = context.session;
129 let info = self
130 .remote_infos
131 .get_mut(&session.id)
132 .expect("RemoteInfo must exists");
133
134 if listens.len() > MAX_ADDRS {
135 self.callback
136 .misbehave(&info.session, Misbehavior::TooManyAddresses(listens.len()))
137 } else {
138 let global_ip_only = self.global_ip_only;
139 let reachable_addrs = listens
140 .into_iter()
141 .filter(|addr| match multiaddr_to_socketaddr(addr) {
142 Some(socket_addr) => !global_ip_only || is_reachable(socket_addr.ip()),
143 None => true,
144 })
145 .collect::<Vec<_>>();
146 self.callback
147 .add_remote_listen_addrs(session, reachable_addrs);
148 MisbehaveResult::Continue
149 }
150 }
151
152 fn process_observed(
153 &mut self,
154 context: &mut ProtocolContextMutRef,
155 observed: Multiaddr,
156 ) -> MisbehaveResult {
157 debug!(
158 "IdentifyProtocol process observed address, session: {:?}, observed: {}",
159 context.session, observed,
160 );
161
162 let session = context.session;
163 let info = self
164 .remote_infos
165 .get_mut(&session.id)
166 .expect("RemoteInfo must exists");
167 let global_ip_only = self.global_ip_only;
168 if multiaddr_to_socketaddr(&observed)
169 .map(|socket_addr| socket_addr.ip())
170 .filter(|ip_addr| !global_ip_only || is_reachable(*ip_addr))
171 .is_none()
172 {
173 return MisbehaveResult::Continue;
174 }
175
176 self.callback.add_observed_addr(observed, info.session.ty)
177 }
178}
179
180pub(crate) struct RemoteInfo {
181 session: SessionContext,
182 connected_at: Instant,
183 timeout: Duration,
184 has_received: bool,
185}
186
187impl RemoteInfo {
188 fn new(session: SessionContext, timeout: Duration) -> RemoteInfo {
189 RemoteInfo {
190 session,
191 connected_at: Instant::now(),
192 timeout,
193 has_received: false,
194 }
195 }
196}
197
198#[async_trait]
199impl<T: Callback> ServiceProtocol for IdentifyProtocol<T> {
200 async fn init(&mut self, context: &mut ProtocolContext) {
201 let proto_id = context.proto_id;
202 if let Err(err) = context
203 .set_service_notify(
204 proto_id,
205 Duration::from_secs(CHECK_TIMEOUT_INTERVAL),
206 CHECK_TIMEOUT_TOKEN,
207 )
208 .await
209 {
210 error!("IdentifyProtocol init error: {:?}", err)
211 }
212 }
213
214 async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
215 let session = context.session;
216 debug!("IdentifyProtocol connected, session: {:?}", session);
217
218 self.callback.register(&context, version);
219
220 let remote_info = RemoteInfo::new(session.clone(), Duration::from_secs(DEFAULT_TIMEOUT));
221 self.remote_infos.insert(session.id, remote_info);
222
223 let listen_addrs: Vec<Multiaddr> = self
224 .callback
225 .local_listen_addrs()
226 .iter()
227 .filter(|addr| {
228 multiaddr_to_socketaddr(addr)
229 .map(|socket_addr| !self.global_ip_only || is_reachable(socket_addr.ip()))
230 .unwrap_or(false)
231 })
232 .take(MAX_ADDRS)
233 .cloned()
234 .collect();
235
236 let identify = self.callback.identify();
237 let data = IdentifyMessage::new(listen_addrs, session.address.clone(), identify).encode();
238 let _ = context
239 .quick_send_message(data)
240 .await
241 .map_err(|err| error!("IdentifyProtocol quick_send_message, error: {:?}", err));
242 }
243
244 async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
245 self.remote_infos
246 .remove(&context.session.id)
247 .expect("RemoteInfo must exists");
248 debug!(
249 "IdentifyProtocol disconnected, session: {:?}",
250 context.session
251 );
252 self.callback.unregister(&context);
253 }
254
255 async fn received(&mut self, mut context: ProtocolContextMutRef<'_>, data: Bytes) {
256 let session = context.session;
257 match IdentifyMessage::decode(&data) {
258 Some(message) => {
259 trace!(
260 "IdentifyProtocol received, session: {:?}, listen_addrs: {:?}, observed_addr: {}",
261 context.session, message.listen_addrs, message.observed_addr
262 );
263
264 if let MisbehaveResult::Disconnect = self.check_duplicate(&mut context) {
266 error!(
267 "Disconnect IdentifyProtocol session {:?} due to duplication.",
268 session
269 );
270 let _ = context.disconnect(session.id).await;
271 return;
272 }
273 if let MisbehaveResult::Disconnect = self
274 .callback
275 .received_identify(&mut context, message.identify)
276 .await
277 {
278 error!(
279 "Disconnect IdentifyProtocol session {:?} due to invalid identify message.",
280 session,
281 );
282 let _ = context.disconnect(session.id).await;
283 return;
284 }
285 if let MisbehaveResult::Disconnect =
286 self.process_listens(&mut context, message.listen_addrs.clone())
287 {
288 error!(
289 "Disconnect IdentifyProtocol session {:?} due to invalid listen addrs: {:?}.",
290 session, message.listen_addrs,
291 );
292 let _ = context.disconnect(session.id).await;
293 return;
294 }
295 if let MisbehaveResult::Disconnect =
296 self.process_observed(&mut context, message.observed_addr.clone())
297 {
298 error!(
299 "Disconnect IdentifyProtocol session {:?} due to invalid observed addr: {}.",
300 session, message.observed_addr,
301 );
302 let _ = context.disconnect(session.id).await;
303 }
304 }
305 None => {
306 let info = self
307 .remote_infos
308 .get(&session.id)
309 .expect("RemoteInfo must exists");
310 if self
311 .callback
312 .misbehave(&info.session, Misbehavior::InvalidData)
313 .is_disconnect()
314 {
315 let _ = context.disconnect(session.id).await;
316 }
317 }
318 }
319 }
320
321 async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
322 for (session_id, info) in &self.remote_infos {
323 if !info.has_received && (info.connected_at + info.timeout) <= Instant::now() {
324 let misbehave_result = self.callback.misbehave(&info.session, Misbehavior::Timeout);
325 if misbehave_result.is_disconnect() {
326 let _ = context.disconnect(*session_id).await;
327 }
328 }
329 }
330 }
331}
332
333#[derive(Clone)]
334pub struct IdentifyCallback {
335 network_state: Arc<NetworkState>,
336 identify: Identify,
337}
338
339impl IdentifyCallback {
340 pub(crate) fn new(
341 network_state: Arc<NetworkState>,
342 name: String,
343 client_version: String,
344 flags: Flags,
345 ) -> IdentifyCallback {
346 IdentifyCallback {
347 network_state,
348 identify: Identify::new(name, flags, client_version),
349 }
350 }
351
352 fn listen_addrs(&self) -> Vec<Multiaddr> {
353 let addrs = self.network_state.public_addrs(MAX_RETURN_LISTEN_ADDRS * 2);
354 addrs
355 .into_iter()
356 .take(MAX_RETURN_LISTEN_ADDRS)
357 .collect::<Vec<_>>()
358 }
359}
360
361#[async_trait]
362impl Callback for IdentifyCallback {
363 fn register(&self, context: &ProtocolContextMutRef, version: &str) {
364 self.network_state.with_peer_registry_mut(|reg| {
365 reg.get_peer_mut(context.session.id).map(|peer| {
366 peer.protocols.insert(context.proto_id, version.to_owned());
367 })
368 });
369 }
370
371 fn unregister(&self, context: &ProtocolContextMutRef) {
372 let protocol_version_match = self
373 .network_state
374 .with_peer_registry(|reg| {
375 reg.get_peer(context.session.id)
376 .map(|p| p.protocol_version(context.proto_id))
377 })
378 .flatten()
379 .map(|version| version != "3")
380 .unwrap_or_default();
381
382 if self.network_state.ckb2023.load(Ordering::SeqCst) && protocol_version_match {
383 } else if context.session.ty.is_outbound() {
384 self.network_state.with_peer_store_mut(|peer_store| {
389 peer_store.update_outbound_addr_last_connected_ms(context.session.address.clone());
390 });
391 }
392 }
393
394 fn identify(&mut self) -> &[u8] {
395 self.identify.encode()
396 }
397
398 async fn received_identify(
399 &mut self,
400 context: &mut ProtocolContextMutRef<'_>,
401 identify: &[u8],
402 ) -> MisbehaveResult {
403 match self.identify.verify(identify) {
404 None => {
405 self.network_state.ban_session(
406 &context.control().clone().into(),
407 context.session.id,
408 BAN_ON_NOT_SAME_NET,
409 "The nodes are not on the same network".to_string(),
410 );
411 MisbehaveResult::Disconnect
412 }
413 Some((flags, client_version)) => {
414 let registry_client_version = |version: String| {
415 self.network_state.with_peer_registry_mut(|registry| {
416 if let Some(peer) = registry.get_peer_mut(context.session.id) {
417 peer.identify_info = Some(PeerIdentifyInfo {
418 client_version: version,
419 flags,
420 })
421 }
422 });
423 };
424
425 registry_client_version(client_version);
426
427 let required_flags = self.network_state.required_flags;
428
429 let protocol_version_match = self
430 .network_state
431 .with_peer_registry(|reg| {
432 reg.get_peer(context.session.id)
433 .map(|p| p.protocol_version(context.proto_id))
434 })
435 .flatten()
436 .map(|version| version != "3")
437 .unwrap_or_default();
438 let ckb2023 = self.network_state.ckb2023.load(Ordering::SeqCst);
439
440 let renew = if ckb2023 && protocol_version_match {
441 if context.session.ty.is_outbound() {
442 self.network_state
443 .peer_store
444 .lock()
445 .mut_addr_manager()
446 .remove(&context.session.address);
447 }
448 false
449 } else {
450 true
451 };
452
453 if context.session.ty.is_outbound() {
454 if renew {
459 self.network_state.with_peer_store_mut(|peer_store| {
460 peer_store.add_outbound_addr(context.session.address.clone(), flags);
461 });
462 }
463
464 if self.network_state.with_peer_registry_mut(|reg| {
465 reg.change_feeler_flags(&context.session.address, flags)
466 }) {
467 let _ = context
468 .open_protocols(
469 context.session.id,
470 TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
471 )
472 .await;
473 } else if required_flags_filter(required_flags, flags) {
474 let _ = context
476 .open_protocols(
477 context.session.id,
478 TargetProtocol::Filter(Box::new(move |id| {
479 if ckb2023 {
480 id != &SupportProtocols::Feeler.protocol_id()
481 && id != &SupportProtocols::RelayV2.protocol_id()
482 } else {
483 id != &SupportProtocols::Feeler.protocol_id()
484 }
485 })),
486 )
487 .await;
488 } else {
489 warn!(
491 "Session closed from IdentifyProtocol due to peer's flag not meeting the requirements"
492 );
493 return MisbehaveResult::Disconnect;
494 }
495 }
496 MisbehaveResult::Continue
497 }
498 }
499 }
500
501 fn local_listen_addrs(&mut self) -> Vec<Multiaddr> {
503 self.listen_addrs()
504 }
505
506 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>) {
507 trace!(
508 "IdentifyProtocol add remote listening addresses, session: {:?}, addresses : {:?}",
509 session, addrs,
510 );
511 let flags = self.network_state.with_peer_registry_mut(|reg| {
512 if let Some(peer) = reg.get_peer_mut(session.id) {
513 peer.listened_addrs = addrs.clone();
514 peer.identify_info
515 .as_ref()
516 .map(|a| a.flags)
517 .unwrap_or(Flags::COMPATIBILITY)
518 } else {
519 Flags::COMPATIBILITY
520 }
521 });
522 self.network_state.with_peer_store_mut(|peer_store| {
523 for addr in addrs {
524 if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
525 error!("IdentifyProtocol failed to add address to peer store, address: {}, error: {:?}", addr, err);
526 }
527 }
528 })
529 }
530
531 fn add_observed_addr(&mut self, mut addr: Multiaddr, ty: SessionType) -> MisbehaveResult {
532 if ty.is_inbound() {
533 return MisbehaveResult::Continue;
535 }
536
537 if !multiaddr_to_socketaddr(&addr)
539 .map(|socket_addr| is_reachable(socket_addr.ip()))
540 .unwrap_or(false)
541 {
542 return MisbehaveResult::Continue;
543 }
544
545 if extract_peer_id(&addr).is_none() {
546 addr.push(Protocol::P2P(Cow::Borrowed(
547 self.network_state.local_peer_id().as_bytes(),
548 )))
549 }
550
551 let source_addr = addr.clone();
552 let observed_addrs_iter = self
553 .listen_addrs()
554 .into_iter()
555 .filter_map(|listen_addr| multiaddr_to_socketaddr(&listen_addr))
556 .map(|socket_addr| {
557 addr.iter()
558 .map(|proto| match proto {
559 Protocol::Tcp(_) => Protocol::Tcp(socket_addr.port()),
560 value => value,
561 })
562 .collect::<Multiaddr>()
563 })
564 .chain(::std::iter::once(source_addr));
565
566 self.network_state.add_observed_addrs(observed_addrs_iter);
567 MisbehaveResult::Continue
569 }
570
571 fn misbehave(&mut self, session: &SessionContext, reason: Misbehavior) -> MisbehaveResult {
572 error!(
573 "IdentifyProtocol detects abnormal behavior, session: {:?}, reason: {:?}",
574 session, reason
575 );
576 MisbehaveResult::Disconnect
577 }
578}
579
580#[derive(Clone)]
581struct Identify {
582 name: String,
583 encode_data: ckb_types::bytes::Bytes,
584}
585
586impl Identify {
587 fn new(name: String, flags: Flags, client_version: String) -> Self {
588 Identify {
589 encode_data: packed::Identify::new_builder()
590 .name(name.as_str().pack())
591 .flag(flags.bits().pack())
592 .client_version(client_version.as_str().pack())
593 .build()
594 .as_bytes(),
595 name,
596 }
597 }
598
599 fn encode(&mut self) -> &[u8] {
600 &self.encode_data
601 }
602
603 fn verify(&self, data: &[u8]) -> Option<(Flags, String)> {
604 let reader = packed::IdentifyReader::from_slice(data).ok()?;
605
606 let name = reader.name().as_utf8().ok()?.to_owned();
607 if self.name != name {
608 warn!(
609 "IdentifyProtocol detects peer has different network identifiers, local network id: {}, remote network id: {}",
610 self.name, name,
611 );
612 return None;
613 }
614
615 let flag: u64 = reader.flag().unpack();
616 if flag == 0 {
617 return None;
618 }
619
620 let raw_client_version = reader.client_version().as_utf8().ok()?.to_owned();
621
622 Some((Flags::from_bits_truncate(flag), raw_client_version))
623 }
624}
625
626bitflags::bitflags! {
627 pub struct Flags: u64 {
629 const COMPATIBILITY = 0b1;
631 const DISCOVERY = 0b10;
633 const SYNC = 0b100;
635 const RELAY = 0b1000;
637 const LIGHT_CLIENT = 0b10000;
639 const BLOCK_FILTER = 0b100000;
641 }
642}