cloudmqtt/
client.rs

1//
2//   This Source Code Form is subject to the terms of the Mozilla Public
3//   License, v. 2.0. If a copy of the MPL was not distributed with this
4//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
5//
6
7use std::{pin::Pin, sync::Arc, time::Duration};
8
9use dashmap::DashSet;
10use mqtt_format::v3::{
11    identifier::MPacketIdentifier,
12    packet::{
13        MConnack, MConnect, MPacket, MPingreq, MPuback, MPubcomp, MPublish, MPubrec, MPubrel,
14        MSubscribe,
15    },
16    qos::MQualityOfService,
17    strings::MString,
18    subscription_request::{MSubscriptionRequest, MSubscriptionRequests},
19    will::MLastWill,
20};
21use tokio::{
22    io::{DuplexStream, ReadHalf, WriteHalf},
23    net::{TcpStream, ToSocketAddrs},
24    sync::Mutex,
25};
26use tokio_util::sync::CancellationToken;
27use tracing::trace;
28
29use crate::packet_stream::{NoOPAck, PacketStreamBuilder};
30use crate::{error::MqttError, mqtt_stream::MqttStream};
31
32pub struct MqttClient {
33    session_present: bool,
34    client_receiver: Mutex<Option<ReadHalf<MqttStream>>>,
35    client_sender: Arc<Mutex<Option<WriteHalf<MqttStream>>>>,
36    received_packets: DashSet<u16>,
37    keep_alive_duration: u16,
38}
39
40impl std::fmt::Debug for MqttClient {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("MqttClient")
43            .field("session_present", &self.session_present)
44            .field("keep_alive_duration", &self.keep_alive_duration)
45            .finish_non_exhaustive()
46    }
47}
48
49impl MqttClient {
50    async fn do_v3_connect(
51        packet: MPacket<'_>,
52        stream: MqttStream,
53        keep_alive_duration: u16,
54    ) -> Result<MqttClient, MqttError> {
55        let (mut read_half, mut write_half) = tokio::io::split(stream);
56
57        crate::write_packet(&mut write_half, packet).await?;
58
59        let maybe_connect = crate::read_one_packet(&mut read_half).await?;
60
61        let session_present = match maybe_connect.get_packet() {
62            MPacket::Connack(MConnack {
63                session_present,
64                connect_return_code,
65            }) => match connect_return_code {
66                mqtt_format::v3::connect_return::MConnectReturnCode::Accepted => session_present,
67                code => return Err(MqttError::ConnectionRejected(*code)),
68            },
69            _ => return Err(MqttError::InvalidConnectionResponse),
70        };
71
72        Ok(MqttClient {
73            session_present: *session_present,
74            client_receiver: Mutex::new(Some(read_half)),
75            client_sender: Arc::new(Mutex::new(Some(write_half))),
76            keep_alive_duration,
77            received_packets: DashSet::new(),
78        })
79    }
80
81    pub async fn connect_v3_duplex(
82        duplex: DuplexStream,
83        connection_params: MqttConnectionParams<'_>,
84    ) -> Result<MqttClient, MqttError> {
85        tracing::debug!("Connecting via duplex");
86        let packet = connection_params.to_packet();
87
88        MqttClient::do_v3_connect(
89            packet,
90            MqttStream::MemoryDuplex(duplex),
91            connection_params.keep_alive,
92        )
93        .await
94    }
95
96    pub async fn connect_v3_unsecured_tcp<Addr: ToSocketAddrs>(
97        addr: Addr,
98        connection_params: MqttConnectionParams<'_>,
99    ) -> Result<MqttClient, MqttError> {
100        let stream = TcpStream::connect(addr).await?;
101
102        tracing::debug!("Connected via TCP to {}", stream.peer_addr()?);
103
104        let packet = connection_params.to_packet();
105
106        trace!(?packet, "Connecting");
107
108        MqttClient::do_v3_connect(
109            packet,
110            MqttStream::UnsecuredTcp(stream),
111            connection_params.keep_alive,
112        )
113        .await
114    }
115
116    /// Run a heartbeat for the client
117    ///
118    /// # Return
119    ///
120    /// Returns Ok(()) only if the `cancel_token` was cancelled, otherwise does not return.
121    pub fn heartbeat(
122        &self,
123        cancel_token: Option<CancellationToken>,
124    ) -> impl std::future::Future<Output = Result<(), MqttError>> {
125        let keep_alive_duration = self.keep_alive_duration;
126        let sender = self.client_sender.clone();
127        let cancel_token = cancel_token.unwrap_or_else(CancellationToken::new);
128        async move {
129            loop {
130                tokio::select! {
131                    _ = tokio::time::sleep(Duration::from_secs(
132                        ((keep_alive_duration as u64 * 100) / 80).max(2),
133                    )) => {
134                        let mut mutex = sender.lock().await;
135
136                        let mut client_stream = match mutex.as_mut() {
137                            Some(cs) => cs,
138                            None => return Err(MqttError::ConnectionClosed),
139                        };
140                        trace!("Sending heartbeat");
141
142                        let packet = MPingreq;
143
144                        crate::write_packet(&mut client_stream, packet).await?;
145                    },
146
147                    _ = cancel_token.cancelled() => break Ok(()),
148                }
149            }
150        }
151    }
152
153    pub(crate) async fn acknowledge_packet<W: tokio::io::AsyncWrite + Unpin>(
154        mut writer: W,
155        packet: &MPacket<'_>,
156    ) -> Result<(), MqttError> {
157        match packet {
158            MPacket::Publish(MPublish {
159                qos: MQualityOfService::AtMostOnce,
160                ..
161            }) => {}
162            MPacket::Publish(MPublish {
163                id: Some(id),
164                qos: qos @ MQualityOfService::AtLeastOnce,
165                ..
166            }) => {
167                trace!(?id, ?qos, "Acknowledging publish");
168
169                let packet = MPuback { id: *id };
170
171                crate::write_packet(&mut writer, packet).await?;
172
173                trace!(?id, "Acknowledged publish");
174            }
175            MPacket::Publish(MPublish {
176                id: Some(id),
177                qos: qos @ MQualityOfService::ExactlyOnce,
178                ..
179            }) => {
180                trace!(?id, ?qos, "Acknowledging publish");
181
182                let packet = MPubrec { id: *id };
183
184                crate::write_packet(&mut writer, packet).await?;
185
186                trace!(?id, "Acknowledged publish");
187            }
188            MPacket::Pubrel(MPubrel { id }) => {
189                trace!(?id, "Acknowledging pubrel");
190
191                let packet = MPubcomp { id: *id };
192
193                crate::write_packet(&mut writer, packet).await?;
194
195                trace!(?id, "Acknowledged publish");
196            }
197            _ => panic!("Tried to acknowledge a non-publish packet"),
198        };
199
200        Ok(())
201    }
202
203    pub fn build_packet_stream(&self) -> PacketStreamBuilder<'_, NoOPAck> {
204        PacketStreamBuilder::<NoOPAck>::new(self)
205    }
206
207    pub async fn subscribe(
208        &self,
209        subscription_requests: &[MSubscriptionRequest<'_>],
210    ) -> Result<(), MqttError> {
211        let mut mutex = match self.client_sender.try_lock() {
212            Ok(guard) => guard,
213            Err(_) => return Err(MqttError::AlreadyListening),
214        };
215
216        let stream = match mutex.as_mut() {
217            Some(cs) => cs,
218            None => return Err(MqttError::ConnectionClosed),
219        };
220
221        let mut requests = vec![];
222        for req in subscription_requests {
223            req.write_to(&mut Pin::new(&mut requests)).await?;
224        }
225
226        let packet = MSubscribe {
227            id: MPacketIdentifier(2),
228            subscriptions: MSubscriptionRequests {
229                count: subscription_requests.len(),
230                data: &requests,
231            },
232        };
233
234        crate::write_packet(stream, packet).await?;
235
236        Ok(())
237    }
238
239    /// Checks whether a session was present upon connecting
240    ///
241    /// Note: This only reflects the presence of the session on connection.
242    /// Later subscriptions or other commands that change the session do not
243    /// update this value.
244    pub fn session_present_at_connection(&self) -> bool {
245        self.session_present
246    }
247
248    pub(crate) fn received_packets(&self) -> &DashSet<u16> {
249        &self.received_packets
250    }
251
252    pub(crate) fn client_sender(&self) -> &Mutex<Option<WriteHalf<MqttStream>>> {
253        self.client_sender.as_ref()
254    }
255
256    pub(crate) fn client_receiver(&self) -> &Mutex<Option<ReadHalf<MqttStream>>> {
257        &self.client_receiver
258    }
259}
260
261pub struct MqttConnectionParams<'conn> {
262    pub clean_session: bool,
263    pub will: Option<MLastWill<'conn>>,
264    pub username: Option<MString<'conn>>,
265    pub password: Option<&'conn [u8]>,
266    pub keep_alive: u16,
267    pub client_id: MString<'conn>,
268}
269
270impl<'a> MqttConnectionParams<'a> {
271    fn to_packet(&self) -> MPacket<'a> {
272        MConnect {
273            protocol_name: MString { value: "MQTT" },
274            protocol_level: 4,
275            clean_session: self.clean_session,
276            will: self.will,
277            username: self.username,
278            password: self.password,
279            keep_alive: self.keep_alive,
280            client_id: self.client_id,
281        }
282        .into()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use static_assertions::assert_impl_all;
289
290    use crate::client::MqttClient;
291
292    assert_impl_all!(MqttClient: Send, Sync);
293}