amq/client/
client_async.rs

1use std::{collections::HashMap, error::Error, future::Future, pin::Pin, sync::Arc};
2
3use tokio::{
4    io::{AsyncReadExt as _, AsyncWriteExt as _},
5    net::{
6        tcp::{OwnedReadHalf as TcpOHR, OwnedWriteHalf as TcpOWH},
7        TcpStream,
8    },
9    sync::{
10        oneshot::{channel, Receiver},
11        RwLock,
12    },
13    task::{spawn, JoinHandle},
14};
15
16#[cfg(unix)]
17use tokio::net::{
18    unix::{OwnedReadHalf as UnixOHR, OwnedWriteHalf as UnixOWH},
19    UnixStream,
20};
21
22use crate::{
23    error::AmqError,
24    message::{
25        Message, MsgStatus, ReqMsgAuthorizer, ReqMsgConsumeAck, ReqMsgConsumeAckMulti,
26        ReqMsgConsumerTopic, ReqMsgPublish, ReqMsgSubscriber, ReqMsgUnconsumerTopic,
27        ReqMsgUnsubscriber, RespMsgConsume, RespMsgSubscribe,
28    },
29    Config,
30};
31
32type OnRecvFn<T> =
33    Arc<dyn Fn(Arc<T>, Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
34
35/// # Async Client
36///
37/// ```rust
38/// loop {
39///     let config = Config::new().unwrap();
40///     let state = Arc::new(Mutex::new(0));
41///
42///     let mut client = AsyncClient::new(config, state.clone());
43///
44///     let rx = match client.connect().await {
45///         Ok(rx) => rx,
46///         Err(e) => {
47///             println!("Connection error: {}", e);
48///             sleep(Duration::from_secs(1)).await;
49///             continue;
50///         }
51///     };
52///
53///     client
54///         .subscribe("topic", |state, msg| async move {
55///             let mut state = state.lock().await;
56///             *state += 1;
57///             println!("Received message: {} {:?}", state, msg);
58///         })
59///         .await?;
60///
61///     let should_exit = select! {
62///         // Receive signal when connection closed
63///         result = rx => {
64///             client.shutdown().await;
65///
66///             match result {
67///                 // Server closed connection
68///                 Ok(AmqError::TcpServerClosed) => {
69///                     true
70///                 }
71///                 // Other error
72///                 Ok(e) => {
73///                     println!("{}", e);
74///                     false
75///                 }
76///                 Err(e) => {
77///                     println!("Receive signal error: {:?}", e);
78///                     false
79///                 }
80///             }
81///         }
82///
83///         // Send message every 1s
84///         _ = async {
85///             loop {
86///                 sleep(Duration::from_secs(1)).await;
87///                 let _ = client.publish("topic", "Hello, world!".as_bytes().to_vec()).await;
88///             }
89///         } => {
90///             false // Exit loop on error, and reconnect
91///         }
92///     };
93///
94///     if should_exit {
95///         break;
96///     }
97///
98///     println!("Reconnecting...");
99/// }
100/// ```
101pub struct Client<T>
102where
103    T: Send + Sync + 'static,
104{
105    config: Config,
106    state: Arc<T>,
107    stream: Option<Owh>,
108    recv_task: Option<JoinHandle<()>>,
109    on_subscribes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
110    on_consumes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
111}
112
113enum Owh {
114    Tcp(TcpOWH),
115    #[cfg(unix)]
116    Unix(UnixOWH),
117}
118
119impl<T> Client<T>
120where
121    T: Send + Sync + 'static,
122{
123    /// # Create a new async client. (tokio)
124    pub fn new(config: Config, state: Arc<T>) -> Self {
125        Self {
126            config,
127            state,
128            stream: None,
129            recv_task: None,
130            on_subscribes: Arc::new(RwLock::new(HashMap::new())),
131            on_consumes: Arc::new(RwLock::new(HashMap::new())),
132        }
133    }
134
135    /// # Subscribe a topic.
136    ///
137    /// ```rust
138    /// client.subscribe("topic", |msg| async move {
139    ///     println!("Received message: {:?}", msg);
140    /// }).await.unwrap();
141    /// ```
142    pub async fn subscribe<F, Fut>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
143    where
144        F: Fn(Arc<T>, Vec<u8>) -> Fut + Send + Sync + 'static,
145        Fut: Future<Output = ()> + Send + 'static,
146    {
147        let handler: OnRecvFn<T> =
148            Arc::new(move |state: Arc<T>, data: Vec<u8>| Box::pin(f(state, data)));
149        self.on_subscribes
150            .write()
151            .await
152            .insert(topic.to_string(), handler);
153        let message = Message::ReqSubscribeTopic(ReqMsgSubscriber {
154            topic: topic.to_string(),
155        });
156        self.send(message).await?;
157        Ok(())
158    }
159
160    /// # Unsubscribe a topic.
161    ///
162    /// ```rust
163    /// client.unsubscribe("topic").await.unwrap();
164    /// ```
165    pub async fn unsubscribe(&mut self, topic: &str) -> Result<(), AmqError> {
166        let message = Message::ReqUnsubscribeTopic(ReqMsgUnsubscriber {
167            topic: topic.to_string(),
168        });
169        self.send(message).await?;
170        self.on_subscribes.write().await.remove(topic);
171        Ok(())
172    }
173
174    /// # Publish a message.
175    ///
176    /// ```rust
177    /// client.publish("topic", "Hello, world!".as_bytes().to_vec()).await.unwrap();
178    /// ```
179    pub async fn publish(&mut self, topic: &str, content: Vec<u8>) -> Result<(), AmqError> {
180        let message = Message::ReqPublish(ReqMsgPublish {
181            topic: topic.to_string(),
182            message: content,
183        });
184        self.send(message).await?;
185        Ok(())
186    }
187
188    /// # Add a consume callback for a topic.
189    ///
190    /// ```rust
191    /// client.consume("topic", |msg| async move {
192    ///     println!("Received message: {:?}", msg);
193    /// }).await.unwrap();
194    /// ```
195    pub async fn consume<F, Fut>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
196    where
197        F: Fn(Arc<T>, Vec<u8>) -> Fut + Send + Sync + 'static,
198        Fut: Future<Output = ()> + Send + 'static,
199    {
200        let handler: OnRecvFn<T> =
201            Arc::new(move |state: Arc<T>, data: Vec<u8>| Box::pin(f(state, data)));
202        self.on_consumes
203            .write()
204            .await
205            .insert(topic.to_string(), handler);
206        let message = Message::ReqConsumerTopic(ReqMsgConsumerTopic {
207            topic: topic.to_string(),
208        });
209        self.send(message).await?;
210        Ok(())
211    }
212
213    /// # Remove a consume callback for a topic.
214    /// ```rust
215    /// client.unconsume("topic").await.unwrap();
216    /// ```
217    pub async fn unconsume(&mut self, topic: &str) -> Result<(), AmqError> {
218        let message = Message::ReqUnconsumerTopic(ReqMsgUnconsumerTopic {
219            topic: topic.to_string(),
220        });
221        self.send(message).await?;
222        self.on_consumes.write().await.remove(topic);
223        Ok(())
224    }
225
226    /// # Acknowledge a message.
227    ///
228    /// ```rust
229    /// client.ack(message_id).await.unwrap();
230    /// ```
231    pub async fn ack(&mut self, message_id: u64) -> Result<(), AmqError> {
232        let message = Message::ReqConsumeAck(ReqMsgConsumeAck { id: message_id });
233        self.send(message).await?;
234        Ok(())
235    }
236
237    /// # Acknowledge multiple messages.
238    ///
239    /// ```rust
240    /// client.ack_multi(vec![message_id1, message_id2]).await.unwrap();
241    /// ```
242    pub async fn ack_multi(&mut self, message_ids: Vec<u64>) -> Result<(), AmqError> {
243        let message = Message::ReqConsumeAckMulti(ReqMsgConsumeAckMulti { ids: message_ids });
244        self.send(message).await?;
245        Ok(())
246    }
247
248    /// # Connect to the server and start the receive task.
249    pub async fn connect(&mut self) -> Result<Receiver<AmqError>, Box<dyn Error>> {
250        if self.config.path.is_empty() {
251            let addr = self.config.get_address();
252            let stream = TcpStream::connect(addr).await?;
253            let (reader, writer) = stream.into_split();
254            self.tcp_conn(reader, writer).await
255        } else {
256            #[cfg(unix)]
257            {
258                let stream = UnixStream::connect(&self.config.path).await?;
259                let (reader, writer) = stream.into_split();
260                self.unix_conn(reader, writer).await
261            }
262            #[cfg(not(unix))]
263            {
264                Err(Box::new(AmqError::UnsupportedPlatform))
265            }
266        }
267    }
268
269    async fn tcp_conn(
270        &mut self,
271        mut reader: TcpOHR,
272        mut writer: TcpOWH,
273    ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
274        let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
275            access_key: self.config.access_key.clone(),
276            access_secret: self.config.access_secret.clone(),
277        });
278        let message = &serde_json::to_vec(&msg)?;
279        writer.write_u32(message.len() as u32).await?;
280        writer.write_all(&message).await?;
281
282        // Read mesage header (4 bytes)
283        let len = reader.read_u32().await?;
284        // Read message body (len bytes)
285        let mut buf = vec![0; len as usize];
286        let _ = reader.read(&mut buf).await?;
287        match Message::deserialize(&buf)? {
288            Message::RespAuthorizer(resp) => {
289                if resp.status != MsgStatus::Success {
290                    return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
291                }
292            }
293            _ => {
294                return Err(Box::new(AmqError::AuthorizationError(
295                    "Invalid response".to_string(),
296                )));
297            }
298        }
299
300        self.stream = Some(Owh::Tcp(writer));
301
302        let (tx, rx) = channel::<AmqError>();
303
304        let on_subscribes = self.on_subscribes.clone();
305        let on_consumes = self.on_consumes.clone();
306        let state = Arc::clone(&self.state);
307        let recv_task = spawn(async move {
308            loop {
309                // Read mesage header (4 bytes)
310                let len = match reader.read_u32().await {
311                    Ok(l) => l as usize,
312                    Err(e) => {
313                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
314                        break;
315                    }
316                };
317                // Read message body (len bytes)
318                let mut buf = vec![0; len];
319                match reader.read(&mut buf).await {
320                    Ok(0) => {
321                        tx.send(AmqError::TcpServerClosed).unwrap();
322                        break;
323                    }
324                    Ok(_) => {
325                        let msg = match Message::deserialize(&buf) {
326                            Ok(m) => m,
327                            Err(_) => {
328                                continue;
329                            }
330                        };
331                        match &msg {
332                            Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
333                                if let Some(cb) = on_subscribes.read().await.get(topic) {
334                                    cb(Arc::clone(&state), message.clone()).await;
335                                }
336                            }
337                            Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
338                                if let Some(cb) = on_consumes.read().await.get(topic) {
339                                    cb(Arc::clone(&state), message.clone()).await;
340                                }
341                            }
342                            _ => {}
343                        }
344                    }
345                    Err(e) => {
346                        tx.send(AmqError::TcpServerError(e.to_string())).unwrap();
347                        break;
348                    }
349                }
350            }
351        });
352
353        self.recv_task = Some(recv_task);
354
355        Ok(rx)
356    }
357
358    #[cfg(unix)]
359    async fn unix_conn(
360        &mut self,
361        mut reader: UnixOHR,
362        mut writer: UnixOWH,
363    ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
364        let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
365            access_key: self.config.access_key.clone(),
366            access_secret: self.config.access_secret.clone(),
367        });
368        let message = &serde_json::to_vec(&msg)?;
369        writer.write_u32(message.len() as u32).await?;
370        writer.write_all(&message).await?;
371
372        // Read mesage header (4 bytes)
373        let len = reader.read_u32().await?;
374        // Read message body (len bytes)
375        let mut buf = vec![0; len as usize];
376        let _ = reader.read(&mut buf).await?;
377        match Message::deserialize(&buf)? {
378            Message::RespAuthorizer(resp) => {
379                if resp.status != MsgStatus::Success {
380                    return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
381                }
382            }
383            _ => {
384                return Err(Box::new(AmqError::AuthorizationError(
385                    "Invalid response".to_string(),
386                )));
387            }
388        }
389
390        self.stream = Some(Owh::Unix(writer));
391
392        let (tx, rx) = channel::<AmqError>();
393
394        let on_subscribes = self.on_subscribes.clone();
395        let on_consumes = self.on_consumes.clone();
396        let state = Arc::clone(&self.state);
397        let recv_task = spawn(async move {
398            loop {
399                // Read mesage header (4 bytes)
400                let len = match reader.read_u32().await {
401                    Ok(l) => l as usize,
402                    Err(e) => {
403                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
404                        break;
405                    }
406                };
407                // Read message body (len bytes)
408                let mut buf = vec![0; len];
409                match reader.read(&mut buf).await {
410                    Ok(0) => {
411                        tx.send(AmqError::TcpServerClosed).unwrap();
412                        break;
413                    }
414                    Ok(_) => {
415                        let msg = match Message::deserialize(&buf) {
416                            Ok(m) => m,
417                            Err(_) => {
418                                continue;
419                            }
420                        };
421                        match &msg {
422                            Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
423                                if let Some(cb) = on_subscribes.read().await.get(topic) {
424                                    cb(Arc::clone(&state), message.clone()).await;
425                                }
426                            }
427                            Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
428                                if let Some(cb) = on_consumes.read().await.get(topic) {
429                                    cb(Arc::clone(&state), message.clone()).await;
430                                }
431                            }
432                            _ => {}
433                        }
434                    }
435                    Err(e) => {
436                        tx.send(AmqError::TcpServerError(e.to_string())).unwrap();
437                        break;
438                    }
439                }
440            }
441        });
442
443        self.recv_task = Some(recv_task);
444
445        Ok(rx)
446    }
447
448    /// # Shutdown the client.
449    pub async fn shutdown(&mut self) {
450        if let Some(task) = self.recv_task.take() {
451            task.abort();
452        }
453
454        self.stream = None;
455    }
456
457    async fn send(&mut self, msg: Message) -> Result<(), AmqError> {
458        match &mut self.stream {
459            Some(Owh::Tcp(writer)) => {
460                let message = &serde_json::to_vec(&msg)
461                    .map_err(|e| AmqError::TcpSendDataError(e.to_string()))?;
462                let _ = writer.write_u32(message.len() as u32).await;
463                writer
464                    .write_all(&message)
465                    .await
466                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
467            }
468            #[cfg(unix)]
469            Some(Owh::Unix(writer)) => {
470                let message = &serde_json::to_vec(&msg)
471                    .map_err(|e| AmqError::TcpSendDataError(e.to_string()))?;
472                let _ = writer.write_u32(message.len() as u32).await;
473                writer
474                    .write_all(&message)
475                    .await
476                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
477            }
478            None => {
479                return Err(AmqError::TcpSendError("not connected".to_string()));
480            }
481        }
482        Ok(())
483    }
484}