asteroid_mq_sdk/
endpoint.rs

1use std::{
2    collections::HashSet,
3    ops::{Deref, DerefMut},
4    sync::Weak,
5};
6
7use asteroid_mq_model::{EndpointAddr, Interest, Message, TopicCode};
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9use tracing::Instrument;
10
11use crate::{
12    node::{ClientNodeError, ClientNodeInner},
13    ClientNode,
14};
15
16// CliEp -> NodeProxy -> Ep
17#[derive(Debug)]
18pub struct ClientEndpoint {
19    pub(crate) addr: EndpointAddr,
20    pub(crate) topic_code: TopicCode,
21    pub(crate) interests: HashSet<Interest>,
22    pub(crate) node: Weak<ClientNodeInner>,
23    pub(crate) message_rx: UnboundedReceiver<Message>,
24}
25#[derive(Debug, Clone)]
26pub struct ClientReceivedMessage {
27    ep_addr: EndpointAddr,
28    topic_code: TopicCode,
29    node: Weak<ClientNodeInner>,
30    message: Message,
31}
32
33impl Deref for ClientReceivedMessage {
34    type Target = Message;
35    fn deref(&self) -> &Self::Target {
36        &self.message
37    }
38}
39
40impl DerefMut for ClientReceivedMessage {
41    fn deref_mut(&mut self) -> &mut Self::Target {
42        &mut self.message
43    }
44}
45
46impl ClientReceivedMessage {
47    pub fn into_inner(self) -> Message {
48        self.message
49    }
50    pub async fn ack_failed(&self) -> Result<(), ClientNodeError> {
51        let Some(node) = self.node.upgrade() else {
52            return Err(ClientNodeError::disconnected("ack_failed"));
53        };
54        let ack = self
55            .message
56            .header
57            .ack_failed(self.topic_code.clone(), self.ep_addr);
58        node.send_single_ack(ack).await
59    }
60    pub async fn ack_processed(&self) -> Result<(), ClientNodeError> {
61        let Some(node) = self.node.upgrade() else {
62            return Err(ClientNodeError::disconnected("ack_processed"));
63        };
64        let ack = self
65            .message
66            .header
67            .ack_processed(self.topic_code.clone(), self.ep_addr);
68        node.send_single_ack(ack).await
69    }
70    pub async fn ack_received(&self) -> Result<(), ClientNodeError> {
71        let Some(node) = self.node.upgrade() else {
72            return Err(ClientNodeError::disconnected("ack_received"));
73        };
74        let ack = self
75            .message
76            .header
77            .ack_received(self.topic_code.clone(), self.ep_addr);
78        node.send_single_ack(ack).await
79    }
80}
81
82impl ClientEndpoint {
83    pub fn node(&self) -> Option<ClientNode> {
84        self.node.upgrade().map(|inner| ClientNode { inner })
85    }
86    pub fn interests(&self) -> &HashSet<Interest> {
87        &self.interests
88    }
89    pub async fn modify_interests(
90        &mut self,
91        modify: impl FnOnce(&mut HashSet<Interest>),
92    ) -> Result<(), ClientNodeError> {
93        let mut new_interest = self.interests.clone();
94        modify(&mut new_interest);
95        self.update_interests(new_interest).await
96    }
97    pub async fn update_interests(
98        &mut self,
99        interests: impl IntoIterator<Item = Interest>,
100    ) -> Result<(), ClientNodeError> {
101        let Some(node) = self.node.upgrade() else {
102            return Err(ClientNodeError::disconnected("update_interests"));
103        };
104        let interests_vec: Vec<_> = interests.into_iter().collect();
105        let interests_set = interests_vec.iter().cloned().collect::<HashSet<_>>();
106        node.send_ep_interests(self.topic_code.clone(), self.addr, interests_vec)
107            .await?;
108        self.interests = interests_set;
109        Ok(())
110    }
111    pub async fn next_message(&mut self) -> Option<ClientReceivedMessage> {
112        self.message_rx
113            .recv()
114            .await
115            .map(|message| ClientReceivedMessage {
116                ep_addr: self.addr,
117                topic_code: self.topic_code.clone(),
118                node: self.node.clone(),
119                message,
120            })
121    }
122
123    pub async fn respawn(&mut self) -> Result<(), ClientNodeError> {
124        let Some(node_inner) = self.node.upgrade() else {
125            return Err(ClientNodeError::disconnected("respawn"));
126        };
127
128        // offline old point
129        let ep_offline_result = node_inner
130            .clone()
131            .send_ep_offline(self.topic_code.clone(), self.addr)
132            .await;
133
134        tracing::info!(?ep_offline_result);
135        // remove old tx
136        node_inner.endpoint_map.write().await.remove(&self.addr);
137        // detach old node, so we won't offline twice in drop
138        self.node = Weak::new();
139        let mut new_ep = node_inner
140            .into_client_node()
141            .create_endpoint(self.topic_code.clone(), self.interests.clone())
142            .await?;
143        std::mem::swap(&mut new_ep, self);
144        tracing::debug!(ep = ?self.addr, "respawn");
145        Ok(())
146    }
147
148    pub async fn next_message_and_auto_respawn(
149        &mut self,
150    ) -> Result<ClientReceivedMessage, ClientNodeError> {
151        loop {
152            let message = self.next_message().await;
153            let message = match message {
154                Some(message) => message,
155                None => {
156                    let addr = self.addr;
157                    tracing::info!(?addr, "respawn endpoint");
158                    self.respawn().await.inspect_err(|e| {
159                        tracing::info!(?addr, error = ?e, "fail to respawn endpoint");
160                    })?;
161                    continue;
162                }
163            };
164            return Ok(message);
165        }
166    }
167}
168
169impl Drop for ClientEndpoint {
170    fn drop(&mut self) {
171        let Some(node) = self.node.upgrade() else {
172            return;
173        };
174        let topic_code = self.topic_code.clone();
175        let endpoint = self.addr;
176        let task = async move {
177            let _ = node.send_ep_offline(topic_code, endpoint).await;
178            node.endpoint_map.write().await.remove(&endpoint);
179        }
180        .instrument(
181            tracing::info_span!("ep_offline", topic_code = %self.topic_code, endpoint = ?endpoint),
182        );
183        tokio::spawn(task);
184    }
185}
186
187#[derive(Debug)]
188pub(crate) struct EndpointMailbox {
189    pub(crate) message_tx: UnboundedSender<Message>,
190    pub(crate) message_rx: Option<UnboundedReceiver<Message>>,
191}
192
193impl EndpointMailbox {
194    pub fn new() -> Self {
195        let (message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel();
196        Self {
197            message_tx,
198            message_rx: Some(message_rx),
199        }
200    }
201
202    pub fn take_rx(&mut self) -> Option<UnboundedReceiver<Message>> {
203        self.message_rx.take()
204    }
205}