amq/client/
client_sync.rs

1use std::{
2    collections::HashMap,
3    error::Error,
4    io::{Read, Write},
5    net::TcpStream,
6    sync::{
7        mpsc::{channel, Receiver},
8        Arc, RwLock,
9    },
10    thread::{spawn, JoinHandle},
11};
12
13#[cfg(unix)]
14use std::os::unix::net::UnixStream;
15
16use crate::{
17    error::AmqError,
18    message::{
19        Message, MsgStatus, ReqMsgAuthorizer, ReqMsgConsumeAck, ReqMsgConsumeAckMulti,
20        ReqMsgConsumerTopic, ReqMsgPublish, ReqMsgSubscriber, ReqMsgUnconsumerTopic,
21        ReqMsgUnsubscriber, RespMsgConsume, RespMsgSubscribe,
22    },
23    Config,
24};
25
26type OnRecvFn<T> = Arc<dyn Fn(Arc<T>, Vec<u8>) + Send + Sync>;
27
28/// # Sync Client
29///
30/// ```rust
31/// loop {
32///     let config = Config::new().unwrap();
33///     let state = Arc::new(Mutex::new(0));
34///
35///     let mut client = SyncClient::new(config, state.clone());
36///
37///     let rx = match client.connect() {
38///         Ok(rx) => rx,
39///         Err(e) => {
40///             println!("Connection error: {}", e);
41///             sleep(Duration::from_secs(1));
42///             continue;
43///         }
44///     };
45///
46///     client.subscribe("topic", |state, msg| {
47///         let mut state = state.lock().unwrap();
48///         *state += 1;
49///         println!("Received message: {:?}", msg);
50///     })?;
51///
52///     let should_exit;
53///     loop {
54///         let result = rx.try_recv();
55///         match result {
56///             Ok(AmqError::TcpServerClosed) => {
57///                 should_exit = true;
58///                 break;
59///             }
60///             Ok(e) => {
61///                 println!("{}", e);
62///                 should_exit = false;
63///                 break;
64///             }
65///             Err(TryRecvError::Empty) => {
66///                 sleep(Duration::from_secs(1));
67///                 let _ = client.publish("topic", "Hello, world!".as_bytes().to_vec());
68///             }
69///             Err(e) => {
70///                 println!("Receive signal error: {:?}", e);
71///                 should_exit = false;
72///                 break;
73///             }
74///         }
75///     }
76///
77///     if should_exit {
78///         break;
79///     }
80///
81///     println!("Reconnecting...");
82/// }
83/// ```
84pub struct Client<T>
85where
86    T: Send + Sync + 'static,
87{
88    config: Config,
89    state: Arc<T>,
90    stream: Option<Stream>,
91    recv_thread: Option<JoinHandle<()>>,
92    on_subscribes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
93    on_consumes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
94}
95
96enum Stream {
97    TcpStream(TcpStream),
98    #[cfg(unix)]
99    UnixStream(UnixStream),
100}
101
102impl<T> Client<T>
103where
104    T: Send + Sync + 'static,
105{
106    /// # Create a new client.
107    pub fn new(config: Config, state: Arc<T>) -> Self {
108        Self {
109            config,
110            state,
111            stream: None,
112            recv_thread: None,
113            on_subscribes: Arc::new(RwLock::new(HashMap::new())),
114            on_consumes: Arc::new(RwLock::new(HashMap::new())),
115        }
116    }
117
118    /// # Subscribe a topic.
119    ///
120    /// ```rust
121    /// client.subscribe("topic", |msg| {
122    ///     println!("Received message: {:?}", msg);
123    /// }).unwrap();
124    /// ```
125    pub fn subscribe<F>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
126    where
127        F: Fn(Arc<T>, Vec<u8>) + Send + Sync + 'static,
128    {
129        let handler: OnRecvFn<T> = Arc::new(f);
130        self.on_subscribes
131            .write()
132            .unwrap()
133            .insert(topic.to_string(), handler);
134        let message = Message::ReqSubscribeTopic(ReqMsgSubscriber {
135            topic: topic.to_string(),
136        });
137        self.send(message)?;
138        Ok(())
139    }
140
141    /// # Unsubscribe a topic.
142    /// ```rust
143    /// client.unsubscribe("topic").unwrap();
144    /// ```
145    pub fn unsubscribe(&mut self, topic: &str) -> Result<(), AmqError> {
146        let message = Message::ReqUnsubscribeTopic(ReqMsgUnsubscriber {
147            topic: topic.to_string(),
148        });
149        self.send(message)?;
150        self.on_subscribes.write().unwrap().remove(topic);
151        Ok(())
152    }
153
154    /// # Publish a message.
155    ///
156    /// ```rust
157    /// client.publish("topic", "Hello, world!".as_bytes().to_vec()).unwrap();
158    /// ```
159    pub fn publish(&mut self, topic: &str, content: Vec<u8>) -> Result<(), AmqError> {
160        let message = Message::ReqPublish(ReqMsgPublish {
161            topic: topic.to_string(),
162            message: content,
163        });
164        self.send(message)?;
165        Ok(())
166    }
167
168    /// # Add a consume callback for a topic.
169    ///
170    /// ```rust
171    /// client.consume("topic", |msg| {
172    ///     println!("Received message: {:?}", msg);
173    /// }).unwrap();
174    /// ```
175    pub fn consume<F>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
176    where
177        F: Fn(Arc<T>, Vec<u8>) + Send + Sync + 'static,
178    {
179        let handler: OnRecvFn<T> = Arc::new(f);
180        self.on_consumes
181            .write()
182            .unwrap()
183            .insert(topic.to_string(), handler);
184        let message = Message::ReqConsumerTopic(ReqMsgConsumerTopic {
185            topic: topic.to_string(),
186        });
187        self.send(message)?;
188        Ok(())
189    }
190
191    /// # Remove a consume callback for a topic.
192    /// ```rust
193    /// client.unconsume("topic").unwrap();
194    /// ```
195    pub fn unconsume(&mut self, topic: &str) -> Result<(), AmqError> {
196        let message = Message::ReqUnconsumerTopic(ReqMsgUnconsumerTopic {
197            topic: topic.to_string(),
198        });
199        self.send(message)?;
200        self.on_consumes.write().unwrap().remove(topic);
201        Ok(())
202    }
203
204    /// # Acknowledge a message.
205    ///
206    /// ```rust
207    /// client.ack(message_id).unwrap();
208    /// ```
209    pub fn ack(&mut self, message_id: u64) -> Result<(), AmqError> {
210        let message = Message::ReqConsumeAck(ReqMsgConsumeAck { id: message_id });
211        self.send(message)?;
212        Ok(())
213    }
214
215    /// # Acknowledge multiple messages.
216    ///
217    /// ```rust
218    /// client.ack_multi(vec![message_id1, message_id2]).unwrap();
219    /// ```
220    pub fn ack_multi(&mut self, message_ids: Vec<u64>) -> Result<(), AmqError> {
221        let message = Message::ReqConsumeAckMulti(ReqMsgConsumeAckMulti { ids: message_ids });
222        self.send(message)?;
223        Ok(())
224    }
225
226    /// # Connect to the server and start the receive task.
227    pub fn connect(&mut self) -> Result<Receiver<AmqError>, Box<dyn Error>> {
228        if self.config.path.is_empty() {
229            let addr = self.config.get_address();
230            let stream = TcpStream::connect(addr)?;
231            self.tcp_conn(stream)
232        } else {
233            #[cfg(unix)]
234            {
235                let stream = UnixStream::connect(&self.config.path)?;
236                self.unix_conn(stream)
237            }
238            #[cfg(not(unix))]
239            {
240                Err(Box::new(AmqError::UnsupportedPlatform))
241            }
242        }
243    }
244
245    pub fn tcp_conn(
246        &mut self,
247        mut stream: TcpStream,
248    ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
249        stream.set_nodelay(true)?;
250        stream.set_nonblocking(false)?;
251
252        let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
253            access_key: self.config.access_key.clone(),
254            access_secret: self.config.access_secret.clone(),
255        });
256        let message = &serde_json::to_vec(&msg)?;
257        let len_data = (message.len() as u32).to_be_bytes();
258        stream.write_all(&len_data)?;
259        stream.write_all(&message)?;
260
261        // Read mesage header (4 bytes)
262        let mut len_buf = [0u8; 4];
263        stream.read_exact(&mut len_buf)?;
264        let len = u32::from_be_bytes(len_buf) as usize;
265        // Read message body (len bytes)
266        let mut buf = vec![0u8; len];
267        let _ = stream.read(&mut buf)?;
268        match Message::deserialize(&buf)? {
269            Message::RespAuthorizer(resp) => {
270                if resp.status != MsgStatus::Success {
271                    return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
272                }
273            }
274            _ => {
275                return Err(Box::new(AmqError::AuthorizationError(
276                    "Invalid response".to_string(),
277                )));
278            }
279        }
280
281        let reader_stream = stream.try_clone()?;
282        self.stream = Some(Stream::TcpStream(stream));
283
284        let (tx, rx) = channel::<AmqError>();
285
286        let on_subscribes = self.on_subscribes.clone();
287        let on_consumes = self.on_consumes.clone();
288        let state = Arc::clone(&self.state);
289        let recv_thread = spawn(move || {
290            let mut reader = reader_stream;
291            loop {
292                // Read mesage header (4 bytes)
293                let mut len_buf = [0u8; 4];
294                let len = match reader.read_exact(&mut len_buf) {
295                    Ok(_) => u32::from_be_bytes(len_buf) as usize,
296                    Err(e) => {
297                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
298                        break;
299                    }
300                };
301                // Read message body (len bytes)
302                let mut buf = vec![0u8; len];
303                match reader.read(&mut buf) {
304                    Ok(0) => {
305                        tx.send(AmqError::TcpServerClosed).unwrap();
306                        break;
307                    }
308                    Ok(_) => {
309                        let msg = match Message::deserialize(&buf) {
310                            Ok(m) => m,
311                            Err(_) => {
312                                continue;
313                            }
314                        };
315
316                        match &msg {
317                            Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
318                                if let Some(cb) = on_subscribes.read().unwrap().get(topic) {
319                                    cb(Arc::clone(&state), message.clone());
320                                }
321                            }
322                            Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
323                                if let Some(cb) = on_consumes.read().unwrap().get(topic) {
324                                    cb(Arc::clone(&state), message.clone());
325                                }
326                            }
327                            _ => {}
328                        }
329                    }
330                    Err(e) => {
331                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
332                        break;
333                    }
334                }
335            }
336        });
337
338        self.recv_thread = Some(recv_thread);
339
340        Ok(rx)
341    }
342
343    #[cfg(unix)]
344    pub fn unix_conn(
345        &mut self,
346        mut stream: UnixStream,
347    ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
348        stream.set_nonblocking(false)?;
349
350        let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
351            access_key: self.config.access_key.clone(),
352            access_secret: self.config.access_secret.clone(),
353        });
354        let message = &serde_json::to_vec(&msg)?;
355        let len_data = (message.len() as u32).to_be_bytes();
356        stream.write_all(&len_data)?;
357        stream.write_all(&message)?;
358
359        // Read mesage header (4 bytes)
360        let mut len_buf = [0u8; 4];
361        stream.read_exact(&mut len_buf)?;
362        let len = u32::from_be_bytes(len_buf) as usize;
363        // Read message body (len bytes)
364        let mut buf = vec![0u8; len];
365        let _ = stream.read(&mut buf)?;
366        match Message::deserialize(&buf)? {
367            Message::RespAuthorizer(resp) => {
368                if resp.status != MsgStatus::Success {
369                    return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
370                }
371            }
372            _ => {
373                return Err(Box::new(AmqError::AuthorizationError(
374                    "Invalid response".to_string(),
375                )));
376            }
377        }
378
379        let reader_stream = stream.try_clone()?;
380        self.stream = Some(Stream::UnixStream(stream));
381
382        let (tx, rx) = channel::<AmqError>();
383
384        let on_subscribes = self.on_subscribes.clone();
385        let on_consumes = self.on_consumes.clone();
386        let state = Arc::clone(&self.state);
387        let recv_thread = spawn(move || {
388            let mut reader = reader_stream;
389            loop {
390                // Read mesage header (4 bytes)
391                let mut len_buf = [0u8; 4];
392                let len = match reader.read_exact(&mut len_buf) {
393                    Ok(_) => u32::from_be_bytes(len_buf) as usize,
394                    Err(e) => {
395                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
396                        break;
397                    }
398                };
399                // Read message body (len bytes)
400                let mut buf = vec![0u8; len];
401                match reader.read(&mut buf) {
402                    Ok(0) => {
403                        tx.send(AmqError::TcpServerClosed).unwrap();
404                        break;
405                    }
406                    Ok(_) => {
407                        let msg = match Message::deserialize(&buf) {
408                            Ok(m) => m,
409                            Err(_) => {
410                                continue;
411                            }
412                        };
413
414                        match &msg {
415                            Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
416                                if let Some(cb) = on_subscribes.read().unwrap().get(topic) {
417                                    cb(Arc::clone(&state), message.clone());
418                                }
419                            }
420                            Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
421                                if let Some(cb) = on_consumes.read().unwrap().get(topic) {
422                                    cb(Arc::clone(&state), message.clone());
423                                }
424                            }
425                            _ => {}
426                        }
427                    }
428                    Err(e) => {
429                        tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
430                        break;
431                    }
432                }
433            }
434        });
435
436        self.recv_thread = Some(recv_thread);
437
438        Ok(rx)
439    }
440
441    /// # Shutdown the client.
442    pub fn shutdown(&mut self) {
443        match &mut self.stream {
444            Some(Stream::TcpStream(stream)) => {
445                let _ = stream.shutdown(std::net::Shutdown::Both);
446            }
447            #[cfg(unix)]
448            Some(Stream::UnixStream(stream)) => {
449                let _ = stream.shutdown(std::net::Shutdown::Both);
450            }
451            None => {}
452        }
453
454        self.stream = None;
455    }
456
457    fn send(&mut self, msg: Message) -> Result<(), AmqError> {
458        match &mut self.stream {
459            Some(Stream::TcpStream(stream)) => {
460                let data = serde_json::to_vec(&msg).unwrap();
461                let len = (data.len() as u32).to_be_bytes();
462                stream
463                    .write_all(&len)
464                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
465                stream
466                    .write_all(&data)
467                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
468                Ok(())
469            }
470            #[cfg(unix)]
471            Some(Stream::UnixStream(stream)) => {
472                let data = serde_json::to_vec(&msg).unwrap();
473                let len = (data.len() as u32).to_be_bytes();
474                stream
475                    .write_all(&len)
476                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
477                stream
478                    .write_all(&data)
479                    .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
480                Ok(())
481            }
482            None => Err(AmqError::TcpSendError("not connected".to_string())),
483        }
484    }
485}