commonware_p2p/simulated/
network.rs

1//! Implementation of a simulated p2p network.
2
3use super::{
4    ingress::{self, Oracle},
5    metrics,
6    transmitter::{self, Completion},
7    Error,
8};
9use crate::{Channel, Message, Recipients};
10use bytes::Bytes;
11use commonware_codec::{DecodeExt, FixedSize};
12use commonware_cryptography::PublicKey;
13use commonware_macros::select;
14use commonware_runtime::{
15    spawn_cell, Clock, ContextCell, Handle, Listener as _, Metrics, Network as RNetwork, Spawner,
16};
17use commonware_stream::utils::codec::{recv_frame, send_frame};
18use commonware_utils::set::Ordered;
19use either::Either;
20use futures::{
21    channel::{mpsc, oneshot},
22    future, SinkExt, StreamExt,
23};
24use prometheus_client::metrics::{counter::Counter, family::Family};
25use rand::Rng;
26use rand_distr::{Distribution, Normal};
27use std::{
28    collections::{BTreeMap, HashMap, HashSet},
29    net::{IpAddr, Ipv4Addr, SocketAddr},
30    time::{Duration, SystemTime},
31};
32use tracing::{debug, error, trace, warn};
33
34/// Task type representing a message to be sent within the network.
35type Task<P> = (Channel, P, Recipients<P>, Bytes, oneshot::Sender<Vec<P>>);
36
37/// Configuration for the simulated network.
38pub struct Config {
39    /// Maximum size of a message that can be sent over the network.
40    pub max_size: usize,
41
42    /// True if peers should disconnect upon being blocked. While production networking would
43    /// typically disconnect, for testing purposes it may be useful to keep peers connected,
44    /// allowing byzantine actors the ability to continue sending messages.
45    pub disconnect_on_block: bool,
46
47    /// The maximum number of peer sets to track. When a new peer set is registered and this
48    /// limit is exceeded, the oldest peer set is removed. Peers that are no longer in any
49    /// tracked peer set will have their links removed and messages to them will be dropped.
50    ///
51    /// If [None], peer sets are not considered.
52    pub tracked_peer_sets: Option<usize>,
53}
54
55/// Implementation of a simulated network.
56pub struct Network<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> {
57    context: ContextCell<E>,
58
59    // Maximum size of a message that can be sent over the network
60    max_size: usize,
61
62    // True if peers should disconnect upon being blocked.
63    // While production networking would typically disconnect, for testing purposes it may be useful
64    // to keep peers connected, allowing byzantine actors the ability to continue sending messages.
65    disconnect_on_block: bool,
66
67    // Next socket address to assign to a new peer
68    // Incremented for each new peer
69    next_addr: SocketAddr,
70
71    // Channel to receive messages from the oracle
72    ingress: mpsc::UnboundedReceiver<ingress::Message<P>>,
73
74    // A channel to receive tasks from peers
75    // The sender is cloned and given to each peer
76    // The receiver is polled in the main loop
77    sender: mpsc::UnboundedSender<Task<P>>,
78    receiver: mpsc::UnboundedReceiver<Task<P>>,
79
80    // A map from a pair of public keys (from, to) to a link between the two peers
81    links: HashMap<(P, P), Link>,
82
83    // A map from a public key to a peer
84    peers: BTreeMap<P, Peer<P>>,
85
86    // Peer sets indexed by their ID
87    peer_sets: BTreeMap<u64, Ordered<P>>,
88
89    // Reference count for each peer (number of peer sets they belong to)
90    peer_refs: BTreeMap<P, usize>,
91
92    // Maximum number of peer sets to track
93    tracked_peer_sets: Option<usize>,
94
95    // A map of peers blocking each other
96    blocks: HashSet<(P, P)>,
97
98    // State of the transmitter
99    transmitter: transmitter::State<P>,
100
101    // Subscribers to peer set updates
102    #[allow(clippy::type_complexity)]
103    subscribers: Vec<mpsc::UnboundedSender<(u64, Ordered<P>, Ordered<P>)>>,
104
105    // Metrics for received and sent messages
106    received_messages: Family<metrics::Message, Counter>,
107    sent_messages: Family<metrics::Message, Counter>,
108}
109
110impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> Network<E, P> {
111    /// Create a new simulated network with a given runtime and configuration.
112    ///
113    /// Returns a tuple containing the network instance and the oracle that can
114    /// be used to modify the state of the network during context.
115    pub fn new(mut context: E, cfg: Config) -> (Self, Oracle<P>) {
116        let (sender, receiver) = mpsc::unbounded();
117        let (oracle_sender, oracle_receiver) = mpsc::unbounded();
118        let sent_messages = Family::<metrics::Message, Counter>::default();
119        let received_messages = Family::<metrics::Message, Counter>::default();
120        context.register("messages_sent", "messages sent", sent_messages.clone());
121        context.register(
122            "messages_received",
123            "messages received",
124            received_messages.clone(),
125        );
126
127        // Start with a pseudo-random IP address to assign sockets to for new peers
128        let next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(context.next_u32())), 0);
129        (
130            Self {
131                context: ContextCell::new(context),
132                max_size: cfg.max_size,
133                disconnect_on_block: cfg.disconnect_on_block,
134                tracked_peer_sets: cfg.tracked_peer_sets,
135                next_addr,
136                ingress: oracle_receiver,
137                sender,
138                receiver,
139                links: HashMap::new(),
140                peers: BTreeMap::new(),
141                peer_sets: BTreeMap::new(),
142                peer_refs: BTreeMap::new(),
143                blocks: HashSet::new(),
144                transmitter: transmitter::State::new(),
145                subscribers: Vec::new(),
146                received_messages,
147                sent_messages,
148            },
149            Oracle::new(oracle_sender.clone()),
150        )
151    }
152
153    /// Returns (and increments) the next available socket address.
154    ///
155    /// The port number is incremented for each call, and the IP address is incremented if the port
156    /// number overflows.
157    fn get_next_socket(&mut self) -> SocketAddr {
158        let result = self.next_addr;
159
160        // Increment the port number, or the IP address if the port number overflows.
161        // Allows the ip address to overflow (wrapping).
162        match self.next_addr.port().checked_add(1) {
163            Some(port) => {
164                self.next_addr.set_port(port);
165            }
166            None => {
167                let ip = match self.next_addr.ip() {
168                    IpAddr::V4(ipv4) => ipv4,
169                    _ => unreachable!(),
170                };
171                let next_ip = Ipv4Addr::to_bits(ip).wrapping_add(1);
172                self.next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(next_ip)), 0);
173            }
174        }
175
176        result
177    }
178
179    /// Handle an ingress message.
180    ///
181    /// This method is called when a message is received from the oracle.
182    async fn handle_ingress(&mut self, message: ingress::Message<P>) {
183        // It is important to ensure that no failed receipt of a message will cause us to exit.
184        // This could happen if the caller drops the `Oracle` after updating the network topology.
185        // Thus, we create a helper function to send the result to the oracle and log any errors.
186        fn send_result<T: std::fmt::Debug>(
187            result: oneshot::Sender<Result<T, Error>>,
188            value: Result<T, Error>,
189        ) {
190            let success = value.is_ok();
191            if let Err(e) = result.send(value) {
192                error!(?e, "failed to send result to oracle (ok = {})", success);
193            }
194        }
195
196        match message {
197            ingress::Message::Update { peer_set, peers } => {
198                let Some(tracked_peer_sets) = self.tracked_peer_sets else {
199                    warn!("attempted to register peer set when tracking is disabled");
200                    return;
201                };
202
203                // Check if peer set already exists
204                if self.peer_sets.contains_key(&peer_set) {
205                    warn!(index = peer_set, "peer set already exists");
206                    return;
207                }
208
209                // Ensure that peer set is monotonically increasing
210                if let Some((last, _)) = self.peer_sets.last_key_value() {
211                    if peer_set <= *last {
212                        warn!(
213                            new_id = peer_set,
214                            old_id = last,
215                            "attempted to register peer set with non-monotonically increasing ID"
216                        );
217                        return;
218                    }
219                }
220
221                // Create and store new peer set
222                for public_key in peers.iter() {
223                    // Create peer if it doesn't exist
224                    if !self.peers.contains_key(public_key) {
225                        let peer = Peer::new(
226                            self.context.with_label("peer"),
227                            public_key.clone(),
228                            self.get_next_socket(),
229                            self.max_size,
230                        );
231                        self.peers.insert(public_key.clone(), peer);
232                    }
233
234                    // Increment reference count
235                    *self.peer_refs.entry(public_key.clone()).or_insert(0) += 1;
236                }
237                self.peer_sets.insert(peer_set, peers.clone());
238
239                // Remove oldest peer set if we exceed the limit
240                while self.peer_sets.len() > tracked_peer_sets {
241                    let (index, set) = self.peer_sets.pop_first().unwrap();
242                    debug!(index, "removed oldest peer set");
243
244                    // Decrement reference counts and clean up peers/links
245                    for public_key in set.iter() {
246                        let refs = self.peer_refs.get_mut(public_key).unwrap();
247                        *refs = refs.checked_sub(1).expect("reference count underflow");
248
249                        // If peer is no longer in any tracked set, remove it
250                        if *refs == 0 {
251                            self.peer_refs.remove(public_key);
252                            self.peers.remove(public_key);
253
254                            debug!(?public_key, "removed peer no longer in any tracked set");
255                        }
256                    }
257                }
258
259                // Notify all subscribers about the new peer set
260                let all = self.peer_refs.keys().cloned().collect();
261                let notification = (peer_set, peers, all);
262                self.subscribers
263                    .retain(|subscriber| subscriber.unbounded_send(notification.clone()).is_ok());
264            }
265            ingress::Message::Register {
266                channel,
267                public_key,
268                result,
269            } => {
270                // If peer does not exist, then create it.
271                if !self.peers.contains_key(&public_key) {
272                    let peer = Peer::new(
273                        self.context.with_label("peer"),
274                        public_key.clone(),
275                        self.get_next_socket(),
276                        self.max_size,
277                    );
278                    self.peers.insert(public_key.clone(), peer);
279                }
280
281                // Create a receiver that allows receiving messages from the network for a certain channel
282                let peer = self.peers.get_mut(&public_key).unwrap();
283                let receiver = match peer.register(channel).await {
284                    Ok(receiver) => Receiver { receiver },
285                    Err(err) => return send_result(result, Err(err)),
286                };
287
288                // Create a sender that allows sending messages to the network for a certain channel
289                let sender = Sender::new(
290                    self.context.with_label("sender"),
291                    public_key,
292                    channel,
293                    self.max_size,
294                    self.sender.clone(),
295                );
296                send_result(result, Ok((sender, receiver)))
297            }
298            ingress::Message::PeerSet { index, response } => {
299                if self.peer_sets.is_empty() {
300                    // Return all peers if no peer sets are registered.
301                    let _ = response.send(Some(self.peers.keys().cloned().collect()));
302                } else {
303                    // Return the peer set at the given index
304                    let _ = response.send(self.peer_sets.get(&index).cloned());
305                }
306            }
307            ingress::Message::Subscribe { response } => {
308                // Create a new subscription channel
309                let (sender, receiver) = mpsc::unbounded();
310
311                // Send the latest peer set upon subscription
312                if let Some((index, peers)) = self.peer_sets.last_key_value() {
313                    let all = self.peer_refs.keys().cloned().collect();
314                    let notification = (*index, peers.clone(), all);
315                    let _ = sender.unbounded_send(notification);
316                }
317                self.subscribers.push(sender);
318
319                // Return the receiver to the caller
320                let _ = response.send(receiver);
321            }
322            ingress::Message::LimitBandwidth {
323                public_key,
324                egress_cap,
325                ingress_cap,
326                result,
327            } => match self.peers.contains_key(&public_key) {
328                true => {
329                    // Update bandwidth limits
330                    let now = self.context.current();
331                    let completions =
332                        self.transmitter
333                            .limit(now, &public_key, egress_cap, ingress_cap);
334                    self.process_completions(completions);
335
336                    // Alert application of update
337                    send_result(result, Ok(()));
338                }
339                false => send_result(result, Err(Error::PeerMissing)),
340            },
341            ingress::Message::AddLink {
342                sender,
343                receiver,
344                sampler,
345                success_rate,
346                result,
347            } => {
348                // Require both peers to be registered
349                if !self.peers.contains_key(&sender) {
350                    return send_result(result, Err(Error::PeerMissing));
351                }
352                let peer = match self.peers.get(&receiver) {
353                    Some(peer) => peer,
354                    None => return send_result(result, Err(Error::PeerMissing)),
355                };
356
357                // Require link to not already exist
358                let key = (sender.clone(), receiver.clone());
359                if self.links.contains_key(&key) {
360                    return send_result(result, Err(Error::LinkExists));
361                }
362
363                let link = Link::new(
364                    &mut self.context,
365                    sender,
366                    receiver,
367                    peer.socket,
368                    sampler,
369                    success_rate,
370                    self.max_size,
371                    self.received_messages.clone(),
372                );
373                self.links.insert(key, link);
374                send_result(result, Ok(()))
375            }
376            ingress::Message::RemoveLink {
377                sender,
378                receiver,
379                result,
380            } => {
381                match self.links.remove(&(sender, receiver)) {
382                    Some(_) => (),
383                    None => return send_result(result, Err(Error::LinkMissing)),
384                }
385                send_result(result, Ok(()))
386            }
387            ingress::Message::Block { from, to } => {
388                self.blocks.insert((from, to));
389            }
390            ingress::Message::Blocked { result } => {
391                send_result(result, Ok(self.blocks.iter().cloned().collect()))
392            }
393        }
394    }
395}
396
397impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> Network<E, P> {
398    /// Process completions from the transmitter.
399    fn process_completions(&mut self, completions: Vec<Completion<P>>) {
400        for completion in completions {
401            // If there is no message to deliver, then skip
402            let Some(deliver_at) = completion.deliver_at else {
403                trace!(
404                    origin = ?completion.origin,
405                    recipient = ?completion.recipient,
406                    "message dropped before delivery",
407                );
408                continue;
409            };
410
411            // Send message to link
412            let key = (completion.origin.clone(), completion.recipient.clone());
413            let Some(link) = self.links.get_mut(&key) else {
414                // This can happen if the link is removed before the message is delivered
415                trace!(
416                    origin = ?completion.origin,
417                    recipient = ?completion.recipient,
418                    "missing link for completion",
419                );
420                continue;
421            };
422            if let Err(err) = link.send(completion.channel, completion.message, deliver_at) {
423                error!(?err, "failed to send");
424            }
425        }
426    }
427
428    /// Handle a task.
429    ///
430    /// This method is called when a task is received from the sender, which can come from
431    /// any peer in the network.
432    fn handle_task(&mut self, task: Task<P>) {
433        // If peer sets are enabled and we are not in one, ignore the message (we are disconnected from all)
434        let (channel, origin, recipients, message, reply) = task;
435        if self.tracked_peer_sets.is_some() && !self.peer_refs.contains_key(&origin) {
436            warn!(
437                ?origin,
438                reason = "not in tracked peer set",
439                "dropping message"
440            );
441            return;
442        }
443
444        // Collect recipients
445        let recipients = match recipients {
446            Recipients::All => {
447                // If peer sets have been registered, send only to tracked peers
448                // Otherwise, send to all registered peers (compatibility
449                // with tests that do not register peer sets.)
450                if self.peer_sets.is_empty() {
451                    self.peers.keys().cloned().collect()
452                } else {
453                    self.peer_refs.keys().cloned().collect()
454                }
455            }
456            Recipients::Some(keys) => keys,
457            Recipients::One(key) => vec![key],
458        };
459
460        // Send to all recipients
461        let now = self.context.current();
462        let mut sent = Vec::new();
463        for recipient in recipients {
464            // Skip self
465            if recipient == origin {
466                trace!(?recipient, reason = "self", "dropping message");
467                continue;
468            }
469
470            // If tracking peer sets, ensure recipient and sender are in a tracked peer set
471            if self.tracked_peer_sets.is_some() && !self.peer_refs.contains_key(&recipient) {
472                trace!(
473                    ?origin,
474                    ?recipient,
475                    reason = "not in tracked peer set",
476                    "dropping message"
477                );
478                continue;
479            }
480
481            // Determine if the sender or recipient has blocked the other
482            let o_r = (origin.clone(), recipient.clone());
483            let r_o = (recipient.clone(), origin.clone());
484            if self.disconnect_on_block
485                && (self.blocks.contains(&o_r) || self.blocks.contains(&r_o))
486            {
487                trace!(?origin, ?recipient, reason = "blocked", "dropping message");
488                continue;
489            }
490
491            // Determine if there is a link between the origin and recipient
492            let Some(link) = self.links.get_mut(&o_r) else {
493                trace!(?origin, ?recipient, reason = "no link", "dropping message");
494                continue;
495            };
496
497            // Record sent message as soon as we determine there is a link with recipient (approximates
498            // having an open connection)
499            self.sent_messages
500                .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
501                .inc();
502
503            // Sample latency
504            let latency = Duration::from_millis(link.sampler.sample(&mut self.context) as u64);
505
506            // Determine if the message should be delivered
507            let should_deliver = self.context.gen_bool(link.success_rate);
508
509            // Enqueue message for delivery
510            let completions = self.transmitter.enqueue(
511                now,
512                origin.clone(),
513                recipient.clone(),
514                channel,
515                message.clone(),
516                latency,
517                should_deliver,
518            );
519            self.process_completions(completions);
520
521            sent.push(recipient);
522        }
523
524        // Alert application of sent messages
525        if let Err(err) = reply.send(sent) {
526            error!(?err, "failed to send ack");
527        }
528    }
529
530    /// Run the simulated network.
531    ///
532    /// It is not necessary to invoke this method before modifying the network topology, however,
533    /// no messages will be sent until this method is called.
534    pub fn start(mut self) -> Handle<()> {
535        spawn_cell!(self.context, self.run().await)
536    }
537
538    async fn run(mut self) {
539        loop {
540            let tick = match self.transmitter.next() {
541                Some(when) => Either::Left(self.context.sleep_until(when)),
542                None => Either::Right(future::pending()),
543            };
544            select! {
545                _ = tick => {
546                    let now = self.context.current();
547                    let completions = self.transmitter.advance(now);
548                    self.process_completions(completions);
549                },
550                message = self.ingress.next() => {
551                    // If ingress is closed, exit
552                    let message = match message {
553                        Some(message) => message,
554                        None => break,
555                    };
556                    self.handle_ingress(message).await;
557                },
558                task = self.receiver.next() => {
559                    // If receiver is closed, exit
560                    let task = match task {
561                        Some(task) => task,
562                        None => break,
563                    };
564                    self.handle_task(task);
565                },
566            }
567        }
568    }
569}
570
571/// Implementation of a [crate::Sender] for the simulated network.
572#[derive(Clone, Debug)]
573pub struct Sender<P: PublicKey> {
574    me: P,
575    channel: Channel,
576    max_size: usize,
577    high: mpsc::UnboundedSender<Task<P>>,
578    low: mpsc::UnboundedSender<Task<P>>,
579}
580
581impl<P: PublicKey> Sender<P> {
582    fn new(
583        context: impl Spawner + Metrics,
584        me: P,
585        channel: Channel,
586        max_size: usize,
587        mut sender: mpsc::UnboundedSender<Task<P>>,
588    ) -> Self {
589        // Listen for messages
590        let (high, mut high_receiver) = mpsc::unbounded();
591        let (low, mut low_receiver) = mpsc::unbounded();
592        context.with_label("sender").spawn(move |_| async move {
593            loop {
594                // Wait for task
595                let task;
596                select! {
597                    high_task = high_receiver.next() => {
598                        task = match high_task {
599                            Some(task) => task,
600                            None => break,
601                        };
602                    },
603                    low_task = low_receiver.next() => {
604                        task = match low_task {
605                            Some(task) => task,
606                            None => break,
607                        };
608                    }
609                }
610
611                // Send task
612                if let Err(err) = sender.send(task).await {
613                    error!(?err, channel, "failed to send task");
614                }
615            }
616        });
617
618        // Return sender
619        Self {
620            me,
621            channel,
622            max_size,
623            high,
624            low,
625        }
626    }
627}
628
629impl<P: PublicKey> crate::Sender for Sender<P> {
630    type Error = Error;
631    type PublicKey = P;
632
633    async fn send(
634        &mut self,
635        recipients: Recipients<P>,
636        message: Bytes,
637        priority: bool,
638    ) -> Result<Vec<P>, Error> {
639        // Check message size
640        if message.len() > self.max_size {
641            return Err(Error::MessageTooLarge(message.len()));
642        }
643
644        // Send message
645        let (sender, receiver) = oneshot::channel();
646        let mut channel = if priority { &self.high } else { &self.low };
647        channel
648            .send((self.channel, self.me.clone(), recipients, message, sender))
649            .await
650            .map_err(|_| Error::NetworkClosed)?;
651        receiver.await.map_err(|_| Error::NetworkClosed)
652    }
653}
654
655type MessageReceiver<P> = mpsc::UnboundedReceiver<Message<P>>;
656type MessageReceiverResult<P> = Result<MessageReceiver<P>, Error>;
657
658/// Implementation of a [crate::Receiver] for the simulated network.
659#[derive(Debug)]
660pub struct Receiver<P: PublicKey> {
661    receiver: MessageReceiver<P>,
662}
663
664impl<P: PublicKey> crate::Receiver for Receiver<P> {
665    type Error = Error;
666    type PublicKey = P;
667
668    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
669        self.receiver.next().await.ok_or(Error::NetworkClosed)
670    }
671}
672
673/// A peer in the simulated network.
674///
675/// The peer can register channels, which allows it to receive messages sent to the channel from other peers.
676struct Peer<P: PublicKey> {
677    // Socket address that the peer is listening on
678    socket: SocketAddr,
679
680    // Control to register new channels
681    control: mpsc::UnboundedSender<(Channel, oneshot::Sender<MessageReceiverResult<P>>)>,
682}
683
684impl<P: PublicKey> Peer<P> {
685    /// Create and return a new peer.
686    ///
687    /// The peer will listen for incoming connections on the given `socket` address.
688    /// `max_size` is the maximum size of a message that can be sent to the peer.
689    fn new<E: Spawner + RNetwork + Metrics + Clock>(
690        context: E,
691        public_key: P,
692        socket: SocketAddr,
693        max_size: usize,
694    ) -> Self {
695        // The control is used to register channels.
696        // There is exactly one mailbox created for each channel that the peer is registered for.
697        let (control_sender, mut control_receiver) = mpsc::unbounded();
698
699        // Whenever a message is received from a peer, it is placed in the inbox.
700        // The router polls the inbox and forwards the message to the appropriate mailbox.
701        let (inbox_sender, mut inbox_receiver) = mpsc::unbounded();
702
703        // Spawn router
704        context.with_label("router").spawn(|_| async move {
705            // Map of channels to mailboxes (senders to particular channels)
706            let mut mailboxes = HashMap::new();
707
708            // Continually listen for control messages and outbound messages
709            loop {
710                select! {
711                    // Listen for control messages, which are used to register channels
712                    control = control_receiver.next() => {
713                        // If control is closed, exit
714                        let (channel, result): (Channel, oneshot::Sender<MessageReceiverResult<P>>) = match control {
715                            Some(control) => control,
716                            None => break,
717                        };
718
719                        // Check if channel is registered
720                        if mailboxes.contains_key(&channel) {
721                            result.send(Err(Error::ChannelAlreadyRegistered(channel))).unwrap();
722                            continue;
723                        }
724
725                        // Register channel
726                        let (sender, receiver) = mpsc::unbounded();
727                        mailboxes.insert(channel, sender);
728                        result.send(Ok(receiver)).unwrap();
729                    },
730
731                    // Listen for messages from the inbox, which are forwarded to the appropriate mailbox
732                    inbox = inbox_receiver.next() => {
733                        // If inbox is closed, exit
734                        let (channel, message) = match inbox {
735                            Some(message) => message,
736                            None => break,
737                        };
738
739                        // Send message to mailbox
740                        match mailboxes.get_mut(&channel) {
741                            Some(mailbox) => {
742                                if let Err(err) = mailbox.send(message).await {
743                                    error!(?err, "failed to send message to mailbox");
744                                }
745                            }
746                            None => {
747                                trace!(
748                                    recipient = ?public_key,
749                                    channel,
750                                    reason = "missing channel",
751                                    "dropping message",
752                                );
753                            }
754                        }
755                    },
756                }
757            }
758        });
759
760        // Spawn a task that accepts new connections and spawns a task for each connection
761        context.with_label("listener").spawn({
762            let inbox_sender = inbox_sender.clone();
763            move |context| async move {
764                // Initialize listener
765                let mut listener = context.bind(socket).await.unwrap();
766
767                // Continually accept new connections
768                while let Ok((_, _, mut stream)) = listener.accept().await {
769                    // New connection accepted. Spawn a task for this connection
770                    context.with_label("receiver").spawn({
771                        let mut inbox_sender = inbox_sender.clone();
772                        move |_| async move {
773                            // Receive dialer's public key as a handshake
774                            let dialer = match recv_frame(&mut stream, max_size).await {
775                                Ok(data) => data,
776                                Err(_) => {
777                                    error!("failed to receive public key from dialer");
778                                    return;
779                                }
780                            };
781                            let Ok(dialer) = P::decode(dialer.as_ref()) else {
782                                error!("received public key is invalid");
783                                return;
784                            };
785
786                            // Continually receive messages from the dialer and send them to the inbox
787                            while let Ok(data) = recv_frame(&mut stream, max_size).await {
788                                let channel = Channel::from_be_bytes(
789                                    data[..Channel::SIZE].try_into().unwrap(),
790                                );
791                                let message = data.slice(Channel::SIZE..);
792                                if let Err(err) = inbox_sender
793                                    .send((channel, (dialer.clone(), message)))
794                                    .await
795                                {
796                                    error!(?err, "failed to send message to mailbox");
797                                    break;
798                                }
799                            }
800                        }
801                    });
802                }
803            }
804        });
805
806        // Return peer
807        Self {
808            socket,
809            control: control_sender,
810        }
811    }
812
813    /// Register a channel with the peer.
814    ///
815    /// This allows the peer to receive messages sent to the channel.
816    /// Returns a receiver that can be used to receive messages sent to the channel.
817    async fn register(&mut self, channel: Channel) -> MessageReceiverResult<P> {
818        let (sender, receiver) = oneshot::channel();
819        self.control
820            .send((channel, sender))
821            .await
822            .map_err(|_| Error::NetworkClosed)?;
823        receiver.await.map_err(|_| Error::NetworkClosed)?
824    }
825}
826
827// A unidirectional link between two peers.
828// Messages can be sent over the link with a given latency, jitter, and success rate.
829struct Link {
830    sampler: Normal<f64>,
831    success_rate: f64,
832    // Messages with their receive time for ordered delivery
833    inbox: mpsc::UnboundedSender<(Channel, Bytes, SystemTime)>,
834}
835
836/// Buffered payload waiting for earlier messages on the same link to complete.
837impl Link {
838    #[allow(clippy::too_many_arguments)]
839    fn new<E: Spawner + RNetwork + Clock + Metrics, P: PublicKey>(
840        context: &mut E,
841        dialer: P,
842        receiver: P,
843        socket: SocketAddr,
844        sampler: Normal<f64>,
845        success_rate: f64,
846        max_size: usize,
847        received_messages: Family<metrics::Message, Counter>,
848    ) -> Self {
849        // Spawn a task that will wait for messages to be sent to the link and then send them
850        // over the network.
851        let (inbox, mut outbox) = mpsc::unbounded::<(Channel, Bytes, SystemTime)>();
852        context.with_label("link").spawn(move |context| async move {
853            // Dial the peer and handshake by sending it the dialer's public key
854            let (mut sink, _) = context.dial(socket).await.unwrap();
855            if let Err(err) = send_frame(&mut sink, &dialer, max_size).await {
856                error!(?err, "failed to send public key to listener");
857                return;
858            }
859
860            // Process messages in order, waiting for their receive time
861            while let Some((channel, message, receive_complete_at)) = outbox.next().await {
862                // Wait until the message should arrive at receiver
863                context.sleep_until(receive_complete_at).await;
864
865                // Send the message
866                let mut data = bytes::BytesMut::with_capacity(Channel::SIZE + message.len());
867                data.extend_from_slice(&channel.to_be_bytes());
868                data.extend_from_slice(&message);
869                let data = data.freeze();
870                send_frame(&mut sink, &data, max_size).await.unwrap();
871
872                // Bump received messages metric
873                received_messages
874                    .get_or_create(&metrics::Message::new(&dialer, &receiver, channel))
875                    .inc();
876            }
877        });
878
879        Self {
880            sampler,
881            success_rate,
882            inbox,
883        }
884    }
885
886    // Send a message over the link with receive timing.
887    fn send(
888        &mut self,
889        channel: Channel,
890        message: Bytes,
891        receive_complete_at: SystemTime,
892    ) -> Result<(), Error> {
893        self.inbox
894            .unbounded_send((channel, message, receive_complete_at))
895            .map_err(|_| Error::NetworkClosed)?;
896        Ok(())
897    }
898}
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903    use crate::{Manager, Receiver as _, Recipients, Sender as _};
904    use bytes::Bytes;
905    use commonware_cryptography::{ed25519, PrivateKeyExt as _, Signer as _};
906    use commonware_runtime::{deterministic, Runner as _};
907    const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
908
909    #[test]
910    fn test_register_and_link() {
911        let executor = deterministic::Runner::default();
912        executor.start(|context| async move {
913            let cfg = Config {
914                max_size: MAX_MESSAGE_SIZE,
915                disconnect_on_block: true,
916                tracked_peer_sets: Some(3),
917            };
918            let network_context = context.with_label("network");
919            let (network, mut oracle) = Network::new(network_context.clone(), cfg);
920            network_context.spawn(|_| network.run());
921
922            // Create two public keys
923            let pk1 = ed25519::PrivateKey::from_seed(1).public_key();
924            let pk2 = ed25519::PrivateKey::from_seed(2).public_key();
925
926            // Register the peer set
927            oracle.update(0, [pk1.clone(), pk2.clone()].into()).await;
928            let mut control = oracle.control(pk1.clone());
929            control.register(0).await.unwrap();
930            control.register(1).await.unwrap();
931            let mut control = oracle.control(pk2.clone());
932            control.register(0).await.unwrap();
933            control.register(1).await.unwrap();
934
935            // Expect error when registering again
936            assert!(matches!(
937                control.register(1).await,
938                Err(Error::ChannelAlreadyRegistered(_))
939            ));
940
941            // Add link
942            let link = ingress::Link {
943                latency: Duration::from_millis(2),
944                jitter: Duration::from_millis(1),
945                success_rate: 0.9,
946            };
947            oracle
948                .add_link(pk1.clone(), pk2.clone(), link.clone())
949                .await
950                .unwrap();
951
952            // Expect error when adding link again
953            assert!(matches!(
954                oracle.add_link(pk1, pk2, link).await,
955                Err(Error::LinkExists)
956            ));
957        });
958    }
959
960    #[test]
961    fn test_get_next_socket() {
962        let cfg = Config {
963            max_size: MAX_MESSAGE_SIZE,
964            disconnect_on_block: true,
965            tracked_peer_sets: None,
966        };
967        let runner = deterministic::Runner::default();
968
969        runner.start(|context| async move {
970            type PublicKey = ed25519::PublicKey;
971            let (mut network, _) =
972                Network::<deterministic::Context, PublicKey>::new(context.clone(), cfg);
973
974            // Test that the next socket address is incremented correctly
975            let mut original = network.next_addr;
976            let next = network.get_next_socket();
977            assert_eq!(next, original);
978            let next = network.get_next_socket();
979            original.set_port(1);
980            assert_eq!(next, original);
981
982            // Test that the port number overflows correctly
983            let max_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 0, 255, 255)), 65535);
984            network.next_addr = max_addr;
985            let next = network.get_next_socket();
986            assert_eq!(next, max_addr);
987            let next = network.get_next_socket();
988            assert_eq!(
989                next,
990                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 1, 0, 0)), 0)
991            );
992        });
993    }
994
995    #[test]
996    fn test_fifo_burst_same_recipient() {
997        let cfg = Config {
998            max_size: MAX_MESSAGE_SIZE,
999            disconnect_on_block: true,
1000            tracked_peer_sets: Some(3),
1001        };
1002        let runner = deterministic::Runner::default();
1003
1004        runner.start(|context| async move {
1005            let (network, mut oracle) = Network::new(context.with_label("network"), cfg);
1006            let network_handle = network.start();
1007
1008            let sender_pk = ed25519::PrivateKey::from_seed(10).public_key();
1009            let recipient_pk = ed25519::PrivateKey::from_seed(11).public_key();
1010
1011            oracle
1012                .update(0, [sender_pk.clone(), recipient_pk.clone()].into())
1013                .await;
1014            let (mut sender, _sender_recv) =
1015                oracle.control(sender_pk.clone()).register(0).await.unwrap();
1016            let (_sender2, mut receiver) = oracle
1017                .control(recipient_pk.clone())
1018                .register(0)
1019                .await
1020                .unwrap();
1021
1022            oracle
1023                .limit_bandwidth(sender_pk.clone(), Some(5_000), None)
1024                .await
1025                .unwrap();
1026            oracle
1027                .limit_bandwidth(recipient_pk.clone(), None, Some(5_000))
1028                .await
1029                .unwrap();
1030
1031            oracle
1032                .add_link(
1033                    sender_pk.clone(),
1034                    recipient_pk.clone(),
1035                    ingress::Link {
1036                        latency: Duration::from_millis(0),
1037                        jitter: Duration::from_millis(0),
1038                        success_rate: 1.0,
1039                    },
1040                )
1041                .await
1042                .unwrap();
1043
1044            const COUNT: usize = 50;
1045            let mut expected = Vec::with_capacity(COUNT);
1046            for i in 0..COUNT {
1047                let msg = Bytes::from(vec![i as u8; 64]);
1048                sender
1049                    .send(Recipients::One(recipient_pk.clone()), msg.clone(), false)
1050                    .await
1051                    .unwrap();
1052                expected.push(msg);
1053            }
1054
1055            for expected_msg in expected {
1056                let (_pk, bytes) = receiver.recv().await.unwrap();
1057                assert_eq!(bytes, expected_msg);
1058            }
1059
1060            drop(oracle);
1061            drop(sender);
1062            network_handle.abort();
1063        });
1064    }
1065
1066    #[test]
1067    fn test_broadcast_respects_transmit_latency() {
1068        let cfg = Config {
1069            max_size: MAX_MESSAGE_SIZE,
1070            disconnect_on_block: true,
1071            tracked_peer_sets: Some(3),
1072        };
1073        let runner = deterministic::Runner::default();
1074
1075        runner.start(|context| async move {
1076            let (network, mut oracle) = Network::new(context.with_label("network"), cfg);
1077            let network_handle = network.start();
1078
1079            let sender_pk = ed25519::PrivateKey::from_seed(42).public_key();
1080            let recipient_a = ed25519::PrivateKey::from_seed(43).public_key();
1081            let recipient_b = ed25519::PrivateKey::from_seed(44).public_key();
1082
1083            oracle
1084                .update(
1085                    0,
1086                    [sender_pk.clone(), recipient_a.clone(), recipient_b.clone()].into(),
1087                )
1088                .await;
1089            let (mut sender, _recv_sender) =
1090                oracle.control(sender_pk.clone()).register(0).await.unwrap();
1091            let (_sender2, mut recv_a) = oracle
1092                .control(recipient_a.clone())
1093                .register(0)
1094                .await
1095                .unwrap();
1096            let (_sender3, mut recv_b) = oracle
1097                .control(recipient_b.clone())
1098                .register(0)
1099                .await
1100                .unwrap();
1101
1102            oracle
1103                .limit_bandwidth(sender_pk.clone(), Some(1_000), None)
1104                .await
1105                .unwrap();
1106            oracle
1107                .limit_bandwidth(recipient_a.clone(), None, Some(1_000))
1108                .await
1109                .unwrap();
1110            oracle
1111                .limit_bandwidth(recipient_b.clone(), None, Some(1_000))
1112                .await
1113                .unwrap();
1114
1115            let link = ingress::Link {
1116                latency: Duration::from_millis(0),
1117                jitter: Duration::from_millis(0),
1118                success_rate: 1.0,
1119            };
1120            oracle
1121                .add_link(sender_pk.clone(), recipient_a.clone(), link.clone())
1122                .await
1123                .unwrap();
1124            oracle
1125                .add_link(sender_pk.clone(), recipient_b.clone(), link)
1126                .await
1127                .unwrap();
1128
1129            let big_msg = Bytes::from(vec![7u8; 10_000]);
1130            let start = context.current();
1131            sender
1132                .send(Recipients::All, big_msg.clone(), false)
1133                .await
1134                .unwrap();
1135
1136            let (_pk, received_a) = recv_a.recv().await.unwrap();
1137            assert_eq!(received_a, big_msg);
1138            let elapsed_a = context.current().duration_since(start).unwrap();
1139            assert!(elapsed_a >= Duration::from_secs(20));
1140
1141            let (_pk, received_b) = recv_b.recv().await.unwrap();
1142            assert_eq!(received_b, big_msg);
1143            let elapsed_b = context.current().duration_since(start).unwrap();
1144            assert!(elapsed_b >= Duration::from_secs(20));
1145
1146            // Because bandwidth is shared, the two messages should take about the same time
1147            assert!(elapsed_a.abs_diff(elapsed_b) <= Duration::from_secs(1));
1148
1149            drop(oracle);
1150            drop(sender);
1151            network_handle.abort();
1152        });
1153    }
1154}