1#![forbid(unsafe_code)]
2
3use async_trait::async_trait;
4use futures::StreamExt;
5use itertools::Itertools;
6use local_delivery::LocalDelivery;
7use message_tracker::MessageTracker;
8use num::traits::NumOps;
9use num::Num;
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::fmt::{Debug, Display};
14use std::hash::Hash;
15use std::sync::atomic::{AtomicBool, Ordering};
16use std::sync::Arc;
17use tokio::sync::Mutex;
18use tokio::time::{sleep, Duration};
19
20pub mod local_delivery;
21pub(crate) mod message_tracker;
22
23#[cfg(feature = "testing")]
24pub mod testing;
25
26const OUTBOUND_POLL: Duration = Duration::from_millis(200);
27const INBOUND_POLL: Duration = Duration::from_millis(200);
28
29#[async_trait]
30pub trait MessageMetadata: Debug + Send + Sync + 'static {
31 type PeerId: Default
32 + Display
33 + Debug
34 + Hash
35 + Eq
36 + Copy
37 + Ord
38 + Serialize
39 + DeserializeOwned
40 + Send
41 + Sync
42 + 'static;
43 type MessageId: Num
44 + NumOps
45 + Eq
46 + Default
47 + PartialEq
48 + Display
49 + Debug
50 + Hash
51 + Ord
52 + PartialOrd
53 + Copy
54 + Serialize
55 + DeserializeOwned
56 + Send
57 + Sync
58 + 'static;
59
60 type Contents: Send + Sync + 'static;
61
62 fn source_id(&self) -> Self::PeerId;
63 fn destination_id(&self) -> Self::PeerId;
64 fn message_id(&self) -> Self::MessageId;
65 fn contents(&self) -> &Self::Contents;
66 fn construct_from_parts(
67 source_id: Self::PeerId,
68 destination_id: Self::PeerId,
69 message_id: Self::MessageId,
70 contents: impl Into<Self::Contents>,
71 ) -> Self;
72}
73
74#[async_trait]
75pub trait UnderlyingSessionTransport {
76 type Message: MessageMetadata + Send + Sync + 'static;
77
78 async fn next_message(&self) -> Option<Payload<Self::Message>>;
79 async fn send_message(
80 &self,
81 message: Payload<Self::Message>,
82 ) -> Result<(), NetworkError<Payload<Self::Message>>>;
83 async fn connected_peers(&self) -> Vec<<Self::Message as MessageMetadata>::PeerId>;
84 fn local_id(&self) -> <Self::Message as MessageMetadata>::PeerId;
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88pub enum Payload<M: MessageMetadata> {
89 Ack {
90 from_id: M::PeerId,
91 to_id: M::PeerId,
92 message_id: M::MessageId,
93 },
94 Message(M),
95 Poll {
96 from_id: M::PeerId,
97 to_id: M::PeerId,
98 },
99}
100
101impl<M: MessageMetadata> Payload<M> {
102 pub fn source_id(&self) -> M::PeerId {
103 match self {
104 Payload::Ack { from_id, .. } => *from_id,
105 Payload::Message(msg) => msg.source_id(),
106 Payload::Poll { from_id, .. } => *from_id,
107 }
108 }
109 pub fn destination_id(&self) -> M::PeerId {
110 match self {
111 Payload::Ack { to_id, .. } => *to_id,
112 Payload::Message(msg) => msg.destination_id(),
113 Payload::Poll { to_id, .. } => *to_id,
114 }
115 }
116
117 pub fn message_id(&self) -> Option<M::MessageId> {
118 match self {
119 Payload::Ack { message_id, .. } => Some(*message_id),
120 Payload::Message(msg) => Some(msg.message_id()),
121 Payload::Poll { .. } => None,
122 }
123 }
124}
125
126#[derive(Debug)]
127pub enum NetworkError<T> {
128 SendFailed { reason: String, message: T },
129 ConnectionError(String),
130 BackendError(BackendError<T>),
131 ShutdownFailed(String),
132 SystemShutdown,
133}
134
135#[derive(Debug)]
136pub enum BackendError<T> {
137 StorageError(String),
138 SendFailed { reason: String, message: T },
139 NotFound,
140}
141
142#[derive(Debug, Copy, Clone)]
143pub enum DeliveryError {
144 NoReceiver,
145 ChannelClosed,
146 BadInput,
147}
148
149#[async_trait]
151#[auto_impl::auto_impl(&, Arc, Box)]
152pub trait Backend<M: MessageMetadata> {
155 async fn store_outbound(&self, message: M) -> Result<(), BackendError<M>>;
156 async fn store_inbound(&self, message: M) -> Result<(), BackendError<M>>;
157 async fn clear_message_inbound(
158 &self,
159 peer_id: M::PeerId,
160 message_id: M::MessageId,
161 ) -> Result<(), BackendError<M>>;
162 async fn clear_message_outbound(
163 &self,
164 peer_id: M::PeerId,
165 message_id: M::MessageId,
166 ) -> Result<(), BackendError<M>>;
167 async fn get_pending_outbound(&self) -> Result<Vec<M>, BackendError<M>>;
168 async fn get_pending_inbound(&self) -> Result<Vec<M>, BackendError<M>>;
169 async fn store_value(&self, key: &str, value: &[u8]) -> Result<(), BackendError<M>>;
171 async fn load_value(&self, key: &str) -> Result<Option<Vec<u8>>, BackendError<M>>;
172}
173
174const MAX_MAP_SIZE: usize = 1000;
175
176pub struct ILM<M, B, L, N>
177where
178 M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
179 B: Backend<M> + Send + Sync + 'static,
180 L: LocalDelivery<M> + Send + Sync + 'static,
181 N: UnderlyingSessionTransport<Message = M> + Send + Sync + 'static,
182{
183 backend: Arc<B>,
184 local_delivery: Arc<Mutex<Option<L>>>,
185 network: Arc<N>,
186 is_running: Arc<AtomicBool>,
187 is_shutting_down: Arc<AtomicBool>,
188 tracker: Arc<MessageTracker<M, B>>,
189 poll_inbound_tx: tokio::sync::mpsc::UnboundedSender<()>,
190 poll_outbound_tx: tokio::sync::mpsc::UnboundedSender<()>,
191 known_peers: Arc<Mutex<Vec<M::PeerId>>>,
192}
193
194impl<M, B, L, N> Drop for ILM<M, B, L, N>
195where
196 M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
197 B: Backend<M> + Send + Sync + 'static,
198 L: LocalDelivery<M> + Send + Sync + 'static,
199 N: UnderlyingSessionTransport<Message = M> + Send + Sync + 'static,
200{
201 fn drop(&mut self) {
202 if Arc::strong_count(&self.is_running) == 1 {
203 let _ = self.poll_outbound_tx.send(());
204 }
205 }
206}
207
208impl<M, B, L, N> ILM<M, B, L, N>
209where
210 M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
211 B: Backend<M> + Send + Sync + 'static,
212 L: LocalDelivery<M> + Send + Sync + 'static,
213 N: UnderlyingSessionTransport<Message = M> + Send + Sync + 'static,
214{
215 pub async fn new(backend: B, local_delivery: L, network: N) -> Result<Self, BackendError<M>> {
216 let (poll_inbound_tx, poll_inbound_rx) = tokio::sync::mpsc::unbounded_channel();
217 let (poll_outbound_tx, poll_outbound_rx) = tokio::sync::mpsc::unbounded_channel();
218
219 let backend = Arc::new(backend);
220 let this = Self {
221 backend: backend.clone(),
222 local_delivery: Arc::new(Mutex::new(Some(local_delivery))),
223 network: Arc::new(network),
224 is_running: Arc::new(AtomicBool::new(true)),
225 is_shutting_down: Arc::new(AtomicBool::new(false)),
226 tracker: Arc::new(MessageTracker::new(backend).await?),
227 poll_inbound_tx,
228 poll_outbound_tx,
229 known_peers: Arc::new(Mutex::new(Vec::new())),
230 };
231
232 this.spawn_background_tasks(poll_inbound_rx, poll_outbound_rx);
233
234 Ok(this)
235 }
236
237 fn clone_internal(&self) -> Self {
238 Self {
239 backend: self.backend.clone(),
240 local_delivery: self.local_delivery.clone(),
241 network: self.network.clone(),
242 is_running: self.is_running.clone(),
243 is_shutting_down: self.is_shutting_down.clone(),
244 tracker: self.tracker.clone(),
245 poll_inbound_tx: self.poll_inbound_tx.clone(),
246 poll_outbound_tx: self.poll_outbound_tx.clone(),
247 known_peers: self.known_peers.clone(),
248 }
249 }
250
251 fn spawn_background_tasks(
252 &self,
253 mut poll_inbound_rx: tokio::sync::mpsc::UnboundedReceiver<()>,
254 mut poll_outbound_rx: tokio::sync::mpsc::UnboundedReceiver<()>,
255 ) {
256 let this = self.clone_internal();
258
259 let background_task = async move {
260 let this = &this;
261
262 let outbound_handle = async move {
263 loop {
264 if !this.can_run() {
265 break;
266 }
267
268 tokio::select! {
269 res0 = poll_outbound_rx.recv() => {
270 if res0.is_none() {
271 log::warn!(target: "ism", "Poll outbound channel closed");
272 return;
273 }
274 },
275 _res1 = sleep(OUTBOUND_POLL) => {},
276 }
277
278 this.process_outbound().await;
279 }
280 };
281
282 let inbound_handle = async move {
284 loop {
285 if !this.can_run() {
286 break;
287 }
288
289 tokio::select! {
290 biased;
291 res0 = poll_inbound_rx.recv() => {
292 if res0.is_none() {
293 log::warn!(target: "ism", "Poll inbound channel closed");
294 }
295 },
296 _res1 = sleep(INBOUND_POLL) => {},
297 }
298
299 this.process_inbound().await;
300 }
301 };
302
303 let network_io_handle = async move {
305 loop {
306 if !this.can_run() {
307 break;
308 }
309
310 this.process_next_network_message().await;
311 }
312 };
313
314 let peer_polling_handle = async move {
316 loop {
317 if !this.can_run() {
318 break;
319 }
320
321 this.poll_peers().await;
322
323 sleep(Duration::from_secs(5)).await;
324 }
325 };
326
327 tokio::select! {
328 _ = outbound_handle => {
329 log::error!(target: "ism", "Outbound processing task prematurely ended");
330 },
331 _ = inbound_handle => {
332 log::error!(target: "ism", "Inbound processing task prematurely ended");
333 },
334 _ = network_io_handle => {
335 log::error!(target: "ism", "Network IO task prematurely ended");
336 },
337 _ = peer_polling_handle => {
338 log::error!(target: "ism", "Peer polling task prematurely ended");
339 },
340 }
341
342 if let Err(err) = this.tracker.sync_backend().await {
343 log::error!(target: "ism", "Failed to sync tracker state to backend on shutdown hook: {err:?}");
344 }
345
346 log::warn!(target: "ism", "Message system has shut down");
347
348 this.toggle_off();
349 drop(this.local_delivery.lock().await.take());
350 };
351
352 drop(tokio::spawn(background_task));
355 }
356
357 async fn poll_peers(&self) {
358 let connected_peers_now = self.get_connected_peers().await;
359 let mut current_peers_lock = self.known_peers.lock().await;
360 let connected_peers_previous = current_peers_lock
361 .iter()
362 .copied()
363 .sorted()
364 .collect::<Vec<_>>();
365 if connected_peers_now != connected_peers_previous {
366 log::info!(target: "ism", "Connected peers changed to {connected_peers_now:?}, sending poll for refresh in state");
367
368 for peer_id in connected_peers_now
370 .iter()
371 .filter(|id| !connected_peers_previous.contains(id))
372 {
373 if let Err(e) = self
374 .send_message_internal(Payload::Poll {
375 from_id: self.network.local_id(),
376 to_id: *peer_id,
377 })
378 .await
379 {
380 log::error!(target: "ism", "Failed to send poll to new peer: {:?}", e);
381 break;
382 }
383 }
384
385 *current_peers_lock = connected_peers_now;
386 }
387 }
388
389 async fn process_outbound(&self) {
390 let pending_messages = match self.backend.get_pending_outbound().await {
391 Ok(messages) => messages,
392 Err(e) => {
393 log::error!(target: "ism", "Failed to get pending outbound messages: {:?}", e);
394 return;
395 }
396 };
397
398 let mut grouped_messages: HashMap<M::PeerId, Vec<M>> = HashMap::new();
400 for msg in pending_messages {
401 grouped_messages
402 .entry(msg.destination_id())
403 .or_default()
404 .push(msg);
405 }
406
407 let connected_peers = &self.network.connected_peers().await;
408 futures::stream::iter(grouped_messages).for_each_concurrent(None, |(peer_id, messages)| {
410 async move {
411 if !connected_peers.contains(&peer_id) {
412 log::warn!(target: "ism", "Peer {peer_id} is not connected, skipping message until later");
413 return;
414 }
415
416 let messages = messages.into_iter().sorted_by_key(|r| r.message_id()).unique_by(|r| r.message_id()).collect::<Vec<_>>();
418
419 'peer: for msg in messages {
421 let message_id = msg.message_id();
422 if self.tracker.can_send(&peer_id, &message_id) {
423 log::trace!(target: "ism", "[CAN SEND] message: {:?}", msg);
424 if let Err(e) = self.send_message_internal(Payload::Message(msg)).await {
425 log::error!(target: "ism", "Failed to send message: {:?}", e);
426 } else {
427 if let Err(err) = self.tracker.mark_sent(peer_id, message_id).await {
428 log::error!(target: "ism", "Failed to mark message as sent: {err:?}");
429 }
430 break 'peer;
432 }
433 } else {
434 log::trace!(target: "ism", "[CANNOT SEND] message: {:?}", msg);
435 break;
437 }
438 }
439 }
440 }).await
441 }
442
443 async fn process_inbound(&self) {
444 let pending_messages = match self.backend.get_pending_inbound().await {
445 Ok(messages) => messages,
446 Err(e) => {
447 log::error!(target: "ism", "Failed to get pending inbound messages: {:?}", e);
448 return;
449 }
450 };
451
452 let pending_messages: Vec<M> = pending_messages
454 .into_iter()
455 .sorted_by_key(|r| r.message_id())
456 .unique_by(|r| r.message_id())
457 .collect();
458
459 log::trace!(target: "ism", "~~~Processing inbound messages: {:?}", pending_messages);
460 if let Some(delivery) = self.local_delivery.lock().await.as_ref() {
461 for message in pending_messages {
462 if self
463 .tracker
464 .has_delivered
465 .contains(&(message.source_id(), message.message_id()))
466 {
467 log::warn!(target: "ism", "Skipping already delivered message: {:?}", message);
468 if let Err(e) = self
470 .backend
471 .clear_message_inbound(message.source_id(), message.message_id())
472 .await
473 {
474 log::error!(target: "ism", "Failed to clear delivered message: {e:?}");
475 }
476 continue;
477 }
478
479 match delivery.deliver(message.clone()).await {
480 Ok(()) => {
481 log::trace!(target: "ism", "Successfully delivered message: {message:?}");
482 self.tracker
483 .has_delivered
484 .insert((message.source_id(), message.message_id()));
485 if let Err(e) = self
487 .send_message_internal(self.create_ack_message(&message))
488 .await
489 {
490 log::error!(target: "ism", "Failed to send ACK: {e:?}");
491 }
492
493 if let Err(e) = self
495 .backend
496 .clear_message_inbound(message.source_id(), message.message_id())
497 .await
498 {
499 log::error!(target: "ism", "Failed to clear delivered message: {e:?}");
500 }
501 }
502 Err(e) => {
503 log::error!(target: "ism", "Failed to deliver message {message:?}: {e:?}");
504 }
505 }
506 }
507 } else {
508 log::warn!(target: "ism", "Unable to deliver messages since local delivery has been dropped");
509 }
510 }
511
512 async fn process_next_network_message(&self) {
514 if let Some(message) = self.network.next_message().await {
515 match message {
516 Payload::Poll { .. } => {
517 if self.poll_outbound_tx.send(()).is_err() {
520 log::warn!(target: "ism", "Failed to send poll signal for outbound messages");
521 }
522 }
523
524 Payload::Ack {
525 from_id,
526 message_id,
527 to_id,
528 } => {
529 if to_id != self.network.local_id() {
530 log::warn!(target: "ism", "Received ACK for another peer");
531 return;
532 }
533
534 if let Err(err) = self.tracker.update_ack(from_id, message_id).await {
536 log::error!(target: "ism", "Failed to update tracker with ACK: {err:?}");
537 }
538
539 log::trace!(target: "ism", "Received ACK from peer {from_id}, message # {message_id}");
540 if let Err(e) = self
541 .backend
542 .clear_message_outbound(from_id, message_id)
543 .await
544 {
545 log::error!(target: "ism", "Failed to clear ACKed message: {e:?}");
546 }
547
548 if self.poll_outbound_tx.send(()).is_err() {
550 log::warn!(target: "ism", "Failed to send poll signal for outbound messages");
551 }
552 }
553 Payload::Message(msg) => {
554 if msg.destination_id() != self.network.local_id() {
555 log::warn!(target: "ism", "Received message for another peer");
556 return;
557 }
558
559 if let Ok(msgs) = self.backend.get_pending_outbound().await {
560 if msgs.iter().any(|m| {
561 m.message_id() == msg.message_id() && m.source_id() == msg.source_id()
562 }) {
563 log::warn!(target: "ism", "Received duplicate message, sending ACK");
564 if let Err(e) = self
565 .send_message_internal(self.create_ack_message(&msg))
566 .await
567 {
568 log::error!(target: "ism", "Failed to send ACK for duplicate message: {e:?}");
569 }
570 return;
571 }
572 }
573
574 match self
576 .tracker
577 .mark_received(msg.source_id(), msg.message_id())
578 .await
579 {
580 Ok(true) => {
581 if let Err(e) = self.backend.store_inbound(msg).await {
583 log::error!(target: "ism", "Failed to store inbound message: {e:?}");
584 }
585
586 if self.poll_inbound_tx.send(()).is_err() {
587 log::warn!(target: "ism", "Failed to send poll signal for inbound messages");
588 }
589 }
590 Ok(false) => {
591 if let Err(e) = self
593 .send_message_internal(self.create_ack_message(&msg))
594 .await
595 {
596 log::error!(target: "ism", "Failed to send ACK for duplicate message: {e:?}");
597 }
598 }
599 Err(e) => {
600 log::error!(target: "ism", "Failed to mark message as received: {e:?}");
601 }
602 }
603 }
604 }
605 }
606 }
607
608 pub async fn send_to(
611 &self,
612 to: M::PeerId,
613 contents: impl Into<M::Contents>,
614 ) -> Result<(), NetworkError<M>> {
615 let my_id = self.network.local_id();
616 let next_id_for_this_peer_conn = self
617 .tracker
618 .get_next_id(to)
619 .await
620 .map_err(|err| NetworkError::BackendError(err))?;
621 let message = M::construct_from_parts(my_id, to, next_id_for_this_peer_conn, contents);
622 self.send_raw_message(message).await
623 }
624
625 pub async fn send_raw_message(&self, message: M) -> Result<(), NetworkError<M>> {
630 if message.source_id() != self.network.local_id() {
631 return Err(NetworkError::SendFailed {
632 reason: "Source ID does not match network peer ID".into(),
633 message,
634 });
635 }
636
637 if message.destination_id() == self.network.local_id() {
638 return Err(NetworkError::SendFailed {
639 reason: "Cannot send message to self".into(),
640 message,
641 });
642 }
643
644 if self.can_run() {
645 self.backend
646 .store_outbound(message)
647 .await
648 .map_err(|err| match err {
649 BackendError::SendFailed { reason, message } => {
650 NetworkError::SendFailed { reason, message }
651 }
652 err => NetworkError::BackendError(err),
653 })?;
654
655 self.poll_outbound_tx
656 .send(())
657 .map_err(|_| NetworkError::SystemShutdown)?;
658 Ok(())
659 } else {
660 Err(NetworkError::SystemShutdown)
661 }
662 }
663
664 fn create_ack_message(&self, original_message: &M) -> Payload<M> {
665 Payload::Ack {
667 from_id: original_message.destination_id(),
668 to_id: original_message.source_id(),
669 message_id: original_message.message_id(),
670 }
671 }
672
673 pub async fn shutdown(&self, timeout: Duration) -> Result<(), NetworkError<M>> {
676 if self.is_shutting_down.fetch_or(true, Ordering::SeqCst) {
677 return Ok(());
678 }
679 tokio::time::timeout(timeout, async {
681 let pending_outbound_task = async move {
682 while !self
683 .backend
684 .get_pending_outbound()
685 .await
686 .map_err(NetworkError::BackendError)?
687 .is_empty()
688 {
689 tokio::time::sleep(Duration::from_millis(100)).await;
690 }
691
692 Ok(())
693 };
694
695 let pending_inbound_task = async move {
696 while !self
697 .backend
698 .get_pending_inbound()
699 .await
700 .map_err(NetworkError::BackendError)?
701 .is_empty()
702 {
703 tokio::time::sleep(Duration::from_millis(100)).await;
704 }
705
706 Ok(())
707 };
708
709 tokio::try_join!(pending_outbound_task, pending_inbound_task)?;
710
711 Ok::<_, NetworkError<M>>(())
712 })
713 .await
714 .map_err(|err| NetworkError::ShutdownFailed(err.to_string()))??;
715
716 self.toggle_off();
717
718 Ok(())
719 }
720
721 pub async fn get_connected_peers(&self) -> Vec<M::PeerId> {
722 self.network
723 .connected_peers()
724 .await
725 .into_iter()
726 .sorted()
727 .collect::<Vec<_>>()
728 }
729
730 pub fn local_id(&self) -> M::PeerId {
732 self.network.local_id()
733 }
734
735 fn can_run(&self) -> bool {
736 self.is_running.load(Ordering::Relaxed)
737 }
738
739 fn toggle_off(&self) {
740 self.is_running.store(false, Ordering::SeqCst);
741 }
742
743 async fn send_message_internal(
744 &self,
745 message: Payload<M>,
746 ) -> Result<(), NetworkError<Payload<M>>> {
747 let res = self.network.send_message(message).await;
748
749 if res.is_err() {
750 }
753
754 res
755 }
756}