1use crate::core::hub::HubEvent;
2use crate::pb::RejectMessage;
3use crate::pb::{kaspad_message::Payload as KaspadMessagePayload, KaspadMessage};
4use crate::{common::ProtocolError, KaspadMessagePayloadType};
5use crate::{make_message, Peer};
6use kaspa_core::{debug, error, info, trace, warn};
7use kaspa_utils::networking::PeerId;
8use parking_lot::{Mutex, RwLock};
9use seqlock::SeqLock;
10use std::fmt::{Debug, Display};
11use std::net::SocketAddr;
12use std::ops::{Deref, DerefMut};
13use std::sync::atomic::{AtomicU32, Ordering};
14use std::time::Instant;
15use std::{collections::HashMap, sync::Arc};
16use tokio::select;
17use tokio::sync::mpsc::error::TrySendError;
18use tokio::sync::mpsc::{channel as mpsc_channel, Receiver as MpscReceiver, Sender as MpscSender};
19use tokio::sync::oneshot::{channel as oneshot_channel, Sender as OneshotSender};
20use tonic::Streaming;
21
22use super::peer::{PeerKey, PeerProperties};
23
24pub struct IncomingRoute {
25 rx: MpscReceiver<KaspadMessage>,
26 id: u32,
27}
28
29pub const BLANK_ROUTE_ID: u32 = 0;
33static ROUTE_ID: AtomicU32 = AtomicU32::new(BLANK_ROUTE_ID + 1);
34
35impl IncomingRoute {
36 pub fn new(rx: MpscReceiver<KaspadMessage>) -> Self {
37 let id = ROUTE_ID.fetch_add(1, Ordering::SeqCst);
38 Self { rx, id }
39 }
40
41 pub fn id(&self) -> u32 {
42 self.id
43 }
44}
45
46impl Deref for IncomingRoute {
47 type Target = MpscReceiver<KaspadMessage>;
48
49 fn deref(&self) -> &Self::Target {
50 &self.rx
51 }
52}
53
54impl DerefMut for IncomingRoute {
55 fn deref_mut(&mut self) -> &mut Self::Target {
56 &mut self.rx
57 }
58}
59
60#[derive(Clone)]
61pub struct SharedIncomingRoute(Arc<tokio::sync::Mutex<IncomingRoute>>);
62
63impl SharedIncomingRoute {
64 pub fn new(incoming_route: IncomingRoute) -> Self {
65 Self(Arc::new(tokio::sync::Mutex::new(incoming_route)))
66 }
67
68 pub async fn recv(&mut self) -> Option<KaspadMessage> {
69 self.0.lock().await.recv().await
70 }
71}
72
73pub enum IncomingRouteOverflowPolicy {
75 Drop,
77
78 Disconnect,
80}
81
82impl From<KaspadMessagePayloadType> for IncomingRouteOverflowPolicy {
83 fn from(msg_type: KaspadMessagePayloadType) -> Self {
84 match msg_type {
85 KaspadMessagePayloadType::InvTransactions | KaspadMessagePayloadType::InvRelayBlock => IncomingRouteOverflowPolicy::Drop,
87 _ => IncomingRouteOverflowPolicy::Disconnect,
88 }
89 }
90}
91
92#[derive(Debug, Default)]
93struct RouterMutableState {
94 start_signal: Option<OneshotSender<()>>,
96
97 shutdown_signal: Option<OneshotSender<()>>,
99
100 properties: Arc<PeerProperties>,
102
103 last_ping_duration: u64,
105}
106
107impl RouterMutableState {
108 fn new(start_signal: Option<OneshotSender<()>>, shutdown_signal: Option<OneshotSender<()>>) -> Self {
109 Self { start_signal, shutdown_signal, ..Default::default() }
110 }
111}
112
113#[derive(Debug)]
116pub struct Router {
117 identity: SeqLock<PeerId>,
119
120 net_address: SocketAddr,
122
123 is_outbound: bool,
125
126 connection_started: Instant,
128
129 routing_map_by_type: RwLock<HashMap<KaspadMessagePayloadType, MpscSender<KaspadMessage>>>,
131
132 routing_map_by_id: RwLock<HashMap<u32, MpscSender<KaspadMessage>>>,
133
134 outgoing_route: MpscSender<KaspadMessage>,
136
137 hub_sender: MpscSender<HubEvent>,
139
140 mutable_state: Mutex<RouterMutableState>,
142}
143
144impl Display for Router {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(f, "{}", self.net_address)
147 }
148}
149
150impl From<&Router> for PeerKey {
151 fn from(value: &Router) -> Self {
152 Self::new(value.identity.read(), value.net_address.ip().into())
153 }
154}
155
156impl From<&Router> for Peer {
157 fn from(router: &Router) -> Self {
158 Self::new(
159 router.identity(),
160 router.net_address,
161 router.is_outbound,
162 router.connection_started,
163 router.properties(),
164 router.last_ping_duration(),
165 )
166 }
167}
168
169fn message_summary(msg: &KaspadMessage) -> impl Debug {
170 msg.payload.as_ref().map(std::convert::Into::<KaspadMessagePayloadType>::into)
173}
174
175impl Router {
176 pub(crate) async fn new(
177 net_address: SocketAddr,
178 is_outbound: bool,
179 hub_sender: MpscSender<HubEvent>,
180 mut incoming_stream: Streaming<KaspadMessage>,
181 outgoing_route: MpscSender<KaspadMessage>,
182 ) -> Arc<Self> {
183 let (start_sender, start_receiver) = oneshot_channel();
184 let (shutdown_sender, mut shutdown_receiver) = oneshot_channel();
185
186 let router = Arc::new(Router {
187 identity: Default::default(),
188 net_address,
189 is_outbound,
190 connection_started: Instant::now(),
191 routing_map_by_type: RwLock::new(HashMap::new()),
192 routing_map_by_id: RwLock::new(HashMap::new()),
193 outgoing_route,
194 hub_sender,
195 mutable_state: Mutex::new(RouterMutableState::new(Some(start_sender), Some(shutdown_sender))),
196 });
197
198 let router_clone = router.clone();
199 tokio::spawn(async move {
201 let _ = start_receiver.await;
203 loop {
204 select! {
205 biased; _ = &mut shutdown_receiver => {
208 debug!("P2P, Router receive loop - shutdown signal received, exiting router receive loop, router-id: {}", router.identity());
209 break;
210 }
211
212 res = incoming_stream.message() => match res {
213 Ok(Some(msg)) => {
214 trace!("P2P msg: {:?}, router-id: {}, peer: {}", message_summary(&msg), router.identity(), router);
215 match router.route_to_flow(msg) {
216 Ok(()) => {},
217 Err(e) => {
218 match e {
219 ProtocolError::IgnorableReject(reason) => debug!("P2P, got reject message: {} from peer: {}", reason, router),
220 ProtocolError::Rejected(reason) => warn!("P2P, got reject message: {} from peer: {}", reason, router),
221 e => warn!("P2P, route error: {} for peer: {}", e, router),
222 }
223 break;
224 },
225 }
226 }
227 Ok(None) => {
228 info!("P2P, incoming stream ended from peer {}", router);
229 break;
230 }
231 Err(status) => {
232 if let Some(err) = match_for_io_error(&status) {
233 info!("P2P, network error: {} from peer {}", err, router);
234 } else {
235 info!("P2P, network error: {} from peer {}", status, router);
236 }
237 break;
238 }
239 }
240 }
241 }
242 router.close().await;
243 debug!("P2P, Router receive loop - exited, router-id: {}, router refs: {}", router.identity(), Arc::strong_count(&router));
244 });
245
246 router_clone
247 }
248
249 pub fn identity(&self) -> PeerId {
251 self.identity.read()
252 }
253
254 pub fn set_identity(&self, identity: PeerId) {
255 *self.identity.lock_write() = identity;
256 }
257
258 pub fn net_address(&self) -> SocketAddr {
260 self.net_address
261 }
262
263 pub fn key(&self) -> PeerKey {
264 self.into()
265 }
266
267 pub fn is_outbound(&self) -> bool {
269 self.is_outbound
270 }
271
272 pub fn connection_started(&self) -> Instant {
273 self.connection_started
274 }
275
276 pub fn time_connected(&self) -> u64 {
277 Instant::now().duration_since(self.connection_started).as_millis() as u64
278 }
279
280 pub fn properties(&self) -> Arc<PeerProperties> {
281 self.mutable_state.lock().properties.clone()
282 }
283
284 pub fn set_properties(&self, properties: Arc<PeerProperties>) {
285 self.mutable_state.lock().properties = properties;
286 }
287
288 pub fn set_last_ping_duration(&self, last_ping_duration: u64) {
290 self.mutable_state.lock().last_ping_duration = last_ping_duration;
291 }
292
293 pub fn last_ping_duration(&self) -> u64 {
294 self.mutable_state.lock().last_ping_duration
295 }
296
297 pub fn incoming_flow_baseline_channel_size() -> usize {
298 256
299 }
300
301 pub fn start(&self) {
303 let op = self.mutable_state.lock().start_signal.take();
305 if let Some(signal) = op {
306 let _ = signal.send(());
307 } else {
308 debug!("P2P, Router start was called more than once, router-id: {}", self.identity())
309 }
310 }
311
312 pub fn subscribe(&self, msg_types: Vec<KaspadMessagePayloadType>) -> IncomingRoute {
316 self.subscribe_with_capacity(msg_types, Self::incoming_flow_baseline_channel_size())
317 }
318
319 pub fn subscribe_with_capacity(&self, msg_types: Vec<KaspadMessagePayloadType>, capacity: usize) -> IncomingRoute {
323 let (sender, receiver) = mpsc_channel(capacity);
324 let incoming_route = IncomingRoute::new(receiver);
325 let mut map_by_type = self.routing_map_by_type.write();
326 for msg_type in msg_types {
327 match map_by_type.insert(msg_type, sender.clone()) {
328 Some(_) => {
329 error!(
331 "P2P, Router::subscribe overrides an existing message type: {:?}, router-id: {}",
332 msg_type,
333 self.identity()
334 );
335 panic!("P2P, Tried to subscribe to an existing route");
336 }
337 None => {
338 trace!("P2P, Router::subscribe - msg_type: {:?} route is registered, router-id:{:?}", msg_type, self.identity());
339 }
340 }
341 }
342 let mut map_by_id = self.routing_map_by_id.write();
343 match map_by_id.insert(incoming_route.id, sender.clone()) {
344 Some(_) => {
345 error!(
347 "P2P, Router::subscribe overrides an existing route id: {:?}, router-id: {}",
348 incoming_route.id,
349 self.identity()
350 );
351 panic!("P2P, Tried to subscribe to an existing route");
352 }
353 None => {
354 trace!(
355 "P2P, Router::subscribe - route id: {:?} route is registered, router-id:{:?}",
356 incoming_route.id,
357 self.identity()
358 );
359 }
360 }
361 incoming_route
362 }
363
364 pub fn route_to_flow(&self, msg: KaspadMessage) -> Result<(), ProtocolError> {
366 if msg.payload.is_none() {
367 debug!("P2P, Route to flow got empty payload, peer: {}", self);
368 return Err(ProtocolError::Other("received kaspad p2p message with empty payload"));
369 }
370 let msg_type: KaspadMessagePayloadType = msg.payload.as_ref().expect("payload was just verified").into();
371 if msg_type == KaspadMessagePayloadType::Reject {
373 let Some(KaspadMessagePayload::Reject(reject)) = msg.payload else { unreachable!() };
374 return Err(ProtocolError::from_reject_message(reject.reason));
375 }
376
377 let op = if msg.response_id != BLANK_ROUTE_ID {
378 self.routing_map_by_id.read().get(&msg.response_id).cloned()
379 } else {
380 self.routing_map_by_type.read().get(&msg_type).cloned()
381 };
382
383 if let Some(sender) = op {
384 match sender.try_send(msg) {
385 Ok(_) => Ok(()),
386 Err(TrySendError::Closed(_)) => Err(ProtocolError::ConnectionClosed),
387 Err(TrySendError::Full(_)) => {
388 let overflow_policy: IncomingRouteOverflowPolicy = msg_type.into();
389 match overflow_policy {
390 IncomingRouteOverflowPolicy::Drop => Ok(()),
391 IncomingRouteOverflowPolicy::Disconnect => {
392 Err(ProtocolError::IncomingRouteCapacityReached(msg_type, self.to_string()))
393 }
394 }
395 }
396 }
397 } else {
398 Err(ProtocolError::NoRouteForMessageType(msg_type))
399 }
400 }
401
402 pub async fn enqueue(&self, msg: KaspadMessage) -> Result<(), ProtocolError> {
404 assert!(msg.payload.is_some(), "Kaspad P2P message should always have a value");
405 match self.outgoing_route.try_send(msg) {
406 Ok(_) => Ok(()),
407 Err(TrySendError::Closed(_)) => Err(ProtocolError::ConnectionClosed),
408 Err(TrySendError::Full(_)) => Err(ProtocolError::OutgoingRouteCapacityReached(self.to_string())),
409 }
410 }
411
412 pub async fn try_sending_reject_message(&self, err: &ProtocolError) {
414 if err.can_send_outgoing_message() {
415 let _ = self.enqueue(make_message!(KaspadMessagePayload::Reject, RejectMessage { reason: err.to_reject_message() })).await;
418 }
419 }
420
421 pub async fn close(self: &Arc<Router>) -> bool {
424 {
427 let mut state = self.mutable_state.lock();
428
429 if let Some(signal) = state.start_signal.take() {
431 let _ = signal.send(());
432 }
433
434 if let Some(signal) = state.shutdown_signal.take() {
435 let _ = signal.send(());
436 } else {
437 trace!("P2P, Router close was called more than once, router-id: {}", self.identity());
439 return false;
440 }
441 }
442
443 self.routing_map_by_type.write().clear();
445 self.routing_map_by_id.write().clear();
446
447 self.hub_sender.send(HubEvent::PeerClosing(self.clone())).await.expect("hub receiver should never drop before senders");
449
450 true
451 }
452}
453
454fn match_for_io_error(err_status: &tonic::Status) -> Option<&std::io::Error> {
455 let mut err: &(dyn std::error::Error + 'static) = err_status;
456
457 loop {
458 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
459 return Some(io_err);
460 }
461
462 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
465 if let Some(io_err) = h2_err.get_io() {
466 return Some(io_err);
467 }
468 }
469
470 err = match err.source() {
471 Some(err) => err,
472 None => return None,
473 };
474 }
475}