1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::{Builder, Result};
19
20pub struct Dispatcher<Io> {
24    pub(crate) io: Framed<Io, MqttCodec>,
26    pub remote_addr: SocketAddr,
28    pub cfg: Arc<Builder>,
30}
31
32impl<Io> Dispatcher<Io>
33where
34    Io: AsyncRead + AsyncWrite + Unpin,
35{
36    pub(crate) fn new(io: Io, remote_addr: SocketAddr, cfg: Arc<Builder>) -> Self {
38        Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg }
39    }
40
41    #[inline]
43    pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
44        Ok(match self.probe_version().await? {
45            ProtocolVersion::MQTT3 => {
46                MqttStream::V3(v3::MqttStream { io: self.io, remote_addr: self.remote_addr, cfg: self.cfg })
47            }
48            ProtocolVersion::MQTT5 => {
49                MqttStream::V5(v5::MqttStream { io: self.io, remote_addr: self.remote_addr, cfg: self.cfg })
50            }
51        })
52    }
53
54    #[inline]
56    async fn probe_version(&mut self) -> Result<ProtocolVersion> {
57        let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
58            return Err(anyhow!(DecodeError::InvalidProtocol));
59        };
60
61        let codec = match ver {
62            ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
63            ProtocolVersion::MQTT5 => {
64                MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
65            }
66        };
67
68        *self.io.codec_mut() = codec;
69        Ok(ver)
70    }
71}
72
73pub enum MqttStream<Io> {
75    V3(v3::MqttStream<Io>),
77    V5(v5::MqttStream<Io>),
79}
80
81pub mod v3 {
82
83    use std::net::SocketAddr;
84    use std::num::NonZeroU16;
85    use std::pin::Pin;
86    use std::sync::Arc;
87    use std::task::{Context, Poll};
88    use std::time::Duration;
89
90    use futures::StreamExt;
91    use tokio::io::{AsyncRead, AsyncWrite};
92    use tokio_util::codec::Framed;
93
94    use rmqtt_codec::error::DecodeError;
95    use rmqtt_codec::types::Publish;
96    use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
97    use rmqtt_codec::{MqttCodec, MqttPacket};
98
99    use crate::error::MqttError;
100    use crate::{Builder, Error, Result};
101
102    pub struct MqttStream<Io> {
104        pub io: Framed<Io, MqttCodec>,
106        pub remote_addr: SocketAddr,
108        pub cfg: Arc<Builder>,
110    }
111
112    impl<Io> MqttStream<Io>
136    where
137        Io: AsyncRead + AsyncWrite + Unpin,
138    {
139        #[inline]
141        pub async fn send_disconnect(&mut self) -> Result<()> {
142            self.send(PacketV3::Disconnect).await?;
143            self.flush().await
144        }
145
146        #[inline]
148        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
149            self.send(PacketV3::Publish(publish)).await
150        }
151
152        #[inline]
154        pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
155            self.send(PacketV3::PublishAck { packet_id }).await
156        }
157
158        #[inline]
160        pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
161            self.send(PacketV3::PublishReceived { packet_id }).await
162        }
163
164        #[inline]
166        pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
167            self.send(PacketV3::PublishRelease { packet_id }).await
168        }
169
170        #[inline]
172        pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
173            self.send(PacketV3::PublishComplete { packet_id }).await
174        }
175
176        #[inline]
178        pub async fn send_subscribe_ack(
179            &mut self,
180            packet_id: NonZeroU16,
181            status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
182        ) -> Result<()> {
183            self.send(PacketV3::SubscribeAck { packet_id, status }).await
184        }
185
186        #[inline]
188        pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
189            self.send(PacketV3::UnsubscribeAck { packet_id }).await
190        }
191
192        #[inline]
194        pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
195            self.send(PacketV3::Connect(Box::new(connect))).await
196        }
197
198        #[inline]
200        pub async fn send_connect_ack(
201            &mut self,
202            return_code: ConnectAckReason,
203            session_present: bool,
204        ) -> Result<()> {
205            self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
206                .await
207        }
208
209        #[inline]
211        pub async fn send_ping_request(&mut self) -> Result<()> {
212            self.send(PacketV3::PingRequest {}).await
213        }
214
215        #[inline]
217        pub async fn send_ping_response(&mut self) -> Result<()> {
218            self.send(PacketV3::PingResponse {}).await
219        }
220
221        #[inline]
223        pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
224            super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
225        }
226
227        #[inline]
229        pub async fn flush(&mut self) -> Result<()> {
230            super::flush(&mut self.io, self.cfg.send_timeout).await
231        }
232
233        #[inline]
235        pub async fn close(&mut self) -> Result<()> {
236            super::close(&mut self.io, self.cfg.send_timeout).await
237        }
238
239        #[inline]
241        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
242            match tokio::time::timeout(tm, self.next()).await {
243                Ok(Some(Ok(msg))) => Ok(Some(msg)),
244                Ok(Some(Err(e))) => Err(e),
245                Ok(None) => Ok(None),
246                Err(_) => Err(MqttError::ReadTimeout.into()),
247            }
248        }
249
250        #[inline]
252        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
253            let connect = match self.recv(tm).await {
254                Ok(Some(Packet::Connect(connect))) => connect,
255                Err(e) => {
256                    return Err(e);
257                }
258                _ => {
259                    return Err(MqttError::InvalidProtocol.into());
260                }
261            };
262            Ok(connect)
263        }
264    }
265
266    impl<Io> futures::Stream for MqttStream<Io>
267    where
268        Io: AsyncRead + Unpin,
269    {
270        type Item = Result<rmqtt_codec::v3::Packet>;
271
272        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
273            let next = Pin::new(&mut self.io).poll_next(cx);
274            Poll::Ready(match futures::ready!(next) {
275                Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
276                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
277                Some(Err(e)) => Some(Err(Error::from(e))),
278                None => None,
279            })
280        }
281    }
282}
283
284pub mod v5 {
285    use std::net::SocketAddr;
286    use std::pin::Pin;
287    use std::sync::Arc;
288    use std::task::{Context, Poll};
289    use std::time::Duration;
290
291    use futures::StreamExt;
292    use tokio::io::{AsyncRead, AsyncWrite};
293    use tokio_util::codec::Framed;
294
295    use rmqtt_codec::error::DecodeError;
296    use rmqtt_codec::types::Publish;
297    use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
298    use rmqtt_codec::{MqttCodec, MqttPacket};
299
300    use crate::error::MqttError;
301    use crate::{Builder, Error, Result};
302
303    pub struct MqttStream<Io> {
305        pub io: Framed<Io, MqttCodec>,
307        pub remote_addr: SocketAddr,
309        pub cfg: Arc<Builder>,
311    }
312
313    impl<Io> MqttStream<Io>
338    where
339        Io: AsyncRead + AsyncWrite + Unpin,
340    {
341        #[inline]
343        pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
344            self.send(PacketV5::Disconnect(disc)).await?;
345            self.flush().await?;
346            tokio::time::sleep(Duration::from_millis(500)).await;
347            Ok(())
348        }
349
350        #[inline]
352        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
353            self.send(PacketV5::Publish(publish)).await
354        }
355
356        #[inline]
358        pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
359            self.send(PacketV5::PublishAck(ack)).await
360        }
361
362        #[inline]
364        pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
365            self.send(PacketV5::PublishReceived(ack)).await
366        }
367
368        #[inline]
370        pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
371            self.send(PacketV5::PublishRelease(ack2)).await
372        }
373
374        #[inline]
376        pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
377            self.send(PacketV5::PublishComplete(ack2)).await
378        }
379
380        #[inline]
382        pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
383            self.send(PacketV5::SubscribeAck(ack)).await
384        }
385
386        #[inline]
388        pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
389            self.send(PacketV5::UnsubscribeAck(unack)).await
390        }
391
392        #[inline]
394        pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
395            self.send(PacketV5::Connect(Box::new(connect))).await
396        }
397
398        #[inline]
400        pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
401            self.send(PacketV5::ConnectAck(Box::new(ack))).await
402        }
403
404        #[inline]
406        pub async fn send_ping_request(&mut self) -> Result<()> {
407            self.send(PacketV5::PingRequest {}).await
408        }
409
410        #[inline]
412        pub async fn send_ping_response(&mut self) -> Result<()> {
413            self.send(PacketV5::PingResponse {}).await
414        }
415
416        #[inline]
418        pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
419            self.send(PacketV5::Auth(auth)).await
420        }
421
422        #[inline]
424        pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
425            super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
426        }
427
428        #[inline]
430        pub async fn flush(&mut self) -> Result<()> {
431            super::flush(&mut self.io, self.cfg.send_timeout).await
432        }
433
434        #[inline]
436        pub async fn close(&mut self) -> Result<()> {
437            super::close(&mut self.io, self.cfg.send_timeout).await
438        }
439
440        #[inline]
442        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
443            match tokio::time::timeout(tm, self.next()).await {
444                Ok(Some(Ok(msg))) => Ok(Some(msg)),
445                Ok(Some(Err(e))) => Err(e),
446                Ok(None) => Ok(None),
447                Err(_) => Err(MqttError::ReadTimeout.into()),
448            }
449        }
450
451        #[inline]
453        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
454            let connect = match self.recv(tm).await {
455                Ok(Some(Packet::Connect(connect))) => connect,
456                Err(e) => {
457                    return Err(e);
458                }
459                _ => {
460                    return Err(MqttError::InvalidProtocol.into());
461                }
462            };
463            Ok(connect)
464        }
465    }
466
467    impl<Io> futures::Stream for MqttStream<Io>
468    where
469        Io: AsyncRead + Unpin,
470    {
471        type Item = Result<rmqtt_codec::v5::Packet>;
472
473        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
474            let next = Pin::new(&mut self.io).poll_next(cx);
475            Poll::Ready(match futures::ready!(next) {
476                Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
477                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
478                Some(Err(e)) => Some(Err(Error::from(e))),
479                None => None,
480            })
481        }
482    }
483}
484
485#[inline]
486async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
487where
488    Io: AsyncWrite + Unpin,
489{
490    if send_timeout.is_zero() {
491        io.send(packet).await?;
492        Ok(())
493    } else {
494        match tokio::time::timeout(send_timeout, io.send(packet)).await {
495            Ok(Ok(())) => Ok(()),
496            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
497            Err(_) => Err(MqttError::WriteTimeout),
498        }?;
499        Ok(())
500    }
501}
502
503#[inline]
504async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
505where
506    Io: AsyncWrite + Unpin,
507{
508    if send_timeout.is_zero() {
509        io.flush().await?;
510        Ok(())
511    } else {
512        match tokio::time::timeout(send_timeout, io.flush()).await {
513            Ok(Ok(())) => Ok(()),
514            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
515            Err(_) => Err(MqttError::FlushTimeout),
516        }?;
517        Ok(())
518    }
519}
520
521#[inline]
522async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
523where
524    Io: AsyncWrite + Unpin,
525{
526    if send_timeout.is_zero() {
527        io.close().await?;
528        Ok(())
529    } else {
530        match tokio::time::timeout(send_timeout, io.close()).await {
531            Ok(Ok(())) => Ok(()),
532            Ok(Err(e)) => Err(MqttError::Encode(e)),
533            Err(_) => Err(MqttError::CloseTimeout),
534        }?;
535        Ok(())
536    }
537}