1use super::{
4 ingress::{self, Oracle},
5 metrics, Error,
6};
7use crate::{Channel, Message, Recipients};
8use bytes::Bytes;
9use commonware_codec::{DecodeExt, FixedSize};
10use commonware_cryptography::PublicKey;
11use commonware_macros::select;
12use commonware_runtime::{Clock, Handle, Listener as _, Metrics, Network as RNetwork, Spawner};
13use commonware_stream::utils::codec::{recv_frame, send_frame};
14use futures::{
15 channel::{mpsc, oneshot},
16 SinkExt, StreamExt,
17};
18use prometheus_client::metrics::{counter::Counter, family::Family};
19use rand::Rng;
20use rand_distr::{Distribution, Normal};
21use std::{
22 collections::{BTreeMap, HashMap, HashSet},
23 net::{IpAddr, Ipv4Addr, SocketAddr},
24 time::Duration,
25};
26use tracing::{error, trace};
27
28type Task<P> = (Channel, P, Recipients<P>, Bytes, oneshot::Sender<Vec<P>>);
30
31pub struct Config {
33 pub max_size: usize,
35}
36
37pub struct Network<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> {
39 context: E,
40
41 max_size: usize,
43
44 next_addr: SocketAddr,
47
48 ingress: mpsc::UnboundedReceiver<ingress::Message<P>>,
50
51 sender: mpsc::UnboundedSender<Task<P>>,
55 receiver: mpsc::UnboundedReceiver<Task<P>>,
56
57 links: HashMap<(P, P), Link>,
59
60 peers: BTreeMap<P, Peer<P>>,
62
63 blocks: HashSet<(P, P)>,
65
66 received_messages: Family<metrics::Message, Counter>,
68 sent_messages: Family<metrics::Message, Counter>,
69}
70
71impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: PublicKey> Network<E, P> {
72 pub fn new(context: E, cfg: Config) -> (Self, Oracle<P>) {
77 let (sender, receiver) = mpsc::unbounded();
78 let (oracle_sender, oracle_receiver) = mpsc::unbounded();
79 let sent_messages = Family::<metrics::Message, Counter>::default();
80 let received_messages = Family::<metrics::Message, Counter>::default();
81 context.register("messages_sent", "messages sent", sent_messages.clone());
82 context.register(
83 "messages_received",
84 "messages received",
85 received_messages.clone(),
86 );
87
88 let next_addr = SocketAddr::new(
90 IpAddr::V4(Ipv4Addr::from_bits(context.clone().next_u32())),
91 0,
92 );
93 (
94 Self {
95 context,
96 max_size: cfg.max_size,
97 next_addr,
98 ingress: oracle_receiver,
99 sender,
100 receiver,
101 links: HashMap::new(),
102 peers: BTreeMap::new(),
103 blocks: HashSet::new(),
104 received_messages,
105 sent_messages,
106 },
107 Oracle::new(oracle_sender.clone()),
108 )
109 }
110
111 fn get_next_socket(&mut self) -> SocketAddr {
116 let result = self.next_addr;
117
118 match self.next_addr.port().checked_add(1) {
121 Some(port) => {
122 self.next_addr.set_port(port);
123 }
124 None => {
125 let ip = match self.next_addr.ip() {
126 IpAddr::V4(ipv4) => ipv4,
127 _ => unreachable!(),
128 };
129 let next_ip = Ipv4Addr::to_bits(ip).wrapping_add(1);
130 self.next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(next_ip)), 0);
131 }
132 }
133
134 result
135 }
136
137 async fn handle_ingress(&mut self, message: ingress::Message<P>) {
141 fn send_result<T: std::fmt::Debug>(
145 result: oneshot::Sender<Result<T, Error>>,
146 value: Result<T, Error>,
147 ) {
148 let success = value.is_ok();
149 if let Err(e) = result.send(value) {
150 error!(?e, "failed to send result to oracle (ok = {})", success);
151 }
152 }
153
154 match message {
155 ingress::Message::Register {
156 public_key,
157 channel,
158 result,
159 } => {
160 if !self.peers.contains_key(&public_key) {
162 let peer = Peer::new(
163 &mut self.context.clone(),
164 public_key.clone(),
165 self.get_next_socket(),
166 self.max_size,
167 );
168 self.peers.insert(public_key.clone(), peer);
169 }
170
171 let peer = self.peers.get_mut(&public_key).unwrap();
173 let receiver = match peer.register(channel).await {
174 Ok(receiver) => Receiver { receiver },
175 Err(err) => return send_result(result, Err(err)),
176 };
177
178 let sender = Sender::new(
180 self.context.clone(),
181 public_key,
182 channel,
183 self.max_size,
184 self.sender.clone(),
185 );
186 send_result(result, Ok((sender, receiver)))
187 }
188 ingress::Message::AddLink {
189 sender,
190 receiver,
191 sampler,
192 success_rate,
193 result,
194 } => {
195 if !self.peers.contains_key(&sender) {
197 return send_result(result, Err(Error::PeerMissing));
198 }
199 let peer = match self.peers.get(&receiver) {
200 Some(peer) => peer,
201 None => return send_result(result, Err(Error::PeerMissing)),
202 };
203
204 let key = (sender.clone(), receiver);
206 if self.links.contains_key(&key) {
207 return send_result(result, Err(Error::LinkExists));
208 }
209
210 let link = Link::new(
211 &mut self.context,
212 sender,
213 peer.socket,
214 sampler,
215 success_rate,
216 self.max_size,
217 );
218 self.links.insert(key, link);
219 send_result(result, Ok(()))
220 }
221 ingress::Message::RemoveLink {
222 sender,
223 receiver,
224 result,
225 } => {
226 match self.links.remove(&(sender, receiver)) {
227 Some(_) => (),
228 None => return send_result(result, Err(Error::LinkMissing)),
229 }
230 send_result(result, Ok(()))
231 }
232 ingress::Message::Block { from, to } => {
233 self.blocks.insert((from, to));
234 }
235 ingress::Message::Blocked { result } => {
236 send_result(result, Ok(self.blocks.iter().cloned().collect()))
237 }
238 }
239 }
240
241 fn handle_task(&mut self, task: Task<P>) {
246 let (channel, origin, recipients, message, reply) = task;
248 let recipients = match recipients {
249 Recipients::All => self.peers.keys().cloned().collect(),
250 Recipients::Some(keys) => keys,
251 Recipients::One(key) => vec![key],
252 };
253
254 let mut sent = Vec::new();
256 let (acquired_sender, mut acquired_receiver) = mpsc::channel(recipients.len());
257 for recipient in recipients {
258 if recipient == origin {
260 trace!(?recipient, reason = "self", "dropping message",);
261 continue;
262 }
263
264 let o_r = (origin.clone(), recipient.clone());
266 let r_o = (recipient.clone(), origin.clone());
267 if self.blocks.contains(&o_r) || self.blocks.contains(&r_o) {
268 trace!(?origin, ?recipient, reason = "blocked", "dropping message");
269 continue;
270 }
271
272 let mut link = match self.links.get(&o_r).cloned() {
274 Some(link) => link,
275 None => {
276 trace!(?origin, ?recipient, reason = "no link", "dropping message",);
277 continue;
278 }
279 };
280
281 self.sent_messages
284 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
285 .inc();
286
287 let delay = link.sampler.sample(&mut self.context);
289 let should_deliver = self.context.gen_bool(link.success_rate);
290 trace!(?origin, ?recipient, ?delay, "sending message",);
291
292 self.context.with_label("messenger").spawn({
294 let message = message.clone();
295 let recipient = recipient.clone();
296 let origin = origin.clone();
297 let mut acquired_sender = acquired_sender.clone();
298 let received_messages = self.received_messages.clone();
299 move |context| async move {
300 acquired_sender.send(()).await.unwrap();
302
303 context.sleep(Duration::from_millis(delay as u64)).await;
308
309 if !should_deliver {
311 trace!(
312 ?recipient,
313 reason = "random link failure",
314 "dropping message",
315 );
316 return;
317 }
318
319 if let Err(err) = link.send(channel, message).await {
321 error!(?origin, ?recipient, ?err, "failed to send",);
323 return;
324 }
325
326 received_messages
328 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
329 .inc();
330 }
331 });
332 sent.push(recipient);
333 }
334
335 self.context
337 .clone()
338 .with_label("notifier")
339 .spawn(|_| async move {
340 for _ in 0..sent.len() {
342 acquired_receiver.next().await.unwrap();
343 }
344
345 if let Err(err) = reply.send(sent) {
347 error!(?err, "failed to send ack");
349 }
350 });
351 }
352
353 pub fn start(mut self) -> Handle<()> {
358 self.context.spawn_ref()(self.run())
359 }
360
361 async fn run(mut self) {
362 loop {
363 select! {
364 message = self.ingress.next() => {
365 let message = match message {
367 Some(message) => message,
368 None => break,
369 };
370 self.handle_ingress(message).await;
371 },
372 task = self.receiver.next() => {
373 let task = match task {
375 Some(task) => task,
376 None => break,
377 };
378 self.handle_task(task);
379 }
380 }
381 }
382 }
383}
384
385#[derive(Clone, Debug)]
387pub struct Sender<P: PublicKey> {
388 me: P,
389 channel: Channel,
390 max_size: usize,
391 high: mpsc::UnboundedSender<Task<P>>,
392 low: mpsc::UnboundedSender<Task<P>>,
393}
394
395impl<P: PublicKey> Sender<P> {
396 fn new(
397 context: impl Spawner + Metrics,
398 me: P,
399 channel: Channel,
400 max_size: usize,
401 mut sender: mpsc::UnboundedSender<Task<P>>,
402 ) -> Self {
403 let (high, mut high_receiver) = mpsc::unbounded();
405 let (low, mut low_receiver) = mpsc::unbounded();
406 context.with_label("sender").spawn(move |_| async move {
407 loop {
408 let task;
410 select! {
411 high_task = high_receiver.next() => {
412 task = match high_task {
413 Some(task) => task,
414 None => break,
415 };
416 },
417 low_task = low_receiver.next() => {
418 task = match low_task {
419 Some(task) => task,
420 None => break,
421 };
422 }
423 }
424
425 if let Err(err) = sender.send(task).await {
427 error!(?err, channel, "failed to send task");
428 }
429 }
430 });
431
432 Self {
434 me,
435 channel,
436 max_size,
437 high,
438 low,
439 }
440 }
441}
442
443impl<P: PublicKey> crate::Sender for Sender<P> {
444 type Error = Error;
445 type PublicKey = P;
446
447 async fn send(
448 &mut self,
449 recipients: Recipients<P>,
450 message: Bytes,
451 priority: bool,
452 ) -> Result<Vec<P>, Error> {
453 if message.len() > self.max_size {
455 return Err(Error::MessageTooLarge(message.len()));
456 }
457
458 let (sender, receiver) = oneshot::channel();
460 let mut channel = if priority { &self.high } else { &self.low };
461 channel
462 .send((self.channel, self.me.clone(), recipients, message, sender))
463 .await
464 .map_err(|_| Error::NetworkClosed)?;
465 receiver.await.map_err(|_| Error::NetworkClosed)
466 }
467}
468
469type MessageReceiver<P> = mpsc::UnboundedReceiver<Message<P>>;
470type MessageReceiverResult<P> = Result<MessageReceiver<P>, Error>;
471
472#[derive(Debug)]
474pub struct Receiver<P: PublicKey> {
475 receiver: MessageReceiver<P>,
476}
477
478impl<P: PublicKey> crate::Receiver for Receiver<P> {
479 type Error = Error;
480 type PublicKey = P;
481
482 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
483 self.receiver.next().await.ok_or(Error::NetworkClosed)
484 }
485}
486
487struct Peer<P: PublicKey> {
491 socket: SocketAddr,
493
494 control: mpsc::UnboundedSender<(Channel, oneshot::Sender<MessageReceiverResult<P>>)>,
496}
497
498impl<P: PublicKey> Peer<P> {
499 fn new<E: Spawner + RNetwork + Metrics>(
504 context: &mut E,
505 public_key: P,
506 socket: SocketAddr,
507 max_size: usize,
508 ) -> Self {
509 let (control_sender, mut control_receiver) = mpsc::unbounded();
512
513 let (inbox_sender, mut inbox_receiver) = mpsc::unbounded();
516
517 context.with_label("router").spawn(|_| async move {
519 let mut mailboxes = HashMap::new();
521
522 loop {
524 select! {
525 control = control_receiver.next() => {
527 let (channel, result): (Channel, oneshot::Sender<MessageReceiverResult<P>>) = match control {
529 Some(control) => control,
530 None => break,
531 };
532
533 if mailboxes.contains_key(&channel) {
535 result.send(Err(Error::ChannelAlreadyRegistered(channel))).unwrap();
536 continue;
537 }
538
539 let (sender, receiver) = mpsc::unbounded();
541 mailboxes.insert(channel, sender);
542 result.send(Ok(receiver)).unwrap();
543 },
544
545 inbox = inbox_receiver.next() => {
547 let (channel, message) = match inbox {
549 Some(message) => message,
550 None => break,
551 };
552
553 match mailboxes.get_mut(&channel) {
555 Some(mailbox) => {
556 if let Err(err) = mailbox.send(message).await {
557 error!(?err, "failed to send message to mailbox");
558 }
559 }
560 None => {
561 trace!(
562 recipient = ?public_key,
563 channel,
564 reason = "missing channel",
565 "dropping message",
566 );
567 }
568 }
569 },
570 }
571 }
572 });
573
574 context.with_label("listener").spawn({
576 let inbox_sender = inbox_sender.clone();
577 move |context| async move {
578 let mut listener = context.bind(socket).await.unwrap();
580
581 while let Ok((_, _, mut stream)) = listener.accept().await {
583 context.with_label("receiver").spawn({
585 let mut inbox_sender = inbox_sender.clone();
586 move |_| async move {
587 let dialer = match recv_frame(&mut stream, max_size).await {
589 Ok(data) => data,
590 Err(_) => {
591 error!("failed to receive public key from dialer");
592 return;
593 }
594 };
595 let Ok(dialer) = P::decode(dialer.as_ref()) else {
596 error!("received public key is invalid");
597 return;
598 };
599
600 while let Ok(data) = recv_frame(&mut stream, max_size).await {
602 let channel = Channel::from_be_bytes(
603 data[..Channel::SIZE].try_into().unwrap(),
604 );
605 let message = data.slice(Channel::SIZE..);
606 if let Err(err) = inbox_sender
607 .send((channel, (dialer.clone(), message)))
608 .await
609 {
610 error!(?err, "failed to send message to mailbox");
611 break;
612 }
613 }
614 }
615 });
616 }
617 }
618 });
619
620 Self {
622 socket,
623 control: control_sender,
624 }
625 }
626
627 async fn register(&mut self, channel: Channel) -> MessageReceiverResult<P> {
632 let (sender, receiver) = oneshot::channel();
633 self.control
634 .send((channel, sender))
635 .await
636 .map_err(|_| Error::NetworkClosed)?;
637 receiver.await.map_err(|_| Error::NetworkClosed)?
638 }
639}
640
641#[derive(Clone)]
644struct Link {
645 sampler: Normal<f64>,
646 success_rate: f64,
647 inbox: mpsc::UnboundedSender<(Channel, Bytes)>,
648}
649
650impl Link {
651 fn new<E: Spawner + RNetwork + Metrics, P: PublicKey>(
652 context: &mut E,
653 dialer: P,
654 socket: SocketAddr,
655 sampler: Normal<f64>,
656 success_rate: f64,
657 max_size: usize,
658 ) -> Self {
659 let (inbox, mut outbox) = mpsc::unbounded();
660 let result = Self {
661 sampler,
662 success_rate,
663 inbox,
664 };
665
666 context
669 .clone()
670 .with_label("link")
671 .spawn(move |context| async move {
672 let (mut sink, _) = context.dial(socket).await.unwrap();
674 if let Err(err) = send_frame(&mut sink, &dialer, max_size).await {
675 error!(?err, "failed to send public key to listener");
676 return;
677 }
678
679 while let Some((channel, message)) = outbox.next().await {
681 let mut data = bytes::BytesMut::with_capacity(Channel::SIZE + message.len());
682 data.extend_from_slice(&channel.to_be_bytes());
683 data.extend_from_slice(&message);
684 let data = data.freeze();
685 send_frame(&mut sink, &data, max_size).await.unwrap();
686 }
687 });
688
689 result
690 }
691
692 async fn send(&mut self, channel: Channel, message: Bytes) -> Result<(), Error> {
694 self.inbox
695 .send((channel, message))
696 .await
697 .map_err(|_| Error::NetworkClosed)?;
698 Ok(())
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use commonware_cryptography::{ed25519, PrivateKeyExt as _, Signer as _};
706 use commonware_runtime::{deterministic, Runner};
707
708 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
709
710 #[test]
711 fn test_register_and_link() {
712 let executor = deterministic::Runner::default();
713 executor.start(|context| async move {
714 let cfg = Config {
715 max_size: MAX_MESSAGE_SIZE,
716 };
717 let network_context = context.with_label("network");
718 let (network, mut oracle) = Network::new(network_context.clone(), cfg);
719 network_context.spawn(|_| network.run());
720
721 let pk1 = ed25519::PrivateKey::from_seed(1).public_key();
723 let pk2 = ed25519::PrivateKey::from_seed(2).public_key();
724
725 oracle.register(pk1.clone(), 0).await.unwrap();
727 oracle.register(pk1.clone(), 1).await.unwrap();
728 oracle.register(pk2.clone(), 0).await.unwrap();
729 oracle.register(pk2.clone(), 1).await.unwrap();
730
731 assert!(matches!(
733 oracle.register(pk1.clone(), 1).await,
734 Err(Error::ChannelAlreadyRegistered(_))
735 ));
736
737 let link = ingress::Link {
739 latency: 2.0,
740 jitter: 1.0,
741 success_rate: 0.9,
742 };
743 oracle
744 .add_link(pk1.clone(), pk2.clone(), link.clone())
745 .await
746 .unwrap();
747
748 assert!(matches!(
750 oracle.add_link(pk1, pk2, link).await,
751 Err(Error::LinkExists)
752 ));
753 });
754 }
755
756 #[test]
757 fn test_get_next_socket() {
758 let cfg = Config {
759 max_size: MAX_MESSAGE_SIZE,
760 };
761 let runner = deterministic::Runner::default();
762
763 runner.start(|context| async move {
764 type PublicKey = ed25519::PublicKey;
765 let (mut network, _) =
766 Network::<deterministic::Context, PublicKey>::new(context.clone(), cfg);
767
768 let mut original = network.next_addr;
770 let next = network.get_next_socket();
771 assert_eq!(next, original);
772 let next = network.get_next_socket();
773 original.set_port(1);
774 assert_eq!(next, original);
775
776 let max_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 0, 255, 255)), 65535);
778 network.next_addr = max_addr;
779 let next = network.get_next_socket();
780 assert_eq!(next, max_addr);
781 let next = network.get_next_socket();
782 assert_eq!(
783 next,
784 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 1, 0, 0)), 0)
785 );
786 });
787 }
788}