1use std::{pin::Pin, sync::Arc, time::Duration};
8
9use dashmap::DashSet;
10use mqtt_format::v3::{
11 identifier::MPacketIdentifier,
12 packet::{
13 MConnack, MConnect, MPacket, MPingreq, MPuback, MPubcomp, MPublish, MPubrec, MPubrel,
14 MSubscribe,
15 },
16 qos::MQualityOfService,
17 strings::MString,
18 subscription_request::{MSubscriptionRequest, MSubscriptionRequests},
19 will::MLastWill,
20};
21use tokio::{
22 io::{DuplexStream, ReadHalf, WriteHalf},
23 net::{TcpStream, ToSocketAddrs},
24 sync::Mutex,
25};
26use tokio_util::sync::CancellationToken;
27use tracing::trace;
28
29use crate::packet_stream::{NoOPAck, PacketStreamBuilder};
30use crate::{error::MqttError, mqtt_stream::MqttStream};
31
32pub struct MqttClient {
33 session_present: bool,
34 client_receiver: Mutex<Option<ReadHalf<MqttStream>>>,
35 client_sender: Arc<Mutex<Option<WriteHalf<MqttStream>>>>,
36 received_packets: DashSet<u16>,
37 keep_alive_duration: u16,
38}
39
40impl std::fmt::Debug for MqttClient {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("MqttClient")
43 .field("session_present", &self.session_present)
44 .field("keep_alive_duration", &self.keep_alive_duration)
45 .finish_non_exhaustive()
46 }
47}
48
49impl MqttClient {
50 async fn do_v3_connect(
51 packet: MPacket<'_>,
52 stream: MqttStream,
53 keep_alive_duration: u16,
54 ) -> Result<MqttClient, MqttError> {
55 let (mut read_half, mut write_half) = tokio::io::split(stream);
56
57 crate::write_packet(&mut write_half, packet).await?;
58
59 let maybe_connect = crate::read_one_packet(&mut read_half).await?;
60
61 let session_present = match maybe_connect.get_packet() {
62 MPacket::Connack(MConnack {
63 session_present,
64 connect_return_code,
65 }) => match connect_return_code {
66 mqtt_format::v3::connect_return::MConnectReturnCode::Accepted => session_present,
67 code => return Err(MqttError::ConnectionRejected(*code)),
68 },
69 _ => return Err(MqttError::InvalidConnectionResponse),
70 };
71
72 Ok(MqttClient {
73 session_present: *session_present,
74 client_receiver: Mutex::new(Some(read_half)),
75 client_sender: Arc::new(Mutex::new(Some(write_half))),
76 keep_alive_duration,
77 received_packets: DashSet::new(),
78 })
79 }
80
81 pub async fn connect_v3_duplex(
82 duplex: DuplexStream,
83 connection_params: MqttConnectionParams<'_>,
84 ) -> Result<MqttClient, MqttError> {
85 tracing::debug!("Connecting via duplex");
86 let packet = connection_params.to_packet();
87
88 MqttClient::do_v3_connect(
89 packet,
90 MqttStream::MemoryDuplex(duplex),
91 connection_params.keep_alive,
92 )
93 .await
94 }
95
96 pub async fn connect_v3_unsecured_tcp<Addr: ToSocketAddrs>(
97 addr: Addr,
98 connection_params: MqttConnectionParams<'_>,
99 ) -> Result<MqttClient, MqttError> {
100 let stream = TcpStream::connect(addr).await?;
101
102 tracing::debug!("Connected via TCP to {}", stream.peer_addr()?);
103
104 let packet = connection_params.to_packet();
105
106 trace!(?packet, "Connecting");
107
108 MqttClient::do_v3_connect(
109 packet,
110 MqttStream::UnsecuredTcp(stream),
111 connection_params.keep_alive,
112 )
113 .await
114 }
115
116 pub fn heartbeat(
122 &self,
123 cancel_token: Option<CancellationToken>,
124 ) -> impl std::future::Future<Output = Result<(), MqttError>> {
125 let keep_alive_duration = self.keep_alive_duration;
126 let sender = self.client_sender.clone();
127 let cancel_token = cancel_token.unwrap_or_else(CancellationToken::new);
128 async move {
129 loop {
130 tokio::select! {
131 _ = tokio::time::sleep(Duration::from_secs(
132 ((keep_alive_duration as u64 * 100) / 80).max(2),
133 )) => {
134 let mut mutex = sender.lock().await;
135
136 let mut client_stream = match mutex.as_mut() {
137 Some(cs) => cs,
138 None => return Err(MqttError::ConnectionClosed),
139 };
140 trace!("Sending heartbeat");
141
142 let packet = MPingreq;
143
144 crate::write_packet(&mut client_stream, packet).await?;
145 },
146
147 _ = cancel_token.cancelled() => break Ok(()),
148 }
149 }
150 }
151 }
152
153 pub(crate) async fn acknowledge_packet<W: tokio::io::AsyncWrite + Unpin>(
154 mut writer: W,
155 packet: &MPacket<'_>,
156 ) -> Result<(), MqttError> {
157 match packet {
158 MPacket::Publish(MPublish {
159 qos: MQualityOfService::AtMostOnce,
160 ..
161 }) => {}
162 MPacket::Publish(MPublish {
163 id: Some(id),
164 qos: qos @ MQualityOfService::AtLeastOnce,
165 ..
166 }) => {
167 trace!(?id, ?qos, "Acknowledging publish");
168
169 let packet = MPuback { id: *id };
170
171 crate::write_packet(&mut writer, packet).await?;
172
173 trace!(?id, "Acknowledged publish");
174 }
175 MPacket::Publish(MPublish {
176 id: Some(id),
177 qos: qos @ MQualityOfService::ExactlyOnce,
178 ..
179 }) => {
180 trace!(?id, ?qos, "Acknowledging publish");
181
182 let packet = MPubrec { id: *id };
183
184 crate::write_packet(&mut writer, packet).await?;
185
186 trace!(?id, "Acknowledged publish");
187 }
188 MPacket::Pubrel(MPubrel { id }) => {
189 trace!(?id, "Acknowledging pubrel");
190
191 let packet = MPubcomp { id: *id };
192
193 crate::write_packet(&mut writer, packet).await?;
194
195 trace!(?id, "Acknowledged publish");
196 }
197 _ => panic!("Tried to acknowledge a non-publish packet"),
198 };
199
200 Ok(())
201 }
202
203 pub fn build_packet_stream(&self) -> PacketStreamBuilder<'_, NoOPAck> {
204 PacketStreamBuilder::<NoOPAck>::new(self)
205 }
206
207 pub async fn subscribe(
208 &self,
209 subscription_requests: &[MSubscriptionRequest<'_>],
210 ) -> Result<(), MqttError> {
211 let mut mutex = match self.client_sender.try_lock() {
212 Ok(guard) => guard,
213 Err(_) => return Err(MqttError::AlreadyListening),
214 };
215
216 let stream = match mutex.as_mut() {
217 Some(cs) => cs,
218 None => return Err(MqttError::ConnectionClosed),
219 };
220
221 let mut requests = vec![];
222 for req in subscription_requests {
223 req.write_to(&mut Pin::new(&mut requests)).await?;
224 }
225
226 let packet = MSubscribe {
227 id: MPacketIdentifier(2),
228 subscriptions: MSubscriptionRequests {
229 count: subscription_requests.len(),
230 data: &requests,
231 },
232 };
233
234 crate::write_packet(stream, packet).await?;
235
236 Ok(())
237 }
238
239 pub fn session_present_at_connection(&self) -> bool {
245 self.session_present
246 }
247
248 pub(crate) fn received_packets(&self) -> &DashSet<u16> {
249 &self.received_packets
250 }
251
252 pub(crate) fn client_sender(&self) -> &Mutex<Option<WriteHalf<MqttStream>>> {
253 self.client_sender.as_ref()
254 }
255
256 pub(crate) fn client_receiver(&self) -> &Mutex<Option<ReadHalf<MqttStream>>> {
257 &self.client_receiver
258 }
259}
260
261pub struct MqttConnectionParams<'conn> {
262 pub clean_session: bool,
263 pub will: Option<MLastWill<'conn>>,
264 pub username: Option<MString<'conn>>,
265 pub password: Option<&'conn [u8]>,
266 pub keep_alive: u16,
267 pub client_id: MString<'conn>,
268}
269
270impl<'a> MqttConnectionParams<'a> {
271 fn to_packet(&self) -> MPacket<'a> {
272 MConnect {
273 protocol_name: MString { value: "MQTT" },
274 protocol_level: 4,
275 clean_session: self.clean_session,
276 will: self.will,
277 username: self.username,
278 password: self.password,
279 keep_alive: self.keep_alive,
280 client_id: self.client_id,
281 }
282 .into()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use static_assertions::assert_impl_all;
289
290 use crate::client::MqttClient;
291
292 assert_impl_all!(MqttClient: Send, Sync);
293}