1use super::{
4 ingress::{self, Oracle},
5 metrics, Error,
6};
7use crate::{Channel, Message, Recipients};
8use bytes::Bytes;
9use commonware_codec::{DecodeExt, FixedSize};
10use commonware_macros::select;
11use commonware_runtime::{Clock, Handle, Listener as _, Metrics, Network as RNetwork, Spawner};
12use commonware_stream::utils::codec::{recv_frame, send_frame};
13use commonware_utils::Array;
14use futures::{
15 channel::{mpsc, oneshot},
16 SinkExt, StreamExt,
17};
18use prometheus_client::metrics::{counter::Counter, family::Family};
19use rand::Rng;
20use rand_distr::{Distribution, Normal};
21use std::{
22 collections::{BTreeMap, HashMap},
23 net::{IpAddr, Ipv4Addr, SocketAddr},
24 time::Duration,
25};
26use tracing::{error, trace};
27
28type Task<P> = (Channel, P, Recipients<P>, Bytes, oneshot::Sender<Vec<P>>);
30
31pub struct Config {
33 pub max_size: usize,
35}
36
37pub struct Network<E: RNetwork + Spawner + Rng + Clock + Metrics, P: Array> {
39 context: E,
40
41 max_size: usize,
43
44 next_addr: SocketAddr,
47
48 ingress: mpsc::UnboundedReceiver<ingress::Message<P>>,
50
51 sender: mpsc::UnboundedSender<Task<P>>,
55 receiver: mpsc::UnboundedReceiver<Task<P>>,
56
57 links: HashMap<(P, P), Link>,
59
60 peers: BTreeMap<P, Peer<P>>,
62
63 received_messages: Family<metrics::Message, Counter>,
65 sent_messages: Family<metrics::Message, Counter>,
66}
67
68impl<E: RNetwork + Spawner + Rng + Clock + Metrics, P: Array> Network<E, P> {
69 pub fn new(context: E, cfg: Config) -> (Self, Oracle<P>) {
74 let (sender, receiver) = mpsc::unbounded();
75 let (oracle_sender, oracle_receiver) = mpsc::unbounded();
76 let sent_messages = Family::<metrics::Message, Counter>::default();
77 let received_messages = Family::<metrics::Message, Counter>::default();
78 context.register("messages_sent", "messages sent", sent_messages.clone());
79 context.register(
80 "messages_received",
81 "messages received",
82 received_messages.clone(),
83 );
84
85 let next_addr = SocketAddr::new(
87 IpAddr::V4(Ipv4Addr::from_bits(context.clone().next_u32())),
88 0,
89 );
90 (
91 Self {
92 context,
93 max_size: cfg.max_size,
94 next_addr,
95 ingress: oracle_receiver,
96 sender,
97 receiver,
98 links: HashMap::new(),
99 peers: BTreeMap::new(),
100 received_messages,
101 sent_messages,
102 },
103 Oracle::new(oracle_sender),
104 )
105 }
106
107 fn get_next_socket(&mut self) -> SocketAddr {
112 let result = self.next_addr;
113
114 match self.next_addr.port().checked_add(1) {
117 Some(port) => {
118 self.next_addr.set_port(port);
119 }
120 None => {
121 let ip = match self.next_addr.ip() {
122 IpAddr::V4(ipv4) => ipv4,
123 _ => unreachable!(),
124 };
125 let next_ip = Ipv4Addr::to_bits(ip).wrapping_add(1);
126 self.next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(next_ip)), 0);
127 }
128 }
129
130 result
131 }
132
133 async fn handle_ingress(&mut self, message: ingress::Message<P>) {
137 fn send_result<T: std::fmt::Debug>(
141 result: oneshot::Sender<Result<T, Error>>,
142 value: Result<T, Error>,
143 ) {
144 let success = value.is_ok();
145 if let Err(e) = result.send(value) {
146 error!(?e, "failed to send result to oracle (ok = {})", success);
147 }
148 }
149
150 match message {
151 ingress::Message::Register {
152 public_key,
153 channel,
154 result,
155 } => {
156 if !self.peers.contains_key(&public_key) {
158 let peer = Peer::new(
159 &mut self.context.clone(),
160 public_key.clone(),
161 self.get_next_socket(),
162 self.max_size,
163 );
164 self.peers.insert(public_key.clone(), peer);
165 }
166
167 let peer = self.peers.get_mut(&public_key).unwrap();
169 let receiver = match peer.register(channel).await {
170 Ok(receiver) => Receiver { receiver },
171 Err(err) => return send_result(result, Err(err)),
172 };
173
174 let sender = Sender::new(
176 self.context.clone(),
177 public_key,
178 channel,
179 self.max_size,
180 self.sender.clone(),
181 );
182 send_result(result, Ok((sender, receiver)))
183 }
184 ingress::Message::AddLink {
185 sender,
186 receiver,
187 sampler,
188 success_rate,
189 result,
190 } => {
191 if !self.peers.contains_key(&sender) {
193 return send_result(result, Err(Error::PeerMissing));
194 }
195 let peer = match self.peers.get(&receiver) {
196 Some(peer) => peer,
197 None => return send_result(result, Err(Error::PeerMissing)),
198 };
199
200 let key = (sender.clone(), receiver);
202 if self.links.contains_key(&key) {
203 return send_result(result, Err(Error::LinkExists));
204 }
205
206 let link = Link::new(
207 &mut self.context,
208 sender,
209 peer.socket,
210 sampler,
211 success_rate,
212 self.max_size,
213 );
214 self.links.insert(key, link);
215 send_result(result, Ok(()))
216 }
217 ingress::Message::RemoveLink {
218 sender,
219 receiver,
220 result,
221 } => {
222 match self.links.remove(&(sender, receiver)) {
223 Some(_) => (),
224 None => return send_result(result, Err(Error::LinkMissing)),
225 }
226 send_result(result, Ok(()))
227 }
228 }
229 }
230
231 fn handle_task(&mut self, task: Task<P>) {
236 let (channel, origin, recipients, message, reply) = task;
238 let recipients = match recipients {
239 Recipients::All => self.peers.keys().cloned().collect(),
240 Recipients::Some(keys) => keys,
241 Recipients::One(key) => vec![key],
242 };
243
244 let mut sent = Vec::new();
246 let (acquired_sender, mut acquired_receiver) = mpsc::channel(recipients.len());
247 for recipient in recipients {
248 if recipient == origin {
250 trace!(?recipient, reason = "self", "dropping message",);
251 continue;
252 }
253
254 let mut link = match self
256 .links
257 .get(&(origin.clone(), recipient.clone()))
258 .cloned()
259 {
260 Some(link) => link,
261 None => {
262 trace!(?origin, ?recipient, reason = "no link", "dropping message",);
263 continue;
264 }
265 };
266
267 self.sent_messages
270 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
271 .inc();
272
273 let delay = link.sampler.sample(&mut self.context);
275 let should_deliver = self.context.gen_bool(link.success_rate);
276 trace!(?origin, ?recipient, ?delay, "sending message",);
277
278 self.context.with_label("messenger").spawn({
280 let message = message.clone();
281 let recipient = recipient.clone();
282 let origin = origin.clone();
283 let mut acquired_sender = acquired_sender.clone();
284 let received_messages = self.received_messages.clone();
285 move |context| async move {
286 acquired_sender.send(()).await.unwrap();
288
289 context.sleep(Duration::from_millis(delay as u64)).await;
294
295 if !should_deliver {
297 trace!(
298 ?recipient,
299 reason = "random link failure",
300 "dropping message",
301 );
302 return;
303 }
304
305 if let Err(err) = link.send(channel, message).await {
307 error!(?origin, ?recipient, ?err, "failed to send",);
309 return;
310 }
311
312 received_messages
314 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
315 .inc();
316 }
317 });
318 sent.push(recipient);
319 }
320
321 self.context
323 .clone()
324 .with_label("notifier")
325 .spawn(|_| async move {
326 for _ in 0..sent.len() {
328 acquired_receiver.next().await.unwrap();
329 }
330
331 if let Err(err) = reply.send(sent) {
333 error!(?err, "failed to send ack");
335 }
336 });
337 }
338
339 pub fn start(mut self) -> Handle<()> {
344 self.context.spawn_ref()(self.run())
345 }
346
347 async fn run(mut self) {
348 loop {
349 select! {
350 message = self.ingress.next() => {
351 let message = match message {
353 Some(message) => message,
354 None => break,
355 };
356 self.handle_ingress(message).await;
357 },
358 task = self.receiver.next() => {
359 let task = match task {
361 Some(task) => task,
362 None => break,
363 };
364 self.handle_task(task);
365 }
366 }
367 }
368 }
369}
370
371#[derive(Clone, Debug)]
373pub struct Sender<P: Array> {
374 me: P,
375 channel: Channel,
376 max_size: usize,
377 high: mpsc::UnboundedSender<Task<P>>,
378 low: mpsc::UnboundedSender<Task<P>>,
379}
380
381impl<P: Array> Sender<P> {
382 fn new(
383 context: impl Spawner + Metrics,
384 me: P,
385 channel: Channel,
386 max_size: usize,
387 mut sender: mpsc::UnboundedSender<Task<P>>,
388 ) -> Self {
389 let (high, mut high_receiver) = mpsc::unbounded();
391 let (low, mut low_receiver) = mpsc::unbounded();
392 context.with_label("sender").spawn(move |_| async move {
393 loop {
394 let task;
396 select! {
397 high_task = high_receiver.next() => {
398 task = match high_task {
399 Some(task) => task,
400 None => break,
401 };
402 },
403 low_task = low_receiver.next() => {
404 task = match low_task {
405 Some(task) => task,
406 None => break,
407 };
408 }
409 }
410
411 if let Err(err) = sender.send(task).await {
413 error!(?err, channel, "failed to send task");
414 }
415 }
416 });
417
418 Self {
420 me,
421 channel,
422 max_size,
423 high,
424 low,
425 }
426 }
427}
428
429impl<P: Array> crate::Sender for Sender<P> {
430 type Error = Error;
431 type PublicKey = P;
432
433 async fn send(
434 &mut self,
435 recipients: Recipients<P>,
436 message: Bytes,
437 priority: bool,
438 ) -> Result<Vec<P>, Error> {
439 if message.len() > self.max_size {
441 return Err(Error::MessageTooLarge(message.len()));
442 }
443
444 let (sender, receiver) = oneshot::channel();
446 let mut channel = if priority { &self.high } else { &self.low };
447 channel
448 .send((self.channel, self.me.clone(), recipients, message, sender))
449 .await
450 .map_err(|_| Error::NetworkClosed)?;
451 receiver.await.map_err(|_| Error::NetworkClosed)
452 }
453}
454
455type MessageReceiver<P> = mpsc::UnboundedReceiver<Message<P>>;
456type MessageReceiverResult<P> = Result<MessageReceiver<P>, Error>;
457
458#[derive(Debug)]
460pub struct Receiver<P: Array> {
461 receiver: MessageReceiver<P>,
462}
463
464impl<P: Array> crate::Receiver for Receiver<P> {
465 type Error = Error;
466 type PublicKey = P;
467
468 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
469 self.receiver.next().await.ok_or(Error::NetworkClosed)
470 }
471}
472
473struct Peer<P: Array> {
477 socket: SocketAddr,
479
480 control: mpsc::UnboundedSender<(Channel, oneshot::Sender<MessageReceiverResult<P>>)>,
482}
483
484impl<P: Array> Peer<P> {
485 fn new<E: Spawner + RNetwork + Metrics>(
490 context: &mut E,
491 public_key: P,
492 socket: SocketAddr,
493 max_size: usize,
494 ) -> Self {
495 let (control_sender, mut control_receiver) = mpsc::unbounded();
498
499 let (inbox_sender, mut inbox_receiver) = mpsc::unbounded();
502
503 context.with_label("router").spawn(|_| async move {
505 let mut mailboxes = HashMap::new();
507
508 loop {
510 select! {
511 control = control_receiver.next() => {
513 let (channel, result): (Channel, oneshot::Sender<MessageReceiverResult<P>>) = match control {
515 Some(control) => control,
516 None => break,
517 };
518
519 if mailboxes.contains_key(&channel) {
521 result.send(Err(Error::ChannelAlreadyRegistered(channel))).unwrap();
522 continue;
523 }
524
525 let (sender, receiver) = mpsc::unbounded();
527 mailboxes.insert(channel, sender);
528 result.send(Ok(receiver)).unwrap();
529 },
530
531 inbox = inbox_receiver.next() => {
533 let (channel, message) = match inbox {
535 Some(message) => message,
536 None => break,
537 };
538
539 match mailboxes.get_mut(&channel) {
541 Some(mailbox) => {
542 if let Err(err) = mailbox.send(message).await {
543 error!(?err, "failed to send message to mailbox");
544 }
545 }
546 None => {
547 trace!(
548 recipient = ?public_key,
549 channel,
550 reason = "missing channel",
551 "dropping message",
552 );
553 }
554 }
555 },
556 }
557 }
558 });
559
560 context.with_label("listener").spawn({
562 let inbox_sender = inbox_sender.clone();
563 move |context| async move {
564 let mut listener = context.bind(socket).await.unwrap();
566
567 while let Ok((_, _, mut stream)) = listener.accept().await {
569 context.with_label("receiver").spawn({
571 let mut inbox_sender = inbox_sender.clone();
572 move |_| async move {
573 let dialer = match recv_frame(&mut stream, max_size).await {
575 Ok(data) => data,
576 Err(_) => {
577 error!("failed to receive public key from dialer");
578 return;
579 }
580 };
581 let Ok(dialer) = P::decode(dialer.as_ref()) else {
582 error!("received public key is invalid");
583 return;
584 };
585
586 while let Ok(data) = recv_frame(&mut stream, max_size).await {
588 let channel = Channel::from_be_bytes(
589 data[..Channel::SIZE].try_into().unwrap(),
590 );
591 let message = data.slice(Channel::SIZE..);
592 if let Err(err) = inbox_sender
593 .send((channel, (dialer.clone(), message)))
594 .await
595 {
596 error!(?err, "failed to send message to mailbox");
597 break;
598 }
599 }
600 }
601 });
602 }
603 }
604 });
605
606 Self {
608 socket,
609 control: control_sender,
610 }
611 }
612
613 async fn register(&mut self, channel: Channel) -> MessageReceiverResult<P> {
618 let (sender, receiver) = oneshot::channel();
619 self.control
620 .send((channel, sender))
621 .await
622 .map_err(|_| Error::NetworkClosed)?;
623 receiver.await.map_err(|_| Error::NetworkClosed)?
624 }
625}
626
627#[derive(Clone)]
630struct Link {
631 sampler: Normal<f64>,
632 success_rate: f64,
633 inbox: mpsc::UnboundedSender<(Channel, Bytes)>,
634}
635
636impl Link {
637 fn new<E: Spawner + RNetwork + Metrics, P: Array>(
638 context: &mut E,
639 dialer: P,
640 socket: SocketAddr,
641 sampler: Normal<f64>,
642 success_rate: f64,
643 max_size: usize,
644 ) -> Self {
645 let (inbox, mut outbox) = mpsc::unbounded();
646 let result = Self {
647 sampler,
648 success_rate,
649 inbox,
650 };
651
652 context
655 .clone()
656 .with_label("link")
657 .spawn(move |context| async move {
658 let (mut sink, _) = context.dial(socket).await.unwrap();
660 if let Err(err) = send_frame(&mut sink, &dialer, max_size).await {
661 error!(?err, "failed to send public key to listener");
662 return;
663 }
664
665 while let Some((channel, message)) = outbox.next().await {
667 let mut data = bytes::BytesMut::with_capacity(Channel::SIZE + message.len());
668 data.extend_from_slice(&channel.to_be_bytes());
669 data.extend_from_slice(&message);
670 let data = data.freeze();
671 send_frame(&mut sink, &data, max_size).await.unwrap();
672 }
673 });
674
675 result
676 }
677
678 async fn send(&mut self, channel: Channel, message: Bytes) -> Result<(), Error> {
680 self.inbox
681 .send((channel, message))
682 .await
683 .map_err(|_| Error::NetworkClosed)?;
684 Ok(())
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use commonware_cryptography::{Ed25519, Signer, Specification};
692 use commonware_runtime::{deterministic, Runner};
693
694 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
695
696 #[test]
697 fn test_register_and_link() {
698 let executor = deterministic::Runner::default();
699 executor.start(|context| async move {
700 let cfg = Config {
701 max_size: MAX_MESSAGE_SIZE,
702 };
703 let network_context = context.with_label("network");
704 let (network, mut oracle) = Network::new(network_context.clone(), cfg);
705 network_context.spawn(|_| network.run());
706
707 let pk1 = Ed25519::from_seed(1).public_key();
709 let pk2 = Ed25519::from_seed(2).public_key();
710
711 oracle.register(pk1.clone(), 0).await.unwrap();
713 oracle.register(pk1.clone(), 1).await.unwrap();
714 oracle.register(pk2.clone(), 0).await.unwrap();
715 oracle.register(pk2.clone(), 1).await.unwrap();
716
717 assert!(matches!(
719 oracle.register(pk1.clone(), 1).await,
720 Err(Error::ChannelAlreadyRegistered(_))
721 ));
722
723 let link = ingress::Link {
725 latency: 2.0,
726 jitter: 1.0,
727 success_rate: 0.9,
728 };
729 oracle
730 .add_link(pk1.clone(), pk2.clone(), link.clone())
731 .await
732 .unwrap();
733
734 assert!(matches!(
736 oracle.add_link(pk1, pk2, link).await,
737 Err(Error::LinkExists)
738 ));
739 });
740 }
741
742 #[test]
743 fn test_get_next_socket() {
744 let cfg = Config {
745 max_size: MAX_MESSAGE_SIZE,
746 };
747 let runner = deterministic::Runner::default();
748
749 runner.start(|context| async move {
750 type PublicKey = <Ed25519 as Specification>::PublicKey;
751 let (mut network, _) =
752 Network::<deterministic::Context, PublicKey>::new(context.clone(), cfg);
753
754 let mut original = network.next_addr;
756 let next = network.get_next_socket();
757 assert_eq!(next, original);
758 let next = network.get_next_socket();
759 original.set_port(1);
760 assert_eq!(next, original);
761
762 let max_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 0, 255, 255)), 65535);
764 network.next_addr = max_addr;
765 let next = network.get_next_socket();
766 assert_eq!(next, max_addr);
767 let next = network.get_next_socket();
768 assert_eq!(
769 next,
770 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 1, 0, 0)), 0)
771 );
772 });
773 }
774}