rmqtt_net/
stream.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::CertInfo;
19use crate::{Builder, Result};
20
21/// MQTT protocol dispatcher handling version negotiation
22///
23/// Manages initial protocol detection and creates version-specific streams
24pub struct Dispatcher<Io> {
25    /// Framed IO layer with MQTT codec
26    pub(crate) io: Framed<Io, MqttCodec>,
27    /// Remote client's network address
28    pub remote_addr: SocketAddr,
29    /// Shared configuration builder
30    pub cfg: Arc<Builder>,
31
32    pub cert_info: Option<CertInfo>,
33}
34
35impl<Io> Dispatcher<Io>
36where
37    Io: AsyncRead + AsyncWrite + Unpin,
38{
39    /// Creates a new Dispatcher instance
40    pub(crate) fn new(
41        io: Io,
42        remote_addr: SocketAddr,
43        cert_info: Option<CertInfo>,
44        cfg: Arc<Builder>,
45    ) -> Self {
46        Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg, cert_info }
47    }
48
49    /// Negotiates protocol version and returns appropriate stream
50    #[inline]
51    pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
52        Ok(match self.probe_version().await? {
53            ProtocolVersion::MQTT3 => MqttStream::V3(v3::MqttStream {
54                io: self.io,
55                remote_addr: self.remote_addr,
56                cfg: self.cfg,
57                #[cfg(feature = "tls")]
58                cert_info: self.cert_info,
59            }),
60            ProtocolVersion::MQTT5 => MqttStream::V5(v5::MqttStream {
61                io: self.io,
62                remote_addr: self.remote_addr,
63                cfg: self.cfg,
64                #[cfg(feature = "tls")]
65                cert_info: self.cert_info,
66            }),
67        })
68    }
69
70    /// Detects protocol version from initial handshake
71    #[inline]
72    async fn probe_version(&mut self) -> Result<ProtocolVersion> {
73        let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
74            return Err(anyhow!(DecodeError::InvalidProtocol));
75        };
76
77        let codec = match ver {
78            ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
79            ProtocolVersion::MQTT5 => {
80                MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
81            }
82        };
83
84        *self.io.codec_mut() = codec;
85        Ok(ver)
86    }
87}
88
89/// Version-specific MQTT protocol streams
90pub enum MqttStream<Io> {
91    /// MQTT v3.1.1 implementation
92    V3(v3::MqttStream<Io>),
93    /// MQTT v5.0 implementation
94    V5(v5::MqttStream<Io>),
95}
96
97pub mod v3 {
98
99    use std::net::SocketAddr;
100    use std::num::NonZeroU16;
101    use std::pin::Pin;
102    use std::sync::Arc;
103    use std::task::{Context, Poll};
104    use std::time::Duration;
105
106    use futures::StreamExt;
107    use tokio::io::{AsyncRead, AsyncWrite};
108    use tokio_util::codec::Framed;
109
110    use rmqtt_codec::error::DecodeError;
111    use rmqtt_codec::types::Publish;
112    use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
113    use rmqtt_codec::{MqttCodec, MqttPacket};
114
115    use crate::error::MqttError;
116    use crate::{Builder, Error, Result};
117
118    #[cfg(feature = "tls")]
119    use crate::CertInfo;
120
121    /// MQTT v3.1.1 protocol stream implementation
122    pub struct MqttStream<Io> {
123        /// Framed IO layer with MQTT codec
124        pub io: Framed<Io, MqttCodec>,
125        /// Remote client's network address
126        pub remote_addr: SocketAddr,
127        /// Shared configuration builder
128        pub cfg: Arc<Builder>,
129        #[cfg(feature = "tls")]
130        /// TLS certificate information (if available)
131        pub cert_info: Option<CertInfo>,
132    }
133
134    /// # Examples
135    /// ```
136    /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
137    /// use std::sync::Arc;
138    /// use tokio::net::TcpStream;
139    /// use tokio_util::codec::Framed;
140    /// use rmqtt_codec::{MqttCodec, types::Publish};
141    /// use rmqtt_net::{Builder,v3};
142    ///
143    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
144    /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1883);
145    /// let stream = TcpStream::connect(addr).await?;
146    /// let mut mqtt_stream = v3::MqttStream {
147    ///     io: Framed::new(stream, MqttCodec::V3(Default::default())),
148    ///     remote_addr: addr,
149    ///     cfg: Arc::new(Builder::default()),
150    ///     cert_info: None,
151    /// };
152    ///
153    /// // Send a PING request
154    /// mqtt_stream.send_ping_request().await?;
155    /// # Ok(())
156    /// # }
157    /// ```
158    impl<Io> MqttStream<Io>
159    where
160        Io: AsyncRead + AsyncWrite + Unpin,
161    {
162        /// Sends DISCONNECT packet and flushes buffers
163        #[inline]
164        pub async fn send_disconnect(&mut self) -> Result<()> {
165            self.send(PacketV3::Disconnect).await?;
166            self.flush().await
167        }
168
169        /// Publishes a message to the broker
170        #[inline]
171        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
172            self.send(PacketV3::Publish(publish)).await
173        }
174
175        /// Acknowledges a received publish (QoS 1)
176        #[inline]
177        pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
178            self.send(PacketV3::PublishAck { packet_id }).await
179        }
180
181        /// Confirms receipt of a publish (QoS 2 step 1)
182        #[inline]
183        pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
184            self.send(PacketV3::PublishReceived { packet_id }).await
185        }
186
187        /// Releases a stored publish (QoS 2 step 2)
188        #[inline]
189        pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
190            self.send(PacketV3::PublishRelease { packet_id }).await
191        }
192
193        /// Confirms publish completion (QoS 2 step 3)
194        #[inline]
195        pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
196            self.send(PacketV3::PublishComplete { packet_id }).await
197        }
198
199        /// Acknowledges a subscription request
200        #[inline]
201        pub async fn send_subscribe_ack(
202            &mut self,
203            packet_id: NonZeroU16,
204            status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
205        ) -> Result<()> {
206            self.send(PacketV3::SubscribeAck { packet_id, status }).await
207        }
208
209        /// Acknowledges an unsubscribe request
210        #[inline]
211        pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
212            self.send(PacketV3::UnsubscribeAck { packet_id }).await
213        }
214
215        /// Initiates connection to the broker
216        #[inline]
217        pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
218            self.send(PacketV3::Connect(Box::new(connect))).await
219        }
220
221        /// Responds to connection request
222        #[inline]
223        pub async fn send_connect_ack(
224            &mut self,
225            return_code: ConnectAckReason,
226            session_present: bool,
227        ) -> Result<()> {
228            self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
229                .await
230        }
231
232        /// Sends keep-alive ping request
233        #[inline]
234        pub async fn send_ping_request(&mut self) -> Result<()> {
235            self.send(PacketV3::PingRequest {}).await
236        }
237
238        /// Responds to ping request
239        #[inline]
240        pub async fn send_ping_response(&mut self) -> Result<()> {
241            self.send(PacketV3::PingResponse {}).await
242        }
243
244        /// Generic packet sending method
245        #[inline]
246        pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
247            super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
248        }
249
250        /// Flushes write buffers
251        #[inline]
252        pub async fn flush(&mut self) -> Result<()> {
253            super::flush(&mut self.io, self.cfg.send_timeout).await
254        }
255
256        /// Closes the connection gracefully
257        #[inline]
258        pub async fn close(&mut self) -> Result<()> {
259            super::close(&mut self.io, self.cfg.send_timeout).await
260        }
261
262        /// Receives next packet with timeout
263        #[inline]
264        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
265            match tokio::time::timeout(tm, self.next()).await {
266                Ok(Some(Ok(msg))) => Ok(Some(msg)),
267                Ok(Some(Err(e))) => Err(e),
268                Ok(None) => Ok(None),
269                Err(_) => Err(MqttError::ReadTimeout.into()),
270            }
271        }
272
273        /// Waits for CONNECT packet with timeout
274        #[inline]
275        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
276            let connect = match self.recv(tm).await {
277                #[allow(unused_mut)]
278                Ok(Some(Packet::Connect(mut connect))) => {
279                    #[cfg(feature = "tls")]
280                    {
281                        if self.cfg.cert_cn_as_username {
282                            if let Some(cert) = &self.cert_info {
283                                if let Some(cn) = &cert.common_name {
284                                    connect.username = Some(cn.clone().into());
285                                }
286                            }
287                        }
288                    }
289                    connect
290                }
291                Err(e) => {
292                    return Err(e);
293                }
294                _ => {
295                    return Err(MqttError::InvalidProtocol.into());
296                }
297            };
298            Ok(connect)
299        }
300    }
301
302    impl<Io> futures::Stream for MqttStream<Io>
303    where
304        Io: AsyncRead + Unpin,
305    {
306        type Item = Result<rmqtt_codec::v3::Packet>;
307
308        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309            let next = Pin::new(&mut self.io).poll_next(cx);
310            Poll::Ready(match futures::ready!(next) {
311                Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
312                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
313                Some(Err(e)) => Some(Err(Error::from(e))),
314                None => None,
315            })
316        }
317    }
318}
319
320pub mod v5 {
321    use std::net::SocketAddr;
322    use std::pin::Pin;
323    use std::sync::Arc;
324    use std::task::{Context, Poll};
325    use std::time::Duration;
326
327    use futures::StreamExt;
328    use tokio::io::{AsyncRead, AsyncWrite};
329    use tokio_util::codec::Framed;
330
331    use rmqtt_codec::error::DecodeError;
332    use rmqtt_codec::types::Publish;
333    use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
334    use rmqtt_codec::{MqttCodec, MqttPacket};
335
336    use crate::error::MqttError;
337    #[cfg(feature = "tls")]
338    use crate::CertInfo;
339    use crate::{Builder, Error, Result};
340
341    /// MQTT v5.0 protocol stream implementation
342    pub struct MqttStream<Io> {
343        /// Framed IO layer with MQTT codec
344        pub io: Framed<Io, MqttCodec>,
345        /// Remote client's network address
346        pub remote_addr: SocketAddr,
347        /// Shared configuration builder
348        pub cfg: Arc<Builder>,
349        #[cfg(feature = "tls")]
350        /// TLS certificate information (if available)
351        pub cert_info: Option<CertInfo>,
352    }
353
354    /// # Examples
355    /// ```
356    /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
357    /// use std::sync::Arc;
358    /// use tokio::net::TcpStream;
359    /// use tokio_util::codec::Framed;
360    /// use rmqtt_codec::{MqttCodec, types::Publish};
361    /// use rmqtt_net::{Builder,v5};
362    /// use rmqtt_codec::v5::Connect;
363    ///
364    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
365    /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1883);
366    /// let stream = TcpStream::connect(addr).await?;
367    /// let mut mqtt_stream = v5::MqttStream {
368    ///     io: Framed::new(stream, MqttCodec::V5(Default::default())),
369    ///     remote_addr: addr,
370    ///     cfg: Arc::new(Builder::default()),
371    ///     cert_info: None
372    /// };
373    ///
374    /// // Send authentication packet
375    /// mqtt_stream.send_auth(rmqtt_codec::v5::Auth::default()).await?;
376    /// # Ok(())
377    /// # }
378    /// ```
379    impl<Io> MqttStream<Io>
380    where
381        Io: AsyncRead + AsyncWrite + Unpin,
382    {
383        /// Sends DISCONNECT packet with reason code
384        #[inline]
385        pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
386            self.send(PacketV5::Disconnect(disc)).await?;
387            self.flush().await?;
388            tokio::time::sleep(Duration::from_millis(500)).await;
389            Ok(())
390        }
391
392        /// Publishes a message to the broker
393        #[inline]
394        pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
395            self.send(PacketV5::Publish(publish)).await
396        }
397
398        /// Acknowledges a received publish (QoS 1)
399        #[inline]
400        pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
401            self.send(PacketV5::PublishAck(ack)).await
402        }
403
404        /// Confirms receipt of a publish (QoS 2 step 1)
405        #[inline]
406        pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
407            self.send(PacketV5::PublishReceived(ack)).await
408        }
409
410        /// Releases a stored publish (QoS 2 step 2)
411        #[inline]
412        pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
413            self.send(PacketV5::PublishRelease(ack2)).await
414        }
415
416        /// Confirms publish completion (QoS 2 step 3)
417        #[inline]
418        pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
419            self.send(PacketV5::PublishComplete(ack2)).await
420        }
421
422        /// Acknowledges a subscription request
423        #[inline]
424        pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
425            self.send(PacketV5::SubscribeAck(ack)).await
426        }
427
428        /// Acknowledges an unsubscribe request
429        #[inline]
430        pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
431            self.send(PacketV5::UnsubscribeAck(unack)).await
432        }
433
434        /// Initiates connection to the broker
435        #[inline]
436        pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
437            self.send(PacketV5::Connect(Box::new(connect))).await
438        }
439
440        /// Responds to connection request
441        #[inline]
442        pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
443            self.send(PacketV5::ConnectAck(Box::new(ack))).await
444        }
445
446        /// Sends keep-alive ping request
447        #[inline]
448        pub async fn send_ping_request(&mut self) -> Result<()> {
449            self.send(PacketV5::PingRequest {}).await
450        }
451
452        /// Responds to ping request
453        #[inline]
454        pub async fn send_ping_response(&mut self) -> Result<()> {
455            self.send(PacketV5::PingResponse {}).await
456        }
457
458        /// Sends authentication exchange packet
459        #[inline]
460        pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
461            self.send(PacketV5::Auth(auth)).await
462        }
463
464        /// Generic packet sending method
465        #[inline]
466        pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
467            super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
468        }
469
470        /// Flushes write buffers
471        #[inline]
472        pub async fn flush(&mut self) -> Result<()> {
473            super::flush(&mut self.io, self.cfg.send_timeout).await
474        }
475
476        /// Closes the connection gracefully
477        #[inline]
478        pub async fn close(&mut self) -> Result<()> {
479            super::close(&mut self.io, self.cfg.send_timeout).await
480        }
481
482        /// Receives next packet with timeout
483        #[inline]
484        pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
485            match tokio::time::timeout(tm, self.next()).await {
486                Ok(Some(Ok(msg))) => Ok(Some(msg)),
487                Ok(Some(Err(e))) => Err(e),
488                Ok(None) => Ok(None),
489                Err(_) => Err(MqttError::ReadTimeout.into()),
490            }
491        }
492
493        /// Waits for CONNECT packet with timeout
494        #[inline]
495        pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
496            let connect = match self.recv(tm).await {
497                Ok(Some(Packet::Connect(connect))) => connect,
498                Err(e) => {
499                    return Err(e);
500                }
501                _ => {
502                    return Err(MqttError::InvalidProtocol.into());
503                }
504            };
505            Ok(connect)
506        }
507    }
508
509    impl<Io> futures::Stream for MqttStream<Io>
510    where
511        Io: AsyncRead + Unpin,
512    {
513        type Item = Result<rmqtt_codec::v5::Packet>;
514
515        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
516            let next = Pin::new(&mut self.io).poll_next(cx);
517            Poll::Ready(match futures::ready!(next) {
518                Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
519                Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
520                Some(Err(e)) => Some(Err(Error::from(e))),
521                None => None,
522            })
523        }
524    }
525}
526
527#[inline]
528async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
529where
530    Io: AsyncWrite + Unpin,
531{
532    if send_timeout.is_zero() {
533        io.send(packet).await?;
534        Ok(())
535    } else {
536        match tokio::time::timeout(send_timeout, io.send(packet)).await {
537            Ok(Ok(())) => Ok(()),
538            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
539            Err(_) => Err(MqttError::WriteTimeout),
540        }?;
541        Ok(())
542    }
543}
544
545#[inline]
546async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
547where
548    Io: AsyncWrite + Unpin,
549{
550    if send_timeout.is_zero() {
551        io.flush().await?;
552        Ok(())
553    } else {
554        match tokio::time::timeout(send_timeout, io.flush()).await {
555            Ok(Ok(())) => Ok(()),
556            Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
557            Err(_) => Err(MqttError::FlushTimeout),
558        }?;
559        Ok(())
560    }
561}
562
563#[inline]
564async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
565where
566    Io: AsyncWrite + Unpin,
567{
568    if send_timeout.is_zero() {
569        io.close().await?;
570        Ok(())
571    } else {
572        match tokio::time::timeout(send_timeout, io.close()).await {
573            Ok(Ok(())) => Ok(()),
574            Ok(Err(e)) => Err(MqttError::Encode(e)),
575            Err(_) => Err(MqttError::CloseTimeout),
576        }?;
577        Ok(())
578    }
579}