Skip to main content

rabbitmq_stream_client/client/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    pin::Pin,
5    sync::{atomic::AtomicU64, Arc},
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9use std::{future::Future, sync::atomic::Ordering};
10
11use futures::{
12    stream::{SplitSink, SplitStream},
13    FutureExt, Stream, StreamExt, TryFutureExt,
14};
15use pin_project::pin_project;
16use rabbitmq_stream_protocol::commands::exchange_command_versions::{
17    ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse,
18};
19use tokio::io::AsyncRead;
20use tokio::io::AsyncWrite;
21use tokio::io::ReadBuf;
22use tokio::sync::RwLock;
23use tokio::{net::TcpStream, sync::Notify};
24use tokio_rustls::client::TlsStream;
25
26use tokio_util::codec::Framed;
27use tracing::{trace, warn};
28
29use crate::{error::ClientError, RabbitMQStreamResult};
30pub use message::ClientMessage;
31pub use metadata::{Broker, StreamMetadata};
32pub use metrics::MetricsCollector;
33pub use options::{ClientOptions, TlsConfiguration, TlsConfigurationBuilder};
34use rabbitmq_stream_protocol::{
35    commands::{
36        close::{CloseRequest, CloseResponse},
37        consumer_update_request::ConsumerUpdateRequestCommand,
38        create_stream::CreateStreamCommand,
39        create_super_stream::CreateSuperStreamCommand,
40        credit::CreditCommand,
41        declare_publisher::DeclarePublisherCommand,
42        delete::Delete,
43        delete_publisher::DeletePublisherCommand,
44        delete_super_stream::DeleteSuperStreamCommand,
45        generic::GenericResponse,
46        heart_beat::HeartBeatCommand,
47        metadata::MetadataCommand,
48        open::{OpenCommand, OpenResponse},
49        peer_properties::{PeerPropertiesCommand, PeerPropertiesResponse},
50        publish::PublishCommand,
51        query_offset::{QueryOffsetRequest, QueryOffsetResponse},
52        query_publisher_sequence::{QueryPublisherRequest, QueryPublisherResponse},
53        sasl_authenticate::SaslAuthenticateCommand,
54        sasl_handshake::{SaslHandshakeCommand, SaslHandshakeResponse},
55        store_offset::StoreOffset,
56        subscribe::{OffsetSpecification, SubscribeCommand},
57        superstream_partitions::SuperStreamPartitionsRequest,
58        superstream_partitions::SuperStreamPartitionsResponse,
59        superstream_route::SuperStreamRouteRequest,
60        superstream_route::SuperStreamRouteResponse,
61        tune::TunesCommand,
62        unsubscribe::UnSubscribeCommand,
63    },
64    types::PublishedMessage,
65    FromResponse, Request, Response, ResponseCode, ResponseKind,
66};
67
68pub use self::handler::{MessageHandler, MessageResult};
69use self::{
70    channel::{channel, ChannelReceiver, ChannelSender},
71    codec::RabbitMqStreamCodec,
72    dispatcher::Dispatcher,
73    message::BaseMessage,
74};
75
76mod channel;
77mod codec;
78mod dispatcher;
79mod handler;
80mod message;
81mod metadata;
82mod metrics;
83mod options;
84mod task;
85
86#[pin_project(project = StreamProj)]
87#[derive(Debug)]
88pub enum GenericTcpStream {
89    Tcp(#[pin] TcpStream),
90    SecureTcp(#[pin] Box<TlsStream<TcpStream>>),
91}
92
93impl AsyncRead for GenericTcpStream {
94    fn poll_read(
95        self: Pin<&mut Self>,
96        cx: &mut Context<'_>,
97        buf: &mut ReadBuf<'_>,
98    ) -> Poll<io::Result<()>> {
99        match self.project() {
100            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_read(cx, buf),
101            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_read(cx, buf),
102        }
103    }
104}
105
106impl AsyncWrite for GenericTcpStream {
107    fn poll_write(
108        self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        buf: &[u8],
111    ) -> Poll<io::Result<usize>> {
112        match self.project() {
113            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_write(cx, buf),
114            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_write(cx, buf),
115        }
116    }
117
118    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119        match self.project() {
120            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_flush(cx),
121            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_flush(cx),
122        }
123    }
124
125    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
126        match self.project() {
127            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_shutdown(cx),
128            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_shutdown(cx),
129        }
130    }
131}
132
133type SinkConnection = SplitSink<Framed<GenericTcpStream, RabbitMqStreamCodec>, Request>;
134type StreamConnection = SplitStream<Framed<GenericTcpStream, RabbitMqStreamCodec>>;
135
136pub struct ClientState {
137    server_properties: HashMap<String, String>,
138    connection_properties: HashMap<String, String>,
139    handler: Option<Arc<dyn MessageHandler>>,
140    heartbeat: u32,
141    max_frame_size: u32,
142    last_heatbeat: Instant,
143    heartbeat_task: Option<task::TaskHandle>,
144    last_received_message: Arc<RwLock<Instant>>,
145}
146
147/// Raw API for taking to RabbitMQ stream
148///
149/// For high level APIs check [`crate::Environment`]
150#[derive(Clone)]
151pub struct Client {
152    dispatcher: Dispatcher<Client>,
153    channel: Arc<ChannelSender<SinkConnection>>,
154    state: Arc<RwLock<ClientState>>,
155    opts: ClientOptions,
156    tune_notifier: Arc<Notify>,
157    publish_sequence: Arc<AtomicU64>,
158    filtering_supported: bool,
159    client_properties: HashMap<String, String>,
160}
161
162impl Client {
163    pub async fn connect(opts: impl Into<ClientOptions>) -> Result<Client, ClientError> {
164        let broker = opts.into();
165
166        let (sender, receiver) = Client::create_connection(&broker).await?;
167
168        let last_received_message = Arc::new(RwLock::new(Instant::now()));
169
170        let dispatcher = Dispatcher::new();
171        let state = ClientState {
172            server_properties: HashMap::new(),
173            connection_properties: HashMap::new(),
174            handler: None,
175            heartbeat: broker.heartbeat,
176            max_frame_size: broker.max_frame_size,
177            last_heatbeat: Instant::now(),
178            heartbeat_task: None,
179            last_received_message: last_received_message.clone(),
180        };
181        let mut client = Client {
182            dispatcher,
183            opts: broker,
184            channel: Arc::new(sender),
185            state: Arc::new(RwLock::new(state)),
186            tune_notifier: Arc::new(Notify::new()),
187            publish_sequence: Arc::new(AtomicU64::new(1)),
188            filtering_supported: false,
189            client_properties: HashMap::new(),
190        };
191
192        const VERSION: &str = env!("CARGO_PKG_VERSION");
193
194        client
195            .client_properties
196            .insert(String::from("product"), String::from("RabbitMQ"));
197        client
198            .client_properties
199            .insert(String::from("version"), String::from(VERSION));
200        client
201            .client_properties
202            .insert(String::from("platform"), String::from("Rust"));
203        client.client_properties.insert(
204            String::from("copyright"),
205            String::from("Copyright (c) 2017-2023 Broadcom. All Rights Reserved. The term Broadcom refers to Broadcom Inc. and/or its subsidiaries."));
206        client.client_properties.insert(
207            String::from("information"),
208            String::from(
209                "Licensed under the Apache 2.0 and MPL 2.0 licenses. See https://www.rabbitmq.com/",
210            ),
211        );
212        client.client_properties.insert(
213            String::from("connection_name"),
214            client.opts.client_provided_name.clone(),
215        );
216
217        client.initialize(receiver).await?;
218
219        let command_versions = client.exchange_command_versions().await?;
220        let (_, max_version) = command_versions.key_version(2);
221        if max_version >= 2 {
222            client.filtering_supported = true
223        }
224        Ok(client)
225    }
226
227    /// Get client's server properties.
228    pub async fn server_properties(&self) -> HashMap<String, String> {
229        self.state.read().await.server_properties.clone()
230    }
231
232    /// Get client's connection properties.
233    pub async fn connection_properties(&self) -> HashMap<String, String> {
234        self.state.read().await.connection_properties.clone()
235    }
236
237    pub async fn set_handler<H: MessageHandler>(&self, handler: H) {
238        let mut state = self.state.write().await;
239
240        state.handler = Some(Arc::new(handler));
241    }
242
243    pub fn is_closed(&self) -> bool {
244        self.channel.is_closed()
245    }
246
247    pub(crate) fn handler_failed_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
248        self.dispatcher.handler_failed_flag()
249    }
250
251    pub async fn close(&self) -> RabbitMQStreamResult<()> {
252        if self.is_closed() {
253            return Err(ClientError::AlreadyClosed);
254        }
255        let _: CloseResponse = self
256            .send_and_receive(|correlation_id| {
257                CloseRequest::new(correlation_id, ResponseCode::Ok, "Ok".to_owned())
258            })
259            .await?;
260
261        let mut state = self.state.write().await;
262        // This stop the tokio task that performs heartbeats
263        state.heartbeat_task.take();
264        drop(state);
265
266        self.force_drop_connection().await
267    }
268
269    async fn force_drop_connection(&self) -> RabbitMQStreamResult<()> {
270        self.channel.close().await
271    }
272
273    pub async fn subscribe(
274        &self,
275        subscription_id: u8,
276        stream: &str,
277        offset_specification: OffsetSpecification,
278        credit: u16,
279        properties: HashMap<String, String>,
280    ) -> RabbitMQStreamResult<GenericResponse> {
281        self.send_and_receive(|correlation_id| {
282            SubscribeCommand::new(
283                correlation_id,
284                subscription_id,
285                stream.to_owned(),
286                offset_specification,
287                credit,
288                properties,
289            )
290        })
291        .await
292    }
293
294    pub async fn unsubscribe(&self, subscription_id: u8) -> RabbitMQStreamResult<GenericResponse> {
295        self.send_and_receive(|correlation_id| {
296            UnSubscribeCommand::new(correlation_id, subscription_id)
297        })
298        .await
299    }
300
301    pub async fn partitions(
302        &self,
303        super_stream: String,
304    ) -> RabbitMQStreamResult<SuperStreamPartitionsResponse> {
305        self.send_and_receive(|correlation_id| {
306            SuperStreamPartitionsRequest::new(correlation_id, super_stream)
307        })
308        .await
309    }
310
311    pub async fn route(
312        &self,
313        routing_key: String,
314        super_stream: String,
315    ) -> RabbitMQStreamResult<SuperStreamRouteResponse> {
316        self.send_and_receive(|correlation_id| {
317            SuperStreamRouteRequest::new(correlation_id, routing_key, super_stream)
318        })
319        .await
320    }
321
322    pub async fn create_stream(
323        &self,
324        stream: &str,
325        options: HashMap<String, String>,
326    ) -> RabbitMQStreamResult<GenericResponse> {
327        self.send_and_receive(|correlation_id| {
328            CreateStreamCommand::new(correlation_id, stream.to_owned(), options)
329        })
330        .await
331    }
332
333    pub async fn create_super_stream(
334        &self,
335        super_stream: &str,
336        partitions: Vec<String>,
337        binding_keys: Vec<String>,
338        options: HashMap<String, String>,
339    ) -> RabbitMQStreamResult<GenericResponse> {
340        self.send_and_receive(|correlation_id| {
341            CreateSuperStreamCommand::new(
342                correlation_id,
343                super_stream.to_owned(),
344                partitions,
345                binding_keys,
346                options,
347            )
348        })
349        .await
350    }
351
352    pub async fn delete_stream(&self, stream: &str) -> RabbitMQStreamResult<GenericResponse> {
353        self.send_and_receive(|correlation_id| Delete::new(correlation_id, stream.to_owned()))
354            .await
355    }
356
357    pub async fn delete_super_stream(
358        &self,
359        super_stream: &str,
360    ) -> RabbitMQStreamResult<GenericResponse> {
361        self.send_and_receive(|correlation_id| {
362            DeleteSuperStreamCommand::new(correlation_id, super_stream.to_owned())
363        })
364        .await
365    }
366
367    pub async fn credit(&self, subscription_id: u8, credit: u16) -> RabbitMQStreamResult<()> {
368        self.send(CreditCommand::new(subscription_id, credit)).await
369    }
370
371    pub async fn metadata(
372        &self,
373        streams: Vec<String>,
374    ) -> RabbitMQStreamResult<HashMap<String, StreamMetadata>> {
375        self.send_and_receive(|correlation_id| MetadataCommand::new(correlation_id, streams))
376            .await
377            .map(metadata::from_response)
378    }
379
380    pub async fn store_offset(
381        &self,
382        reference: &str,
383        stream: &str,
384        offset: u64,
385    ) -> RabbitMQStreamResult<()> {
386        self.send(StoreOffset::new(
387            reference.to_owned(),
388            stream.to_owned(),
389            offset,
390        ))
391        .await
392    }
393
394    pub async fn query_offset(&self, reference: String, stream: &str) -> Result<u64, ClientError> {
395        let response = self
396            .send_and_receive::<QueryOffsetResponse, _, _>(|correlation_id| {
397                QueryOffsetRequest::new(correlation_id, reference, stream.to_owned())
398            })
399            .await?;
400
401        if !response.is_ok() {
402            Err(ClientError::RequestError(response.code().clone()))
403        } else {
404            Ok(response.from_response())
405        }
406    }
407
408    pub async fn declare_publisher(
409        &self,
410        publisher_id: u8,
411        publisher_reference: Option<String>,
412        stream: &str,
413    ) -> RabbitMQStreamResult<GenericResponse> {
414        self.send_and_receive(|correlation_id| {
415            DeclarePublisherCommand::new(
416                correlation_id,
417                publisher_id,
418                publisher_reference,
419                stream.to_owned(),
420            )
421        })
422        .await
423    }
424
425    pub async fn delete_publisher(
426        &self,
427        publisher_id: u8,
428    ) -> RabbitMQStreamResult<GenericResponse> {
429        self.send_and_receive(|correlation_id| {
430            DeletePublisherCommand::new(correlation_id, publisher_id)
431        })
432        .await
433    }
434
435    pub async fn publish<T: BaseMessage>(
436        &self,
437        publisher_id: u8,
438        messages: impl Into<Vec<T>>,
439        version: u16,
440    ) -> RabbitMQStreamResult<Vec<u64>> {
441        let messages: Vec<PublishedMessage> = messages
442            .into()
443            .into_iter()
444            .map(|message| {
445                let publishing_id: u64 = message
446                    .publishing_id()
447                    .unwrap_or_else(|| self.publish_sequence.fetch_add(1, Ordering::Relaxed));
448                let filter_value = message.filter_value();
449                PublishedMessage::new(publishing_id, message.to_message(), filter_value)
450            })
451            .collect();
452        let sequences = messages
453            .iter()
454            .map(rabbitmq_stream_protocol::types::PublishedMessage::publishing_id)
455            .collect();
456        let len = messages.len();
457
458        // TODO batch publish with max frame size check
459        self.send(PublishCommand::new(publisher_id, messages, version))
460            .await?;
461
462        self.opts.collector.publish(len as u64).await;
463
464        Ok(sequences)
465    }
466
467    pub async fn query_publisher_sequence(
468        &self,
469        reference: &str,
470        stream: &str,
471    ) -> Result<u64, ClientError> {
472        self.send_and_receive::<QueryPublisherResponse, _, _>(|correlation_id| {
473            QueryPublisherRequest::new(correlation_id, reference.to_owned(), stream.to_owned())
474        })
475        .await
476        .map(|sequence| sequence.from_response())
477    }
478
479    pub async fn exchange_command_versions(
480        &self,
481    ) -> RabbitMQStreamResult<ExchangeCommandVersionsResponse> {
482        self.send_and_receive::<ExchangeCommandVersionsResponse, _, _>(|correlation_id| {
483            ExchangeCommandVersionsRequest::new(correlation_id, vec![])
484        })
485        .await
486    }
487
488    pub fn filtering_supported(&self) -> bool {
489        self.filtering_supported
490    }
491
492    /// Don't use this method in production code.
493    pub async fn set_heartbeat(&self, heartbeat: u32) {
494        let mut state = self.state.write().await;
495        state.heartbeat = heartbeat;
496        // Eventually, this drops the previous heartbeat task
497        state.heartbeat_task =
498            self.start_hearbeat_task(heartbeat, state.last_received_message.clone());
499    }
500
501    async fn create_connection(
502        broker: &ClientOptions,
503    ) -> Result<
504        (
505            ChannelSender<SinkConnection>,
506            ChannelReceiver<StreamConnection>,
507        ),
508        ClientError,
509    > {
510        let stream = broker.build_generic_tcp_stream().await?;
511        let stream = Framed::new(stream, RabbitMqStreamCodec {});
512
513        let (sink, stream) = stream.split();
514        let (tx, rx) = channel(sink, stream);
515
516        Ok((tx, rx))
517    }
518
519    async fn initialize<T>(&mut self, receiver: ChannelReceiver<T>) -> Result<(), ClientError>
520    where
521        T: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
522        T: 'static,
523    {
524        self.dispatcher.set_handler(self.clone()).await;
525        self.dispatcher.start(receiver).await;
526
527        self.with_state_lock(self.peer_properties(), move |state, server_properties| {
528            state.server_properties = server_properties;
529        })
530        .await?;
531        self.authenticate().await?;
532
533        self.wait_for_tune_data().await?;
534
535        self.with_state_lock(self.open(), |state, connection_properties| {
536            state.connection_properties = connection_properties;
537        })
538        .await?;
539
540        // Start heartbeat task after connection is established
541        let mut state = self.state.write().await;
542        state.heartbeat_task =
543            self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
544        drop(state);
545
546        Ok(())
547    }
548
549    async fn with_state_lock<T>(
550        &self,
551        task: impl Future<Output = RabbitMQStreamResult<T>>,
552        mut updater: impl FnMut(&mut ClientState, T),
553    ) -> RabbitMQStreamResult<()> {
554        let result = task.await?;
555
556        let mut state = self.state.write().await;
557
558        updater(&mut state, result);
559
560        Ok(())
561    }
562
563    fn negotiate_value(&self, client: u32, server: u32) -> u32 {
564        match (client, server) {
565            (client, server) if client == 0 || server == 0 => client.max(server),
566            (client, server) => client.min(server),
567        }
568    }
569
570    async fn wait_for_tune_data(&mut self) -> Result<(), ClientError> {
571        self.tune_notifier.notified().await;
572        Ok(())
573    }
574
575    async fn authenticate(&self) -> Result<(), ClientError> {
576        self.sasl_mechanism()
577            .and_then(|mechanisms| self.handle_authentication(mechanisms))
578            .await
579    }
580
581    async fn handle_authentication(&self, _mechanism: Vec<String>) -> Result<(), ClientError> {
582        let auth_data = format!("\u{0000}{}\u{0000}{}", self.opts.user, self.opts.password);
583
584        let response = self
585            .send_and_receive::<GenericResponse, _, _>(|correlation_id| {
586                SaslAuthenticateCommand::new(
587                    correlation_id,
588                    "PLAIN".to_owned(),
589                    auth_data.as_bytes().to_vec(),
590                )
591            })
592            .await?;
593
594        if response.is_ok() {
595            Ok(())
596        } else {
597            Err(ClientError::RequestError(response.code().clone()))
598        }
599    }
600
601    async fn sasl_mechanism(&self) -> Result<Vec<String>, ClientError> {
602        self.send_and_receive::<SaslHandshakeResponse, _, _>(|correlation_id| {
603            SaslHandshakeCommand::new(correlation_id)
604        })
605        .await
606        .map(|handshake| handshake.mechanisms)
607    }
608
609    async fn send_and_receive<T, R, M>(&self, msg_factory: M) -> Result<T, ClientError>
610    where
611        R: Into<Request>,
612        T: FromResponse,
613        M: FnOnce(u32) -> R,
614    {
615        let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else {
616            trace!("Connection is closed here");
617            return Err(ClientError::ConnectionClosed);
618        };
619
620        self.channel
621            .send(msg_factory(correlation_id).into())
622            .await?;
623
624        let response = receiver.recv().await.ok_or(ClientError::ConnectionClosed)?;
625
626        self.handle_response::<T>(response).await
627    }
628
629    async fn send<R>(&self, msg: R) -> Result<(), ClientError>
630    where
631        R: Into<Request>,
632    {
633        self.channel.send(msg.into()).await?;
634        Ok(())
635    }
636
637    async fn handle_response<T: FromResponse>(&self, response: Response) -> Result<T, ClientError> {
638        response.get::<T>().ok_or_else(|| {
639            ClientError::CastError(format!(
640                "Cannot cast response to {}",
641                std::any::type_name::<T>()
642            ))
643        })
644    }
645
646    async fn open(&self) -> Result<HashMap<String, String>, ClientError> {
647        self.send_and_receive::<OpenResponse, _, _>(|correlation_id| {
648            OpenCommand::new(correlation_id, self.opts.v_host.clone())
649        })
650        .await
651        .and_then(|open| {
652            if open.is_ok() {
653                Ok(open.connection_properties)
654            } else {
655                Err(ClientError::RequestError(open.code().clone()))
656            }
657        })
658    }
659
660    async fn peer_properties(&self) -> Result<HashMap<String, String>, ClientError> {
661        self.send_and_receive::<PeerPropertiesResponse, _, _>(|correlation_id| {
662            PeerPropertiesCommand::new(correlation_id, self.client_properties.clone())
663        })
664        .await
665        .map(|peer_properties| peer_properties.server_properties)
666    }
667
668    async fn handle_tune_command(&self, tunes: &TunesCommand) {
669        let mut state = self.state.write().await;
670        state.heartbeat = self.negotiate_value(self.opts.heartbeat, tunes.heartbeat);
671        state.max_frame_size = self.negotiate_value(self.opts.max_frame_size, tunes.max_frame_size);
672
673        let heart_beat = state.heartbeat;
674        let max_frame_size = state.max_frame_size;
675
676        trace!(
677            "Handling tune with frame size {} and heartbeat {}",
678            max_frame_size,
679            heart_beat
680        );
681
682        if state.heartbeat_task.take().is_some() {
683            // Start heartbeat task after connection is established
684            state.heartbeat_task =
685                self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
686        }
687
688        drop(state);
689
690        let _ = self
691            .channel
692            .send(TunesCommand::new(max_frame_size, heart_beat).into())
693            .await;
694
695        self.tune_notifier.notify_one();
696    }
697
698    fn start_hearbeat_task(
699        &self,
700        heartbeat: u32,
701        last_received_message: Arc<RwLock<Instant>>,
702    ) -> Option<task::TaskHandle> {
703        if heartbeat == 0 {
704            return None;
705        }
706        let heartbeat_interval = (heartbeat / 2).max(1);
707        let channel = self.channel.clone();
708
709        let client = self.clone();
710
711        let heartbeat_task: task::TaskHandle = tokio::spawn(async move {
712            let timeout_threashold = u64::from(heartbeat * 4);
713
714            loop {
715                trace!("Sending heartbeat");
716                if channel
717                    .send(HeartBeatCommand::default().into())
718                    .await
719                    .is_err()
720                {
721                    break;
722                }
723                tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
724
725                let now = Instant::now();
726                let last_message = last_received_message.read().await;
727                if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) {
728                    warn!("Heartbeat timeout reached. Force closing connection.");
729                    if !client.is_closed() {
730                        if let Err(e) = client.close().await {
731                            warn!("Error closing client: {}", e);
732                        }
733                    }
734                    break;
735                }
736            }
737
738            warn!("Heartbeat task stopped. Force closing connection");
739        })
740        .into();
741
742        Some(heartbeat_task)
743    }
744
745    async fn handle_heart_beat_command(&self) {
746        trace!("Received heartbeat");
747        let mut state = self.state.write().await;
748        state.last_heatbeat = Instant::now();
749    }
750
751    pub async fn consumer_update(
752        &self,
753        correlation_id: u32,
754        offset_specification: OffsetSpecification,
755    ) -> RabbitMQStreamResult<GenericResponse> {
756        self.send_and_receive(|_| {
757            ConsumerUpdateRequestCommand::new(correlation_id, 1, offset_specification)
758        })
759        .await
760    }
761}
762
763#[async_trait::async_trait]
764impl MessageHandler for Client {
765    async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
766        match &item {
767            Some(Ok(response)) => {
768                // Update last received message time: needed for heartbeat task
769                {
770                    let s = self.state.read().await;
771                    let mut last_received_message = s.last_received_message.write().await;
772                    *last_received_message = Instant::now();
773                    drop(last_received_message);
774                    drop(s);
775                }
776
777                match response.kind_ref() {
778                    ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
779                    ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
780                    _ => {
781                        if let Some(handler) = self.state.read().await.handler.as_ref() {
782                            let handler = handler.clone();
783
784                            tokio::task::spawn(async move {
785                                // We want to log any panic that happens in the user provided handler.
786                                //
787                                // NB: tokio::task::spawn catches panics and prevents them from crashing the process,
788                                //     but we want to log them for debugging purposes.
789                                match std::panic::AssertUnwindSafe(handler.handle_message(item))
790                                    .catch_unwind()
791                                    .await
792                                {
793                                    Ok(Ok(())) => {}
794                                    Ok(Err(err)) => {
795                                        warn!("Message handler returned error: {}", err);
796                                    }
797                                    Err(panic) => {
798                                        tracing::error!("Message handler panicked: {:?}", panic);
799                                    }
800                                }
801                            });
802                        }
803                    }
804                }
805            }
806            Some(Err(err)) => {
807                trace!(?err);
808                if let Some(handler) = self.state.read().await.handler.as_ref() {
809                    let handler = handler.clone();
810
811                    tokio::task::spawn(async move { handler.handle_message(item).await });
812                }
813            }
814            None => {
815                trace!("Closing client");
816                if let Some(handler) = self.state.read().await.handler.as_ref() {
817                    let handler = handler.clone();
818                    tokio::task::spawn(async move { handler.handle_message(None).await });
819                }
820            }
821        }
822
823        Ok(())
824    }
825}