1use std::{
2 collections::HashMap,
3 error::Error,
4 io::{Read, Write},
5 net::TcpStream,
6 sync::{
7 mpsc::{channel, Receiver},
8 Arc, RwLock,
9 },
10 thread::{spawn, JoinHandle},
11};
12
13#[cfg(unix)]
14use std::os::unix::net::UnixStream;
15
16use crate::{
17 error::AmqError,
18 message::{
19 Message, MsgStatus, ReqMsgAuthorizer, ReqMsgConsumeAck, ReqMsgConsumeAckMulti,
20 ReqMsgConsumerTopic, ReqMsgPublish, ReqMsgSubscriber, ReqMsgUnconsumerTopic,
21 ReqMsgUnsubscriber, RespMsgConsume, RespMsgSubscribe,
22 },
23 Config,
24};
25
26type OnRecvFn<T> = Arc<dyn Fn(Arc<T>, Vec<u8>) + Send + Sync>;
27
28pub struct Client<T>
85where
86 T: Send + Sync + 'static,
87{
88 config: Config,
89 state: Arc<T>,
90 stream: Option<Stream>,
91 recv_thread: Option<JoinHandle<()>>,
92 on_subscribes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
93 on_consumes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
94}
95
96enum Stream {
97 TcpStream(TcpStream),
98 #[cfg(unix)]
99 UnixStream(UnixStream),
100}
101
102impl<T> Client<T>
103where
104 T: Send + Sync + 'static,
105{
106 pub fn new(config: Config, state: Arc<T>) -> Self {
108 Self {
109 config,
110 state,
111 stream: None,
112 recv_thread: None,
113 on_subscribes: Arc::new(RwLock::new(HashMap::new())),
114 on_consumes: Arc::new(RwLock::new(HashMap::new())),
115 }
116 }
117
118 pub fn subscribe<F>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
126 where
127 F: Fn(Arc<T>, Vec<u8>) + Send + Sync + 'static,
128 {
129 let handler: OnRecvFn<T> = Arc::new(f);
130 self.on_subscribes
131 .write()
132 .unwrap()
133 .insert(topic.to_string(), handler);
134 let message = Message::ReqSubscribeTopic(ReqMsgSubscriber {
135 topic: topic.to_string(),
136 });
137 self.send(message)?;
138 Ok(())
139 }
140
141 pub fn unsubscribe(&mut self, topic: &str) -> Result<(), AmqError> {
146 let message = Message::ReqUnsubscribeTopic(ReqMsgUnsubscriber {
147 topic: topic.to_string(),
148 });
149 self.send(message)?;
150 self.on_subscribes.write().unwrap().remove(topic);
151 Ok(())
152 }
153
154 pub fn publish(&mut self, topic: &str, content: Vec<u8>) -> Result<(), AmqError> {
160 let message = Message::ReqPublish(ReqMsgPublish {
161 topic: topic.to_string(),
162 message: content,
163 });
164 self.send(message)?;
165 Ok(())
166 }
167
168 pub fn consume<F>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
176 where
177 F: Fn(Arc<T>, Vec<u8>) + Send + Sync + 'static,
178 {
179 let handler: OnRecvFn<T> = Arc::new(f);
180 self.on_consumes
181 .write()
182 .unwrap()
183 .insert(topic.to_string(), handler);
184 let message = Message::ReqConsumerTopic(ReqMsgConsumerTopic {
185 topic: topic.to_string(),
186 });
187 self.send(message)?;
188 Ok(())
189 }
190
191 pub fn unconsume(&mut self, topic: &str) -> Result<(), AmqError> {
196 let message = Message::ReqUnconsumerTopic(ReqMsgUnconsumerTopic {
197 topic: topic.to_string(),
198 });
199 self.send(message)?;
200 self.on_consumes.write().unwrap().remove(topic);
201 Ok(())
202 }
203
204 pub fn ack(&mut self, message_id: u64) -> Result<(), AmqError> {
210 let message = Message::ReqConsumeAck(ReqMsgConsumeAck { id: message_id });
211 self.send(message)?;
212 Ok(())
213 }
214
215 pub fn ack_multi(&mut self, message_ids: Vec<u64>) -> Result<(), AmqError> {
221 let message = Message::ReqConsumeAckMulti(ReqMsgConsumeAckMulti { ids: message_ids });
222 self.send(message)?;
223 Ok(())
224 }
225
226 pub fn connect(&mut self) -> Result<Receiver<AmqError>, Box<dyn Error>> {
228 if self.config.path.is_empty() {
229 let addr = self.config.get_address();
230 let stream = TcpStream::connect(addr)?;
231 self.tcp_conn(stream)
232 } else {
233 #[cfg(unix)]
234 {
235 let stream = UnixStream::connect(&self.config.path)?;
236 self.unix_conn(stream)
237 }
238 #[cfg(not(unix))]
239 {
240 Err(Box::new(AmqError::UnsupportedPlatform))
241 }
242 }
243 }
244
245 pub fn tcp_conn(
246 &mut self,
247 mut stream: TcpStream,
248 ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
249 stream.set_nodelay(true)?;
250 stream.set_nonblocking(false)?;
251
252 let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
253 access_key: self.config.access_key.clone(),
254 access_secret: self.config.access_secret.clone(),
255 });
256 let message = &serde_json::to_vec(&msg)?;
257 let len_data = (message.len() as u32).to_be_bytes();
258 stream.write_all(&len_data)?;
259 stream.write_all(&message)?;
260
261 let mut len_buf = [0u8; 4];
263 stream.read_exact(&mut len_buf)?;
264 let len = u32::from_be_bytes(len_buf) as usize;
265 let mut buf = vec![0u8; len];
267 let _ = stream.read(&mut buf)?;
268 match Message::deserialize(&buf)? {
269 Message::RespAuthorizer(resp) => {
270 if resp.status != MsgStatus::Success {
271 return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
272 }
273 }
274 _ => {
275 return Err(Box::new(AmqError::AuthorizationError(
276 "Invalid response".to_string(),
277 )));
278 }
279 }
280
281 let reader_stream = stream.try_clone()?;
282 self.stream = Some(Stream::TcpStream(stream));
283
284 let (tx, rx) = channel::<AmqError>();
285
286 let on_subscribes = self.on_subscribes.clone();
287 let on_consumes = self.on_consumes.clone();
288 let state = Arc::clone(&self.state);
289 let recv_thread = spawn(move || {
290 let mut reader = reader_stream;
291 loop {
292 let mut len_buf = [0u8; 4];
294 let len = match reader.read_exact(&mut len_buf) {
295 Ok(_) => u32::from_be_bytes(len_buf) as usize,
296 Err(e) => {
297 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
298 break;
299 }
300 };
301 let mut buf = vec![0u8; len];
303 match reader.read(&mut buf) {
304 Ok(0) => {
305 tx.send(AmqError::TcpServerClosed).unwrap();
306 break;
307 }
308 Ok(_) => {
309 let msg = match Message::deserialize(&buf) {
310 Ok(m) => m,
311 Err(_) => {
312 continue;
313 }
314 };
315
316 match &msg {
317 Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
318 if let Some(cb) = on_subscribes.read().unwrap().get(topic) {
319 cb(Arc::clone(&state), message.clone());
320 }
321 }
322 Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
323 if let Some(cb) = on_consumes.read().unwrap().get(topic) {
324 cb(Arc::clone(&state), message.clone());
325 }
326 }
327 _ => {}
328 }
329 }
330 Err(e) => {
331 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
332 break;
333 }
334 }
335 }
336 });
337
338 self.recv_thread = Some(recv_thread);
339
340 Ok(rx)
341 }
342
343 #[cfg(unix)]
344 pub fn unix_conn(
345 &mut self,
346 mut stream: UnixStream,
347 ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
348 stream.set_nonblocking(false)?;
349
350 let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
351 access_key: self.config.access_key.clone(),
352 access_secret: self.config.access_secret.clone(),
353 });
354 let message = &serde_json::to_vec(&msg)?;
355 let len_data = (message.len() as u32).to_be_bytes();
356 stream.write_all(&len_data)?;
357 stream.write_all(&message)?;
358
359 let mut len_buf = [0u8; 4];
361 stream.read_exact(&mut len_buf)?;
362 let len = u32::from_be_bytes(len_buf) as usize;
363 let mut buf = vec![0u8; len];
365 let _ = stream.read(&mut buf)?;
366 match Message::deserialize(&buf)? {
367 Message::RespAuthorizer(resp) => {
368 if resp.status != MsgStatus::Success {
369 return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
370 }
371 }
372 _ => {
373 return Err(Box::new(AmqError::AuthorizationError(
374 "Invalid response".to_string(),
375 )));
376 }
377 }
378
379 let reader_stream = stream.try_clone()?;
380 self.stream = Some(Stream::UnixStream(stream));
381
382 let (tx, rx) = channel::<AmqError>();
383
384 let on_subscribes = self.on_subscribes.clone();
385 let on_consumes = self.on_consumes.clone();
386 let state = Arc::clone(&self.state);
387 let recv_thread = spawn(move || {
388 let mut reader = reader_stream;
389 loop {
390 let mut len_buf = [0u8; 4];
392 let len = match reader.read_exact(&mut len_buf) {
393 Ok(_) => u32::from_be_bytes(len_buf) as usize,
394 Err(e) => {
395 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
396 break;
397 }
398 };
399 let mut buf = vec![0u8; len];
401 match reader.read(&mut buf) {
402 Ok(0) => {
403 tx.send(AmqError::TcpServerClosed).unwrap();
404 break;
405 }
406 Ok(_) => {
407 let msg = match Message::deserialize(&buf) {
408 Ok(m) => m,
409 Err(_) => {
410 continue;
411 }
412 };
413
414 match &msg {
415 Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
416 if let Some(cb) = on_subscribes.read().unwrap().get(topic) {
417 cb(Arc::clone(&state), message.clone());
418 }
419 }
420 Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
421 if let Some(cb) = on_consumes.read().unwrap().get(topic) {
422 cb(Arc::clone(&state), message.clone());
423 }
424 }
425 _ => {}
426 }
427 }
428 Err(e) => {
429 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
430 break;
431 }
432 }
433 }
434 });
435
436 self.recv_thread = Some(recv_thread);
437
438 Ok(rx)
439 }
440
441 pub fn shutdown(&mut self) {
443 match &mut self.stream {
444 Some(Stream::TcpStream(stream)) => {
445 let _ = stream.shutdown(std::net::Shutdown::Both);
446 }
447 #[cfg(unix)]
448 Some(Stream::UnixStream(stream)) => {
449 let _ = stream.shutdown(std::net::Shutdown::Both);
450 }
451 None => {}
452 }
453
454 self.stream = None;
455 }
456
457 fn send(&mut self, msg: Message) -> Result<(), AmqError> {
458 match &mut self.stream {
459 Some(Stream::TcpStream(stream)) => {
460 let data = serde_json::to_vec(&msg).unwrap();
461 let len = (data.len() as u32).to_be_bytes();
462 stream
463 .write_all(&len)
464 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
465 stream
466 .write_all(&data)
467 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
468 Ok(())
469 }
470 #[cfg(unix)]
471 Some(Stream::UnixStream(stream)) => {
472 let data = serde_json::to_vec(&msg).unwrap();
473 let len = (data.len() as u32).to_be_bytes();
474 stream
475 .write_all(&len)
476 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
477 stream
478 .write_all(&data)
479 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
480 Ok(())
481 }
482 None => Err(AmqError::TcpSendError("not connected".to_string())),
483 }
484 }
485}