1use super::{
4 bandwidth,
5 ingress::{self, Oracle},
6 metrics, Error,
7};
8use crate::{Channel, Message, Recipients};
9use bytes::Bytes;
10use commonware_codec::{DecodeExt, FixedSize};
11use commonware_cryptography::PublicKey;
12use commonware_macros::select;
13use commonware_runtime::{Clock, Handle, Listener as _, Metrics, Network as RNetwork, Spawner};
14use commonware_stream::utils::codec::{recv_frame, send_frame};
15use futures::{
16 channel::{mpsc, oneshot},
17 SinkExt, StreamExt,
18};
19use prometheus_client::metrics::{counter::Counter, family::Family};
20use rand::Rng;
21use rand_distr::{Distribution, Normal};
22use std::{
23 collections::{BTreeMap, HashMap, HashSet},
24 net::{IpAddr, Ipv4Addr, SocketAddr},
25 time::{Duration, SystemTime},
26};
27use tracing::{error, trace};
28
29type Task<P> = (Channel, P, Recipients<P>, Bytes, oneshot::Sender<Vec<P>>);
31
32pub struct Config {
34 pub max_size: usize,
36}
37
38pub struct Network<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> {
40 context: E,
41
42 max_size: usize,
44
45 next_addr: SocketAddr,
48
49 ingress: mpsc::UnboundedReceiver<ingress::Message<P>>,
51
52 sender: mpsc::UnboundedSender<Task<P>>,
56 receiver: mpsc::UnboundedReceiver<Task<P>>,
57
58 links: HashMap<(P, P), Link>,
60
61 peers: BTreeMap<P, Peer<P>>,
63
64 blocks: HashSet<(P, P)>,
66
67 received_messages: Family<metrics::Message, Counter>,
69 sent_messages: Family<metrics::Message, Counter>,
70}
71
72impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> Network<E, P> {
73 pub fn new(context: E, cfg: Config) -> (Self, Oracle<P>) {
78 let (sender, receiver) = mpsc::unbounded();
79 let (oracle_sender, oracle_receiver) = mpsc::unbounded();
80 let sent_messages = Family::<metrics::Message, Counter>::default();
81 let received_messages = Family::<metrics::Message, Counter>::default();
82 context.register("messages_sent", "messages sent", sent_messages.clone());
83 context.register(
84 "messages_received",
85 "messages received",
86 received_messages.clone(),
87 );
88
89 let next_addr = SocketAddr::new(
91 IpAddr::V4(Ipv4Addr::from_bits(context.clone().next_u32())),
92 0,
93 );
94 (
95 Self {
96 context,
97 max_size: cfg.max_size,
98 next_addr,
99 ingress: oracle_receiver,
100 sender,
101 receiver,
102 links: HashMap::new(),
103 peers: BTreeMap::new(),
104 blocks: HashSet::new(),
105 received_messages,
106 sent_messages,
107 },
108 Oracle::new(oracle_sender.clone()),
109 )
110 }
111
112 fn get_next_socket(&mut self) -> SocketAddr {
117 let result = self.next_addr;
118
119 match self.next_addr.port().checked_add(1) {
122 Some(port) => {
123 self.next_addr.set_port(port);
124 }
125 None => {
126 let ip = match self.next_addr.ip() {
127 IpAddr::V4(ipv4) => ipv4,
128 _ => unreachable!(),
129 };
130 let next_ip = Ipv4Addr::to_bits(ip).wrapping_add(1);
131 self.next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(next_ip)), 0);
132 }
133 }
134
135 result
136 }
137
138 async fn handle_ingress(&mut self, message: ingress::Message<P>) {
142 fn send_result<T: std::fmt::Debug>(
146 result: oneshot::Sender<Result<T, Error>>,
147 value: Result<T, Error>,
148 ) {
149 let success = value.is_ok();
150 if let Err(e) = result.send(value) {
151 error!(?e, "failed to send result to oracle (ok = {})", success);
152 }
153 }
154
155 match message {
156 ingress::Message::Register {
157 public_key,
158 channel,
159 result,
160 } => {
161 if !self.peers.contains_key(&public_key) {
163 let peer = Peer::new(
164 &mut self.context.clone(),
165 public_key.clone(),
166 self.get_next_socket(),
167 usize::MAX,
168 usize::MAX,
169 self.max_size,
170 );
171 self.peers.insert(public_key.clone(), peer);
172 }
173
174 let peer = self.peers.get_mut(&public_key).unwrap();
176 let receiver = match peer.register(channel).await {
177 Ok(receiver) => Receiver { receiver },
178 Err(err) => return send_result(result, Err(err)),
179 };
180
181 let sender = Sender::new(
183 self.context.clone(),
184 public_key,
185 channel,
186 self.max_size,
187 self.sender.clone(),
188 );
189 send_result(result, Ok((sender, receiver)))
190 }
191 ingress::Message::SetBandwidth {
192 public_key,
193 egress_bps,
194 ingress_bps,
195 result,
196 } => match self.peers.get_mut(&public_key) {
197 Some(peer) => {
198 peer.set_bandwidth(egress_bps, ingress_bps);
199 send_result(result, Ok(()));
200 }
201 None => send_result(result, Err(Error::PeerMissing)),
202 },
203 ingress::Message::AddLink {
204 sender,
205 receiver,
206 sampler,
207 success_rate,
208 result,
209 } => {
210 if !self.peers.contains_key(&sender) {
212 return send_result(result, Err(Error::PeerMissing));
213 }
214 let peer = match self.peers.get(&receiver) {
215 Some(peer) => peer,
216 None => return send_result(result, Err(Error::PeerMissing)),
217 };
218
219 let key = (sender.clone(), receiver.clone());
221 if self.links.contains_key(&key) {
222 return send_result(result, Err(Error::LinkExists));
223 }
224
225 let link = Link::new(
226 &mut self.context,
227 sender,
228 receiver,
229 peer.socket,
230 sampler,
231 success_rate,
232 self.max_size,
233 self.received_messages.clone(),
234 );
235 self.links.insert(key, link);
236 send_result(result, Ok(()))
237 }
238 ingress::Message::RemoveLink {
239 sender,
240 receiver,
241 result,
242 } => {
243 match self.links.remove(&(sender, receiver)) {
244 Some(_) => (),
245 None => return send_result(result, Err(Error::LinkMissing)),
246 }
247 send_result(result, Ok(()))
248 }
249 ingress::Message::Block { from, to } => {
250 self.blocks.insert((from, to));
251 }
252 ingress::Message::Blocked { result } => {
253 send_result(result, Ok(self.blocks.iter().cloned().collect()))
254 }
255 }
256 }
257}
258
259impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> Network<E, P> {
260 fn schedule_transmission(
262 &mut self,
263 sender: &P,
264 receiver: &P,
265 data_size: usize,
266 now: SystemTime,
267 should_deliver: bool,
268 ) -> SystemTime {
269 let sender_used = {
271 let sender_peer = self.peers.get_mut(sender).expect("sender not found");
272 sender_peer.egress.prune_and_get_usage(now)
273 };
274
275 let receiver_used = if should_deliver && sender != receiver {
276 let receiver_peer = self.peers.get_mut(receiver).expect("receiver not found");
277 Some(receiver_peer.ingress.prune_and_get_usage(now))
278 } else {
279 None
280 };
281
282 let sender_schedule = {
283 let sender = self.peers.get(sender).expect("sender not found");
284 (&sender.egress, sender_used)
285 };
286
287 let receiver_schedule = if let Some(used) = receiver_used {
288 let receiver_peer = self.peers.get(receiver).expect("receiver not found");
289 Some((&receiver_peer.ingress, used))
290 } else {
291 None
292 };
293
294 let (reservations, completion_time) =
296 bandwidth::calculate_reservations(data_size, now, sender_schedule, receiver_schedule);
297
298 if !reservations.is_empty() {
300 let sender_peer = self.peers.get_mut(sender).expect("sender not found");
301 for reservation in &reservations {
302 sender_peer.egress.add_reservation(
303 reservation.start,
304 reservation.end,
305 reservation.bandwidth,
306 );
307 }
308
309 if receiver_used.is_some() {
311 let receiver_peer = self.peers.get_mut(receiver).expect("receiver not found");
312 for reservation in &reservations {
313 receiver_peer.ingress.add_reservation(
314 reservation.start,
315 reservation.end,
316 reservation.bandwidth,
317 );
318 }
319 }
320 }
321
322 completion_time
323 }
324
325 fn handle_task(&mut self, task: Task<P>) {
330 let (channel, origin, recipients, message, reply) = task;
332 let recipients = match recipients {
333 Recipients::All => self.peers.keys().cloned().collect(),
334 Recipients::Some(keys) => keys,
335 Recipients::One(key) => vec![key],
336 };
337
338 let mut sent = Vec::new();
340 let (acquired_sender, mut acquired_receiver) = mpsc::channel(recipients.len());
341 for recipient in recipients {
342 if recipient == origin {
344 trace!(?recipient, reason = "self", "dropping message",);
345 continue;
346 }
347
348 let o_r = (origin.clone(), recipient.clone());
350 let r_o = (recipient.clone(), origin.clone());
351 if self.blocks.contains(&o_r) || self.blocks.contains(&r_o) {
352 trace!(?origin, ?recipient, reason = "blocked", "dropping message");
353 continue;
354 }
355
356 let link = match self.links.get(&o_r) {
358 Some(link) => link,
359 None => {
360 trace!(?origin, ?recipient, reason = "no link", "dropping message",);
361 continue;
362 }
363 };
364
365 self.sent_messages
368 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
369 .inc();
370
371 let (sender_has_bandwidth, should_deliver) = {
373 let sender_peer = self.peers.get(&origin).expect("sender must exist");
374 let receiver_peer = self.peers.get(&recipient).expect("receiver must exist");
375
376 let sender_has_bandwidth = sender_peer.egress.bps > 0;
377 let receiver_has_bandwidth = receiver_peer.ingress.bps > 0;
378
379 let should_deliver = self.context.gen_bool(link.success_rate);
380
381 (
382 sender_has_bandwidth,
383 should_deliver && receiver_has_bandwidth,
386 )
387 };
388
389 if !sender_has_bandwidth {
390 trace!(
392 ?origin,
393 ?recipient,
394 "sender has zero bandwidth, skipping recipient"
395 );
396 continue;
397 }
398
399 let latency = Duration::from_millis(link.sampler.sample(&mut self.context) as u64);
401 let now = self.context.current();
402
403 let transmission_complete_at =
405 self.schedule_transmission(&origin, &recipient, message.len(), now, should_deliver);
406
407 if should_deliver {
410 let link = self.links.get_mut(&o_r).unwrap();
411
412 let receive_complete_at = transmission_complete_at + latency;
414
415 if let Err(err) = link.send(channel, message.clone(), receive_complete_at) {
416 error!(?origin, ?recipient, ?err, "failed to send");
418 continue;
419 }
420 }
421
422 let transmission_duration = transmission_complete_at
423 .duration_since(now)
424 .unwrap_or(Duration::ZERO);
425 trace!(
426 ?origin,
427 ?recipient,
428 transmission_duration_ms = transmission_duration.as_millis(),
429 latency_ms = latency.as_millis(),
430 delivered = should_deliver,
431 "sending message",
432 );
433
434 self.context.with_label("sender-timing").spawn({
436 let recipient = recipient.clone();
437 let mut acquired_sender = acquired_sender.clone();
438 move |context| async move {
439 context.sleep_until(transmission_complete_at).await;
441
442 acquired_sender.send(()).await.unwrap();
444
445 if !should_deliver {
446 trace!(
447 ?recipient,
448 reason = "random link failure",
449 "dropping message",
450 );
451 }
452 }
453 });
454
455 sent.push(recipient);
456 }
457
458 self.context
460 .clone()
461 .with_label("notifier")
462 .spawn(|_| async move {
463 for _ in 0..sent.len() {
465 acquired_receiver.next().await.unwrap();
466 }
467
468 if let Err(err) = reply.send(sent) {
470 error!(?err, "failed to send ack");
472 }
473 });
474 }
475
476 pub fn start(mut self) -> Handle<()> {
481 self.context.spawn_ref()(self.run())
482 }
483
484 async fn run(mut self) {
485 loop {
486 select! {
487 message = self.ingress.next() => {
488 let message = match message {
490 Some(message) => message,
491 None => break,
492 };
493 self.handle_ingress(message).await;
494 },
495 task = self.receiver.next() => {
496 let task = match task {
498 Some(task) => task,
499 None => break,
500 };
501 self.handle_task(task);
502 }
503 }
504 }
505 }
506}
507
508#[derive(Clone, Debug)]
510pub struct Sender<P: PublicKey> {
511 me: P,
512 channel: Channel,
513 max_size: usize,
514 high: mpsc::UnboundedSender<Task<P>>,
515 low: mpsc::UnboundedSender<Task<P>>,
516}
517
518impl<P: PublicKey> Sender<P> {
519 fn new(
520 context: impl Spawner + Metrics,
521 me: P,
522 channel: Channel,
523 max_size: usize,
524 mut sender: mpsc::UnboundedSender<Task<P>>,
525 ) -> Self {
526 let (high, mut high_receiver) = mpsc::unbounded();
528 let (low, mut low_receiver) = mpsc::unbounded();
529 context.with_label("sender").spawn(move |_| async move {
530 loop {
531 let task;
533 select! {
534 high_task = high_receiver.next() => {
535 task = match high_task {
536 Some(task) => task,
537 None => break,
538 };
539 },
540 low_task = low_receiver.next() => {
541 task = match low_task {
542 Some(task) => task,
543 None => break,
544 };
545 }
546 }
547
548 if let Err(err) = sender.send(task).await {
550 error!(?err, channel, "failed to send task");
551 }
552 }
553 });
554
555 Self {
557 me,
558 channel,
559 max_size,
560 high,
561 low,
562 }
563 }
564}
565
566impl<P: PublicKey> crate::Sender for Sender<P> {
567 type Error = Error;
568 type PublicKey = P;
569
570 async fn send(
571 &mut self,
572 recipients: Recipients<P>,
573 message: Bytes,
574 priority: bool,
575 ) -> Result<Vec<P>, Error> {
576 if message.len() > self.max_size {
578 return Err(Error::MessageTooLarge(message.len()));
579 }
580
581 let (sender, receiver) = oneshot::channel();
583 let mut channel = if priority { &self.high } else { &self.low };
584 channel
585 .send((self.channel, self.me.clone(), recipients, message, sender))
586 .await
587 .map_err(|_| Error::NetworkClosed)?;
588 receiver.await.map_err(|_| Error::NetworkClosed)
589 }
590}
591
592type MessageReceiver<P> = mpsc::UnboundedReceiver<Message<P>>;
593type MessageReceiverResult<P> = Result<MessageReceiver<P>, Error>;
594
595#[derive(Debug)]
597pub struct Receiver<P: PublicKey> {
598 receiver: MessageReceiver<P>,
599}
600
601impl<P: PublicKey> crate::Receiver for Receiver<P> {
602 type Error = Error;
603 type PublicKey = P;
604
605 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
606 self.receiver.next().await.ok_or(Error::NetworkClosed)
607 }
608}
609
610struct Peer<P: PublicKey> {
614 socket: SocketAddr,
616
617 control: mpsc::UnboundedSender<(Channel, oneshot::Sender<MessageReceiverResult<P>>)>,
619
620 egress: bandwidth::Schedule,
622 ingress: bandwidth::Schedule,
623}
624
625impl<P: PublicKey> Peer<P> {
626 fn new<E: Spawner + RNetwork + Metrics + Clock>(
631 context: &mut E,
632 public_key: P,
633 socket: SocketAddr,
634 egress_bps: usize,
635 ingress_bps: usize,
636 max_size: usize,
637 ) -> Self {
638 let (control_sender, mut control_receiver) = mpsc::unbounded();
641
642 let (inbox_sender, mut inbox_receiver) = mpsc::unbounded();
645
646 context.with_label("router").spawn(|_| async move {
648 let mut mailboxes = HashMap::new();
650
651 loop {
653 select! {
654 control = control_receiver.next() => {
656 let (channel, result): (Channel, oneshot::Sender<MessageReceiverResult<P>>) = match control {
658 Some(control) => control,
659 None => break,
660 };
661
662 if mailboxes.contains_key(&channel) {
664 result.send(Err(Error::ChannelAlreadyRegistered(channel))).unwrap();
665 continue;
666 }
667
668 let (sender, receiver) = mpsc::unbounded();
670 mailboxes.insert(channel, sender);
671 result.send(Ok(receiver)).unwrap();
672 },
673
674 inbox = inbox_receiver.next() => {
676 let (channel, message) = match inbox {
678 Some(message) => message,
679 None => break,
680 };
681
682 match mailboxes.get_mut(&channel) {
684 Some(mailbox) => {
685 if let Err(err) = mailbox.send(message).await {
686 error!(?err, "failed to send message to mailbox");
687 }
688 }
689 None => {
690 trace!(
691 recipient = ?public_key,
692 channel,
693 reason = "missing channel",
694 "dropping message",
695 );
696 }
697 }
698 },
699 }
700 }
701 });
702
703 context.with_label("listener").spawn({
705 let inbox_sender = inbox_sender.clone();
706 move |context| async move {
707 let mut listener = context.bind(socket).await.unwrap();
709
710 while let Ok((_, _, mut stream)) = listener.accept().await {
712 context.with_label("receiver").spawn({
714 let mut inbox_sender = inbox_sender.clone();
715 move |_| async move {
716 let dialer = match recv_frame(&mut stream, max_size).await {
718 Ok(data) => data,
719 Err(_) => {
720 error!("failed to receive public key from dialer");
721 return;
722 }
723 };
724 let Ok(dialer) = P::decode(dialer.as_ref()) else {
725 error!("received public key is invalid");
726 return;
727 };
728
729 while let Ok(data) = recv_frame(&mut stream, max_size).await {
731 let channel = Channel::from_be_bytes(
732 data[..Channel::SIZE].try_into().unwrap(),
733 );
734 let message = data.slice(Channel::SIZE..);
735 if let Err(err) = inbox_sender
736 .send((channel, (dialer.clone(), message)))
737 .await
738 {
739 error!(?err, "failed to send message to mailbox");
740 break;
741 }
742 }
743 }
744 });
745 }
746 }
747 });
748
749 Self {
751 socket,
752 control: control_sender,
753 egress: bandwidth::Schedule::new(egress_bps),
754 ingress: bandwidth::Schedule::new(ingress_bps),
755 }
756 }
757
758 async fn register(&mut self, channel: Channel) -> MessageReceiverResult<P> {
763 let (sender, receiver) = oneshot::channel();
764 self.control
765 .send((channel, sender))
766 .await
767 .map_err(|_| Error::NetworkClosed)?;
768 receiver.await.map_err(|_| Error::NetworkClosed)?
769 }
770
771 fn set_bandwidth(&mut self, egress_bps: usize, ingress_bps: usize) {
777 self.egress.bps = egress_bps;
778 self.ingress.bps = ingress_bps;
779 }
780}
781
782#[derive(Clone)]
785struct Link {
786 sampler: Normal<f64>,
787 success_rate: f64,
788 inbox: mpsc::UnboundedSender<(Channel, Bytes, SystemTime)>,
790}
791
792impl Link {
793 #[allow(clippy::too_many_arguments)]
794 fn new<E: Spawner + RNetwork + Clock + Metrics, P: PublicKey>(
795 context: &mut E,
796 dialer: P,
797 receiver: P,
798 socket: SocketAddr,
799 sampler: Normal<f64>,
800 success_rate: f64,
801 max_size: usize,
802 received_messages: Family<metrics::Message, Counter>,
803 ) -> Self {
804 let (inbox, mut outbox) = mpsc::unbounded();
805 let result = Self {
806 sampler,
807 success_rate,
808 inbox,
809 };
810
811 context
814 .clone()
815 .with_label("link")
816 .spawn(move |context| async move {
817 let (mut sink, _) = context.dial(socket).await.unwrap();
819 if let Err(err) = send_frame(&mut sink, &dialer, max_size).await {
820 error!(?err, "failed to send public key to listener");
821 return;
822 }
823
824 while let Some((channel, message, receive_complete_at)) = outbox.next().await {
826 context.sleep_until(receive_complete_at).await;
828
829 let mut data = bytes::BytesMut::with_capacity(Channel::SIZE + message.len());
831 data.extend_from_slice(&channel.to_be_bytes());
832 data.extend_from_slice(&message);
833 let data = data.freeze();
834 send_frame(&mut sink, &data, max_size).await.unwrap();
835
836 received_messages
838 .get_or_create(&metrics::Message::new(&dialer, &receiver, channel))
839 .inc();
840 }
841 });
842
843 result
844 }
845
846 fn send(
848 &mut self,
849 channel: Channel,
850 message: Bytes,
851 receive_complete_at: SystemTime,
852 ) -> Result<(), Error> {
853 self.inbox
854 .unbounded_send((channel, message, receive_complete_at))
855 .map_err(|_| Error::NetworkClosed)?;
856 Ok(())
857 }
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use commonware_cryptography::{ed25519, PrivateKeyExt as _, Signer as _};
864 use commonware_runtime::{deterministic, Runner};
865
866 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
867
868 #[test]
869 fn test_register_and_link() {
870 let executor = deterministic::Runner::default();
871 executor.start(|context| async move {
872 let cfg = Config {
873 max_size: MAX_MESSAGE_SIZE,
874 };
875 let network_context = context.with_label("network");
876 let (network, mut oracle) = Network::new(network_context.clone(), cfg);
877 network_context.spawn(|_| network.run());
878
879 let pk1 = ed25519::PrivateKey::from_seed(1).public_key();
881 let pk2 = ed25519::PrivateKey::from_seed(2).public_key();
882
883 oracle.register(pk1.clone(), 0).await.unwrap();
885 oracle.register(pk1.clone(), 1).await.unwrap();
886 oracle.register(pk2.clone(), 0).await.unwrap();
887 oracle.register(pk2.clone(), 1).await.unwrap();
888
889 assert!(matches!(
891 oracle.register(pk1.clone(), 1).await,
892 Err(Error::ChannelAlreadyRegistered(_))
893 ));
894
895 let link = ingress::Link {
897 latency: Duration::from_millis(2),
898 jitter: Duration::from_millis(1),
899 success_rate: 0.9,
900 };
901 oracle
902 .add_link(pk1.clone(), pk2.clone(), link.clone())
903 .await
904 .unwrap();
905
906 assert!(matches!(
908 oracle.add_link(pk1, pk2, link).await,
909 Err(Error::LinkExists)
910 ));
911 });
912 }
913
914 #[test]
915 fn test_get_next_socket() {
916 let cfg = Config {
917 max_size: MAX_MESSAGE_SIZE,
918 };
919 let runner = deterministic::Runner::default();
920
921 runner.start(|context| async move {
922 type PublicKey = ed25519::PublicKey;
923 let (mut network, _) =
924 Network::<deterministic::Context, PublicKey>::new(context.clone(), cfg);
925
926 let mut original = network.next_addr;
928 let next = network.get_next_socket();
929 assert_eq!(next, original);
930 let next = network.get_next_socket();
931 original.set_port(1);
932 assert_eq!(next, original);
933
934 let max_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 0, 255, 255)), 65535);
936 network.next_addr = max_addr;
937 let next = network.get_next_socket();
938 assert_eq!(next, max_addr);
939 let next = network.get_next_socket();
940 assert_eq!(
941 next,
942 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 1, 0, 0)), 0)
943 );
944 });
945 }
946}