asteroid_mq_sdk/
endpoint.rs1use std::{
2 collections::HashSet,
3 ops::{Deref, DerefMut},
4 sync::Weak,
5};
6
7use asteroid_mq_model::{EndpointAddr, Interest, Message, TopicCode};
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9use tracing::Instrument;
10
11use crate::{
12 node::{ClientNodeError, ClientNodeInner},
13 ClientNode,
14};
15
16#[derive(Debug)]
18pub struct ClientEndpoint {
19 pub(crate) addr: EndpointAddr,
20 pub(crate) topic_code: TopicCode,
21 pub(crate) interests: HashSet<Interest>,
22 pub(crate) node: Weak<ClientNodeInner>,
23 pub(crate) message_rx: UnboundedReceiver<Message>,
24}
25#[derive(Debug, Clone)]
26pub struct ClientReceivedMessage {
27 ep_addr: EndpointAddr,
28 topic_code: TopicCode,
29 node: Weak<ClientNodeInner>,
30 message: Message,
31}
32
33impl Deref for ClientReceivedMessage {
34 type Target = Message;
35 fn deref(&self) -> &Self::Target {
36 &self.message
37 }
38}
39
40impl DerefMut for ClientReceivedMessage {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.message
43 }
44}
45
46impl ClientReceivedMessage {
47 pub fn into_inner(self) -> Message {
48 self.message
49 }
50 pub async fn ack_failed(&self) -> Result<(), ClientNodeError> {
51 let Some(node) = self.node.upgrade() else {
52 return Err(ClientNodeError::disconnected("ack_failed"));
53 };
54 let ack = self
55 .message
56 .header
57 .ack_failed(self.topic_code.clone(), self.ep_addr);
58 node.send_single_ack(ack).await
59 }
60 pub async fn ack_processed(&self) -> Result<(), ClientNodeError> {
61 let Some(node) = self.node.upgrade() else {
62 return Err(ClientNodeError::disconnected("ack_processed"));
63 };
64 let ack = self
65 .message
66 .header
67 .ack_processed(self.topic_code.clone(), self.ep_addr);
68 node.send_single_ack(ack).await
69 }
70 pub async fn ack_received(&self) -> Result<(), ClientNodeError> {
71 let Some(node) = self.node.upgrade() else {
72 return Err(ClientNodeError::disconnected("ack_received"));
73 };
74 let ack = self
75 .message
76 .header
77 .ack_received(self.topic_code.clone(), self.ep_addr);
78 node.send_single_ack(ack).await
79 }
80}
81
82impl ClientEndpoint {
83 pub fn node(&self) -> Option<ClientNode> {
84 self.node.upgrade().map(|inner| ClientNode { inner })
85 }
86 pub fn interests(&self) -> &HashSet<Interest> {
87 &self.interests
88 }
89 pub async fn modify_interests(
90 &mut self,
91 modify: impl FnOnce(&mut HashSet<Interest>),
92 ) -> Result<(), ClientNodeError> {
93 let mut new_interest = self.interests.clone();
94 modify(&mut new_interest);
95 self.update_interests(new_interest).await
96 }
97 pub async fn update_interests(
98 &mut self,
99 interests: impl IntoIterator<Item = Interest>,
100 ) -> Result<(), ClientNodeError> {
101 let Some(node) = self.node.upgrade() else {
102 return Err(ClientNodeError::disconnected("update_interests"));
103 };
104 let interests_vec: Vec<_> = interests.into_iter().collect();
105 let interests_set = interests_vec.iter().cloned().collect::<HashSet<_>>();
106 node.send_ep_interests(self.topic_code.clone(), self.addr, interests_vec)
107 .await?;
108 self.interests = interests_set;
109 Ok(())
110 }
111 pub async fn next_message(&mut self) -> Option<ClientReceivedMessage> {
112 self.message_rx
113 .recv()
114 .await
115 .map(|message| ClientReceivedMessage {
116 ep_addr: self.addr,
117 topic_code: self.topic_code.clone(),
118 node: self.node.clone(),
119 message,
120 })
121 }
122
123 pub async fn respawn(&mut self) -> Result<(), ClientNodeError> {
124 let Some(node_inner) = self.node.upgrade() else {
125 return Err(ClientNodeError::disconnected("respawn"));
126 };
127
128 let ep_offline_result = node_inner
130 .clone()
131 .send_ep_offline(self.topic_code.clone(), self.addr)
132 .await;
133
134 tracing::info!(?ep_offline_result);
135 node_inner.endpoint_map.write().await.remove(&self.addr);
137 self.node = Weak::new();
139 let mut new_ep = node_inner
140 .into_client_node()
141 .create_endpoint(self.topic_code.clone(), self.interests.clone())
142 .await?;
143 std::mem::swap(&mut new_ep, self);
144 tracing::debug!(ep = ?self.addr, "respawn");
145 Ok(())
146 }
147
148 pub async fn next_message_and_auto_respawn(
149 &mut self,
150 ) -> Result<ClientReceivedMessage, ClientNodeError> {
151 loop {
152 let message = self.next_message().await;
153 let message = match message {
154 Some(message) => message,
155 None => {
156 let addr = self.addr;
157 tracing::info!(?addr, "respawn endpoint");
158 self.respawn().await.inspect_err(|e| {
159 tracing::info!(?addr, error = ?e, "fail to respawn endpoint");
160 })?;
161 continue;
162 }
163 };
164 return Ok(message);
165 }
166 }
167}
168
169impl Drop for ClientEndpoint {
170 fn drop(&mut self) {
171 let Some(node) = self.node.upgrade() else {
172 return;
173 };
174 let topic_code = self.topic_code.clone();
175 let endpoint = self.addr;
176 let task = async move {
177 let _ = node.send_ep_offline(topic_code, endpoint).await;
178 node.endpoint_map.write().await.remove(&endpoint);
179 }
180 .instrument(
181 tracing::info_span!("ep_offline", topic_code = %self.topic_code, endpoint = ?endpoint),
182 );
183 tokio::spawn(task);
184 }
185}
186
187#[derive(Debug)]
188pub(crate) struct EndpointMailbox {
189 pub(crate) message_tx: UnboundedSender<Message>,
190 pub(crate) message_rx: Option<UnboundedReceiver<Message>>,
191}
192
193impl EndpointMailbox {
194 pub fn new() -> Self {
195 let (message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel();
196 Self {
197 message_tx,
198 message_rx: Some(message_rx),
199 }
200 }
201
202 pub fn take_rx(&mut self) -> Option<UnboundedReceiver<Message>> {
203 self.message_rx.take()
204 }
205}