Skip to main content

ankurah_connector_local_process/
lib.rs

1use ankurah_core::policy::PolicyAgent;
2use ankurah_core::storage::StorageEngine;
3use ankurah_proto as proto;
4use async_trait::async_trait;
5use tokio::sync::mpsc;
6
7use ankurah_core::connector::{PeerSender, SendError};
8use ankurah_core::node::{Node, WeakNode};
9
10#[derive(Clone)]
11/// Sender for local process connection
12pub struct LocalProcessSender {
13    sender: mpsc::Sender<proto::NodeMessage>,
14    node_id: proto::EntityId,
15}
16
17#[async_trait]
18impl PeerSender for LocalProcessSender {
19    fn send_message(&self, message: proto::NodeMessage) -> Result<(), SendError> {
20        self.sender.try_send(message).map_err(|_| SendError::ConnectionClosed)?;
21        Ok(())
22    }
23
24    fn recipient_node_id(&self) -> proto::EntityId { self.node_id }
25
26    fn cloned(&self) -> Box<dyn PeerSender> { Box::new(self.clone()) }
27}
28
29/// connector which establishes one sender between each of the two given nodes
30pub struct LocalProcessConnection<SE1, PA1, SE2, PA2>
31where
32    SE1: StorageEngine + Send + Sync + 'static,
33    PA1: PolicyAgent + Send + Sync + 'static,
34    SE2: StorageEngine + Send + Sync + 'static,
35    PA2: PolicyAgent + Send + Sync + 'static,
36{
37    receiver1_task: tokio::task::JoinHandle<()>,
38    receiver2_task: tokio::task::JoinHandle<()>,
39    node1: WeakNode<SE1, PA1>,
40    node2: WeakNode<SE2, PA2>,
41    node1_id: proto::EntityId,
42    node2_id: proto::EntityId,
43}
44
45impl<SE1, PA1, SE2, PA2> LocalProcessConnection<SE1, PA1, SE2, PA2>
46where
47    SE1: StorageEngine + Send + Sync + 'static,
48    PA1: PolicyAgent + Send + Sync + 'static,
49    SE2: StorageEngine + Send + Sync + 'static,
50    PA2: PolicyAgent + Send + Sync + 'static,
51{
52    /// Create a new LocalConnector and establish connection between the nodes
53    pub async fn new(node1: &Node<SE1, PA1>, node2: &Node<SE2, PA2>) -> anyhow::Result<Self> {
54        let (node1_tx, node1_rx) = mpsc::channel(1024);
55        let (node2_tx, node2_rx) = mpsc::channel(1024);
56
57        // we have to register the senders with the nodes
58        node1.register_peer(
59            proto::Presence { node_id: node2.id, durable: node2.durable, system_root: node2.system.root() },
60            Box::new(LocalProcessSender { sender: node2_tx, node_id: node2.id }),
61        );
62        node2.register_peer(
63            proto::Presence { node_id: node1.id, durable: node1.durable, system_root: node1.system.root() },
64            Box::new(LocalProcessSender { sender: node1_tx, node_id: node1.id }),
65        );
66
67        let receiver1_task = Self::setup_receiver(node1.clone(), node1_rx);
68        let receiver2_task = Self::setup_receiver(node2.clone(), node2_rx);
69
70        Ok(Self { node1: node1.weak(), node2: node2.weak(), node1_id: node1.id, node2_id: node2.id, receiver1_task, receiver2_task })
71    }
72
73    fn setup_receiver<SE, PA>(node: Node<SE, PA>, mut rx: mpsc::Receiver<proto::NodeMessage>) -> tokio::task::JoinHandle<()>
74    where
75        SE: StorageEngine + Send + Sync + 'static,
76        PA: PolicyAgent + Send + Sync + 'static,
77    {
78        tokio::spawn(async move {
79            while let Some(message) = rx.recv().await {
80                let node = node.clone();
81                tokio::spawn(async move {
82                    let _ = node.handle_message(message).await;
83                });
84            }
85        })
86    }
87}
88
89impl<SE1, PA1, SE2, PA2> Drop for LocalProcessConnection<SE1, PA1, SE2, PA2>
90where
91    SE1: StorageEngine + Send + Sync + 'static,
92    PA1: PolicyAgent + Send + Sync + 'static,
93    SE2: StorageEngine + Send + Sync + 'static,
94    PA2: PolicyAgent + Send + Sync + 'static,
95{
96    fn drop(&mut self) {
97        self.receiver1_task.abort();
98        self.receiver2_task.abort();
99        if let Some(node1) = self.node1.upgrade() {
100            node1.deregister_peer(self.node2_id);
101        }
102        if let Some(node2) = self.node2.upgrade() {
103            node2.deregister_peer(self.node1_id);
104        }
105    }
106}