1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use ckb_logger::{debug, error, trace, warn};
6use ckb_systemtime::{Duration, Instant};
7use p2p::{
8 SessionId, async_trait,
9 bytes::Bytes,
10 context::{ProtocolContext, ProtocolContextMutRef, SessionContext},
11 multiaddr::{Multiaddr, Protocol},
12 service::TargetProtocol,
13 traits::ServiceProtocol,
14 utils::{extract_peer_id, is_reachable, multiaddr_to_socketaddr},
15};
16
17mod protocol;
18
19use crate::{NetworkState, PeerIdentifyInfo, SupportProtocols, peer_store::required_flags_filter};
20use ckb_types::{packed, prelude::*};
21
22use protocol::IdentifyMessage;
23
24const MAX_RETURN_LISTEN_ADDRS: usize = 10;
25const BAN_ON_NOT_SAME_NET: Duration = Duration::from_secs(5 * 60);
26const CHECK_TIMEOUT_TOKEN: u64 = 100;
27const CHECK_TIMEOUT_INTERVAL: u64 = 1;
29const DEFAULT_TIMEOUT: u64 = 8;
30const MAX_ADDRS: usize = 10;
31
32#[allow(dead_code)]
34#[derive(Clone, Debug)]
35pub enum Misbehavior {
36 DuplicateReceived,
38 Timeout,
40 InvalidData,
42 TooManyAddresses(usize),
44}
45
46pub enum MisbehaveResult {
48 Continue,
50 Disconnect,
52}
53
54impl MisbehaveResult {
55 pub fn is_disconnect(&self) -> bool {
56 matches!(self, MisbehaveResult::Disconnect)
57 }
58}
59
60#[async_trait]
62pub trait Callback: Clone + Send {
63 fn register(&self, context: &ProtocolContextMutRef, version: &str) -> bool;
65 fn unregister(&self, context: &ProtocolContextMutRef);
67 async fn received_identify(
69 &mut self,
70 context: &mut ProtocolContextMutRef<'_>,
71 identify: &[u8],
72 ) -> MisbehaveResult;
73 fn identify(&mut self) -> &[u8];
75 fn local_listen_addrs(&mut self) -> Vec<Multiaddr>;
77 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>);
79 fn add_observed_addr(&mut self, addr: Multiaddr, session_id: SessionId) -> MisbehaveResult;
81 fn misbehave(&mut self, session: &SessionContext, kind: Misbehavior) -> MisbehaveResult;
83}
84
85pub struct IdentifyProtocol<T> {
87 callback: T,
88 remote_infos: HashMap<SessionId, RemoteInfo>,
89 global_ip_only: bool,
90}
91
92impl<T: Callback> IdentifyProtocol<T> {
93 pub fn new(callback: T) -> IdentifyProtocol<T> {
94 IdentifyProtocol {
95 callback,
96 remote_infos: HashMap::default(),
97 global_ip_only: true,
98 }
99 }
100
101 #[cfg(test)]
102 pub fn global_ip_only(mut self, only: bool) -> Self {
103 self.global_ip_only = only;
104 self
105 }
106
107 fn check_duplicate(&mut self, context: &mut ProtocolContextMutRef) -> MisbehaveResult {
108 let session = context.session;
109 let info = self
110 .remote_infos
111 .get_mut(&session.id)
112 .expect("RemoteInfo must exists");
113
114 if info.has_received {
115 self.callback
116 .misbehave(&info.session, Misbehavior::DuplicateReceived)
117 } else {
118 info.has_received = true;
119 MisbehaveResult::Continue
120 }
121 }
122
123 fn process_listens(
124 &mut self,
125 context: &mut ProtocolContextMutRef,
126 listens: Vec<Multiaddr>,
127 ) -> MisbehaveResult {
128 let session = context.session;
129 let info = self
130 .remote_infos
131 .get_mut(&session.id)
132 .expect("RemoteInfo must exists");
133
134 if listens.len() > MAX_ADDRS {
135 self.callback
136 .misbehave(&info.session, Misbehavior::TooManyAddresses(listens.len()))
137 } else {
138 let global_ip_only = self.global_ip_only;
139 let reachable_addrs = listens
140 .into_iter()
141 .filter(|addr| match multiaddr_to_socketaddr(addr) {
142 Some(socket_addr) => !global_ip_only || is_reachable(socket_addr.ip()),
143 None => true,
144 })
145 .collect::<Vec<_>>();
146 self.callback
147 .add_remote_listen_addrs(session, reachable_addrs);
148 MisbehaveResult::Continue
149 }
150 }
151
152 fn process_observed(
153 &mut self,
154 context: &mut ProtocolContextMutRef,
155 observed: Multiaddr,
156 ) -> MisbehaveResult {
157 debug!(
158 "IdentifyProtocol process observed address, session: {:?}, observed: {}",
159 context.session, observed,
160 );
161
162 let session = context.session;
163 let info = self
164 .remote_infos
165 .get_mut(&session.id)
166 .expect("RemoteInfo must exists");
167 self.callback.add_observed_addr(observed, info.session.id);
168 MisbehaveResult::Continue
169 }
170}
171
172pub(crate) struct RemoteInfo {
173 session: SessionContext,
174 connected_at: Instant,
175 timeout: Duration,
176 has_received: bool,
177}
178
179impl RemoteInfo {
180 fn new(session: SessionContext, timeout: Duration) -> RemoteInfo {
181 RemoteInfo {
182 session,
183 connected_at: Instant::now(),
184 timeout,
185 has_received: false,
186 }
187 }
188}
189
190#[async_trait]
191impl<T: Callback> ServiceProtocol for IdentifyProtocol<T> {
192 async fn init(&mut self, context: &mut ProtocolContext) {
193 let proto_id = context.proto_id;
194 if let Err(err) = context
195 .set_service_notify(
196 proto_id,
197 Duration::from_secs(CHECK_TIMEOUT_INTERVAL),
198 CHECK_TIMEOUT_TOKEN,
199 )
200 .await
201 {
202 error!("IdentifyProtocol init error: {:?}", err)
203 }
204 }
205
206 async fn connected(&mut self, context: ProtocolContextMutRef<'_>, version: &str) {
207 let session = context.session;
208 debug!("IdentifyProtocol connected, session: {:?}", session);
209 let remote_info = RemoteInfo::new(session.clone(), Duration::from_secs(DEFAULT_TIMEOUT));
210 self.remote_infos.insert(session.id, remote_info);
211 let listen_addrs = if self.callback.register(&context, version) {
212 Vec::new()
213 } else {
214 self.callback
215 .local_listen_addrs()
216 .iter()
217 .filter(|addr| {
218 multiaddr_to_socketaddr(addr)
219 .map(|socket_addr| !self.global_ip_only || is_reachable(socket_addr.ip()))
220 .unwrap_or(false)
221 })
222 .take(MAX_ADDRS)
223 .cloned()
224 .collect()
225 };
226
227 let identify = self.callback.identify();
228 let data = IdentifyMessage::new(listen_addrs, session.address.clone(), identify).encode();
229 let _ = context
230 .quick_send_message(data)
231 .await
232 .map_err(|err| error!("IdentifyProtocol quick_send_message, error: {:?}", err));
233 }
234
235 async fn disconnected(&mut self, context: ProtocolContextMutRef<'_>) {
236 self.remote_infos
237 .remove(&context.session.id)
238 .expect("RemoteInfo must exists");
239 debug!(
240 "IdentifyProtocol disconnected, session: {:?}",
241 context.session
242 );
243 self.callback.unregister(&context);
244 }
245
246 async fn received(&mut self, mut context: ProtocolContextMutRef<'_>, data: Bytes) {
247 let session = context.session;
248 match IdentifyMessage::decode(&data) {
249 Some(message) => {
250 trace!(
251 "IdentifyProtocol received, session: {:?}, listen_addrs: {:?}, observed_addr: {}",
252 context.session, message.listen_addrs, message.observed_addr
253 );
254
255 if let MisbehaveResult::Disconnect = self.check_duplicate(&mut context) {
257 error!(
258 "Disconnect IdentifyProtocol session {:?} due to duplication.",
259 session
260 );
261 let _ = context.disconnect(session.id).await;
262 return;
263 }
264 if let MisbehaveResult::Disconnect = self
265 .callback
266 .received_identify(&mut context, message.identify)
267 .await
268 {
269 error!(
270 "Disconnect IdentifyProtocol session {:?} due to invalid identify message.",
271 session,
272 );
273 let _ = context.disconnect(session.id).await;
274 return;
275 }
276 if let MisbehaveResult::Disconnect =
277 self.process_listens(&mut context, message.listen_addrs.clone())
278 {
279 error!(
280 "Disconnect IdentifyProtocol session {:?} due to invalid listen addrs: {:?}.",
281 session, message.listen_addrs,
282 );
283 let _ = context.disconnect(session.id).await;
284 return;
285 }
286 if let MisbehaveResult::Disconnect =
287 self.process_observed(&mut context, message.observed_addr.clone())
288 {
289 error!(
290 "Disconnect IdentifyProtocol session {:?} due to invalid observed addr: {}.",
291 session, message.observed_addr,
292 );
293 let _ = context.disconnect(session.id).await;
294 }
295 }
296 None => {
297 let info = self
298 .remote_infos
299 .get(&session.id)
300 .expect("RemoteInfo must exists");
301 if self
302 .callback
303 .misbehave(&info.session, Misbehavior::InvalidData)
304 .is_disconnect()
305 {
306 let _ = context.disconnect(session.id).await;
307 }
308 }
309 }
310 }
311
312 async fn notify(&mut self, context: &mut ProtocolContext, _token: u64) {
313 for (session_id, info) in &self.remote_infos {
314 if !info.has_received && (info.connected_at + info.timeout) <= Instant::now() {
315 let misbehave_result = self.callback.misbehave(&info.session, Misbehavior::Timeout);
316 if misbehave_result.is_disconnect() {
317 let _ = context.disconnect(*session_id).await;
318 }
319 }
320 }
321 }
322}
323
324#[derive(Clone)]
325pub struct IdentifyCallback {
326 network_state: Arc<NetworkState>,
327 identify: Identify,
328}
329
330impl IdentifyCallback {
331 pub(crate) fn new(
332 network_state: Arc<NetworkState>,
333 name: String,
334 client_version: String,
335 flags: Flags,
336 ) -> IdentifyCallback {
337 IdentifyCallback {
338 network_state,
339 identify: Identify::new(name, flags, client_version),
340 }
341 }
342
343 fn listen_addrs(&self) -> Vec<Multiaddr> {
344 let addrs = self.network_state.public_addrs(MAX_RETURN_LISTEN_ADDRS * 2);
345 addrs
346 .into_iter()
347 .take(MAX_RETURN_LISTEN_ADDRS)
348 .collect::<Vec<_>>()
349 }
350}
351
352#[async_trait]
353impl Callback for IdentifyCallback {
354 fn register(&self, context: &ProtocolContextMutRef, version: &str) -> bool {
355 let session_id = context.session.id;
356 self.network_state.with_peer_registry_mut(|reg| {
357 if let Some(peer) = reg.get_peer_mut(session_id) {
358 peer.protocols.insert(context.proto_id, version.to_owned());
359 }
360 reg.is_anchor(session_id)
361 })
362 }
363
364 fn unregister(&self, context: &ProtocolContextMutRef) {
365 if context.session.ty.is_outbound() {
366 self.network_state.with_peer_store_mut(|peer_store| {
371 peer_store.update_outbound_addr_last_connected_ms(context.session.address.clone());
372 });
373 }
374 }
375
376 fn identify(&mut self) -> &[u8] {
377 self.identify.encode()
378 }
379
380 async fn received_identify(
381 &mut self,
382 context: &mut ProtocolContextMutRef<'_>,
383 identify: &[u8],
384 ) -> MisbehaveResult {
385 match self.identify.verify(identify) {
386 None => {
387 self.network_state.ban_session(
388 &context.control().clone().into(),
389 context.session.id,
390 BAN_ON_NOT_SAME_NET,
391 "The nodes are not on the same network".to_string(),
392 );
393 MisbehaveResult::Disconnect
394 }
395 Some((flags, client_version)) => {
396 let registry_client_version = |version: String| {
397 self.network_state.with_peer_registry_mut(|registry| {
398 if let Some(peer) = registry.get_peer_mut(context.session.id) {
399 peer.identify_info = Some(PeerIdentifyInfo {
400 client_version: version,
401 flags,
402 })
403 }
404 });
405 };
406
407 registry_client_version(client_version);
408
409 let required_flags = self.network_state.required_flags;
410
411 if context.session.ty.is_outbound() {
412 self.network_state.with_peer_store_mut(|peer_store| {
418 peer_store.add_outbound_addr(context.session.address.clone(), flags);
419 });
420
421 if self.network_state.with_peer_registry_mut(|reg| {
422 reg.change_feeler_flags(&context.session.address, flags)
423 }) {
424 let _ = context
425 .open_protocols(
426 context.session.id,
427 TargetProtocol::Single(SupportProtocols::Feeler.protocol_id()),
428 )
429 .await;
430 } else if required_flags_filter(required_flags, flags) {
431 let _ = context
433 .open_protocols(
434 context.session.id,
435 TargetProtocol::Filter(Box::new(move |id| {
436 id != &SupportProtocols::Feeler.protocol_id()
437 })),
438 )
439 .await;
440 } else {
441 warn!(
443 "Session closed from IdentifyProtocol due to peer's flag not meeting the requirements"
444 );
445 return MisbehaveResult::Disconnect;
446 }
447 }
448 MisbehaveResult::Continue
449 }
450 }
451 }
452
453 fn local_listen_addrs(&mut self) -> Vec<Multiaddr> {
455 let mut listens = self.listen_addrs();
456
457 if listens.len() < MAX_RETURN_LISTEN_ADDRS {
458 let observe_addrs = self
459 .network_state
460 .observed_addrs(MAX_RETURN_LISTEN_ADDRS - listens.len());
461 listens.extend(observe_addrs);
462 listens
463 } else {
464 listens
465 }
466 }
467
468 fn add_remote_listen_addrs(&mut self, session: &SessionContext, addrs: Vec<Multiaddr>) {
469 trace!(
470 "IdentifyProtocol add remote listening addresses, session: {:?}, addresses : {:?}",
471 session, addrs,
472 );
473 let flags = self.network_state.with_peer_registry_mut(|reg| {
474 if let Some(peer) = reg.get_peer_mut(session.id) {
475 peer.listened_addrs = addrs.clone();
476 peer.identify_info
477 .as_ref()
478 .map(|a| a.flags)
479 .unwrap_or(Flags::COMPATIBILITY)
480 } else {
481 Flags::COMPATIBILITY
482 }
483 });
484 self.network_state.with_peer_store_mut(|peer_store| {
485 for addr in addrs {
486 if let Err(err) = peer_store.add_addr(addr.clone(), flags) {
487 error!("IdentifyProtocol failed to add address to peer store, address: {}, error: {:?}", addr, err);
488 }
489 }
490 })
491 }
492
493 fn add_observed_addr(&mut self, mut addr: Multiaddr, session_id: SessionId) -> MisbehaveResult {
494 if extract_peer_id(&addr).is_none() {
495 addr.push(Protocol::P2P(Cow::Borrowed(
496 self.network_state.local_peer_id().as_bytes(),
497 )))
498 }
499
500 self.network_state.add_observed_addr(session_id, addr);
501 MisbehaveResult::Continue
503 }
504
505 fn misbehave(&mut self, session: &SessionContext, reason: Misbehavior) -> MisbehaveResult {
506 error!(
507 "IdentifyProtocol detects abnormal behavior, session: {:?}, reason: {:?}",
508 session, reason
509 );
510 MisbehaveResult::Disconnect
511 }
512}
513
514#[derive(Clone)]
515struct Identify {
516 name: String,
517 encode_data: ckb_types::bytes::Bytes,
518}
519
520impl Identify {
521 fn new(name: String, flags: Flags, client_version: String) -> Self {
522 Identify {
523 encode_data: packed::Identify::new_builder()
524 .name(name.as_str())
525 .flag(flags.bits())
526 .client_version(client_version.as_str())
527 .build()
528 .as_bytes(),
529 name,
530 }
531 }
532
533 fn encode(&mut self) -> &[u8] {
534 &self.encode_data
535 }
536
537 fn verify(&self, data: &[u8]) -> Option<(Flags, String)> {
538 let reader = packed::IdentifyReader::from_slice(data).ok()?;
539
540 let name = reader.name().as_utf8().ok()?.to_owned();
541 if self.name != name {
542 warn!(
543 "IdentifyProtocol detects peer has different network identifiers, local network id: {}, remote network id: {}",
544 self.name, name,
545 );
546 return None;
547 }
548
549 let flag: u64 = reader.flag().into();
550 if flag == 0 {
551 return None;
552 }
553
554 let raw_client_version = reader.client_version().as_utf8().ok()?.to_owned();
555
556 Some((Flags::from_bits_truncate(flag), raw_client_version))
557 }
558}
559
560bitflags::bitflags! {
561 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
563 pub struct Flags: u64 {
564 const COMPATIBILITY = 0b1;
566 const DISCOVERY = 0b10;
568 const SYNC = 0b100;
570 const RELAY = 0b1000;
572 const LIGHT_CLIENT = 0b10000;
574 const BLOCK_FILTER = 0b100000;
576 }
577}