1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use ckb_logger::{debug, error, trace, warn};
6use ckb_systemtime::{Duration, Instant};
7use p2p::{
8 SessionId, async_trait,
9 bytes::Bytes,
10 context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
11 multiaddr::{Multiaddr, Protocol},
12 service::TargetProtocol,
13 traits::ServiceProtocol,
14 utils::{extract_peer_id, is_reachable, multiaddr_to_socketaddr},
15};
16
17mod protocol;
18
19use crate::{NetworkState, PeerIdentifyInfo, SupportProtocols, peer_store::required_flags_filter};
20use ckb_types::{packed, prelude::*};
21
22use protocol::IdentifyMessage;
23
24const MAX_RETURN_LISTEN_ADDRS: usize = 10;
25const BAN_ON_NOT_SAME_NET: Duration = Duration::from_secs(5 * 60);
26const CHECK_TIMEOUT_TOKEN: u64 = 100;
27const CHECK_TIMEOUT_INTERVAL: u64 = 1;
29const DEFAULT_TIMEOUT: u64 = 8;
30const MAX_ADDRS: usize = 10;
31
32#[allow(dead_code)]
34#[derive(Clone, Debug)]
35pub enum Misbehavior {
36 DuplicateReceived,
38 Timeout,
40 InvalidData,
42 TooManyAddresses(usize),
44}
45
46pub enum MisbehaveResult {
48 Continue,
50 Disconnect,
52}
53
54impl MisbehaveResult {
55 pub fn is_disconnect(&self) -> bool {
56 matches!(self, MisbehaveResult::Disconnect)
57 }
58}
59
60#[async_trait]
62pub trait Callback: Clone + Send {
63 fn register(&self, context: &ProtocolContextMutRef, version: &str) -> bool;
65 fn unregister(&self, context: &ProtocolContextMutRef);
67 async fn received_identify(
69 &mut self,
70 context: &mut ProtocolContextMutRef<'_>,
71 identify: &[u8],
72 ) -> MisbehaveResult;
73 fn identify(&mut self) -> &[u8];
75 fn local_listen_addrs(&mut self) -> Vec<Multiaddr>;
77 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>);
79 fn add_observed_addr(&mut self, addr: Multiaddr, session_id: SessionId) -> MisbehaveResult;
81 fn misbehave(&mut self, session: &SessionContext, kind: Misbehavior) -> MisbehaveResult;
83}
84
85pub struct IdentifyProtocol<T> {
87 callback: T,
88 remote_infos: HashMap<SessionId, RemoteInfo>,
89 global_ip_only: bool,
90}
91
92impl<T: Callback> IdentifyProtocol<T> {
93 pub fn new(callback: T) -> IdentifyProtocol<T> {
94 IdentifyProtocol {
95 callback,
96 remote_infos: HashMap::default(),
97 global_ip_only: true,
98 }
99 }
100
101 #[cfg(test)]
102 pub fn global_ip_only(mut self, only: bool) -> Self {
103 self.global_ip_only = only;
104 self
105 }
106
107 fn check_duplicate(&mut self, context: &mut ProtocolContextMutRef) -> MisbehaveResult {
108 let session = context.session;
109 let info = self
110 .remote_infos
111 .get_mut(&session.id)
112 .expect("RemoteInfo must exists");
113
114 if info.has_received {
115 self.callback
116 .misbehave(&info.session, Misbehavior::DuplicateReceived)
117 } else {
118 info.has_received = true;
119 MisbehaveResult::Continue
120 }
121 }
122
123 fn process_listens(
124 &mut self,
125 context: &mut ProtocolContextMutRef,
126 listens: Vec<Multiaddr>,
127 ) -> MisbehaveResult {
128 let session = context.session;
129 let info = self
130 .remote_infos
131 .get_mut(&session.id)
132 .expect("RemoteInfo must exists");
133
134 if listens.len() > MAX_ADDRS {
135 self.callback
136 .misbehave(&info.session, Misbehavior::TooManyAddresses(listens.len()))
137 } else {
138 let global_ip_only = self.global_ip_only;
139 let reachable_addrs = listens
140 .into_iter()
141 .filter(|addr| match multiaddr_to_socketaddr(addr) {
142 Some(socket_addr) => !global_ip_only || is_reachable(socket_addr.ip()),
143 None => true,
144 })
145 .collect::<Vec<_>>();
146 self.callback
147 .add_remote_listen_addrs(session, reachable_addrs);
148 MisbehaveResult::Continue
149 }
150 }
151
152 fn process_observed(
153 &mut self,
154 context: &mut ProtocolContextMutRef,
155 observed: Multiaddr,
156 ) -> MisbehaveResult {
157 debug!(
158 "IdentifyProtocol process observed address, session: {:?}, observed: {}",
159 context.session, observed,
160 );
161
162 let session = context.session;
163 let info = self
164 .remote_infos
165 .get_mut(&session.id)
166 .expect("RemoteInfo must exists");
167 self.callback.add_observed_addr(observed, info.session.id);
168 MisbehaveResult::Continue
169 }
170}
171
172pub(crate) struct RemoteInfo {
173 session: SessionContext,
174 connected_at: Instant,
175 timeout: Duration,
176 has_received: bool,
177}
178
179impl RemoteInfo {
180 fn new(session: SessionContext, timeout: Duration) -> RemoteInfo {
181 RemoteInfo {
182 session,
183 connected_at: Instant::now(),
184 timeout,
185 has_received: false,
186 }
187 }
188}
189
190#[async_trait]
191impl<T: Callback> ServiceProtocol for IdentifyProtocol<T> {
192 async fn init(&mut self, context: &mut ProtocolContext) {
193 let proto_id = context.proto_id;
194 if let Err(err) = context
195 .set_service_notify(
196 proto_id,
197 Duration::from_secs(CHECK_TIMEOUT_INTERVAL),
198 CHECK_TIMEOUT_TOKEN,
199 )
200 .await
201 {
202 error!("IdentifyProtocol init error: {:?}", err)
203 }
204 }
205
206 async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
207 let session = context.session;
208 debug!("IdentifyProtocol connected, session: {:?}", session);
209 let remote_info = RemoteInfo::new(session.clone(), Duration::from_secs(DEFAULT_TIMEOUT));
210 self.remote_infos.insert(session.id, remote_info);
211 let listen_addrs = if self.callback.register(&context, version) {
212 Vec::new()
213 } else {
214 self.callback
215 .local_listen_addrs()
216 .iter()
217 .filter(|addr| {
218 if let Some(socket_addr) = multiaddr_to_socketaddr(addr) {
219 !self.global_ip_only || is_reachable(socket_addr.ip())
220 } else {
221 addr.iter()
223 .any(|protocol| matches!(protocol, Protocol::Onion3(_)))
224 }
225 })
226 .take(MAX_ADDRS)
227 .cloned()
228 .collect()
229 };
230
231 let identify = self.callback.identify();
232 let data = IdentifyMessage::new(listen_addrs, session.address.clone(), identify).encode();
233 let _ = context
234 .quick_send_message(data)
235 .await
236 .map_err(|err| error!("IdentifyProtocol quick_send_message, error: {:?}", err));
237 }
238
239 async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
240 self.remote_infos
241 .remove(&context.session.id)
242 .expect("RemoteInfo must exists");
243 debug!(
244 "IdentifyProtocol disconnected, session: {:?}",
245 context.session
246 );
247 self.callback.unregister(&context);
248 }
249
250 async fn received(&mut self, mut context: ProtocolContextMutRef<'_>, data: Bytes) {
251 let session = context.session;
252 match IdentifyMessage::decode(&data) {
253 Some(message) => {
254 trace!(
255 "IdentifyProtocol received, session: {:?}, listen_addrs: {:?}, observed_addr: {}",
256 context.session, message.listen_addrs, message.observed_addr
257 );
258
259 if let MisbehaveResult::Disconnect = self.check_duplicate(&mut context) {
261 error!(
262 "Disconnect IdentifyProtocol session {:?} due to duplication.",
263 session
264 );
265 let _ = context.disconnect(session.id).await;
266 return;
267 }
268 if let MisbehaveResult::Disconnect = self
269 .callback
270 .received_identify(&mut context, message.identify)
271 .await
272 {
273 error!(
274 "Disconnect IdentifyProtocol session {:?} due to invalid identify message.",
275 session,
276 );
277 let _ = context.disconnect(session.id).await;
278 return;
279 }
280 if let MisbehaveResult::Disconnect =
281 self.process_listens(&mut context, message.listen_addrs.clone())
282 {
283 error!(
284 "Disconnect IdentifyProtocol session {:?} due to invalid listen addrs: {:?}.",
285 session, message.listen_addrs,
286 );
287 let _ = context.disconnect(session.id).await;
288 return;
289 }
290 if let MisbehaveResult::Disconnect =
291 self.process_observed(&mut context, message.observed_addr.clone())
292 {
293 error!(
294 "Disconnect IdentifyProtocol session {:?} due to invalid observed addr: {}.",
295 session, message.observed_addr,
296 );
297 let _ = context.disconnect(session.id).await;
298 }
299 }
300 None => {
301 let info = self
302 .remote_infos
303 .get(&session.id)
304 .expect("RemoteInfo must exists");
305 if self
306 .callback
307 .misbehave(&info.session, Misbehavior::InvalidData)
308 .is_disconnect()
309 {
310 let _ = context.disconnect(session.id).await;
311 }
312 }
313 }
314 }
315
316 async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
317 for (session_id, info) in &self.remote_infos {
318 if !info.has_received && (info.connected_at + info.timeout) <= Instant::now() {
319 let misbehave_result = self.callback.misbehave(&info.session, Misbehavior::Timeout);
320 if misbehave_result.is_disconnect() {
321 let _ = context.disconnect(*session_id).await;
322 }
323 }
324 }
325 }
326}
327
328#[derive(Clone)]
329pub struct IdentifyCallback {
330 network_state: Arc<NetworkState>,
331 identify: Identify,
332}
333
334impl IdentifyCallback {
335 pub(crate) fn new(
336 network_state: Arc<NetworkState>,
337 name: String,
338 client_version: String,
339 flags: Flags,
340 ) -> IdentifyCallback {
341 IdentifyCallback {
342 network_state,
343 identify: Identify::new(name, flags, client_version),
344 }
345 }
346
347 fn listen_addrs(&self) -> Vec<Multiaddr> {
348 let addrs = self.network_state.public_addrs(MAX_RETURN_LISTEN_ADDRS * 2);
349 addrs
350 .into_iter()
351 .take(MAX_RETURN_LISTEN_ADDRS)
352 .collect::<Vec<_>>()
353 }
354}
355
356#[async_trait]
357impl Callback for IdentifyCallback {
358 fn register(&self, context: &ProtocolContextMutRef, version: &str) -> bool {
359 let session_id = context.session.id;
360 self.network_state.with_peer_registry_mut(|reg| {
361 if let Some(peer) = reg.get_peer_mut(session_id) {
362 peer.protocols.insert(context.proto_id, version.to_owned());
363 }
364 reg.is_anchor(session_id)
365 })
366 }
367
368 fn unregister(&self, context: &ProtocolContextMutRef) {
369 if context.session.ty.is_outbound() {
370 self.network_state.with_peer_store_mut(|peer_store| {
375 peer_store.update_outbound_addr_last_connected_ms(context.session.address.clone());
376 });
377 }
378 }
379
380 fn identify(&mut self) -> &[u8] {
381 self.identify.encode()
382 }
383
384 async fn received_identify(
385 &mut self,
386 context: &mut ProtocolContextMutRef<'_>,
387 identify: &[u8],
388 ) -> MisbehaveResult {
389 match self.identify.verify(identify) {
390 None => {
391 self.network_state.ban_session(
392 &context.control().clone().into(),
393 context.session.id,
394 BAN_ON_NOT_SAME_NET,
395 "The nodes are not on the same network".to_string(),
396 );
397 MisbehaveResult::Disconnect
398 }
399 Some((flags, client_version)) => {
400 let registry_client_version = |version: String| {
401 self.network_state.with_peer_registry_mut(|registry| {
402 if let Some(peer) = registry.get_peer_mut(context.session.id) {
403 peer.identify_info = Some(PeerIdentifyInfo {
404 client_version: version,
405 flags,
406 })
407 }
408 });
409 };
410
411 registry_client_version(client_version);
412
413 let required_flags = self.network_state.required_flags;
414
415 if context.session.ty.is_outbound() {
416 self.network_state.with_peer_store_mut(|peer_store| {
422 peer_store.add_outbound_addr(context.session.address.clone(), flags);
423 });
424
425 if self.network_state.with_peer_registry_mut(|reg| {
426 reg.change_feeler_flags(&context.session.address, flags)
427 }) {
428 let _ = context
429 .open_protocols(
430 context.session.id,
431 TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
432 )
433 .await;
434 } else if required_flags_filter(required_flags, flags) {
435 let _ = context
437 .open_protocols(
438 context.session.id,
439 TargetProtocol::Filter(Box::new(move |id| {
440 id != &SupportProtocols::Feeler.protocol_id()
441 })),
442 )
443 .await;
444 } else {
445 warn!(
447 "Session closed from IdentifyProtocol due to peer's flag not meeting the requirements"
448 );
449 return MisbehaveResult::Disconnect;
450 }
451 }
452 MisbehaveResult::Continue
453 }
454 }
455 }
456
457 fn local_listen_addrs(&mut self) -> Vec<Multiaddr> {
459 let mut listens = self.listen_addrs();
460
461 if listens.len() < MAX_RETURN_LISTEN_ADDRS {
462 let observe_addrs = self
463 .network_state
464 .observed_addrs(MAX_RETURN_LISTEN_ADDRS - listens.len());
465 listens.extend(observe_addrs);
466 listens
467 } else {
468 listens
469 }
470 }
471
472 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>) {
473 trace!(
474 "IdentifyProtocol add remote listening addresses, session: {:?}, addresses : {:?}",
475 session, addrs,
476 );
477 let flags = self.network_state.with_peer_registry_mut(|reg| {
478 if let Some(peer) = reg.get_peer_mut(session.id) {
479 peer.listened_addrs = addrs.clone();
480 peer.identify_info
481 .as_ref()
482 .map(|a| a.flags)
483 .unwrap_or(Flags::COMPATIBILITY)
484 } else {
485 Flags::COMPATIBILITY
486 }
487 });
488 self.network_state.with_peer_store_mut(|peer_store| {
489 for addr in addrs {
490 if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
491 error!("IdentifyProtocol failed to add address to peer store, address: {}, error: {:?}", addr, err);
492 }
493 }
494 })
495 }
496
497 fn add_observed_addr(&mut self, mut addr: Multiaddr, session_id: SessionId) -> MisbehaveResult {
498 if extract_peer_id(&addr).is_none() {
499 addr.push(Protocol::P2P(Cow::Borrowed(
500 self.network_state.local_peer_id().as_bytes(),
501 )))
502 }
503
504 self.network_state.add_observed_addr(session_id, addr);
505 MisbehaveResult::Continue
507 }
508
509 fn misbehave(&mut self, session: &SessionContext, reason: Misbehavior) -> MisbehaveResult {
510 error!(
511 "IdentifyProtocol detects abnormal behavior, session: {:?}, reason: {:?}",
512 session, reason
513 );
514 MisbehaveResult::Disconnect
515 }
516}
517
518#[derive(Clone)]
519struct Identify {
520 name: String,
521 encode_data: ckb_types::bytes::Bytes,
522}
523
524impl Identify {
525 fn new(name: String, flags: Flags, client_version: String) -> Self {
526 Identify {
527 encode_data: packed::Identify::new_builder()
528 .name(name.as_str())
529 .flag(flags.bits())
530 .client_version(client_version.as_str())
531 .build()
532 .as_bytes(),
533 name,
534 }
535 }
536
537 fn encode(&mut self) -> &[u8] {
538 &self.encode_data
539 }
540
541 fn verify(&self, data: &[u8]) -> Option<(Flags, String)> {
542 let reader = packed::IdentifyReader::from_slice(data).ok()?;
543
544 let name = reader.name().as_utf8().ok()?.to_owned();
545 if self.name != name {
546 warn!(
547 "IdentifyProtocol detects peer has different network identifiers, local network id: {}, remote network id: {}",
548 self.name, name,
549 );
550 return None;
551 }
552
553 let flag: u64 = reader.flag().into();
554 if flag == 0 {
555 return None;
556 }
557
558 let raw_client_version = reader.client_version().as_utf8().ok()?.to_owned();
559
560 Some((Flags::from_bits_truncate(flag), raw_client_version))
561 }
562}
563
564bitflags::bitflags! {
565 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
567 pub struct Flags: u64 {
568 const COMPATIBILITY = 0b1;
570 const DISCOVERY = 0b10;
572 const SYNC = 0b100;
574 const RELAY = 0b1000;
576 const LIGHT_CLIENT = 0b10000;
578 const BLOCK_FILTER = 0b100000;
580 }
581}