1use std::{
2 collections::HashMap,
3 io,
4 pin::Pin,
5 sync::{atomic::AtomicU64, Arc},
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9use std::{future::Future, sync::atomic::Ordering};
10
11use futures::{
12 stream::{SplitSink, SplitStream},
13 FutureExt, Stream, StreamExt, TryFutureExt,
14};
15use pin_project::pin_project;
16use rabbitmq_stream_protocol::commands::exchange_command_versions::{
17 ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse,
18};
19use tokio::io::AsyncRead;
20use tokio::io::AsyncWrite;
21use tokio::io::ReadBuf;
22use tokio::sync::RwLock;
23use tokio::{net::TcpStream, sync::Notify};
24use tokio_rustls::client::TlsStream;
25
26use tokio_util::codec::Framed;
27use tracing::{trace, warn};
28
29use crate::{error::ClientError, RabbitMQStreamResult};
30pub use message::ClientMessage;
31pub use metadata::{Broker, StreamMetadata};
32pub use metrics::MetricsCollector;
33pub use options::{ClientOptions, TlsConfiguration, TlsConfigurationBuilder};
34use rabbitmq_stream_protocol::{
35 commands::{
36 close::{CloseRequest, CloseResponse},
37 consumer_update_request::ConsumerUpdateRequestCommand,
38 create_stream::CreateStreamCommand,
39 create_super_stream::CreateSuperStreamCommand,
40 credit::CreditCommand,
41 declare_publisher::DeclarePublisherCommand,
42 delete::Delete,
43 delete_publisher::DeletePublisherCommand,
44 delete_super_stream::DeleteSuperStreamCommand,
45 generic::GenericResponse,
46 heart_beat::HeartBeatCommand,
47 metadata::MetadataCommand,
48 open::{OpenCommand, OpenResponse},
49 peer_properties::{PeerPropertiesCommand, PeerPropertiesResponse},
50 publish::PublishCommand,
51 query_offset::{QueryOffsetRequest, QueryOffsetResponse},
52 query_publisher_sequence::{QueryPublisherRequest, QueryPublisherResponse},
53 sasl_authenticate::SaslAuthenticateCommand,
54 sasl_handshake::{SaslHandshakeCommand, SaslHandshakeResponse},
55 store_offset::StoreOffset,
56 subscribe::{OffsetSpecification, SubscribeCommand},
57 superstream_partitions::SuperStreamPartitionsRequest,
58 superstream_partitions::SuperStreamPartitionsResponse,
59 superstream_route::SuperStreamRouteRequest,
60 superstream_route::SuperStreamRouteResponse,
61 tune::TunesCommand,
62 unsubscribe::UnSubscribeCommand,
63 },
64 types::PublishedMessage,
65 FromResponse, Request, Response, ResponseCode, ResponseKind,
66};
67
68pub use self::handler::{MessageHandler, MessageResult};
69use self::{
70 channel::{channel, ChannelReceiver, ChannelSender},
71 codec::RabbitMqStreamCodec,
72 dispatcher::Dispatcher,
73 message::BaseMessage,
74};
75
76mod channel;
77mod codec;
78mod dispatcher;
79mod handler;
80mod message;
81mod metadata;
82mod metrics;
83mod options;
84mod task;
85
86#[pin_project(project = StreamProj)]
87#[derive(Debug)]
88pub enum GenericTcpStream {
89 Tcp(#[pin] TcpStream),
90 SecureTcp(#[pin] Box<TlsStream<TcpStream>>),
91}
92
93impl AsyncRead for GenericTcpStream {
94 fn poll_read(
95 self: Pin<&mut Self>,
96 cx: &mut Context<'_>,
97 buf: &mut ReadBuf<'_>,
98 ) -> Poll<io::Result<()>> {
99 match self.project() {
100 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_read(cx, buf),
101 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_read(cx, buf),
102 }
103 }
104}
105
106impl AsyncWrite for GenericTcpStream {
107 fn poll_write(
108 self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 buf: &[u8],
111 ) -> Poll<io::Result<usize>> {
112 match self.project() {
113 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_write(cx, buf),
114 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_write(cx, buf),
115 }
116 }
117
118 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 match self.project() {
120 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_flush(cx),
121 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_flush(cx),
122 }
123 }
124
125 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
126 match self.project() {
127 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_shutdown(cx),
128 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_shutdown(cx),
129 }
130 }
131}
132
133type SinkConnection = SplitSink<Framed<GenericTcpStream, RabbitMqStreamCodec>, Request>;
134type StreamConnection = SplitStream<Framed<GenericTcpStream, RabbitMqStreamCodec>>;
135
136pub struct ClientState {
137 server_properties: HashMap<String, String>,
138 connection_properties: HashMap<String, String>,
139 handler: Option<Arc<dyn MessageHandler>>,
140 heartbeat: u32,
141 max_frame_size: u32,
142 last_heatbeat: Instant,
143 heartbeat_task: Option<task::TaskHandle>,
144 last_received_message: Arc<RwLock<Instant>>,
145}
146
147#[derive(Clone)]
151pub struct Client {
152 dispatcher: Dispatcher<Client>,
153 channel: Arc<ChannelSender<SinkConnection>>,
154 state: Arc<RwLock<ClientState>>,
155 opts: ClientOptions,
156 tune_notifier: Arc<Notify>,
157 publish_sequence: Arc<AtomicU64>,
158 filtering_supported: bool,
159 client_properties: HashMap<String, String>,
160}
161
162impl Client {
163 pub async fn connect(opts: impl Into<ClientOptions>) -> Result<Client, ClientError> {
164 let broker = opts.into();
165
166 let (sender, receiver) = Client::create_connection(&broker).await?;
167
168 let last_received_message = Arc::new(RwLock::new(Instant::now()));
169
170 let dispatcher = Dispatcher::new();
171 let state = ClientState {
172 server_properties: HashMap::new(),
173 connection_properties: HashMap::new(),
174 handler: None,
175 heartbeat: broker.heartbeat,
176 max_frame_size: broker.max_frame_size,
177 last_heatbeat: Instant::now(),
178 heartbeat_task: None,
179 last_received_message: last_received_message.clone(),
180 };
181 let mut client = Client {
182 dispatcher,
183 opts: broker,
184 channel: Arc::new(sender),
185 state: Arc::new(RwLock::new(state)),
186 tune_notifier: Arc::new(Notify::new()),
187 publish_sequence: Arc::new(AtomicU64::new(1)),
188 filtering_supported: false,
189 client_properties: HashMap::new(),
190 };
191
192 const VERSION: &str = env!("CARGO_PKG_VERSION");
193
194 client
195 .client_properties
196 .insert(String::from("product"), String::from("RabbitMQ"));
197 client
198 .client_properties
199 .insert(String::from("version"), String::from(VERSION));
200 client
201 .client_properties
202 .insert(String::from("platform"), String::from("Rust"));
203 client.client_properties.insert(
204 String::from("copyright"),
205 String::from("Copyright (c) 2017-2023 Broadcom. All Rights Reserved. The term Broadcom refers to Broadcom Inc. and/or its subsidiaries."));
206 client.client_properties.insert(
207 String::from("information"),
208 String::from(
209 "Licensed under the Apache 2.0 and MPL 2.0 licenses. See https://www.rabbitmq.com/",
210 ),
211 );
212 client.client_properties.insert(
213 String::from("connection_name"),
214 client.opts.client_provided_name.clone(),
215 );
216
217 client.initialize(receiver).await?;
218
219 let command_versions = client.exchange_command_versions().await?;
220 let (_, max_version) = command_versions.key_version(2);
221 if max_version >= 2 {
222 client.filtering_supported = true
223 }
224 Ok(client)
225 }
226
227 pub async fn server_properties(&self) -> HashMap<String, String> {
229 self.state.read().await.server_properties.clone()
230 }
231
232 pub async fn connection_properties(&self) -> HashMap<String, String> {
234 self.state.read().await.connection_properties.clone()
235 }
236
237 pub async fn set_handler<H: MessageHandler>(&self, handler: H) {
238 let mut state = self.state.write().await;
239
240 state.handler = Some(Arc::new(handler));
241 }
242
243 pub fn is_closed(&self) -> bool {
244 self.channel.is_closed()
245 }
246
247 pub(crate) fn handler_failed_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
248 self.dispatcher.handler_failed_flag()
249 }
250
251 pub async fn close(&self) -> RabbitMQStreamResult<()> {
252 if self.is_closed() {
253 return Err(ClientError::AlreadyClosed);
254 }
255 let _: CloseResponse = self
256 .send_and_receive(|correlation_id| {
257 CloseRequest::new(correlation_id, ResponseCode::Ok, "Ok".to_owned())
258 })
259 .await?;
260
261 let mut state = self.state.write().await;
262 state.heartbeat_task.take();
264 drop(state);
265
266 self.force_drop_connection().await
267 }
268
269 async fn force_drop_connection(&self) -> RabbitMQStreamResult<()> {
270 self.channel.close().await
271 }
272
273 pub async fn subscribe(
274 &self,
275 subscription_id: u8,
276 stream: &str,
277 offset_specification: OffsetSpecification,
278 credit: u16,
279 properties: HashMap<String, String>,
280 ) -> RabbitMQStreamResult<GenericResponse> {
281 self.send_and_receive(|correlation_id| {
282 SubscribeCommand::new(
283 correlation_id,
284 subscription_id,
285 stream.to_owned(),
286 offset_specification,
287 credit,
288 properties,
289 )
290 })
291 .await
292 }
293
294 pub async fn unsubscribe(&self, subscription_id: u8) -> RabbitMQStreamResult<GenericResponse> {
295 self.send_and_receive(|correlation_id| {
296 UnSubscribeCommand::new(correlation_id, subscription_id)
297 })
298 .await
299 }
300
301 pub async fn partitions(
302 &self,
303 super_stream: String,
304 ) -> RabbitMQStreamResult<SuperStreamPartitionsResponse> {
305 self.send_and_receive(|correlation_id| {
306 SuperStreamPartitionsRequest::new(correlation_id, super_stream)
307 })
308 .await
309 }
310
311 pub async fn route(
312 &self,
313 routing_key: String,
314 super_stream: String,
315 ) -> RabbitMQStreamResult<SuperStreamRouteResponse> {
316 self.send_and_receive(|correlation_id| {
317 SuperStreamRouteRequest::new(correlation_id, routing_key, super_stream)
318 })
319 .await
320 }
321
322 pub async fn create_stream(
323 &self,
324 stream: &str,
325 options: HashMap<String, String>,
326 ) -> RabbitMQStreamResult<GenericResponse> {
327 self.send_and_receive(|correlation_id| {
328 CreateStreamCommand::new(correlation_id, stream.to_owned(), options)
329 })
330 .await
331 }
332
333 pub async fn create_super_stream(
334 &self,
335 super_stream: &str,
336 partitions: Vec<String>,
337 binding_keys: Vec<String>,
338 options: HashMap<String, String>,
339 ) -> RabbitMQStreamResult<GenericResponse> {
340 self.send_and_receive(|correlation_id| {
341 CreateSuperStreamCommand::new(
342 correlation_id,
343 super_stream.to_owned(),
344 partitions,
345 binding_keys,
346 options,
347 )
348 })
349 .await
350 }
351
352 pub async fn delete_stream(&self, stream: &str) -> RabbitMQStreamResult<GenericResponse> {
353 self.send_and_receive(|correlation_id| Delete::new(correlation_id, stream.to_owned()))
354 .await
355 }
356
357 pub async fn delete_super_stream(
358 &self,
359 super_stream: &str,
360 ) -> RabbitMQStreamResult<GenericResponse> {
361 self.send_and_receive(|correlation_id| {
362 DeleteSuperStreamCommand::new(correlation_id, super_stream.to_owned())
363 })
364 .await
365 }
366
367 pub async fn credit(&self, subscription_id: u8, credit: u16) -> RabbitMQStreamResult<()> {
368 self.send(CreditCommand::new(subscription_id, credit)).await
369 }
370
371 pub async fn metadata(
372 &self,
373 streams: Vec<String>,
374 ) -> RabbitMQStreamResult<HashMap<String, StreamMetadata>> {
375 self.send_and_receive(|correlation_id| MetadataCommand::new(correlation_id, streams))
376 .await
377 .map(metadata::from_response)
378 }
379
380 pub async fn store_offset(
381 &self,
382 reference: &str,
383 stream: &str,
384 offset: u64,
385 ) -> RabbitMQStreamResult<()> {
386 self.send(StoreOffset::new(
387 reference.to_owned(),
388 stream.to_owned(),
389 offset,
390 ))
391 .await
392 }
393
394 pub async fn query_offset(&self, reference: String, stream: &str) -> Result<u64, ClientError> {
395 let response = self
396 .send_and_receive::<QueryOffsetResponse, _, _>(|correlation_id| {
397 QueryOffsetRequest::new(correlation_id, reference, stream.to_owned())
398 })
399 .await?;
400
401 if !response.is_ok() {
402 Err(ClientError::RequestError(response.code().clone()))
403 } else {
404 Ok(response.from_response())
405 }
406 }
407
408 pub async fn declare_publisher(
409 &self,
410 publisher_id: u8,
411 publisher_reference: Option<String>,
412 stream: &str,
413 ) -> RabbitMQStreamResult<GenericResponse> {
414 self.send_and_receive(|correlation_id| {
415 DeclarePublisherCommand::new(
416 correlation_id,
417 publisher_id,
418 publisher_reference,
419 stream.to_owned(),
420 )
421 })
422 .await
423 }
424
425 pub async fn delete_publisher(
426 &self,
427 publisher_id: u8,
428 ) -> RabbitMQStreamResult<GenericResponse> {
429 self.send_and_receive(|correlation_id| {
430 DeletePublisherCommand::new(correlation_id, publisher_id)
431 })
432 .await
433 }
434
435 pub async fn publish<T: BaseMessage>(
436 &self,
437 publisher_id: u8,
438 messages: impl Into<Vec<T>>,
439 version: u16,
440 ) -> RabbitMQStreamResult<Vec<u64>> {
441 let messages: Vec<PublishedMessage> = messages
442 .into()
443 .into_iter()
444 .map(|message| {
445 let publishing_id: u64 = message
446 .publishing_id()
447 .unwrap_or_else(|| self.publish_sequence.fetch_add(1, Ordering::Relaxed));
448 let filter_value = message.filter_value();
449 PublishedMessage::new(publishing_id, message.to_message(), filter_value)
450 })
451 .collect();
452 let sequences = messages
453 .iter()
454 .map(rabbitmq_stream_protocol::types::PublishedMessage::publishing_id)
455 .collect();
456 let len = messages.len();
457
458 self.send(PublishCommand::new(publisher_id, messages, version))
460 .await?;
461
462 self.opts.collector.publish(len as u64).await;
463
464 Ok(sequences)
465 }
466
467 pub async fn query_publisher_sequence(
468 &self,
469 reference: &str,
470 stream: &str,
471 ) -> Result<u64, ClientError> {
472 self.send_and_receive::<QueryPublisherResponse, _, _>(|correlation_id| {
473 QueryPublisherRequest::new(correlation_id, reference.to_owned(), stream.to_owned())
474 })
475 .await
476 .map(|sequence| sequence.from_response())
477 }
478
479 pub async fn exchange_command_versions(
480 &self,
481 ) -> RabbitMQStreamResult<ExchangeCommandVersionsResponse> {
482 self.send_and_receive::<ExchangeCommandVersionsResponse, _, _>(|correlation_id| {
483 ExchangeCommandVersionsRequest::new(correlation_id, vec![])
484 })
485 .await
486 }
487
488 pub fn filtering_supported(&self) -> bool {
489 self.filtering_supported
490 }
491
492 pub async fn set_heartbeat(&self, heartbeat: u32) {
494 let mut state = self.state.write().await;
495 state.heartbeat = heartbeat;
496 state.heartbeat_task =
498 self.start_hearbeat_task(heartbeat, state.last_received_message.clone());
499 }
500
501 async fn create_connection(
502 broker: &ClientOptions,
503 ) -> Result<
504 (
505 ChannelSender<SinkConnection>,
506 ChannelReceiver<StreamConnection>,
507 ),
508 ClientError,
509 > {
510 let stream = broker.build_generic_tcp_stream().await?;
511 let stream = Framed::new(stream, RabbitMqStreamCodec {});
512
513 let (sink, stream) = stream.split();
514 let (tx, rx) = channel(sink, stream);
515
516 Ok((tx, rx))
517 }
518
519 async fn initialize<T>(&mut self, receiver: ChannelReceiver<T>) -> Result<(), ClientError>
520 where
521 T: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
522 T: 'static,
523 {
524 self.dispatcher.set_handler(self.clone()).await;
525 self.dispatcher.start(receiver).await;
526
527 self.with_state_lock(self.peer_properties(), move |state, server_properties| {
528 state.server_properties = server_properties;
529 })
530 .await?;
531 self.authenticate().await?;
532
533 self.wait_for_tune_data().await?;
534
535 self.with_state_lock(self.open(), |state, connection_properties| {
536 state.connection_properties = connection_properties;
537 })
538 .await?;
539
540 let mut state = self.state.write().await;
542 state.heartbeat_task =
543 self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
544 drop(state);
545
546 Ok(())
547 }
548
549 async fn with_state_lock<T>(
550 &self,
551 task: impl Future<Output = RabbitMQStreamResult<T>>,
552 mut updater: impl FnMut(&mut ClientState, T),
553 ) -> RabbitMQStreamResult<()> {
554 let result = task.await?;
555
556 let mut state = self.state.write().await;
557
558 updater(&mut state, result);
559
560 Ok(())
561 }
562
563 fn negotiate_value(&self, client: u32, server: u32) -> u32 {
564 match (client, server) {
565 (client, server) if client == 0 || server == 0 => client.max(server),
566 (client, server) => client.min(server),
567 }
568 }
569
570 async fn wait_for_tune_data(&mut self) -> Result<(), ClientError> {
571 self.tune_notifier.notified().await;
572 Ok(())
573 }
574
575 async fn authenticate(&self) -> Result<(), ClientError> {
576 self.sasl_mechanism()
577 .and_then(|mechanisms| self.handle_authentication(mechanisms))
578 .await
579 }
580
581 async fn handle_authentication(&self, _mechanism: Vec<String>) -> Result<(), ClientError> {
582 let auth_data = format!("\u{0000}{}\u{0000}{}", self.opts.user, self.opts.password);
583
584 let response = self
585 .send_and_receive::<GenericResponse, _, _>(|correlation_id| {
586 SaslAuthenticateCommand::new(
587 correlation_id,
588 "PLAIN".to_owned(),
589 auth_data.as_bytes().to_vec(),
590 )
591 })
592 .await?;
593
594 if response.is_ok() {
595 Ok(())
596 } else {
597 Err(ClientError::RequestError(response.code().clone()))
598 }
599 }
600
601 async fn sasl_mechanism(&self) -> Result<Vec<String>, ClientError> {
602 self.send_and_receive::<SaslHandshakeResponse, _, _>(|correlation_id| {
603 SaslHandshakeCommand::new(correlation_id)
604 })
605 .await
606 .map(|handshake| handshake.mechanisms)
607 }
608
609 async fn send_and_receive<T, R, M>(&self, msg_factory: M) -> Result<T, ClientError>
610 where
611 R: Into<Request>,
612 T: FromResponse,
613 M: FnOnce(u32) -> R,
614 {
615 let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else {
616 trace!("Connection is closed here");
617 return Err(ClientError::ConnectionClosed);
618 };
619
620 self.channel
621 .send(msg_factory(correlation_id).into())
622 .await?;
623
624 let response = receiver.recv().await.ok_or(ClientError::ConnectionClosed)?;
625
626 self.handle_response::<T>(response).await
627 }
628
629 async fn send<R>(&self, msg: R) -> Result<(), ClientError>
630 where
631 R: Into<Request>,
632 {
633 self.channel.send(msg.into()).await?;
634 Ok(())
635 }
636
637 async fn handle_response<T: FromResponse>(&self, response: Response) -> Result<T, ClientError> {
638 response.get::<T>().ok_or_else(|| {
639 ClientError::CastError(format!(
640 "Cannot cast response to {}",
641 std::any::type_name::<T>()
642 ))
643 })
644 }
645
646 async fn open(&self) -> Result<HashMap<String, String>, ClientError> {
647 self.send_and_receive::<OpenResponse, _, _>(|correlation_id| {
648 OpenCommand::new(correlation_id, self.opts.v_host.clone())
649 })
650 .await
651 .and_then(|open| {
652 if open.is_ok() {
653 Ok(open.connection_properties)
654 } else {
655 Err(ClientError::RequestError(open.code().clone()))
656 }
657 })
658 }
659
660 async fn peer_properties(&self) -> Result<HashMap<String, String>, ClientError> {
661 self.send_and_receive::<PeerPropertiesResponse, _, _>(|correlation_id| {
662 PeerPropertiesCommand::new(correlation_id, self.client_properties.clone())
663 })
664 .await
665 .map(|peer_properties| peer_properties.server_properties)
666 }
667
668 async fn handle_tune_command(&self, tunes: &TunesCommand) {
669 let mut state = self.state.write().await;
670 state.heartbeat = self.negotiate_value(self.opts.heartbeat, tunes.heartbeat);
671 state.max_frame_size = self.negotiate_value(self.opts.max_frame_size, tunes.max_frame_size);
672
673 let heart_beat = state.heartbeat;
674 let max_frame_size = state.max_frame_size;
675
676 trace!(
677 "Handling tune with frame size {} and heartbeat {}",
678 max_frame_size,
679 heart_beat
680 );
681
682 if state.heartbeat_task.take().is_some() {
683 state.heartbeat_task =
685 self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
686 }
687
688 drop(state);
689
690 let _ = self
691 .channel
692 .send(TunesCommand::new(max_frame_size, heart_beat).into())
693 .await;
694
695 self.tune_notifier.notify_one();
696 }
697
698 fn start_hearbeat_task(
699 &self,
700 heartbeat: u32,
701 last_received_message: Arc<RwLock<Instant>>,
702 ) -> Option<task::TaskHandle> {
703 if heartbeat == 0 {
704 return None;
705 }
706 let heartbeat_interval = (heartbeat / 2).max(1);
707 let channel = self.channel.clone();
708
709 let client = self.clone();
710
711 let heartbeat_task: task::TaskHandle = tokio::spawn(async move {
712 let timeout_threashold = u64::from(heartbeat * 4);
713
714 loop {
715 trace!("Sending heartbeat");
716 if channel
717 .send(HeartBeatCommand::default().into())
718 .await
719 .is_err()
720 {
721 break;
722 }
723 tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
724
725 let now = Instant::now();
726 let last_message = last_received_message.read().await;
727 if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) {
728 warn!("Heartbeat timeout reached. Force closing connection.");
729 if !client.is_closed() {
730 if let Err(e) = client.close().await {
731 warn!("Error closing client: {}", e);
732 }
733 }
734 break;
735 }
736 }
737
738 warn!("Heartbeat task stopped. Force closing connection");
739 })
740 .into();
741
742 Some(heartbeat_task)
743 }
744
745 async fn handle_heart_beat_command(&self) {
746 trace!("Received heartbeat");
747 let mut state = self.state.write().await;
748 state.last_heatbeat = Instant::now();
749 }
750
751 pub async fn consumer_update(
752 &self,
753 correlation_id: u32,
754 offset_specification: OffsetSpecification,
755 ) -> RabbitMQStreamResult<GenericResponse> {
756 self.send_and_receive(|_| {
757 ConsumerUpdateRequestCommand::new(correlation_id, 1, offset_specification)
758 })
759 .await
760 }
761}
762
763#[async_trait::async_trait]
764impl MessageHandler for Client {
765 async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
766 match &item {
767 Some(Ok(response)) => {
768 {
770 let s = self.state.read().await;
771 let mut last_received_message = s.last_received_message.write().await;
772 *last_received_message = Instant::now();
773 drop(last_received_message);
774 drop(s);
775 }
776
777 match response.kind_ref() {
778 ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
779 ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
780 _ => {
781 if let Some(handler) = self.state.read().await.handler.as_ref() {
782 let handler = handler.clone();
783
784 tokio::task::spawn(async move {
785 match std::panic::AssertUnwindSafe(handler.handle_message(item))
790 .catch_unwind()
791 .await
792 {
793 Ok(Ok(())) => {}
794 Ok(Err(err)) => {
795 warn!("Message handler returned error: {}", err);
796 }
797 Err(panic) => {
798 tracing::error!("Message handler panicked: {:?}", panic);
799 }
800 }
801 });
802 }
803 }
804 }
805 }
806 Some(Err(err)) => {
807 trace!(?err);
808 if let Some(handler) = self.state.read().await.handler.as_ref() {
809 let handler = handler.clone();
810
811 tokio::task::spawn(async move { handler.handle_message(item).await });
812 }
813 }
814 None => {
815 trace!("Closing client");
816 if let Some(handler) = self.state.read().await.handler.as_ref() {
817 let handler = handler.clone();
818 tokio::task::spawn(async move { handler.handle_message(None).await });
819 }
820 }
821 }
822
823 Ok(())
824 }
825}