1use std::collections::{BTreeSet, HashMap};
2use std::marker::PhantomData;
3
4use async_trait::async_trait;
5use derive_where::derive_where;
6use eyre::eyre;
7use libp2p::request_response;
8use ractor::{Actor, ActorProcessingErr, ActorRef, RpcReplyPort};
9use tokio::task::JoinHandle;
10use tracing::{error, info, trace, warn};
11
12use malachitebft_codec as codec;
13use malachitebft_core_consensus::{LivenessMsg, SignedConsensusMsg};
14use malachitebft_core_types::{
15 Context, PolkaCertificate, RoundCertificate, SignedProposal, SignedVote, Validator,
16 ValidatorSet,
17};
18use malachitebft_metrics::SharedRegistry;
19use malachitebft_network::handle::CtrlHandle;
20use malachitebft_network::{Channel, Config, Event, PeerId};
21
22pub use malachitebft_network::{
23 Multiaddr, NetworkIdentity, NetworkStateDump, PersistentPeerError, PersistentPeersOp,
24};
25
26use malachitebft_sync::{
27 self as sync, InboundRequestId, OutboundRequestId, RawMessage, Request, Response,
28};
29
30use crate::consensus::ConsensusCodec;
31use crate::sync::SyncCodec;
32use crate::util::output_port::{OutputPort, OutputPortSubscriberTrait};
33use crate::util::streaming::StreamMessage;
34
35pub type NetworkRef<Ctx> = ActorRef<Msg<Ctx>>;
36pub type NetworkMsg<Ctx> = Msg<Ctx>;
37
38pub trait Subscriber<Msg>: OutputPortSubscriberTrait<Msg>
39where
40 Msg: Clone + ractor::Message,
41{
42 fn send(&self, msg: Msg);
43}
44
45impl<Msg, To> Subscriber<Msg> for ActorRef<To>
46where
47 Msg: Clone + ractor::Message,
48 To: From<Msg> + ractor::Message,
49{
50 fn send(&self, msg: Msg) {
51 if let Err(e) = self.cast(To::from(msg)) {
52 error!("Failed to send message to subscriber: {e:?}");
53 }
54 }
55}
56
57pub struct Network<Ctx, Codec> {
58 codec: Codec,
59 span: tracing::Span,
60 marker: PhantomData<Ctx>,
61}
62
63impl<Ctx, Codec> Network<Ctx, Codec> {
64 pub fn new(codec: Codec, span: tracing::Span) -> Self {
65 Self {
66 codec,
67 span,
68 marker: PhantomData,
69 }
70 }
71}
72
73impl<Ctx, Codec> Network<Ctx, Codec>
74where
75 Ctx: Context,
76 Codec: ConsensusCodec<Ctx>,
77 Codec: SyncCodec<Ctx>,
78 Codec: codec::HasEncodedLen<sync::Response<Ctx>>,
79{
80 pub async fn spawn(
81 identity: NetworkIdentity,
82 config: Config,
83 metrics: SharedRegistry,
84 codec: Codec,
85 span: tracing::Span,
86 ) -> Result<ActorRef<Msg<Ctx>>, ractor::SpawnErr> {
87 let args = Args {
88 identity,
89 config: config.clone(),
90 metrics,
91 };
92
93 let (actor_ref, _) = Actor::spawn(None, Self::new(codec, span), args).await?;
94 Ok(actor_ref)
95 }
96}
97
98pub struct Args {
99 pub identity: NetworkIdentity,
100 pub config: Config,
101 pub metrics: SharedRegistry,
102}
103
104#[derive_where(Clone, Debug, PartialEq, Eq)]
105pub enum NetworkEvent<Ctx: Context> {
106 Listening(Multiaddr),
107
108 PeerConnected(PeerId),
109 PeerDisconnected(PeerId),
110
111 Vote(PeerId, SignedVote<Ctx>),
112
113 Proposal(PeerId, SignedProposal<Ctx>),
114 ProposalPart(PeerId, StreamMessage<Ctx::ProposalPart>),
115
116 PolkaCertificate(PeerId, PolkaCertificate<Ctx>),
117
118 RoundCertificate(PeerId, RoundCertificate<Ctx>),
119
120 Status(PeerId, Status<Ctx>),
121
122 SyncRequest(InboundRequestId, PeerId, Request<Ctx>),
123 SyncResponse(OutboundRequestId, PeerId, Option<Response<Ctx>>),
124}
125
126pub enum State<Ctx: Context> {
127 Stopped,
128 Running {
129 listen_addrs: Vec<Multiaddr>,
130 peers: BTreeSet<PeerId>,
131 output_port: OutputPort<NetworkEvent<Ctx>>,
132 ctrl_handle: Box<CtrlHandle>,
133 recv_task: JoinHandle<()>,
134 inbound_requests: HashMap<InboundRequestId, request_response::InboundRequestId>,
135 },
136}
137
138#[derive_where(Clone, Debug, PartialEq, Eq)]
139pub struct Status<Ctx: Context> {
140 pub tip_height: Ctx::Height,
141 pub history_min_height: Ctx::Height,
142}
143
144impl<Ctx: Context> Status<Ctx> {
145 pub fn new(tip_height: Ctx::Height, history_min_height: Ctx::Height) -> Self {
146 Self {
147 tip_height,
148 history_min_height,
149 }
150 }
151}
152
153pub enum Msg<Ctx: Context> {
154 Subscribe(Box<dyn Subscriber<NetworkEvent<Ctx>>>),
156
157 PublishConsensusMsg(SignedConsensusMsg<Ctx>),
159
160 PublishLivenessMsg(LivenessMsg<Ctx>),
162
163 PublishProposalPart(StreamMessage<Ctx::ProposalPart>),
165
166 BroadcastStatus(Status<Ctx>),
168
169 OutgoingRequest(PeerId, Request<Ctx>, RpcReplyPort<OutboundRequestId>),
171
172 OutgoingResponse(InboundRequestId, Response<Ctx>),
174
175 DumpState(RpcReplyPort<Option<NetworkStateDump>>),
177
178 UpdatePersistentPeers(
180 PersistentPeersOp,
181 RpcReplyPort<Result<(), PersistentPeerError>>,
182 ),
183
184 UpdateValidatorSet(Ctx::ValidatorSet),
186
187 #[doc(hidden)]
189 NewEvent(Event),
190}
191
192#[async_trait]
193impl<Ctx, Codec> Actor for Network<Ctx, Codec>
194where
195 Ctx: Context,
196 Codec: Send + Sync + 'static,
197 Codec: codec::Codec<Ctx::ProposalPart>,
198 Codec: codec::Codec<SignedConsensusMsg<Ctx>>,
199 Codec: codec::Codec<StreamMessage<Ctx::ProposalPart>>,
200 Codec: codec::Codec<LivenessMsg<Ctx>>,
201 Codec: SyncCodec<Ctx>,
202{
203 type Msg = Msg<Ctx>;
204 type State = State<Ctx>;
205 type Arguments = Args;
206
207 async fn pre_start(
208 &self,
209 myself: ActorRef<Msg<Ctx>>,
210 args: Args,
211 ) -> Result<Self::State, ActorProcessingErr> {
212 let handle = malachitebft_network::spawn(args.identity, args.config, args.metrics).await?;
213
214 let (mut recv_handle, ctrl_handle) = handle.split();
215
216 let recv_task = tokio::spawn(async move {
217 while let Some(event) = recv_handle.recv().await {
218 if let Err(e) = myself.cast(Msg::NewEvent(event)) {
219 error!("Actor has died, stopping network: {e:?}");
220 break;
221 }
222 }
223 });
224
225 Ok(State::Running {
226 listen_addrs: Vec::new(),
227 peers: BTreeSet::new(),
228 output_port: OutputPort::with_capacity(128),
229 ctrl_handle: Box::new(ctrl_handle),
230 recv_task,
231 inbound_requests: HashMap::new(),
232 })
233 }
234
235 async fn post_start(
236 &self,
237 _myself: ActorRef<Msg<Ctx>>,
238 _state: &mut State<Ctx>,
239 ) -> Result<(), ActorProcessingErr> {
240 Ok(())
241 }
242
243 #[tracing::instrument(name = "network", parent = &self.span, skip_all)]
244 async fn handle(
245 &self,
246 _myself: ActorRef<Msg<Ctx>>,
247 msg: Msg<Ctx>,
248 state: &mut State<Ctx>,
249 ) -> Result<(), ActorProcessingErr> {
250 if let Msg::DumpState(reply_to) = msg {
252 handle_dump_state(state, reply_to).await;
253 return Ok(());
254 }
255
256 if let Msg::UpdatePersistentPeers(op, reply_to) = msg {
257 handle_update_persistent_peers(state, op, reply_to).await;
258 return Ok(());
259 }
260
261 let State::Running {
262 listen_addrs,
263 peers,
264 output_port,
265 ctrl_handle,
266 inbound_requests,
267 ..
268 } = state
269 else {
270 return Ok(());
271 };
272
273 match msg {
274 Msg::Subscribe(subscriber) => {
275 for addr in listen_addrs.iter() {
276 subscriber.send(NetworkEvent::Listening(addr.clone()));
277 }
278
279 for peer in peers.iter() {
280 subscriber.send(NetworkEvent::PeerConnected(*peer));
281 }
282
283 subscriber.subscribe_to_port(output_port);
284 }
285
286 Msg::PublishConsensusMsg(msg) => match self.codec.encode(&msg) {
287 Ok(data) => ctrl_handle.publish(Channel::Consensus, data).await?,
288 Err(e) => error!("Failed to encode consensus message: {e:?}"),
289 },
290
291 Msg::PublishLivenessMsg(msg) => match self.codec.encode(&msg) {
292 Ok(data) => ctrl_handle.publish(Channel::Liveness, data).await?,
293 Err(e) => error!("Failed to encode liveness message: {e:?}"),
294 },
295
296 Msg::PublishProposalPart(msg) => {
297 trace!(
298 stream_id = %msg.stream_id,
299 sequence = %msg.sequence,
300 "Broadcasting proposal part"
301 );
302
303 let data = self.codec.encode(&msg);
304 match data {
305 Ok(data) => ctrl_handle.publish(Channel::ProposalParts, data).await?,
306 Err(e) => error!("Failed to encode proposal part: {e:?}"),
307 }
308 }
309
310 Msg::BroadcastStatus(status) => {
311 let status = sync::Status {
312 peer_id: ctrl_handle.peer_id(),
313 tip_height: status.tip_height,
314 history_min_height: status.history_min_height,
315 };
316
317 let data = self.codec.encode(&status);
318 match data {
319 Ok(data) => ctrl_handle.broadcast(Channel::Sync, data).await?,
320 Err(e) => error!("Failed to encode status message: {e:?}"),
321 }
322 }
323
324 Msg::OutgoingRequest(peer_id, request, reply_to) => {
325 let request = self.codec.encode(&request);
326
327 match request {
328 Ok(data) => {
329 let p2p_request_id = ctrl_handle.sync_request(peer_id, data).await?;
330 reply_to.send(OutboundRequestId::new(p2p_request_id))?;
331 }
332 Err(e) => error!("Failed to encode request message: {e:?}"),
333 }
334 }
335
336 Msg::OutgoingResponse(request_id, response) => {
337 let response = self.codec.encode(&response);
338
339 match response {
340 Ok(data) => {
341 let request_id = inbound_requests
342 .remove(&request_id)
343 .ok_or_else(|| eyre!("Unknown inbound request ID: {request_id}"))?;
344
345 ctrl_handle.sync_reply(request_id, data).await?
346 }
347 Err(e) => {
348 error!(%request_id, "Failed to encode response message: {e:?}");
349 return Ok(());
350 }
351 };
352 }
353
354 Msg::NewEvent(Event::Listening(addr)) => {
355 listen_addrs.push(addr.clone());
356 output_port.send(NetworkEvent::Listening(addr));
357 }
358
359 Msg::NewEvent(Event::PeerConnected(peer_id)) => {
360 peers.insert(peer_id);
361 output_port.send(NetworkEvent::PeerConnected(peer_id));
362 }
363
364 Msg::NewEvent(Event::PeerDisconnected(peer_id)) => {
365 peers.remove(&peer_id);
366 output_port.send(NetworkEvent::PeerDisconnected(peer_id));
367 }
368
369 Msg::NewEvent(Event::LivenessMessage(Channel::Liveness, from, data)) => {
370 let msg = match self.codec.decode(data) {
371 Ok(msg) => msg,
372 Err(e) => {
373 error!(%from, "Failed to decode liveness message: {e:?}");
374 return Ok(());
375 }
376 };
377
378 let event = match msg {
379 LivenessMsg::PolkaCertificate(polka_cert) => {
380 NetworkEvent::PolkaCertificate(from, polka_cert)
381 }
382 LivenessMsg::SkipRoundCertificate(round_cert) => {
383 NetworkEvent::RoundCertificate(from, round_cert)
384 }
385 LivenessMsg::Vote(vote) => NetworkEvent::Vote(from, vote),
386 };
387
388 output_port.send(event);
389 }
390
391 Msg::NewEvent(Event::LivenessMessage(channel, from, _)) => {
392 error!(%from, "Unexpected liveness message on {channel} channel");
393 return Ok(());
394 }
395
396 Msg::NewEvent(Event::ConsensusMessage(Channel::Consensus, from, data)) => {
397 let msg = match self.codec.decode(data) {
398 Ok(msg) => msg,
399 Err(e) => {
400 error!(%from, "Failed to decode consensus message: {e:?}");
401 return Ok(());
402 }
403 };
404
405 let event = match msg {
406 SignedConsensusMsg::Vote(vote) => NetworkEvent::Vote(from, vote),
407 SignedConsensusMsg::Proposal(proposal) => {
408 NetworkEvent::Proposal(from, proposal)
409 }
410 };
411
412 output_port.send(event);
413 }
414
415 Msg::NewEvent(Event::ConsensusMessage(Channel::ProposalParts, from, data)) => {
416 let msg: StreamMessage<Ctx::ProposalPart> = match self.codec.decode(data) {
417 Ok(stream_msg) => stream_msg,
418 Err(e) => {
419 error!(%from, "Failed to decode stream message: {e:?}");
420 return Ok(());
421 }
422 };
423
424 trace!(
425 %from,
426 stream_id = %msg.stream_id,
427 sequence = %msg.sequence,
428 "Received proposal part"
429 );
430
431 output_port.send(NetworkEvent::ProposalPart(from, msg));
432 }
433
434 Msg::NewEvent(Event::ConsensusMessage(Channel::Sync, from, data)) => {
435 let status: sync::Status<Ctx> = match self.codec.decode(data) {
436 Ok(status) => status,
437 Err(e) => {
438 error!(%from, "Failed to decode status message: {e:?}");
439 return Ok(());
440 }
441 };
442
443 if from != status.peer_id {
444 error!(%from, %status.peer_id, "Mismatched peer ID in status message");
445 return Ok(());
446 }
447
448 trace!(%from, tip_height = %status.tip_height, "Received status");
449
450 output_port.send(NetworkEvent::Status(
451 status.peer_id,
452 Status::new(status.tip_height, status.history_min_height),
453 ));
454 }
455
456 Msg::NewEvent(Event::ConsensusMessage(channel, from, _)) => {
457 error!(%from, "Unexpected consensus message on {channel} channel");
458 return Ok(());
459 }
460
461 Msg::NewEvent(Event::Sync(raw_msg)) => match raw_msg {
462 RawMessage::Request {
463 request_id,
464 peer,
465 body,
466 } => {
467 let request = match self.codec.decode(body) {
468 Ok(request) => request,
469 Err(e) => {
470 error!(%peer, "Failed to decode sync request: {e:?}");
471 return Ok(());
472 }
473 };
474
475 inbound_requests.insert(InboundRequestId::new(request_id), request_id);
476
477 output_port.send(NetworkEvent::SyncRequest(
478 InboundRequestId::new(request_id),
479 peer,
480 request,
481 ));
482 }
483
484 RawMessage::Response {
485 request_id,
486 peer,
487 body,
488 } => {
489 let response = match self.codec.decode(body) {
490 Ok(response) => Some(response),
491 Err(e) => {
492 error!(%peer, "Failed to decode sync response: {e:?}");
493 None
494 }
495 };
496
497 output_port.send(NetworkEvent::SyncResponse(
498 OutboundRequestId::new(request_id),
499 peer,
500 response,
501 ));
502 }
503 },
504
505 Msg::UpdateValidatorSet(validator_set) => {
506 info!(
507 "Updating validator set: {} validators",
508 validator_set.count()
509 );
510 let validators: Vec<_> = validator_set
513 .iter()
514 .map(|v| malachitebft_network::ValidatorInfo {
515 address: v.address().to_string(),
516 voting_power: v.voting_power(),
517 })
518 .collect();
519 ctrl_handle.update_validator_set(validators).await?;
520 }
521
522 Msg::DumpState(_) => unreachable!("DumpState handled above to ensure a reply"),
523 Msg::UpdatePersistentPeers(_, _) => {
524 unreachable!("UpdatePersistentPeers handled above to ensure a reply")
525 }
526 }
527
528 Ok(())
529 }
530
531 async fn post_stop(
532 &self,
533 _myself: ActorRef<Msg<Ctx>>,
534 state: &mut State<Ctx>,
535 ) -> Result<(), ActorProcessingErr> {
536 let state = std::mem::replace(state, State::Stopped);
537
538 if let State::Running {
539 ctrl_handle,
540 recv_task,
541 ..
542 } = state
543 {
544 ctrl_handle.wait_shutdown().await?;
545 recv_task.await?;
546 }
547
548 Ok(())
549 }
550}
551
552async fn handle_dump_state<Ctx>(
553 state: &mut State<Ctx>,
554 reply_to: RpcReplyPort<Option<NetworkStateDump>>,
555) where
556 Ctx: Context,
557{
558 let dump = match state {
559 State::Stopped => {
560 info!("Dumping network state: not started");
561 None
562 }
563 State::Running { ctrl_handle, .. } => match ctrl_handle.dump_state().await {
564 Ok(snapshot) => Some(snapshot),
565 Err(error) => {
566 error!(%error, "Failed to obtain network dump");
567 None
568 }
569 },
570 };
571
572 if let Err(error) = reply_to.send(dump) {
573 error!(%error, "Failed to reply with network state dump");
574 }
575}
576
577async fn handle_update_persistent_peers<Ctx>(
578 state: &mut State<Ctx>,
579 op: PersistentPeersOp,
580 reply_to: RpcReplyPort<Result<(), PersistentPeerError>>,
581) where
582 Ctx: Context,
583{
584 fn log_result(result: &Result<(), PersistentPeerError>, op: &PersistentPeersOp) {
585 match result {
586 Ok(_) => match op {
587 PersistentPeersOp::Add(addr) => {
588 info!("Successfully added persistent peer: {addr}");
589 }
590 PersistentPeersOp::Remove(addr) => {
591 info!("Successfully removed persistent peer: {addr}");
592 }
593 },
594 Err(error) => {
595 error!(%error, "Failed to update persistent peers");
596 }
597 }
598 }
599
600 let result = match state {
601 State::Stopped => {
602 warn!("Cannot update persistent peers: network not started");
603 Err(PersistentPeerError::NetworkStopped)
604 }
605 State::Running { ctrl_handle, .. } => {
606 let op_result = match &op {
607 PersistentPeersOp::Add(addr) => ctrl_handle.add_persistent_peer(addr.clone()).await,
608 PersistentPeersOp::Remove(addr) => {
609 ctrl_handle.remove_persistent_peer(addr.clone()).await
610 }
611 };
612
613 op_result
614 .inspect(|res| log_result(res, &op))
615 .unwrap_or_else(|error| {
616 error!(%error, "Internal error: failed to update persistent peers");
617 Err(PersistentPeerError::InternalError(error.to_string()))
618 })
619 }
620 };
621
622 if let Err(error) = reply_to.send(result) {
623 error!(%error, "Failed to reply to UpdatePersistentPeers");
624 }
625}