rmqtt_net/
stream.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::CertInfo;
19use crate::{Builder, Result};
20
21/// MQTT protocol dispatcher handling version negotiation
22///
23/// Manages initial protocol detection and creates version-specific streams
24pub struct Dispatcher<Io> {
25    /// Framed IO layer with MQTT codec
26    pub(crate) io: Framed<Io, MqttCodec>,
27    /// Remote client's network address
28    pub remote_addr: SocketAddr,
29    /// Shared configuration builder
30    pub cfg: Arc<Builder>,
31
32    pub cert_info: Option<CertInfo>,
33}
34
35impl<Io> Dispatcher<Io>
36where
37    Io: AsyncRead + AsyncWrite + Unpin,
38{
39    /// Creates a new Dispatcher instance
40    pub(crate) fn new(
41        io: Io,
42        remote_addr: SocketAddr,
43        cert_info: Option<CertInfo>,
44        cfg: Arc<Builder>,
45    ) -> Self {
46        Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg, cert_info }
47    }
48
49    /// Negotiates protocol version and returns appropriate stream
50    #[inline]
51    pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
52        Ok(match self.probe_version().await? {
53            ProtocolVersion::MQTT3 => MqttStream::V3(v3::MqttStream {
54                io: self.io,
55                remote_addr: self.remote_addr,
56                cfg: self.cfg,
57                #[cfg(feature = "tls")]
58                cert_info: self.cert_info,
59            }),
60            ProtocolVersion::MQTT5 => MqttStream::V5(v5::MqttStream {
61                io: self.io,
62                remote_addr: self.remote_addr,
63                cfg: self.cfg,
64                #[cfg(feature = "tls")]
65                cert_info: self.cert_info,
66            }),
67        })
68    }
69
70    /// Detects protocol version from initial handshake
71    #[inline]
72    async fn probe_version(&mut self) -> Result<ProtocolVersion> {
73        let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
74            return Err(anyhow!(DecodeError::InvalidProtocol));
75        };
76
77        let codec = match ver {
78            ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
79            ProtocolVersion::MQTT5 => {
80                MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
81            }
82        };
83
84        *self.io.codec_mut() = codec;
85        Ok(ver)
86    }
87}
88
89/// Version-specific MQTT protocol streams
90pub enum MqttStream<Io> {
91    /// MQTT v3.1.1 implementation
92    V3(v3::MqttStream<Io>),
93    /// MQTT v5.0 implementation
94    V5(v5::MqttStream<Io>),
95}
96
97pub mod v3 {
98
99    use std::net::SocketAddr;
100    use std::num::NonZeroU16;
101    use std::pin::Pin;
102    use std::sync::Arc;
103    use std::task::{Context, Poll};
104    use std::time::Duration;
105
106    use futures::StreamExt;
107    use tokio::io::{AsyncRead, AsyncWrite};
108    use tokio_util::codec::Framed;
109
110    use rmqtt_codec::error::DecodeError;
111    use rmqtt_codec::types::Publish;
112    use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
113    use rmqtt_codec::{MqttCodec, MqttPacket};
114
115    use crate::error::MqttError;
116    use crate::{Builder, Error, Result};
117
118    #[cfg(feature = "tls")]
119    use crate::CertInfo;
120
121    /// MQTT v3.1.1 protocol stream implementation
122    pub struct MqttStream<Io> {
123        /// Framed IO layer with MQTT codec
124        pub io: Framed<Io, MqttCodec>,
125        /// Remote client's network address
126        pub remote_addr: SocketAddr,
127        /// Shared configuration builder
128        pub cfg: Arc<Builder>,
129        #[cfg(feature = "tls")]
130        /// TLS certificate information (if available)
131        pub cert_info: Option<CertInfo>,
132    }
133
134    /// # Examples
135    /// ```
136    /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
137    /// use std::sync::Arc;
138    /// use tokio::net::TcpStream;
139    /// use tokio_util::codec::Framed;
140    /// use rmqtt_codec::{MqttCodec, types::Publish};
141    /// use rmqtt_net::{Builder,v3};
142    ///
143    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
144    /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1883);
145    /// let stream = TcpStream::connect(addr).await?;
146    /// let mut mqtt_stream = v3::MqttStream {
147    ///     io: Framed::new(stream, MqttCodec::V3(Default::default())),
148    ///     remote_addr: addr,
149    ///     cfg: Arc::new(Builder::default()),
150    ///     cert_info: None,
151    /// };
152    ///
153    /// // Send a PING request
154    /// mqtt_stream.send_ping_request().await?;
155    /// # Ok(())
156    /// # }
157    /// ```
158    impl<Io> MqttStream<Io>
159    where
160        Io: AsyncRead + AsyncWrite + Unpin,
161    {
162        /// Sends DISCONNECT packet and flushes buffers
163        #[inline]
164        pub async fn send_disconnect(&mut self) -> Result<()> {
165            self.send(PacketV3::Disconnect).await?;
166            self.flush().await
167        }
168
169        /// Publishes a message to the broker
170        #[inline]
171        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
172            self.send(PacketV3::Publish(publish)).await
173        }
174
175        /// Acknowledges a received publish (QoS 1)
176        #[inline]
177        pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
178            self.send(PacketV3::PublishAck { packet_id }).await
179        }
180
181        /// Confirms receipt of a publish (QoS 2 step 1)
182        #[inline]
183        pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
184            self.send(PacketV3::PublishReceived { packet_id }).await
185        }
186
187        /// Releases a stored publish (QoS 2 step 2)
188        #[inline]
189        pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
190            self.send(PacketV3::PublishRelease { packet_id }).await
191        }
192
193        /// Confirms publish completion (QoS 2 step 3)
194        #[inline]
195        pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
196            self.send(PacketV3::PublishComplete { packet_id }).await
197        }
198
199        /// Acknowledges a subscription request
200        #[inline]
201        pub async fn send_subscribe_ack(
202            &mut self,
203            packet_id: NonZeroU16,
204            status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
205        ) -> Result<()> {
206            self.send(PacketV3::SubscribeAck { packet_id, status }).await
207        }
208
209        /// Acknowledges an unsubscribe request
210        #[inline]
211        pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
212            self.send(PacketV3::UnsubscribeAck { packet_id }).await
213        }
214
215        /// Initiates connection to the broker
216        #[inline]
217        pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
218            self.send(PacketV3::Connect(Box::new(connect))).await
219        }
220
221        /// Responds to connection request
222        #[inline]
223        pub async fn send_connect_ack(
224            &mut self,
225            return_code: ConnectAckReason,
226            session_present: bool,
227        ) -> Result<()> {
228            self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
229                .await
230        }
231
232        /// Sends keep-alive ping request
233        #[inline]
234        pub async fn send_ping_request(&mut self) -> Result<()> {
235            self.send(PacketV3::PingRequest {}).await
236        }
237
238        /// Responds to ping request
239        #[inline]
240        pub async fn send_ping_response(&mut self) -> Result<()> {
241            self.send(PacketV3::PingResponse {}).await
242        }
243
244        /// Generic packet sending method
245        #[inline]
246        pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
247            super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
248        }
249
250        /// Flushes write buffers
251        #[inline]
252        pub async fn flush(&mut self) -> Result<()> {
253            super::flush(&mut self.io, self.cfg.send_timeout).await
254        }
255
256        /// Closes the connection gracefully
257        #[inline]
258        pub async fn close(&mut self) -> Result<()> {
259            super::close(&mut self.io, self.cfg.send_timeout).await
260        }
261
262        /// Receives next packet with timeout
263        #[inline]
264        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
265            match tokio::time::timeout(tm, self.next()).await {
266                Ok(Some(Ok(msg))) => Ok(Some(msg)),
267                Ok(Some(Err(e))) => Err(e),
268                Ok(None) => Ok(None),
269                Err(_) => Err(MqttError::ReadTimeout.into()),
270            }
271        }
272
273        /// Waits for CONNECT packet with timeout
274        #[inline]
275        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
276            let connect = match self.recv(tm).await {
277                Ok(Some(Packet::Connect(mut connect))) => {
278                    #[cfg(feature = "tls")]
279                    {
280                        if self.cfg.cert_cn_as_username {
281                            if let Some(cert) = &self.cert_info {
282                                if let Some(cn) = &cert.common_name {
283                                    connect.username = Some(cn.clone().into());
284                                }
285                            }
286                        }
287                    }
288                    connect
289                }
290                Err(e) => {
291                    return Err(e);
292                }
293                _ => {
294                    return Err(MqttError::InvalidProtocol.into());
295                }
296            };
297            Ok(connect)
298        }
299    }
300
301    impl<Io> futures::Stream for MqttStream<Io>
302    where
303        Io: AsyncRead + Unpin,
304    {
305        type Item = Result<rmqtt_codec::v3::Packet>;
306
307        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
308            let next = Pin::new(&mut self.io).poll_next(cx);
309            Poll::Ready(match futures::ready!(next) {
310                Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
311                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
312                Some(Err(e)) => Some(Err(Error::from(e))),
313                None => None,
314            })
315        }
316    }
317}
318
319pub mod v5 {
320    use std::net::SocketAddr;
321    use std::pin::Pin;
322    use std::sync::Arc;
323    use std::task::{Context, Poll};
324    use std::time::Duration;
325
326    use futures::StreamExt;
327    use tokio::io::{AsyncRead, AsyncWrite};
328    use tokio_util::codec::Framed;
329
330    use rmqtt_codec::error::DecodeError;
331    use rmqtt_codec::types::Publish;
332    use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
333    use rmqtt_codec::{MqttCodec, MqttPacket};
334
335    use crate::error::MqttError;
336    use crate::{Builder, CertInfo, Error, Result};
337
338    /// MQTT v5.0 protocol stream implementation
339    pub struct MqttStream<Io> {
340        /// Framed IO layer with MQTT codec
341        pub io: Framed<Io, MqttCodec>,
342        /// Remote client's network address
343        pub remote_addr: SocketAddr,
344        /// Shared configuration builder
345        pub cfg: Arc<Builder>,
346        #[cfg(feature = "tls")]
347        /// TLS certificate information (if available)
348        pub cert_info: Option<CertInfo>,
349    }
350
351    /// # Examples
352    /// ```
353    /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
354    /// use std::sync::Arc;
355    /// use tokio::net::TcpStream;
356    /// use tokio_util::codec::Framed;
357    /// use rmqtt_codec::{MqttCodec, types::Publish};
358    /// use rmqtt_net::{Builder,v5};
359    /// use rmqtt_codec::v5::Connect;
360    ///
361    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
362    /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1883);
363    /// let stream = TcpStream::connect(addr).await?;
364    /// let mut mqtt_stream = v5::MqttStream {
365    ///     io: Framed::new(stream, MqttCodec::V5(Default::default())),
366    ///     remote_addr: addr,
367    ///     cfg: Arc::new(Builder::default()),
368    ///     cert_info: None
369    /// };
370    ///
371    /// // Send authentication packet
372    /// mqtt_stream.send_auth(rmqtt_codec::v5::Auth::default()).await?;
373    /// # Ok(())
374    /// # }
375    /// ```
376    impl<Io> MqttStream<Io>
377    where
378        Io: AsyncRead + AsyncWrite + Unpin,
379    {
380        /// Sends DISCONNECT packet with reason code
381        #[inline]
382        pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
383            self.send(PacketV5::Disconnect(disc)).await?;
384            self.flush().await?;
385            tokio::time::sleep(Duration::from_millis(500)).await;
386            Ok(())
387        }
388
389        /// Publishes a message to the broker
390        #[inline]
391        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
392            self.send(PacketV5::Publish(publish)).await
393        }
394
395        /// Acknowledges a received publish (QoS 1)
396        #[inline]
397        pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
398            self.send(PacketV5::PublishAck(ack)).await
399        }
400
401        /// Confirms receipt of a publish (QoS 2 step 1)
402        #[inline]
403        pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
404            self.send(PacketV5::PublishReceived(ack)).await
405        }
406
407        /// Releases a stored publish (QoS 2 step 2)
408        #[inline]
409        pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
410            self.send(PacketV5::PublishRelease(ack2)).await
411        }
412
413        /// Confirms publish completion (QoS 2 step 3)
414        #[inline]
415        pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
416            self.send(PacketV5::PublishComplete(ack2)).await
417        }
418
419        /// Acknowledges a subscription request
420        #[inline]
421        pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
422            self.send(PacketV5::SubscribeAck(ack)).await
423        }
424
425        /// Acknowledges an unsubscribe request
426        #[inline]
427        pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
428            self.send(PacketV5::UnsubscribeAck(unack)).await
429        }
430
431        /// Initiates connection to the broker
432        #[inline]
433        pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
434            self.send(PacketV5::Connect(Box::new(connect))).await
435        }
436
437        /// Responds to connection request
438        #[inline]
439        pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
440            self.send(PacketV5::ConnectAck(Box::new(ack))).await
441        }
442
443        /// Sends keep-alive ping request
444        #[inline]
445        pub async fn send_ping_request(&mut self) -> Result<()> {
446            self.send(PacketV5::PingRequest {}).await
447        }
448
449        /// Responds to ping request
450        #[inline]
451        pub async fn send_ping_response(&mut self) -> Result<()> {
452            self.send(PacketV5::PingResponse {}).await
453        }
454
455        /// Sends authentication exchange packet
456        #[inline]
457        pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
458            self.send(PacketV5::Auth(auth)).await
459        }
460
461        /// Generic packet sending method
462        #[inline]
463        pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
464            super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
465        }
466
467        /// Flushes write buffers
468        #[inline]
469        pub async fn flush(&mut self) -> Result<()> {
470            super::flush(&mut self.io, self.cfg.send_timeout).await
471        }
472
473        /// Closes the connection gracefully
474        #[inline]
475        pub async fn close(&mut self) -> Result<()> {
476            super::close(&mut self.io, self.cfg.send_timeout).await
477        }
478
479        /// Receives next packet with timeout
480        #[inline]
481        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
482            match tokio::time::timeout(tm, self.next()).await {
483                Ok(Some(Ok(msg))) => Ok(Some(msg)),
484                Ok(Some(Err(e))) => Err(e),
485                Ok(None) => Ok(None),
486                Err(_) => Err(MqttError::ReadTimeout.into()),
487            }
488        }
489
490        /// Waits for CONNECT packet with timeout
491        #[inline]
492        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
493            let connect = match self.recv(tm).await {
494                Ok(Some(Packet::Connect(connect))) => connect,
495                Err(e) => {
496                    return Err(e);
497                }
498                _ => {
499                    return Err(MqttError::InvalidProtocol.into());
500                }
501            };
502            Ok(connect)
503        }
504    }
505
506    impl<Io> futures::Stream for MqttStream<Io>
507    where
508        Io: AsyncRead + Unpin,
509    {
510        type Item = Result<rmqtt_codec::v5::Packet>;
511
512        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
513            let next = Pin::new(&mut self.io).poll_next(cx);
514            Poll::Ready(match futures::ready!(next) {
515                Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
516                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
517                Some(Err(e)) => Some(Err(Error::from(e))),
518                None => None,
519            })
520        }
521    }
522}
523
524#[inline]
525async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
526where
527    Io: AsyncWrite + Unpin,
528{
529    if send_timeout.is_zero() {
530        io.send(packet).await?;
531        Ok(())
532    } else {
533        match tokio::time::timeout(send_timeout, io.send(packet)).await {
534            Ok(Ok(())) => Ok(()),
535            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
536            Err(_) => Err(MqttError::WriteTimeout),
537        }?;
538        Ok(())
539    }
540}
541
542#[inline]
543async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
544where
545    Io: AsyncWrite + Unpin,
546{
547    if send_timeout.is_zero() {
548        io.flush().await?;
549        Ok(())
550    } else {
551        match tokio::time::timeout(send_timeout, io.flush()).await {
552            Ok(Ok(())) => Ok(()),
553            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
554            Err(_) => Err(MqttError::FlushTimeout),
555        }?;
556        Ok(())
557    }
558}
559
560#[inline]
561async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
562where
563    Io: AsyncWrite + Unpin,
564{
565    if send_timeout.is_zero() {
566        io.close().await?;
567        Ok(())
568    } else {
569        match tokio::time::timeout(send_timeout, io.close()).await {
570            Ok(Ok(())) => Ok(()),
571            Ok(Err(e)) => Err(MqttError::Encode(e)),
572            Err(_) => Err(MqttError::CloseTimeout),
573        }?;
574        Ok(())
575    }
576}