1use super::{
4 ingress::{self, Oracle},
5 metrics, Error,
6};
7use crate::{Channel, Message, Recipients};
8use bytes::Bytes;
9use commonware_macros::select;
10use commonware_runtime::{
11 deterministic::{Listener, Sink, Stream},
12 Clock, Handle, Listener as _, Metrics, Network as RNetwork, Spawner,
13};
14use commonware_stream::utils::codec::{recv_frame, send_frame};
15use commonware_utils::{Array, SizedSerialize};
16use futures::{
17 channel::{mpsc, oneshot},
18 SinkExt, StreamExt,
19};
20use prometheus_client::metrics::{counter::Counter, family::Family};
21use rand::Rng;
22use rand_distr::{Distribution, Normal};
23use std::{
24 collections::{BTreeMap, HashMap},
25 net::{IpAddr, Ipv4Addr, SocketAddr},
26 time::Duration,
27};
28use tracing::{error, trace};
29
30type Task<P> = (Channel, P, Recipients<P>, Bytes, oneshot::Sender<Vec<P>>);
32
33pub struct Config {
35 pub max_size: usize,
37}
38
39pub struct Network<E: RNetwork<Listener, Sink, Stream> + Spawner + Rng + Clock + Metrics, P: Array>
41{
42 context: E,
43
44 max_size: usize,
46
47 next_addr: SocketAddr,
50
51 ingress: mpsc::UnboundedReceiver<ingress::Message<P>>,
53
54 sender: mpsc::UnboundedSender<Task<P>>,
58 receiver: mpsc::UnboundedReceiver<Task<P>>,
59
60 links: HashMap<(P, P), Link>,
62
63 peers: BTreeMap<P, Peer<P>>,
65
66 received_messages: Family<metrics::Message, Counter>,
68 sent_messages: Family<metrics::Message, Counter>,
69}
70
71impl<E: RNetwork<Listener, Sink, Stream> + Spawner + Rng + Clock + Metrics, P: Array>
72 Network<E, P>
73{
74 pub fn new(context: E, cfg: Config) -> (Self, Oracle<P>) {
79 let (sender, receiver) = mpsc::unbounded();
80 let (oracle_sender, oracle_receiver) = mpsc::unbounded();
81 let sent_messages = Family::<metrics::Message, Counter>::default();
82 let received_messages = Family::<metrics::Message, Counter>::default();
83 context.register("messages_sent", "messages sent", sent_messages.clone());
84 context.register(
85 "messages_received",
86 "messages received",
87 received_messages.clone(),
88 );
89
90 let next_addr = SocketAddr::new(
92 IpAddr::V4(Ipv4Addr::from_bits(context.clone().next_u32())),
93 0,
94 );
95 (
96 Self {
97 context,
98 max_size: cfg.max_size,
99 next_addr,
100 ingress: oracle_receiver,
101 sender,
102 receiver,
103 links: HashMap::new(),
104 peers: BTreeMap::new(),
105 received_messages,
106 sent_messages,
107 },
108 Oracle::new(oracle_sender),
109 )
110 }
111
112 fn get_next_socket(&mut self) -> SocketAddr {
117 let result = self.next_addr;
118
119 match self.next_addr.port().checked_add(1) {
122 Some(port) => {
123 self.next_addr.set_port(port);
124 }
125 None => {
126 let ip = match self.next_addr.ip() {
127 IpAddr::V4(ipv4) => ipv4,
128 _ => unreachable!(),
129 };
130 let next_ip = Ipv4Addr::to_bits(ip).wrapping_add(1);
131 self.next_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(next_ip)), 0);
132 }
133 }
134
135 result
136 }
137
138 async fn handle_ingress(&mut self, message: ingress::Message<P>) {
142 fn send_result<T: std::fmt::Debug>(
146 result: oneshot::Sender<Result<T, Error>>,
147 value: Result<T, Error>,
148 ) {
149 let success = value.is_ok();
150 if let Err(e) = result.send(value) {
151 error!(?e, "failed to send result to oracle (ok = {})", success);
152 }
153 }
154
155 match message {
156 ingress::Message::Register {
157 public_key,
158 channel,
159 result,
160 } => {
161 if !self.peers.contains_key(&public_key) {
163 let peer = Peer::new(
164 &mut self.context.clone(),
165 public_key.clone(),
166 self.get_next_socket(),
167 self.max_size,
168 );
169 self.peers.insert(public_key.clone(), peer);
170 }
171
172 let peer = self.peers.get_mut(&public_key).unwrap();
174 let receiver = match peer.register(channel).await {
175 Ok(receiver) => Receiver { receiver },
176 Err(err) => return send_result(result, Err(err)),
177 };
178
179 let sender = Sender::new(
181 self.context.clone(),
182 public_key,
183 channel,
184 self.max_size,
185 self.sender.clone(),
186 );
187 send_result(result, Ok((sender, receiver)))
188 }
189 ingress::Message::AddLink {
190 sender,
191 receiver,
192 sampler,
193 success_rate,
194 result,
195 } => {
196 if !self.peers.contains_key(&sender) {
198 return send_result(result, Err(Error::PeerMissing));
199 }
200 let peer = match self.peers.get(&receiver) {
201 Some(peer) => peer,
202 None => return send_result(result, Err(Error::PeerMissing)),
203 };
204
205 let key = (sender.clone(), receiver);
207 if self.links.contains_key(&key) {
208 return send_result(result, Err(Error::LinkExists));
209 }
210
211 let link = Link::new(
212 &mut self.context,
213 sender,
214 peer.socket,
215 sampler,
216 success_rate,
217 self.max_size,
218 );
219 self.links.insert(key, link);
220 send_result(result, Ok(()))
221 }
222 ingress::Message::RemoveLink {
223 sender,
224 receiver,
225 result,
226 } => {
227 match self.links.remove(&(sender, receiver)) {
228 Some(_) => (),
229 None => return send_result(result, Err(Error::LinkMissing)),
230 }
231 send_result(result, Ok(()))
232 }
233 }
234 }
235
236 fn handle_task(&mut self, task: Task<P>) {
241 let (channel, origin, recipients, message, reply) = task;
243 let recipients = match recipients {
244 Recipients::All => self.peers.keys().cloned().collect(),
245 Recipients::Some(keys) => keys,
246 Recipients::One(key) => vec![key],
247 };
248
249 let mut sent = Vec::new();
251 let (acquired_sender, mut acquired_receiver) = mpsc::channel(recipients.len());
252 for recipient in recipients {
253 if recipient == origin {
255 trace!(?recipient, reason = "self", "dropping message",);
256 continue;
257 }
258
259 let mut link = match self
261 .links
262 .get(&(origin.clone(), recipient.clone()))
263 .cloned()
264 {
265 Some(link) => link,
266 None => {
267 trace!(?origin, ?recipient, reason = "no link", "dropping message",);
268 continue;
269 }
270 };
271
272 self.sent_messages
275 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
276 .inc();
277
278 let delay = link.sampler.sample(&mut self.context);
280 let should_deliver = self.context.gen_bool(link.success_rate);
281 trace!(?origin, ?recipient, ?delay, "sending message",);
282
283 self.context.with_label("messenger").spawn({
285 let message = message.clone();
286 let recipient = recipient.clone();
287 let origin = origin.clone();
288 let mut acquired_sender = acquired_sender.clone();
289 let received_messages = self.received_messages.clone();
290 move |context| async move {
291 acquired_sender.send(()).await.unwrap();
293
294 context.sleep(Duration::from_millis(delay as u64)).await;
299
300 if !should_deliver {
302 trace!(
303 ?recipient,
304 reason = "random link failure",
305 "dropping message",
306 );
307 return;
308 }
309
310 if let Err(err) = link.send(channel, message).await {
312 error!(?origin, ?recipient, ?err, "failed to send",);
314 return;
315 }
316
317 received_messages
319 .get_or_create(&metrics::Message::new(&origin, &recipient, channel))
320 .inc();
321 }
322 });
323 sent.push(recipient);
324 }
325
326 self.context
328 .clone()
329 .with_label("notifier")
330 .spawn(|_| async move {
331 for _ in 0..sent.len() {
333 acquired_receiver.next().await.unwrap();
334 }
335
336 if let Err(err) = reply.send(sent) {
338 error!(?err, "failed to send ack");
340 }
341 });
342 }
343
344 pub fn start(mut self) -> Handle<()> {
349 self.context.spawn_ref()(self.run())
350 }
351
352 async fn run(mut self) {
353 loop {
354 select! {
355 message = self.ingress.next() => {
356 let message = match message {
358 Some(message) => message,
359 None => break,
360 };
361 self.handle_ingress(message).await;
362 },
363 task = self.receiver.next() => {
364 let task = match task {
366 Some(task) => task,
367 None => break,
368 };
369 self.handle_task(task);
370 }
371 }
372 }
373 }
374}
375
376#[derive(Clone, Debug)]
378pub struct Sender<P: Array> {
379 me: P,
380 channel: Channel,
381 max_size: usize,
382 high: mpsc::UnboundedSender<Task<P>>,
383 low: mpsc::UnboundedSender<Task<P>>,
384}
385
386impl<P: Array> Sender<P> {
387 fn new(
388 context: impl Spawner + Metrics,
389 me: P,
390 channel: Channel,
391 max_size: usize,
392 mut sender: mpsc::UnboundedSender<Task<P>>,
393 ) -> Self {
394 let (high, mut high_receiver) = mpsc::unbounded();
396 let (low, mut low_receiver) = mpsc::unbounded();
397 context.with_label("sender").spawn(move |_| async move {
398 loop {
399 let task;
401 select! {
402 high_task = high_receiver.next() => {
403 task = match high_task {
404 Some(task) => task,
405 None => break,
406 };
407 },
408 low_task = low_receiver.next() => {
409 task = match low_task {
410 Some(task) => task,
411 None => break,
412 };
413 }
414 }
415
416 if let Err(err) = sender.send(task).await {
418 error!(?err, channel, "failed to send task");
419 }
420 }
421 });
422
423 Self {
425 me,
426 channel,
427 max_size,
428 high,
429 low,
430 }
431 }
432}
433
434impl<P: Array> crate::Sender for Sender<P> {
435 type Error = Error;
436 type PublicKey = P;
437
438 async fn send(
439 &mut self,
440 recipients: Recipients<P>,
441 message: Bytes,
442 priority: bool,
443 ) -> Result<Vec<P>, Error> {
444 if message.len() > self.max_size {
446 return Err(Error::MessageTooLarge(message.len()));
447 }
448
449 let (sender, receiver) = oneshot::channel();
451 let mut channel = if priority { &self.high } else { &self.low };
452 channel
453 .send((self.channel, self.me.clone(), recipients, message, sender))
454 .await
455 .map_err(|_| Error::NetworkClosed)?;
456 receiver.await.map_err(|_| Error::NetworkClosed)
457 }
458}
459
460type MessageReceiver<P> = mpsc::UnboundedReceiver<Message<P>>;
461type MessageReceiverResult<P> = Result<MessageReceiver<P>, Error>;
462
463#[derive(Debug)]
465pub struct Receiver<P: Array> {
466 receiver: MessageReceiver<P>,
467}
468
469impl<P: Array> crate::Receiver for Receiver<P> {
470 type Error = Error;
471 type PublicKey = P;
472
473 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Error> {
474 self.receiver.next().await.ok_or(Error::NetworkClosed)
475 }
476}
477
478struct Peer<P: Array> {
482 socket: SocketAddr,
484
485 control: mpsc::UnboundedSender<(Channel, oneshot::Sender<MessageReceiverResult<P>>)>,
487}
488
489impl<P: Array> Peer<P> {
490 fn new<E: Spawner + RNetwork<Listener, Sink, Stream> + Metrics>(
495 context: &mut E,
496 public_key: P,
497 socket: SocketAddr,
498 max_size: usize,
499 ) -> Self {
500 let (control_sender, mut control_receiver) = mpsc::unbounded();
503
504 let (inbox_sender, mut inbox_receiver) = mpsc::unbounded();
507
508 context.with_label("router").spawn(|_| async move {
510 let mut mailboxes = HashMap::new();
512
513 loop {
515 select! {
516 control = control_receiver.next() => {
518 let (channel, result): (Channel, oneshot::Sender<MessageReceiverResult<P>>) = match control {
520 Some(control) => control,
521 None => break,
522 };
523
524 if mailboxes.contains_key(&channel) {
526 result.send(Err(Error::ChannelAlreadyRegistered(channel))).unwrap();
527 continue;
528 }
529
530 let (sender, receiver) = mpsc::unbounded();
532 mailboxes.insert(channel, sender);
533 result.send(Ok(receiver)).unwrap();
534 },
535
536 inbox = inbox_receiver.next() => {
538 let (channel, message) = match inbox {
540 Some(message) => message,
541 None => break,
542 };
543
544 match mailboxes.get_mut(&channel) {
546 Some(mailbox) => {
547 if let Err(err) = mailbox.send(message).await {
548 error!(?err, "failed to send message to mailbox");
549 }
550 }
551 None => {
552 trace!(
553 recipient = ?public_key,
554 channel,
555 reason = "missing channel",
556 "dropping message",
557 );
558 }
559 }
560 },
561 }
562 }
563 });
564
565 context.with_label("listener").spawn({
567 let inbox_sender = inbox_sender.clone();
568 move |context| async move {
569 let mut listener = context.bind(socket).await.unwrap();
571
572 while let Ok((_, _, mut stream)) = listener.accept().await {
574 context.with_label("receiver").spawn({
576 let mut inbox_sender = inbox_sender.clone();
577 move |_| async move {
578 let dialer = match recv_frame(&mut stream, max_size).await {
580 Ok(data) => data,
581 Err(_) => {
582 error!("failed to receive public key from dialer");
583 return;
584 }
585 };
586 let Ok(dialer) = P::try_from(dialer.as_ref()) else {
587 error!("received public key is invalid");
588 return;
589 };
590
591 while let Ok(data) = recv_frame(&mut stream, max_size).await {
593 let channel = Channel::from_be_bytes(
594 data[..Channel::SERIALIZED_LEN].try_into().unwrap(),
595 );
596 let message = data.slice(Channel::SERIALIZED_LEN..);
597 if let Err(err) = inbox_sender
598 .send((channel, (dialer.clone(), message)))
599 .await
600 {
601 error!(?err, "failed to send message to mailbox");
602 break;
603 }
604 }
605 }
606 });
607 }
608 }
609 });
610
611 Self {
613 socket,
614 control: control_sender,
615 }
616 }
617
618 async fn register(&mut self, channel: Channel) -> MessageReceiverResult<P> {
623 let (sender, receiver) = oneshot::channel();
624 self.control
625 .send((channel, sender))
626 .await
627 .map_err(|_| Error::NetworkClosed)?;
628 receiver.await.map_err(|_| Error::NetworkClosed)?
629 }
630}
631
632#[derive(Clone)]
635struct Link {
636 sampler: Normal<f64>,
637 success_rate: f64,
638 inbox: mpsc::UnboundedSender<(Channel, Bytes)>,
639}
640
641impl Link {
642 fn new<E: Spawner + RNetwork<Listener, Sink, Stream> + Metrics, P: Array>(
643 context: &mut E,
644 dialer: P,
645 socket: SocketAddr,
646 sampler: Normal<f64>,
647 success_rate: f64,
648 max_size: usize,
649 ) -> Self {
650 let (inbox, mut outbox) = mpsc::unbounded();
651 let result = Self {
652 sampler,
653 success_rate,
654 inbox,
655 };
656
657 context
660 .clone()
661 .with_label("link")
662 .spawn(move |context| async move {
663 let (mut sink, _) = context.dial(socket).await.unwrap();
665 if let Err(err) = send_frame(&mut sink, &dialer, max_size).await {
666 error!(?err, "failed to send public key to dialee");
667 return;
668 }
669
670 while let Some((channel, message)) = outbox.next().await {
672 let mut data =
673 bytes::BytesMut::with_capacity(Channel::SERIALIZED_LEN + message.len());
674 data.extend_from_slice(&channel.to_be_bytes());
675 data.extend_from_slice(&message);
676 let data = data.freeze();
677 send_frame(&mut sink, &data, max_size).await.unwrap();
678 }
679 });
680
681 result
682 }
683
684 async fn send(&mut self, channel: Channel, message: Bytes) -> Result<(), Error> {
686 self.inbox
687 .send((channel, message))
688 .await
689 .map_err(|_| Error::NetworkClosed)?;
690 Ok(())
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use commonware_cryptography::{Ed25519, Scheme};
698 use commonware_runtime::{
699 deterministic::{Context, Executor},
700 Runner,
701 };
702
703 const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
704
705 #[test]
706 fn test_register_and_link() {
707 let (executor, context, _) = Executor::default();
708 executor.start(async move {
709 let cfg = Config {
710 max_size: MAX_MESSAGE_SIZE,
711 };
712 let network_context = context.with_label("network");
713 let (network, mut oracle) = Network::new(network_context.clone(), cfg);
714 network_context.spawn(|_| network.run());
715
716 let pk1 = Ed25519::from_seed(1).public_key();
718 let pk2 = Ed25519::from_seed(2).public_key();
719
720 oracle.register(pk1.clone(), 0).await.unwrap();
722 oracle.register(pk1.clone(), 1).await.unwrap();
723 oracle.register(pk2.clone(), 0).await.unwrap();
724 oracle.register(pk2.clone(), 1).await.unwrap();
725
726 assert!(matches!(
728 oracle.register(pk1.clone(), 1).await,
729 Err(Error::ChannelAlreadyRegistered(_))
730 ));
731
732 let link = ingress::Link {
734 latency: 2.0,
735 jitter: 1.0,
736 success_rate: 0.9,
737 };
738 oracle
739 .add_link(pk1.clone(), pk2.clone(), link.clone())
740 .await
741 .unwrap();
742
743 assert!(matches!(
745 oracle.add_link(pk1, pk2, link).await,
746 Err(Error::LinkExists)
747 ));
748 });
749 }
750
751 #[test]
752 fn test_get_next_socket() {
753 let cfg = Config {
754 max_size: MAX_MESSAGE_SIZE,
755 };
756 let (_, context, _) = Executor::default();
757 type PublicKey = <Ed25519 as Scheme>::PublicKey;
758 let (mut network, _) = Network::<Context, PublicKey>::new(context.clone(), cfg);
759
760 let mut original = network.next_addr;
762 let next = network.get_next_socket();
763 assert_eq!(next, original);
764 let next = network.get_next_socket();
765 original.set_port(1);
766 assert_eq!(next, original);
767
768 let max_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 0, 255, 255)), 65535);
770 network.next_addr = max_addr;
771 let next = network.get_next_socket();
772 assert_eq!(next, max_addr);
773 let next = network.get_next_socket();
774 assert_eq!(
775 next,
776 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 1, 0, 0)), 0)
777 );
778 }
779}