asteroid_mq_sdk/
node.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{atomic::AtomicU32, Arc},
4};
5
6use crate::endpoint::{ClientEndpoint, EndpointMailbox};
7pub use crate::error::*;
8use asteroid_mq_model::{
9    connection::{EdgeConnectionErrorKind, EdgeNodeConnection},
10    EdgeEndpointOffline, EdgeEndpointOnline, EdgeError, EdgeMessage, EdgePayload, EdgePush,
11    EdgeRequest, EdgeRequestEnum, EdgeResponseEnum, EndpointAddr, EndpointInterest, Interest,
12    Message, MessageAck, MessageStateUpdate, SetState, TopicCode, WaitAckSuccess,
13};
14use futures_util::{SinkExt, StreamExt};
15use tokio::sync::{oneshot, RwLock};
16use tokio_util::sync::CancellationToken;
17use tracing::Instrument;
18type Responder = oneshot::Sender<Result<EdgeResponseEnum, EdgeError>>;
19type ResponseHandle = oneshot::Receiver<Result<EdgeResponseEnum, EdgeError>>;
20type ResponsePool = Arc<RwLock<HashMap<u32, Responder>>>;
21type EndpointMailboxMap = Arc<RwLock<HashMap<EndpointAddr, EndpointMailbox>>>;
22
23pub struct MessageAckHandle {
24    response_handle: ResponseHandle,
25}
26
27impl MessageAckHandle {
28    pub async fn wait(self) -> Result<WaitAckSuccess, ClientNodeError> {
29        let response = ClientNodeInner::wait_handle(self.response_handle).await?;
30        if let EdgeResponseEnum::SendMessage(edge_result) = response {
31            let wait_ack = edge_result.into_std()?;
32            Ok(wait_ack)
33        } else {
34            Err(ClientNodeError::unexpected_response(response))
35        }
36    }
37}
38
39#[derive(Debug, Clone)]
40pub struct ClientNode {
41    pub(crate) inner: Arc<ClientNodeInner>,
42}
43
44impl ClientNode {
45    pub async fn ack(&self, ack: MessageAck) -> Result<(), ClientNodeError> {
46        self.inner.send_single_ack(ack).await
47    }
48    pub async fn send_message(
49        &self,
50        message: EdgeMessage,
51    ) -> Result<MessageAckHandle, ClientNodeError> {
52        self.inner.send_message(message).await
53    }
54    pub async fn send_message_and_wait(
55        &self,
56        message: EdgeMessage,
57    ) -> Result<WaitAckSuccess, ClientNodeError> {
58        let handle = self.inner.send_message(message).await?;
59        handle.wait().await
60    }
61    pub async fn create_endpoint(
62        &self,
63        topic_code: TopicCode,
64        interests: impl IntoIterator<Item = Interest>,
65    ) -> Result<ClientEndpoint, ClientNodeError> {
66        let interests = interests.into_iter().collect::<HashSet<_>>();
67        let addr = self
68            .inner
69            .send_ep_online(EdgeEndpointOnline {
70                topic_code: topic_code.clone(),
71                interests: interests.iter().cloned().collect(),
72            })
73            .await?;
74        let message_rx = self
75            .inner
76            .ensure_mailbox_and_take_rx(addr)
77            .await
78            .expect("conflict endpoint addr should not happen");
79        Ok(ClientEndpoint {
80            addr,
81            topic_code,
82            interests,
83            node: Arc::downgrade(&self.inner),
84            message_rx,
85        })
86    }
87    pub async fn connect<C>(connect: C) -> Result<Self, ClientNodeError>
88    where
89        C: EdgeNodeConnection,
90    {
91        let inner = ClientNodeInner::connect(connect).await?;
92        Ok(ClientNode {
93            inner: Arc::new(inner),
94        })
95    }
96}
97
98#[derive(Debug)]
99pub(crate) struct ClientNodeInner {
100    pub(crate) sender: tokio::sync::mpsc::UnboundedSender<(EdgeRequest, Responder)>,
101    pub(crate) endpoint_map: EndpointMailboxMap,
102    pub(crate) seq: AtomicU32,
103    pub(crate) cancellation_token: CancellationToken,
104    pub(crate) _rx_handle: tokio::task::JoinHandle<()>,
105    pub(crate) _tx_handle: tokio::task::JoinHandle<()>,
106}
107
108impl ClientNodeInner {
109    pub fn into_client_node(self: Arc<ClientNodeInner>) -> ClientNode {
110        ClientNode {
111            inner: self.clone(),
112        }
113    }
114    async fn ensure_mailbox_and_take_rx(
115        &self,
116        addr: EndpointAddr,
117    ) -> Option<tokio::sync::mpsc::UnboundedReceiver<Message>> {
118        self.endpoint_map
119            .write()
120            .await
121            .entry(addr)
122            .or_insert_with(EndpointMailbox::new)
123            .take_rx()
124    }
125    async fn send_message(
126        &self,
127        message: EdgeMessage,
128    ) -> Result<MessageAckHandle, ClientNodeError> {
129        let response_handle = self
130            .send_request(EdgeRequestEnum::SendMessage(message))
131            .await?;
132        Ok(MessageAckHandle { response_handle })
133    }
134    pub(crate) async fn send_ep_online(
135        &self,
136        request: EdgeEndpointOnline,
137    ) -> Result<EndpointAddr, ClientNodeError> {
138        let response = self
139            .send_request_and_wait(EdgeRequestEnum::EndpointOnline(request))
140            .await?;
141        if let EdgeResponseEnum::EndpointOnline(ep_addr) = response {
142            Ok(ep_addr)
143        } else {
144            Err(ClientNodeError::unexpected_response(response))
145        }
146    }
147    pub(crate) async fn send_ep_offline(
148        &self,
149        topic_code: TopicCode,
150        endpoint: EndpointAddr,
151    ) -> Result<(), ClientNodeError> {
152        let response = self
153            .send_request_and_wait(EdgeRequestEnum::EndpointOffline(EdgeEndpointOffline {
154                endpoint,
155                topic_code,
156            }))
157            .await?;
158        if let EdgeResponseEnum::EndpointOffline = response {
159            Ok(())
160        } else {
161            Err(ClientNodeError::unexpected_response(response))
162        }
163    }
164    pub(crate) async fn send_ep_interests(
165        &self,
166        topic_code: TopicCode,
167        endpoint: EndpointAddr,
168        interests: Vec<Interest>,
169    ) -> Result<(), ClientNodeError> {
170        let response = self
171            .send_request_and_wait(EdgeRequestEnum::EndpointInterest(EndpointInterest {
172                topic_code,
173                endpoint,
174                interests,
175            }))
176            .await?;
177        if let EdgeResponseEnum::EndpointInterest = response {
178            Ok(())
179        } else {
180            Err(ClientNodeError::unexpected_response(response))
181        }
182    }
183    pub(crate) async fn send_single_ack(
184        &self,
185        MessageAck {
186            ack_to,
187            topic_code,
188            from,
189            kind,
190        }: MessageAck,
191    ) -> Result<(), ClientNodeError> {
192        let response = self
193            .send_request_and_wait(EdgeRequestEnum::SetState(SetState {
194                topic: topic_code,
195                update: MessageStateUpdate {
196                    message_id: ack_to,
197                    status: HashMap::from_iter([(from, kind)]),
198                },
199            }))
200            .await?;
201        if let EdgeResponseEnum::SetState = response {
202            Ok(())
203        } else {
204            Err(ClientNodeError::unexpected_response(response))
205        }
206    }
207    async fn send_request(
208        &self,
209        request: EdgeRequestEnum,
210    ) -> Result<ResponseHandle, ClientNodeError> {
211        let (responder, response_handle) = oneshot::channel();
212        let request = EdgeRequest {
213            seq_id: self.seq.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
214            request,
215        };
216        self.sender
217            .send((request, responder))
218            .map_err(|e| ClientNodeError {
219                kind: ClientErrorKind::NoConnection(e.0 .0.request),
220            })?;
221        Ok(response_handle)
222        // let response = rx.await.map_err(|_| ClientNodeError {
223        //     kind: ClientErrorKind::Disconnected,
224        // })??;
225        // Ok(response)
226    }
227    async fn wait_handle(
228        response_handle: ResponseHandle,
229    ) -> Result<EdgeResponseEnum, ClientNodeError> {
230        let response = response_handle
231            .await
232            .map_err(|_| ClientNodeError::disconnected("wait handle"))??;
233        Ok(response)
234    }
235    async fn send_request_and_wait(
236        &self,
237        request: EdgeRequestEnum,
238    ) -> Result<EdgeResponseEnum, ClientNodeError> {
239        let response_handle = self.send_request(request).await?;
240        Self::wait_handle(response_handle).await
241    }
242    pub(crate) async fn connect<C>(connection: C) -> Result<ClientNodeInner, ClientNodeError>
243    where
244        C: EdgeNodeConnection,
245    {
246        let response_pool: ResponsePool = Default::default();
247
248        let endpoint_map = EndpointMailboxMap::default();
249        let ct = CancellationToken::new();
250        let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel::<(
251            EdgeRequest,
252            oneshot::Sender<Result<EdgeResponseEnum, EdgeError>>,
253        )>();
254
255        let (mut sink, mut stream) = connection.split();
256        let tx_ct = ct.child_token();
257        let tx_task = {
258            let response_pool = response_pool.clone();
259            async move {
260                loop {
261                    let (request, responder) = tokio::select! {
262                        _ = tx_ct.cancelled() => {
263                            tracing::debug!("task cancelled");
264                            break
265                        }
266                        message = request_rx.recv() => {
267                            match message {
268                                Some(message) => {
269                                    message
270                                }
271                                None => {
272                                    tracing::warn!("tx dropped");
273                                    break;
274                                }
275                            }
276                        }
277                    };
278                    let seq_id = request.seq_id;
279                    response_pool.write().await.insert(seq_id, responder);
280                    // tracing::warn!(seq_id, "[debug] request do send");
281                    let send_result = sink.send(EdgePayload::Request(request)).await;
282                    if let Err(e) = send_result {
283                        tracing::error!("failed to send message: {:?}", e);
284                        break;
285                    }
286                }
287            }
288            .instrument(tracing::info_span!("client_node_tx"))
289        };
290
291        let rx_ct = ct.child_token();
292        let rx_task = {
293            let endpoints_map = endpoint_map.clone();
294            async move {
295                loop {
296                    let edge_pld = tokio::select! {
297                        _ = rx_ct.cancelled() => {
298                            endpoints_map.write().await.clear();
299                            tracing::debug!("task cancelled");
300                            break
301                        }
302                        received = stream.next() => {
303                            match received {
304                                Some(Ok(edge_pld)) => {
305                                    edge_pld
306                                }
307                                Some(Err(e)) => {
308                                    match e.kind {
309                                        EdgeConnectionErrorKind::Reconnect | EdgeConnectionErrorKind::Closed  => {
310                                            // clear the response pool
311                                            response_pool.write().await.clear();
312                                            // clear ep map, so the endpoint will know there is going to have a new connection
313                                            endpoints_map.write().await.clear();
314                                            if matches!(e.kind,EdgeConnectionErrorKind::Closed) {
315                                                tracing::info!("connection closed");
316                                                break;
317                                            }
318                                        },
319                                        _ => {
320                                            tracing::error!("failed to receive message: {:?}", e);
321                                        }
322                                    }
323                                    continue
324                                }
325                                None => {
326                                    tracing::debug!("stream closed");
327                                    break;
328                                }
329                            }
330                        }
331                    };
332
333                    match edge_pld {
334                        EdgePayload::Push(EdgePush::Message { endpoints, message }) => {
335                            tracing::trace!(endpoints = ?endpoints, ?message, "received message");
336                            let mut wg = endpoints_map.write().await;
337                            for ep in endpoints {
338                                if let Some(mailbox) = wg.get(&ep) {
339                                    let send_result = mailbox.message_tx.send(message.clone());
340                                    if send_result.is_err() {
341                                        tracing::warn!(addr=?ep, "target endpoint is dropped")
342                                    }
343                                } else {
344                                    let mailbox = EndpointMailbox::new();
345                                    mailbox
346                                        .message_tx
347                                        .send(message.clone())
348                                        .expect("a brand new channel must have the receiver");
349                                    wg.insert(ep, mailbox);
350                                    tracing::warn!(addr=?ep, "target endpoint not found")
351                                }
352                            }
353                            drop(wg);
354                        }
355                        EdgePayload::Response(edge_response) => {
356                            let seq_id = edge_response.seq_id;
357                            if let Some(responder) = response_pool.write().await.remove(&seq_id) {
358                                // tracing::warn!(?edge_response, "[debug] received response from server");
359                                let _ = responder.send(edge_response.result.into_std());
360                            } else {
361                                tracing::error!(seq_id, "response handle not found");
362                            }
363                        }
364                        EdgePayload::Error(e) => {
365                            tracing::error!(?e, "received error");
366                        }
367                        _ => {}
368                    }
369                }
370            }
371            .instrument(tracing::info_span!("client_node_rx"))
372        };
373        let tx_handle = tokio::spawn(tx_task);
374        let rx_handle = tokio::spawn(rx_task);
375        Ok(ClientNodeInner {
376            sender: request_tx,
377            seq: Default::default(),
378            endpoint_map,
379            cancellation_token: ct,
380            _rx_handle: rx_handle,
381            _tx_handle: tx_handle,
382        })
383    }
384}
385
386impl Drop for ClientNodeInner {
387    fn drop(&mut self) {
388        self.cancellation_token.cancel();
389    }
390}