1use std::{collections::HashMap, error::Error, future::Future, pin::Pin, sync::Arc};
2
3use tokio::{
4 io::{AsyncReadExt as _, AsyncWriteExt as _},
5 net::{
6 tcp::{OwnedReadHalf as TcpOHR, OwnedWriteHalf as TcpOWH},
7 TcpStream,
8 },
9 sync::{
10 oneshot::{channel, Receiver},
11 RwLock,
12 },
13 task::{spawn, JoinHandle},
14};
15
16#[cfg(unix)]
17use tokio::net::{
18 unix::{OwnedReadHalf as UnixOHR, OwnedWriteHalf as UnixOWH},
19 UnixStream,
20};
21
22use crate::{
23 error::AmqError,
24 message::{
25 Message, MsgStatus, ReqMsgAuthorizer, ReqMsgConsumeAck, ReqMsgConsumeAckMulti,
26 ReqMsgConsumerTopic, ReqMsgPublish, ReqMsgSubscriber, ReqMsgUnconsumerTopic,
27 ReqMsgUnsubscriber, RespMsgConsume, RespMsgSubscribe,
28 },
29 Config,
30};
31
32type OnRecvFn<T> =
33 Arc<dyn Fn(Arc<T>, Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
34
35pub struct Client<T>
102where
103 T: Send + Sync + 'static,
104{
105 config: Config,
106 state: Arc<T>,
107 stream: Option<Owh>,
108 recv_task: Option<JoinHandle<()>>,
109 on_subscribes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
110 on_consumes: Arc<RwLock<HashMap<String, OnRecvFn<T>>>>,
111}
112
113enum Owh {
114 Tcp(TcpOWH),
115 #[cfg(unix)]
116 Unix(UnixOWH),
117}
118
119impl<T> Client<T>
120where
121 T: Send + Sync + 'static,
122{
123 pub fn new(config: Config, state: Arc<T>) -> Self {
125 Self {
126 config,
127 state,
128 stream: None,
129 recv_task: None,
130 on_subscribes: Arc::new(RwLock::new(HashMap::new())),
131 on_consumes: Arc::new(RwLock::new(HashMap::new())),
132 }
133 }
134
135 pub async fn subscribe<F, Fut>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
143 where
144 F: Fn(Arc<T>, Vec<u8>) -> Fut + Send + Sync + 'static,
145 Fut: Future<Output = ()> + Send + 'static,
146 {
147 let handler: OnRecvFn<T> =
148 Arc::new(move |state: Arc<T>, data: Vec<u8>| Box::pin(f(state, data)));
149 self.on_subscribes
150 .write()
151 .await
152 .insert(topic.to_string(), handler);
153 let message = Message::ReqSubscribeTopic(ReqMsgSubscriber {
154 topic: topic.to_string(),
155 });
156 self.send(message).await?;
157 Ok(())
158 }
159
160 pub async fn unsubscribe(&mut self, topic: &str) -> Result<(), AmqError> {
166 let message = Message::ReqUnsubscribeTopic(ReqMsgUnsubscriber {
167 topic: topic.to_string(),
168 });
169 self.send(message).await?;
170 self.on_subscribes.write().await.remove(topic);
171 Ok(())
172 }
173
174 pub async fn publish(&mut self, topic: &str, content: Vec<u8>) -> Result<(), AmqError> {
180 let message = Message::ReqPublish(ReqMsgPublish {
181 topic: topic.to_string(),
182 message: content,
183 });
184 self.send(message).await?;
185 Ok(())
186 }
187
188 pub async fn consume<F, Fut>(&mut self, topic: &str, f: F) -> Result<(), AmqError>
196 where
197 F: Fn(Arc<T>, Vec<u8>) -> Fut + Send + Sync + 'static,
198 Fut: Future<Output = ()> + Send + 'static,
199 {
200 let handler: OnRecvFn<T> =
201 Arc::new(move |state: Arc<T>, data: Vec<u8>| Box::pin(f(state, data)));
202 self.on_consumes
203 .write()
204 .await
205 .insert(topic.to_string(), handler);
206 let message = Message::ReqConsumerTopic(ReqMsgConsumerTopic {
207 topic: topic.to_string(),
208 });
209 self.send(message).await?;
210 Ok(())
211 }
212
213 pub async fn unconsume(&mut self, topic: &str) -> Result<(), AmqError> {
218 let message = Message::ReqUnconsumerTopic(ReqMsgUnconsumerTopic {
219 topic: topic.to_string(),
220 });
221 self.send(message).await?;
222 self.on_consumes.write().await.remove(topic);
223 Ok(())
224 }
225
226 pub async fn ack(&mut self, message_id: u64) -> Result<(), AmqError> {
232 let message = Message::ReqConsumeAck(ReqMsgConsumeAck { id: message_id });
233 self.send(message).await?;
234 Ok(())
235 }
236
237 pub async fn ack_multi(&mut self, message_ids: Vec<u64>) -> Result<(), AmqError> {
243 let message = Message::ReqConsumeAckMulti(ReqMsgConsumeAckMulti { ids: message_ids });
244 self.send(message).await?;
245 Ok(())
246 }
247
248 pub async fn connect(&mut self) -> Result<Receiver<AmqError>, Box<dyn Error>> {
250 if self.config.path.is_empty() {
251 let addr = self.config.get_address();
252 let stream = TcpStream::connect(addr).await?;
253 let (reader, writer) = stream.into_split();
254 self.tcp_conn(reader, writer).await
255 } else {
256 #[cfg(unix)]
257 {
258 let stream = UnixStream::connect(&self.config.path).await?;
259 let (reader, writer) = stream.into_split();
260 self.unix_conn(reader, writer).await
261 }
262 #[cfg(not(unix))]
263 {
264 Err(Box::new(AmqError::UnsupportedPlatform))
265 }
266 }
267 }
268
269 async fn tcp_conn(
270 &mut self,
271 mut reader: TcpOHR,
272 mut writer: TcpOWH,
273 ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
274 let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
275 access_key: self.config.access_key.clone(),
276 access_secret: self.config.access_secret.clone(),
277 });
278 let message = &serde_json::to_vec(&msg)?;
279 writer.write_u32(message.len() as u32).await?;
280 writer.write_all(&message).await?;
281
282 let len = reader.read_u32().await?;
284 let mut buf = vec![0; len as usize];
286 let _ = reader.read(&mut buf).await?;
287 match Message::deserialize(&buf)? {
288 Message::RespAuthorizer(resp) => {
289 if resp.status != MsgStatus::Success {
290 return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
291 }
292 }
293 _ => {
294 return Err(Box::new(AmqError::AuthorizationError(
295 "Invalid response".to_string(),
296 )));
297 }
298 }
299
300 self.stream = Some(Owh::Tcp(writer));
301
302 let (tx, rx) = channel::<AmqError>();
303
304 let on_subscribes = self.on_subscribes.clone();
305 let on_consumes = self.on_consumes.clone();
306 let state = Arc::clone(&self.state);
307 let recv_task = spawn(async move {
308 loop {
309 let len = match reader.read_u32().await {
311 Ok(l) => l as usize,
312 Err(e) => {
313 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
314 break;
315 }
316 };
317 let mut buf = vec![0; len];
319 match reader.read(&mut buf).await {
320 Ok(0) => {
321 tx.send(AmqError::TcpServerClosed).unwrap();
322 break;
323 }
324 Ok(_) => {
325 let msg = match Message::deserialize(&buf) {
326 Ok(m) => m,
327 Err(_) => {
328 continue;
329 }
330 };
331 match &msg {
332 Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
333 if let Some(cb) = on_subscribes.read().await.get(topic) {
334 cb(Arc::clone(&state), message.clone()).await;
335 }
336 }
337 Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
338 if let Some(cb) = on_consumes.read().await.get(topic) {
339 cb(Arc::clone(&state), message.clone()).await;
340 }
341 }
342 _ => {}
343 }
344 }
345 Err(e) => {
346 tx.send(AmqError::TcpServerError(e.to_string())).unwrap();
347 break;
348 }
349 }
350 }
351 });
352
353 self.recv_task = Some(recv_task);
354
355 Ok(rx)
356 }
357
358 #[cfg(unix)]
359 async fn unix_conn(
360 &mut self,
361 mut reader: UnixOHR,
362 mut writer: UnixOWH,
363 ) -> Result<Receiver<AmqError>, Box<dyn Error>> {
364 let msg = Message::ReqAuthorizer(ReqMsgAuthorizer {
365 access_key: self.config.access_key.clone(),
366 access_secret: self.config.access_secret.clone(),
367 });
368 let message = &serde_json::to_vec(&msg)?;
369 writer.write_u32(message.len() as u32).await?;
370 writer.write_all(&message).await?;
371
372 let len = reader.read_u32().await?;
374 let mut buf = vec![0; len as usize];
376 let _ = reader.read(&mut buf).await?;
377 match Message::deserialize(&buf)? {
378 Message::RespAuthorizer(resp) => {
379 if resp.status != MsgStatus::Success {
380 return Err(Box::new(AmqError::AuthorizationError(resp.msg)));
381 }
382 }
383 _ => {
384 return Err(Box::new(AmqError::AuthorizationError(
385 "Invalid response".to_string(),
386 )));
387 }
388 }
389
390 self.stream = Some(Owh::Unix(writer));
391
392 let (tx, rx) = channel::<AmqError>();
393
394 let on_subscribes = self.on_subscribes.clone();
395 let on_consumes = self.on_consumes.clone();
396 let state = Arc::clone(&self.state);
397 let recv_task = spawn(async move {
398 loop {
399 let len = match reader.read_u32().await {
401 Ok(l) => l as usize,
402 Err(e) => {
403 tx.send(AmqError::TcpReceiveError(e.to_string())).unwrap();
404 break;
405 }
406 };
407 let mut buf = vec![0; len];
409 match reader.read(&mut buf).await {
410 Ok(0) => {
411 tx.send(AmqError::TcpServerClosed).unwrap();
412 break;
413 }
414 Ok(_) => {
415 let msg = match Message::deserialize(&buf) {
416 Ok(m) => m,
417 Err(_) => {
418 continue;
419 }
420 };
421 match &msg {
422 Message::RespSubscribe(RespMsgSubscribe { topic, message, .. }) => {
423 if let Some(cb) = on_subscribes.read().await.get(topic) {
424 cb(Arc::clone(&state), message.clone()).await;
425 }
426 }
427 Message::RespConsume(RespMsgConsume { topic, message, .. }) => {
428 if let Some(cb) = on_consumes.read().await.get(topic) {
429 cb(Arc::clone(&state), message.clone()).await;
430 }
431 }
432 _ => {}
433 }
434 }
435 Err(e) => {
436 tx.send(AmqError::TcpServerError(e.to_string())).unwrap();
437 break;
438 }
439 }
440 }
441 });
442
443 self.recv_task = Some(recv_task);
444
445 Ok(rx)
446 }
447
448 pub async fn shutdown(&mut self) {
450 if let Some(task) = self.recv_task.take() {
451 task.abort();
452 }
453
454 self.stream = None;
455 }
456
457 async fn send(&mut self, msg: Message) -> Result<(), AmqError> {
458 match &mut self.stream {
459 Some(Owh::Tcp(writer)) => {
460 let message = &serde_json::to_vec(&msg)
461 .map_err(|e| AmqError::TcpSendDataError(e.to_string()))?;
462 let _ = writer.write_u32(message.len() as u32).await;
463 writer
464 .write_all(&message)
465 .await
466 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
467 }
468 #[cfg(unix)]
469 Some(Owh::Unix(writer)) => {
470 let message = &serde_json::to_vec(&msg)
471 .map_err(|e| AmqError::TcpSendDataError(e.to_string()))?;
472 let _ = writer.write_u32(message.len() as u32).await;
473 writer
474 .write_all(&message)
475 .await
476 .map_err(|e| AmqError::TcpReceiveError(e.to_string()))?;
477 }
478 None => {
479 return Err(AmqError::TcpSendError("not connected".to_string()));
480 }
481 }
482 Ok(())
483 }
484}