Skip to main content

arc_malachitebft_engine/
network.rs

1use std::collections::{BTreeSet, HashMap};
2use std::marker::PhantomData;
3
4use async_trait::async_trait;
5use derive_where::derive_where;
6use eyre::eyre;
7use libp2p::request_response;
8use ractor::{Actor, ActorProcessingErr, ActorRef, RpcReplyPort};
9use tokio::task::JoinHandle;
10use tracing::{error, info, trace, warn};
11
12use malachitebft_codec as codec;
13use malachitebft_core_consensus::{LivenessMsg, SignedConsensusMsg};
14use malachitebft_core_types::{
15    Context, PolkaCertificate, RoundCertificate, SignedProposal, SignedVote, Validator,
16    ValidatorSet,
17};
18use malachitebft_metrics::SharedRegistry;
19use malachitebft_network::handle::CtrlHandle;
20use malachitebft_network::{Channel, Config, Event, PeerId};
21
22pub use malachitebft_network::{
23    Multiaddr, NetworkIdentity, NetworkStateDump, PersistentPeerError, PersistentPeersOp,
24};
25
26use malachitebft_sync::{
27    self as sync, InboundRequestId, OutboundRequestId, RawMessage, Request, Response,
28};
29
30use crate::consensus::ConsensusCodec;
31use crate::sync::SyncCodec;
32use crate::util::output_port::{OutputPort, OutputPortSubscriberTrait};
33use crate::util::streaming::StreamMessage;
34
35pub type NetworkRef<Ctx> = ActorRef<Msg<Ctx>>;
36pub type NetworkMsg<Ctx> = Msg<Ctx>;
37
38pub trait Subscriber<Msg>: OutputPortSubscriberTrait<Msg>
39where
40    Msg: Clone + ractor::Message,
41{
42    fn send(&self, msg: Msg);
43}
44
45impl<Msg, To> Subscriber<Msg> for ActorRef<To>
46where
47    Msg: Clone + ractor::Message,
48    To: From<Msg> + ractor::Message,
49{
50    fn send(&self, msg: Msg) {
51        if let Err(e) = self.cast(To::from(msg)) {
52            error!("Failed to send message to subscriber: {e:?}");
53        }
54    }
55}
56
57pub struct Network<Ctx, Codec> {
58    codec: Codec,
59    span: tracing::Span,
60    marker: PhantomData<Ctx>,
61}
62
63impl<Ctx, Codec> Network<Ctx, Codec> {
64    pub fn new(codec: Codec, span: tracing::Span) -> Self {
65        Self {
66            codec,
67            span,
68            marker: PhantomData,
69        }
70    }
71}
72
73impl<Ctx, Codec> Network<Ctx, Codec>
74where
75    Ctx: Context,
76    Codec: ConsensusCodec<Ctx>,
77    Codec: SyncCodec<Ctx>,
78    Codec: codec::HasEncodedLen<sync::Response<Ctx>>,
79{
80    pub async fn spawn(
81        identity: NetworkIdentity,
82        config: Config,
83        metrics: SharedRegistry,
84        codec: Codec,
85        span: tracing::Span,
86    ) -> Result<ActorRef<Msg<Ctx>>, ractor::SpawnErr> {
87        let args = Args {
88            identity,
89            config: config.clone(),
90            metrics,
91        };
92
93        let (actor_ref, _) = Actor::spawn(None, Self::new(codec, span), args).await?;
94        Ok(actor_ref)
95    }
96}
97
98pub struct Args {
99    pub identity: NetworkIdentity,
100    pub config: Config,
101    pub metrics: SharedRegistry,
102}
103
104#[derive_where(Clone, Debug, PartialEq, Eq)]
105pub enum NetworkEvent<Ctx: Context> {
106    Listening(Multiaddr),
107
108    PeerConnected(PeerId),
109    PeerDisconnected(PeerId),
110
111    Vote(PeerId, SignedVote<Ctx>),
112
113    Proposal(PeerId, SignedProposal<Ctx>),
114    ProposalPart(PeerId, StreamMessage<Ctx::ProposalPart>),
115
116    PolkaCertificate(PeerId, PolkaCertificate<Ctx>),
117
118    RoundCertificate(PeerId, RoundCertificate<Ctx>),
119
120    Status(PeerId, Status<Ctx>),
121
122    SyncRequest(InboundRequestId, PeerId, Request<Ctx>),
123    SyncResponse(OutboundRequestId, PeerId, Option<Response<Ctx>>),
124}
125
126pub enum State<Ctx: Context> {
127    Stopped,
128    Running {
129        listen_addrs: Vec<Multiaddr>,
130        peers: BTreeSet<PeerId>,
131        output_port: OutputPort<NetworkEvent<Ctx>>,
132        ctrl_handle: Box<CtrlHandle>,
133        recv_task: JoinHandle<()>,
134        inbound_requests: HashMap<InboundRequestId, request_response::InboundRequestId>,
135    },
136}
137
138#[derive_where(Clone, Debug, PartialEq, Eq)]
139pub struct Status<Ctx: Context> {
140    pub tip_height: Ctx::Height,
141    pub history_min_height: Ctx::Height,
142}
143
144impl<Ctx: Context> Status<Ctx> {
145    pub fn new(tip_height: Ctx::Height, history_min_height: Ctx::Height) -> Self {
146        Self {
147            tip_height,
148            history_min_height,
149        }
150    }
151}
152
153pub enum Msg<Ctx: Context> {
154    /// Subscribe this actor to receive gossip events
155    Subscribe(Box<dyn Subscriber<NetworkEvent<Ctx>>>),
156
157    /// Publish a signed consensus message
158    PublishConsensusMsg(SignedConsensusMsg<Ctx>),
159
160    /// Publish a liveness message
161    PublishLivenessMsg(LivenessMsg<Ctx>),
162
163    /// Publish a proposal part
164    PublishProposalPart(StreamMessage<Ctx::ProposalPart>),
165
166    /// Broadcast status to all direct peers
167    BroadcastStatus(Status<Ctx>),
168
169    /// Send a request to a peer, returning the outbound request ID
170    OutgoingRequest(PeerId, Request<Ctx>, RpcReplyPort<OutboundRequestId>),
171
172    /// Send a response for a request to a peer
173    OutgoingResponse(InboundRequestId, Response<Ctx>),
174
175    /// Request to dump the current network state
176    DumpState(RpcReplyPort<Option<NetworkStateDump>>),
177
178    /// Add or remove a persistent peer at runtime
179    UpdatePersistentPeers(
180        PersistentPeersOp,
181        RpcReplyPort<Result<(), PersistentPeerError>>,
182    ),
183
184    /// Update the validator set for the current height
185    UpdateValidatorSet(Ctx::ValidatorSet),
186
187    // Event emitted by the gossip layer
188    #[doc(hidden)]
189    NewEvent(Event),
190}
191
192#[async_trait]
193impl<Ctx, Codec> Actor for Network<Ctx, Codec>
194where
195    Ctx: Context,
196    Codec: Send + Sync + 'static,
197    Codec: codec::Codec<Ctx::ProposalPart>,
198    Codec: codec::Codec<SignedConsensusMsg<Ctx>>,
199    Codec: codec::Codec<StreamMessage<Ctx::ProposalPart>>,
200    Codec: codec::Codec<LivenessMsg<Ctx>>,
201    Codec: SyncCodec<Ctx>,
202{
203    type Msg = Msg<Ctx>;
204    type State = State<Ctx>;
205    type Arguments = Args;
206
207    async fn pre_start(
208        &self,
209        myself: ActorRef<Msg<Ctx>>,
210        args: Args,
211    ) -> Result<Self::State, ActorProcessingErr> {
212        let handle = malachitebft_network::spawn(args.identity, args.config, args.metrics).await?;
213
214        let (mut recv_handle, ctrl_handle) = handle.split();
215
216        let recv_task = tokio::spawn(async move {
217            while let Some(event) = recv_handle.recv().await {
218                if let Err(e) = myself.cast(Msg::NewEvent(event)) {
219                    error!("Actor has died, stopping network: {e:?}");
220                    break;
221                }
222            }
223        });
224
225        Ok(State::Running {
226            listen_addrs: Vec::new(),
227            peers: BTreeSet::new(),
228            output_port: OutputPort::with_capacity(128),
229            ctrl_handle: Box::new(ctrl_handle),
230            recv_task,
231            inbound_requests: HashMap::new(),
232        })
233    }
234
235    async fn post_start(
236        &self,
237        _myself: ActorRef<Msg<Ctx>>,
238        _state: &mut State<Ctx>,
239    ) -> Result<(), ActorProcessingErr> {
240        Ok(())
241    }
242
243    #[tracing::instrument(name = "network", parent = &self.span, skip_all)]
244    async fn handle(
245        &self,
246        _myself: ActorRef<Msg<Ctx>>,
247        msg: Msg<Ctx>,
248        state: &mut State<Ctx>,
249    ) -> Result<(), ActorProcessingErr> {
250        // We need to handle before deconstructing `state` to always reply.
251        if let Msg::DumpState(reply_to) = msg {
252            handle_dump_state(state, reply_to).await;
253            return Ok(());
254        }
255
256        if let Msg::UpdatePersistentPeers(op, reply_to) = msg {
257            handle_update_persistent_peers(state, op, reply_to).await;
258            return Ok(());
259        }
260
261        let State::Running {
262            listen_addrs,
263            peers,
264            output_port,
265            ctrl_handle,
266            inbound_requests,
267            ..
268        } = state
269        else {
270            return Ok(());
271        };
272
273        match msg {
274            Msg::Subscribe(subscriber) => {
275                for addr in listen_addrs.iter() {
276                    subscriber.send(NetworkEvent::Listening(addr.clone()));
277                }
278
279                for peer in peers.iter() {
280                    subscriber.send(NetworkEvent::PeerConnected(*peer));
281                }
282
283                subscriber.subscribe_to_port(output_port);
284            }
285
286            Msg::PublishConsensusMsg(msg) => match self.codec.encode(&msg) {
287                Ok(data) => ctrl_handle.publish(Channel::Consensus, data).await?,
288                Err(e) => error!("Failed to encode consensus message: {e:?}"),
289            },
290
291            Msg::PublishLivenessMsg(msg) => match self.codec.encode(&msg) {
292                Ok(data) => ctrl_handle.publish(Channel::Liveness, data).await?,
293                Err(e) => error!("Failed to encode liveness message: {e:?}"),
294            },
295
296            Msg::PublishProposalPart(msg) => {
297                trace!(
298                    stream_id = %msg.stream_id,
299                    sequence = %msg.sequence,
300                    "Broadcasting proposal part"
301                );
302
303                let data = self.codec.encode(&msg);
304                match data {
305                    Ok(data) => ctrl_handle.publish(Channel::ProposalParts, data).await?,
306                    Err(e) => error!("Failed to encode proposal part: {e:?}"),
307                }
308            }
309
310            Msg::BroadcastStatus(status) => {
311                let status = sync::Status {
312                    peer_id: ctrl_handle.peer_id(),
313                    tip_height: status.tip_height,
314                    history_min_height: status.history_min_height,
315                };
316
317                let data = self.codec.encode(&status);
318                match data {
319                    Ok(data) => ctrl_handle.broadcast(Channel::Sync, data).await?,
320                    Err(e) => error!("Failed to encode status message: {e:?}"),
321                }
322            }
323
324            Msg::OutgoingRequest(peer_id, request, reply_to) => {
325                let request = self.codec.encode(&request);
326
327                match request {
328                    Ok(data) => {
329                        let p2p_request_id = ctrl_handle.sync_request(peer_id, data).await?;
330                        reply_to.send(OutboundRequestId::new(p2p_request_id))?;
331                    }
332                    Err(e) => error!("Failed to encode request message: {e:?}"),
333                }
334            }
335
336            Msg::OutgoingResponse(request_id, response) => {
337                let response = self.codec.encode(&response);
338
339                match response {
340                    Ok(data) => {
341                        let request_id = inbound_requests
342                            .remove(&request_id)
343                            .ok_or_else(|| eyre!("Unknown inbound request ID: {request_id}"))?;
344
345                        ctrl_handle.sync_reply(request_id, data).await?
346                    }
347                    Err(e) => {
348                        error!(%request_id, "Failed to encode response message: {e:?}");
349                        return Ok(());
350                    }
351                };
352            }
353
354            Msg::NewEvent(Event::Listening(addr)) => {
355                listen_addrs.push(addr.clone());
356                output_port.send(NetworkEvent::Listening(addr));
357            }
358
359            Msg::NewEvent(Event::PeerConnected(peer_id)) => {
360                peers.insert(peer_id);
361                output_port.send(NetworkEvent::PeerConnected(peer_id));
362            }
363
364            Msg::NewEvent(Event::PeerDisconnected(peer_id)) => {
365                peers.remove(&peer_id);
366                output_port.send(NetworkEvent::PeerDisconnected(peer_id));
367            }
368
369            Msg::NewEvent(Event::LivenessMessage(Channel::Liveness, from, data)) => {
370                let msg = match self.codec.decode(data) {
371                    Ok(msg) => msg,
372                    Err(e) => {
373                        error!(%from, "Failed to decode liveness message: {e:?}");
374                        return Ok(());
375                    }
376                };
377
378                let event = match msg {
379                    LivenessMsg::PolkaCertificate(polka_cert) => {
380                        NetworkEvent::PolkaCertificate(from, polka_cert)
381                    }
382                    LivenessMsg::SkipRoundCertificate(round_cert) => {
383                        NetworkEvent::RoundCertificate(from, round_cert)
384                    }
385                    LivenessMsg::Vote(vote) => NetworkEvent::Vote(from, vote),
386                };
387
388                output_port.send(event);
389            }
390
391            Msg::NewEvent(Event::LivenessMessage(channel, from, _)) => {
392                error!(%from, "Unexpected liveness message on {channel} channel");
393                return Ok(());
394            }
395
396            Msg::NewEvent(Event::ConsensusMessage(Channel::Consensus, from, data)) => {
397                let msg = match self.codec.decode(data) {
398                    Ok(msg) => msg,
399                    Err(e) => {
400                        error!(%from, "Failed to decode consensus message: {e:?}");
401                        return Ok(());
402                    }
403                };
404
405                let event = match msg {
406                    SignedConsensusMsg::Vote(vote) => NetworkEvent::Vote(from, vote),
407                    SignedConsensusMsg::Proposal(proposal) => {
408                        NetworkEvent::Proposal(from, proposal)
409                    }
410                };
411
412                output_port.send(event);
413            }
414
415            Msg::NewEvent(Event::ConsensusMessage(Channel::ProposalParts, from, data)) => {
416                let msg: StreamMessage<Ctx::ProposalPart> = match self.codec.decode(data) {
417                    Ok(stream_msg) => stream_msg,
418                    Err(e) => {
419                        error!(%from, "Failed to decode stream message: {e:?}");
420                        return Ok(());
421                    }
422                };
423
424                trace!(
425                    %from,
426                    stream_id = %msg.stream_id,
427                    sequence = %msg.sequence,
428                    "Received proposal part"
429                );
430
431                output_port.send(NetworkEvent::ProposalPart(from, msg));
432            }
433
434            Msg::NewEvent(Event::ConsensusMessage(Channel::Sync, from, data)) => {
435                let status: sync::Status<Ctx> = match self.codec.decode(data) {
436                    Ok(status) => status,
437                    Err(e) => {
438                        error!(%from, "Failed to decode status message: {e:?}");
439                        return Ok(());
440                    }
441                };
442
443                if from != status.peer_id {
444                    error!(%from, %status.peer_id, "Mismatched peer ID in status message");
445                    return Ok(());
446                }
447
448                trace!(%from, tip_height = %status.tip_height, "Received status");
449
450                output_port.send(NetworkEvent::Status(
451                    status.peer_id,
452                    Status::new(status.tip_height, status.history_min_height),
453                ));
454            }
455
456            Msg::NewEvent(Event::ConsensusMessage(channel, from, _)) => {
457                error!(%from, "Unexpected consensus message on {channel} channel");
458                return Ok(());
459            }
460
461            Msg::NewEvent(Event::Sync(raw_msg)) => match raw_msg {
462                RawMessage::Request {
463                    request_id,
464                    peer,
465                    body,
466                } => {
467                    let request = match self.codec.decode(body) {
468                        Ok(request) => request,
469                        Err(e) => {
470                            error!(%peer, "Failed to decode sync request: {e:?}");
471                            return Ok(());
472                        }
473                    };
474
475                    inbound_requests.insert(InboundRequestId::new(request_id), request_id);
476
477                    output_port.send(NetworkEvent::SyncRequest(
478                        InboundRequestId::new(request_id),
479                        peer,
480                        request,
481                    ));
482                }
483
484                RawMessage::Response {
485                    request_id,
486                    peer,
487                    body,
488                } => {
489                    let response = match self.codec.decode(body) {
490                        Ok(response) => Some(response),
491                        Err(e) => {
492                            error!(%peer, "Failed to decode sync response: {e:?}");
493                            None
494                        }
495                    };
496
497                    output_port.send(NetworkEvent::SyncResponse(
498                        OutboundRequestId::new(request_id),
499                        peer,
500                        response,
501                    ));
502                }
503            },
504
505            Msg::UpdateValidatorSet(validator_set) => {
506                info!(
507                    "Updating validator set: {} validators",
508                    validator_set.count()
509                );
510                // Convert ValidatorSet to Vec<ValidatorInfo>
511                // Note: We don't pass the Ctx to the network layer
512                let validators: Vec<_> = validator_set
513                    .iter()
514                    .map(|v| malachitebft_network::ValidatorInfo {
515                        address: v.address().to_string(),
516                        voting_power: v.voting_power(),
517                    })
518                    .collect();
519                ctrl_handle.update_validator_set(validators).await?;
520            }
521
522            Msg::DumpState(_) => unreachable!("DumpState handled above to ensure a reply"),
523            Msg::UpdatePersistentPeers(_, _) => {
524                unreachable!("UpdatePersistentPeers handled above to ensure a reply")
525            }
526        }
527
528        Ok(())
529    }
530
531    async fn post_stop(
532        &self,
533        _myself: ActorRef<Msg<Ctx>>,
534        state: &mut State<Ctx>,
535    ) -> Result<(), ActorProcessingErr> {
536        let state = std::mem::replace(state, State::Stopped);
537
538        if let State::Running {
539            ctrl_handle,
540            recv_task,
541            ..
542        } = state
543        {
544            ctrl_handle.wait_shutdown().await?;
545            recv_task.await?;
546        }
547
548        Ok(())
549    }
550}
551
552async fn handle_dump_state<Ctx>(
553    state: &mut State<Ctx>,
554    reply_to: RpcReplyPort<Option<NetworkStateDump>>,
555) where
556    Ctx: Context,
557{
558    let dump = match state {
559        State::Stopped => {
560            info!("Dumping network state: not started");
561            None
562        }
563        State::Running { ctrl_handle, .. } => match ctrl_handle.dump_state().await {
564            Ok(snapshot) => Some(snapshot),
565            Err(error) => {
566                error!(%error, "Failed to obtain network dump");
567                None
568            }
569        },
570    };
571
572    if let Err(error) = reply_to.send(dump) {
573        error!(%error, "Failed to reply with network state dump");
574    }
575}
576
577async fn handle_update_persistent_peers<Ctx>(
578    state: &mut State<Ctx>,
579    op: PersistentPeersOp,
580    reply_to: RpcReplyPort<Result<(), PersistentPeerError>>,
581) where
582    Ctx: Context,
583{
584    fn log_result(result: &Result<(), PersistentPeerError>, op: &PersistentPeersOp) {
585        match result {
586            Ok(_) => match op {
587                PersistentPeersOp::Add(addr) => {
588                    info!("Successfully added persistent peer: {addr}");
589                }
590                PersistentPeersOp::Remove(addr) => {
591                    info!("Successfully removed persistent peer: {addr}");
592                }
593            },
594            Err(error) => {
595                error!(%error, "Failed to update persistent peers");
596            }
597        }
598    }
599
600    let result = match state {
601        State::Stopped => {
602            warn!("Cannot update persistent peers: network not started");
603            Err(PersistentPeerError::NetworkStopped)
604        }
605        State::Running { ctrl_handle, .. } => {
606            let op_result = match &op {
607                PersistentPeersOp::Add(addr) => ctrl_handle.add_persistent_peer(addr.clone()).await,
608                PersistentPeersOp::Remove(addr) => {
609                    ctrl_handle.remove_persistent_peer(addr.clone()).await
610                }
611            };
612
613            op_result
614                .inspect(|res| log_result(res, &op))
615                .unwrap_or_else(|error| {
616                    error!(%error, "Internal error: failed to update persistent peers");
617                    Err(PersistentPeerError::InternalError(error.to_string()))
618                })
619        }
620    };
621
622    if let Err(error) = reply_to.send(result) {
623        error!(%error, "Failed to reply to UpdatePersistentPeers");
624    }
625}