1use std::{
2 collections::{HashMap, HashSet},
3 sync::{atomic::AtomicU32, Arc},
4};
5
6use crate::endpoint::{ClientEndpoint, EndpointMailbox};
7pub use crate::error::*;
8use asteroid_mq_model::{
9 connection::{EdgeConnectionErrorKind, EdgeNodeConnection},
10 EdgeEndpointOffline, EdgeEndpointOnline, EdgeError, EdgeMessage, EdgePayload, EdgePush,
11 EdgeRequest, EdgeRequestEnum, EdgeResponseEnum, EndpointAddr, EndpointInterest, Interest,
12 Message, MessageAck, MessageStateUpdate, SetState, TopicCode, WaitAckSuccess,
13};
14use futures_util::{SinkExt, StreamExt};
15use tokio::sync::{oneshot, RwLock};
16use tokio_util::sync::CancellationToken;
17use tracing::Instrument;
18type Responder = oneshot::Sender<Result<EdgeResponseEnum, EdgeError>>;
19type ResponseHandle = oneshot::Receiver<Result<EdgeResponseEnum, EdgeError>>;
20type ResponsePool = Arc<RwLock<HashMap<u32, Responder>>>;
21type EndpointMailboxMap = Arc<RwLock<HashMap<EndpointAddr, EndpointMailbox>>>;
22
23pub struct MessageAckHandle {
24 response_handle: ResponseHandle,
25}
26
27impl MessageAckHandle {
28 pub async fn wait(self) -> Result<WaitAckSuccess, ClientNodeError> {
29 let response = ClientNodeInner::wait_handle(self.response_handle).await?;
30 if let EdgeResponseEnum::SendMessage(edge_result) = response {
31 let wait_ack = edge_result.into_std()?;
32 Ok(wait_ack)
33 } else {
34 Err(ClientNodeError::unexpected_response(response))
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
40pub struct ClientNode {
41 pub(crate) inner: Arc<ClientNodeInner>,
42}
43
44impl ClientNode {
45 pub async fn ack(&self, ack: MessageAck) -> Result<(), ClientNodeError> {
46 self.inner.send_single_ack(ack).await
47 }
48 pub async fn send_message(
49 &self,
50 message: EdgeMessage,
51 ) -> Result<MessageAckHandle, ClientNodeError> {
52 self.inner.send_message(message).await
53 }
54 pub async fn send_message_and_wait(
55 &self,
56 message: EdgeMessage,
57 ) -> Result<WaitAckSuccess, ClientNodeError> {
58 let handle = self.inner.send_message(message).await?;
59 handle.wait().await
60 }
61 pub async fn create_endpoint(
62 &self,
63 topic_code: TopicCode,
64 interests: impl IntoIterator<Item = Interest>,
65 ) -> Result<ClientEndpoint, ClientNodeError> {
66 let interests = interests.into_iter().collect::<HashSet<_>>();
67 let addr = self
68 .inner
69 .send_ep_online(EdgeEndpointOnline {
70 topic_code: topic_code.clone(),
71 interests: interests.iter().cloned().collect(),
72 })
73 .await?;
74 let message_rx = self
75 .inner
76 .ensure_mailbox_and_take_rx(addr)
77 .await
78 .expect("conflict endpoint addr should not happen");
79 Ok(ClientEndpoint {
80 addr,
81 topic_code,
82 interests,
83 node: Arc::downgrade(&self.inner),
84 message_rx,
85 })
86 }
87 pub async fn connect<C>(connect: C) -> Result<Self, ClientNodeError>
88 where
89 C: EdgeNodeConnection,
90 {
91 let inner = ClientNodeInner::connect(connect).await?;
92 Ok(ClientNode {
93 inner: Arc::new(inner),
94 })
95 }
96}
97
98#[derive(Debug)]
99pub(crate) struct ClientNodeInner {
100 pub(crate) sender: tokio::sync::mpsc::UnboundedSender<(EdgeRequest, Responder)>,
101 pub(crate) endpoint_map: EndpointMailboxMap,
102 pub(crate) seq: AtomicU32,
103 pub(crate) cancellation_token: CancellationToken,
104 pub(crate) _rx_handle: tokio::task::JoinHandle<()>,
105 pub(crate) _tx_handle: tokio::task::JoinHandle<()>,
106}
107
108impl ClientNodeInner {
109 pub fn into_client_node(self: Arc<ClientNodeInner>) -> ClientNode {
110 ClientNode {
111 inner: self.clone(),
112 }
113 }
114 async fn ensure_mailbox_and_take_rx(
115 &self,
116 addr: EndpointAddr,
117 ) -> Option<tokio::sync::mpsc::UnboundedReceiver<Message>> {
118 self.endpoint_map
119 .write()
120 .await
121 .entry(addr)
122 .or_insert_with(EndpointMailbox::new)
123 .take_rx()
124 }
125 async fn send_message(
126 &self,
127 message: EdgeMessage,
128 ) -> Result<MessageAckHandle, ClientNodeError> {
129 let response_handle = self
130 .send_request(EdgeRequestEnum::SendMessage(message))
131 .await?;
132 Ok(MessageAckHandle { response_handle })
133 }
134 pub(crate) async fn send_ep_online(
135 &self,
136 request: EdgeEndpointOnline,
137 ) -> Result<EndpointAddr, ClientNodeError> {
138 let response = self
139 .send_request_and_wait(EdgeRequestEnum::EndpointOnline(request))
140 .await?;
141 if let EdgeResponseEnum::EndpointOnline(ep_addr) = response {
142 Ok(ep_addr)
143 } else {
144 Err(ClientNodeError::unexpected_response(response))
145 }
146 }
147 pub(crate) async fn send_ep_offline(
148 &self,
149 topic_code: TopicCode,
150 endpoint: EndpointAddr,
151 ) -> Result<(), ClientNodeError> {
152 let response = self
153 .send_request_and_wait(EdgeRequestEnum::EndpointOffline(EdgeEndpointOffline {
154 endpoint,
155 topic_code,
156 }))
157 .await?;
158 if let EdgeResponseEnum::EndpointOffline = response {
159 Ok(())
160 } else {
161 Err(ClientNodeError::unexpected_response(response))
162 }
163 }
164 pub(crate) async fn send_ep_interests(
165 &self,
166 topic_code: TopicCode,
167 endpoint: EndpointAddr,
168 interests: Vec<Interest>,
169 ) -> Result<(), ClientNodeError> {
170 let response = self
171 .send_request_and_wait(EdgeRequestEnum::EndpointInterest(EndpointInterest {
172 topic_code,
173 endpoint,
174 interests,
175 }))
176 .await?;
177 if let EdgeResponseEnum::EndpointInterest = response {
178 Ok(())
179 } else {
180 Err(ClientNodeError::unexpected_response(response))
181 }
182 }
183 pub(crate) async fn send_single_ack(
184 &self,
185 MessageAck {
186 ack_to,
187 topic_code,
188 from,
189 kind,
190 }: MessageAck,
191 ) -> Result<(), ClientNodeError> {
192 let response = self
193 .send_request_and_wait(EdgeRequestEnum::SetState(SetState {
194 topic: topic_code,
195 update: MessageStateUpdate {
196 message_id: ack_to,
197 status: HashMap::from_iter([(from, kind)]),
198 },
199 }))
200 .await?;
201 if let EdgeResponseEnum::SetState = response {
202 Ok(())
203 } else {
204 Err(ClientNodeError::unexpected_response(response))
205 }
206 }
207 async fn send_request(
208 &self,
209 request: EdgeRequestEnum,
210 ) -> Result<ResponseHandle, ClientNodeError> {
211 let (responder, response_handle) = oneshot::channel();
212 let request = EdgeRequest {
213 seq_id: self.seq.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
214 request,
215 };
216 self.sender
217 .send((request, responder))
218 .map_err(|e| ClientNodeError {
219 kind: ClientErrorKind::NoConnection(e.0 .0.request),
220 })?;
221 Ok(response_handle)
222 }
227 async fn wait_handle(
228 response_handle: ResponseHandle,
229 ) -> Result<EdgeResponseEnum, ClientNodeError> {
230 let response = response_handle
231 .await
232 .map_err(|_| ClientNodeError::disconnected("wait handle"))??;
233 Ok(response)
234 }
235 async fn send_request_and_wait(
236 &self,
237 request: EdgeRequestEnum,
238 ) -> Result<EdgeResponseEnum, ClientNodeError> {
239 let response_handle = self.send_request(request).await?;
240 Self::wait_handle(response_handle).await
241 }
242 pub(crate) async fn connect<C>(connection: C) -> Result<ClientNodeInner, ClientNodeError>
243 where
244 C: EdgeNodeConnection,
245 {
246 let response_pool: ResponsePool = Default::default();
247
248 let endpoint_map = EndpointMailboxMap::default();
249 let ct = CancellationToken::new();
250 let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel::<(
251 EdgeRequest,
252 oneshot::Sender<Result<EdgeResponseEnum, EdgeError>>,
253 )>();
254
255 let (mut sink, mut stream) = connection.split();
256 let tx_ct = ct.child_token();
257 let tx_task = {
258 let response_pool = response_pool.clone();
259 async move {
260 loop {
261 let (request, responder) = tokio::select! {
262 _ = tx_ct.cancelled() => {
263 tracing::debug!("task cancelled");
264 break
265 }
266 message = request_rx.recv() => {
267 match message {
268 Some(message) => {
269 message
270 }
271 None => {
272 tracing::warn!("tx dropped");
273 break;
274 }
275 }
276 }
277 };
278 let seq_id = request.seq_id;
279 response_pool.write().await.insert(seq_id, responder);
280 let send_result = sink.send(EdgePayload::Request(request)).await;
282 if let Err(e) = send_result {
283 tracing::error!("failed to send message: {:?}", e);
284 break;
285 }
286 }
287 }
288 .instrument(tracing::info_span!("client_node_tx"))
289 };
290
291 let rx_ct = ct.child_token();
292 let rx_task = {
293 let endpoints_map = endpoint_map.clone();
294 async move {
295 loop {
296 let edge_pld = tokio::select! {
297 _ = rx_ct.cancelled() => {
298 endpoints_map.write().await.clear();
299 tracing::debug!("task cancelled");
300 break
301 }
302 received = stream.next() => {
303 match received {
304 Some(Ok(edge_pld)) => {
305 edge_pld
306 }
307 Some(Err(e)) => {
308 match e.kind {
309 EdgeConnectionErrorKind::Reconnect | EdgeConnectionErrorKind::Closed => {
310 response_pool.write().await.clear();
312 endpoints_map.write().await.clear();
314 if matches!(e.kind,EdgeConnectionErrorKind::Closed) {
315 tracing::info!("connection closed");
316 break;
317 }
318 },
319 _ => {
320 tracing::error!("failed to receive message: {:?}", e);
321 }
322 }
323 continue
324 }
325 None => {
326 tracing::debug!("stream closed");
327 break;
328 }
329 }
330 }
331 };
332
333 match edge_pld {
334 EdgePayload::Push(EdgePush::Message { endpoints, message }) => {
335 tracing::trace!(endpoints = ?endpoints, ?message, "received message");
336 let mut wg = endpoints_map.write().await;
337 for ep in endpoints {
338 if let Some(mailbox) = wg.get(&ep) {
339 let send_result = mailbox.message_tx.send(message.clone());
340 if send_result.is_err() {
341 tracing::warn!(addr=?ep, "target endpoint is dropped")
342 }
343 } else {
344 let mailbox = EndpointMailbox::new();
345 mailbox
346 .message_tx
347 .send(message.clone())
348 .expect("a brand new channel must have the receiver");
349 wg.insert(ep, mailbox);
350 tracing::warn!(addr=?ep, "target endpoint not found")
351 }
352 }
353 drop(wg);
354 }
355 EdgePayload::Response(edge_response) => {
356 let seq_id = edge_response.seq_id;
357 if let Some(responder) = response_pool.write().await.remove(&seq_id) {
358 let _ = responder.send(edge_response.result.into_std());
360 } else {
361 tracing::error!(seq_id, "response handle not found");
362 }
363 }
364 EdgePayload::Error(e) => {
365 tracing::error!(?e, "received error");
366 }
367 _ => {}
368 }
369 }
370 }
371 .instrument(tracing::info_span!("client_node_rx"))
372 };
373 let tx_handle = tokio::spawn(tx_task);
374 let rx_handle = tokio::spawn(rx_task);
375 Ok(ClientNodeInner {
376 sender: request_tx,
377 seq: Default::default(),
378 endpoint_map,
379 cancellation_token: ct,
380 _rx_handle: rx_handle,
381 _tx_handle: tx_handle,
382 })
383 }
384}
385
386impl Drop for ClientNodeInner {
387 fn drop(&mut self) {
388 self.cancellation_token.cancel();
389 }
390}