1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::{Builder, Result};
19
20pub struct Dispatcher<Io> {
24 pub(crate) io: Framed<Io, MqttCodec>,
26 pub remote_addr: SocketAddr,
28 pub cfg: Arc<Builder>,
30}
31
32impl<Io> Dispatcher<Io>
33where
34 Io: AsyncRead + AsyncWrite + Unpin,
35{
36 pub(crate) fn new(io: Io, remote_addr: SocketAddr, cfg: Arc<Builder>) -> Self {
38 Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg }
39 }
40
41 #[inline]
43 pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
44 Ok(match self.probe_version().await? {
45 ProtocolVersion::MQTT3 => {
46 MqttStream::V3(v3::MqttStream { io: self.io, remote_addr: self.remote_addr, cfg: self.cfg })
47 }
48 ProtocolVersion::MQTT5 => {
49 MqttStream::V5(v5::MqttStream { io: self.io, remote_addr: self.remote_addr, cfg: self.cfg })
50 }
51 })
52 }
53
54 #[inline]
56 async fn probe_version(&mut self) -> Result<ProtocolVersion> {
57 let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
58 return Err(anyhow!(DecodeError::InvalidProtocol));
59 };
60
61 let codec = match ver {
62 ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
63 ProtocolVersion::MQTT5 => {
64 MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
65 }
66 };
67
68 *self.io.codec_mut() = codec;
69 Ok(ver)
70 }
71}
72
73pub enum MqttStream<Io> {
75 V3(v3::MqttStream<Io>),
77 V5(v5::MqttStream<Io>),
79}
80
81pub mod v3 {
82
83 use std::net::SocketAddr;
84 use std::num::NonZeroU16;
85 use std::pin::Pin;
86 use std::sync::Arc;
87 use std::task::{Context, Poll};
88 use std::time::Duration;
89
90 use futures::StreamExt;
91 use tokio::io::{AsyncRead, AsyncWrite};
92 use tokio_util::codec::Framed;
93
94 use rmqtt_codec::error::DecodeError;
95 use rmqtt_codec::types::Publish;
96 use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
97 use rmqtt_codec::{MqttCodec, MqttPacket};
98
99 use crate::error::MqttError;
100 use crate::{Builder, Error, Result};
101
102 pub struct MqttStream<Io> {
104 pub io: Framed<Io, MqttCodec>,
106 pub remote_addr: SocketAddr,
108 pub cfg: Arc<Builder>,
110 }
111
112 impl<Io> MqttStream<Io>
136 where
137 Io: AsyncRead + AsyncWrite + Unpin,
138 {
139 #[inline]
141 pub async fn send_disconnect(&mut self) -> Result<()> {
142 self.send(PacketV3::Disconnect).await?;
143 self.flush().await
144 }
145
146 #[inline]
148 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
149 self.send(PacketV3::Publish(publish)).await
150 }
151
152 #[inline]
154 pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
155 self.send(PacketV3::PublishAck { packet_id }).await
156 }
157
158 #[inline]
160 pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
161 self.send(PacketV3::PublishReceived { packet_id }).await
162 }
163
164 #[inline]
166 pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
167 self.send(PacketV3::PublishRelease { packet_id }).await
168 }
169
170 #[inline]
172 pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
173 self.send(PacketV3::PublishComplete { packet_id }).await
174 }
175
176 #[inline]
178 pub async fn send_subscribe_ack(
179 &mut self,
180 packet_id: NonZeroU16,
181 status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
182 ) -> Result<()> {
183 self.send(PacketV3::SubscribeAck { packet_id, status }).await
184 }
185
186 #[inline]
188 pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
189 self.send(PacketV3::UnsubscribeAck { packet_id }).await
190 }
191
192 #[inline]
194 pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
195 self.send(PacketV3::Connect(Box::new(connect))).await
196 }
197
198 #[inline]
200 pub async fn send_connect_ack(
201 &mut self,
202 return_code: ConnectAckReason,
203 session_present: bool,
204 ) -> Result<()> {
205 self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
206 .await
207 }
208
209 #[inline]
211 pub async fn send_ping_request(&mut self) -> Result<()> {
212 self.send(PacketV3::PingRequest {}).await
213 }
214
215 #[inline]
217 pub async fn send_ping_response(&mut self) -> Result<()> {
218 self.send(PacketV3::PingResponse {}).await
219 }
220
221 #[inline]
223 pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
224 super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
225 }
226
227 #[inline]
229 pub async fn flush(&mut self) -> Result<()> {
230 super::flush(&mut self.io, self.cfg.send_timeout).await
231 }
232
233 #[inline]
235 pub async fn close(&mut self) -> Result<()> {
236 super::close(&mut self.io, self.cfg.send_timeout).await
237 }
238
239 #[inline]
241 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
242 match tokio::time::timeout(tm, self.next()).await {
243 Ok(Some(Ok(msg))) => Ok(Some(msg)),
244 Ok(Some(Err(e))) => Err(e),
245 Ok(None) => Ok(None),
246 Err(_) => Err(MqttError::ReadTimeout.into()),
247 }
248 }
249
250 #[inline]
252 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
253 let connect = match self.recv(tm).await {
254 Ok(Some(Packet::Connect(connect))) => connect,
255 Err(e) => {
256 return Err(e);
257 }
258 _ => {
259 return Err(MqttError::InvalidProtocol.into());
260 }
261 };
262 Ok(connect)
263 }
264 }
265
266 impl<Io> futures::Stream for MqttStream<Io>
267 where
268 Io: AsyncRead + Unpin,
269 {
270 type Item = Result<rmqtt_codec::v3::Packet>;
271
272 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
273 let next = Pin::new(&mut self.io).poll_next(cx);
274 Poll::Ready(match futures::ready!(next) {
275 Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
276 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
277 Some(Err(e)) => Some(Err(Error::from(e))),
278 None => None,
279 })
280 }
281 }
282}
283
284pub mod v5 {
285 use std::net::SocketAddr;
286 use std::pin::Pin;
287 use std::sync::Arc;
288 use std::task::{Context, Poll};
289 use std::time::Duration;
290
291 use futures::StreamExt;
292 use tokio::io::{AsyncRead, AsyncWrite};
293 use tokio_util::codec::Framed;
294
295 use rmqtt_codec::error::DecodeError;
296 use rmqtt_codec::types::Publish;
297 use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
298 use rmqtt_codec::{MqttCodec, MqttPacket};
299
300 use crate::error::MqttError;
301 use crate::{Builder, Error, Result};
302
303 pub struct MqttStream<Io> {
305 pub io: Framed<Io, MqttCodec>,
307 pub remote_addr: SocketAddr,
309 pub cfg: Arc<Builder>,
311 }
312
313 impl<Io> MqttStream<Io>
338 where
339 Io: AsyncRead + AsyncWrite + Unpin,
340 {
341 #[inline]
343 pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
344 self.send(PacketV5::Disconnect(disc)).await?;
345 self.flush().await?;
346 tokio::time::sleep(Duration::from_millis(500)).await;
347 Ok(())
348 }
349
350 #[inline]
352 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
353 self.send(PacketV5::Publish(publish)).await
354 }
355
356 #[inline]
358 pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
359 self.send(PacketV5::PublishAck(ack)).await
360 }
361
362 #[inline]
364 pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
365 self.send(PacketV5::PublishReceived(ack)).await
366 }
367
368 #[inline]
370 pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
371 self.send(PacketV5::PublishRelease(ack2)).await
372 }
373
374 #[inline]
376 pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
377 self.send(PacketV5::PublishComplete(ack2)).await
378 }
379
380 #[inline]
382 pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
383 self.send(PacketV5::SubscribeAck(ack)).await
384 }
385
386 #[inline]
388 pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
389 self.send(PacketV5::UnsubscribeAck(unack)).await
390 }
391
392 #[inline]
394 pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
395 self.send(PacketV5::Connect(Box::new(connect))).await
396 }
397
398 #[inline]
400 pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
401 self.send(PacketV5::ConnectAck(Box::new(ack))).await
402 }
403
404 #[inline]
406 pub async fn send_ping_request(&mut self) -> Result<()> {
407 self.send(PacketV5::PingRequest {}).await
408 }
409
410 #[inline]
412 pub async fn send_ping_response(&mut self) -> Result<()> {
413 self.send(PacketV5::PingResponse {}).await
414 }
415
416 #[inline]
418 pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
419 self.send(PacketV5::Auth(auth)).await
420 }
421
422 #[inline]
424 pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
425 super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
426 }
427
428 #[inline]
430 pub async fn flush(&mut self) -> Result<()> {
431 super::flush(&mut self.io, self.cfg.send_timeout).await
432 }
433
434 #[inline]
436 pub async fn close(&mut self) -> Result<()> {
437 super::close(&mut self.io, self.cfg.send_timeout).await
438 }
439
440 #[inline]
442 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
443 match tokio::time::timeout(tm, self.next()).await {
444 Ok(Some(Ok(msg))) => Ok(Some(msg)),
445 Ok(Some(Err(e))) => Err(e),
446 Ok(None) => Ok(None),
447 Err(_) => Err(MqttError::ReadTimeout.into()),
448 }
449 }
450
451 #[inline]
453 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
454 let connect = match self.recv(tm).await {
455 Ok(Some(Packet::Connect(connect))) => connect,
456 Err(e) => {
457 return Err(e);
458 }
459 _ => {
460 return Err(MqttError::InvalidProtocol.into());
461 }
462 };
463 Ok(connect)
464 }
465 }
466
467 impl<Io> futures::Stream for MqttStream<Io>
468 where
469 Io: AsyncRead + Unpin,
470 {
471 type Item = Result<rmqtt_codec::v5::Packet>;
472
473 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
474 let next = Pin::new(&mut self.io).poll_next(cx);
475 Poll::Ready(match futures::ready!(next) {
476 Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
477 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
478 Some(Err(e)) => Some(Err(Error::from(e))),
479 None => None,
480 })
481 }
482 }
483}
484
485#[inline]
486async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
487where
488 Io: AsyncWrite + Unpin,
489{
490 if send_timeout.is_zero() {
491 io.send(packet).await?;
492 Ok(())
493 } else {
494 match tokio::time::timeout(send_timeout, io.send(packet)).await {
495 Ok(Ok(())) => Ok(()),
496 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
497 Err(_) => Err(MqttError::WriteTimeout),
498 }?;
499 Ok(())
500 }
501}
502
503#[inline]
504async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
505where
506 Io: AsyncWrite + Unpin,
507{
508 if send_timeout.is_zero() {
509 io.flush().await?;
510 Ok(())
511 } else {
512 match tokio::time::timeout(send_timeout, io.flush()).await {
513 Ok(Ok(())) => Ok(()),
514 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
515 Err(_) => Err(MqttError::FlushTimeout),
516 }?;
517 Ok(())
518 }
519}
520
521#[inline]
522async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
523where
524 Io: AsyncWrite + Unpin,
525{
526 if send_timeout.is_zero() {
527 io.close().await?;
528 Ok(())
529 } else {
530 match tokio::time::timeout(send_timeout, io.close()).await {
531 Ok(Ok(())) => Ok(()),
532 Ok(Err(e)) => Err(MqttError::Encode(e)),
533 Err(_) => Err(MqttError::CloseTimeout),
534 }?;
535 Ok(())
536 }
537}