1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::CertInfo;
19use crate::{Builder, Result};
20
21pub struct Dispatcher<Io> {
25 pub(crate) io: Framed<Io, MqttCodec>,
27 pub remote_addr: SocketAddr,
29 pub cfg: Arc<Builder>,
31
32 pub cert_info: Option<CertInfo>,
33}
34
35impl<Io> Dispatcher<Io>
36where
37 Io: AsyncRead + AsyncWrite + Unpin,
38{
39 pub(crate) fn new(
41 io: Io,
42 remote_addr: SocketAddr,
43 cert_info: Option<CertInfo>,
44 cfg: Arc<Builder>,
45 ) -> Self {
46 Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg, cert_info }
47 }
48
49 #[inline]
51 pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
52 Ok(match self.probe_version().await? {
53 ProtocolVersion::MQTT3 => MqttStream::V3(v3::MqttStream {
54 io: self.io,
55 remote_addr: self.remote_addr,
56 cfg: self.cfg,
57 #[cfg(feature = "tls")]
58 cert_info: self.cert_info,
59 }),
60 ProtocolVersion::MQTT5 => MqttStream::V5(v5::MqttStream {
61 io: self.io,
62 remote_addr: self.remote_addr,
63 cfg: self.cfg,
64 #[cfg(feature = "tls")]
65 cert_info: self.cert_info,
66 }),
67 })
68 }
69
70 #[inline]
72 async fn probe_version(&mut self) -> Result<ProtocolVersion> {
73 let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
74 return Err(anyhow!(DecodeError::InvalidProtocol));
75 };
76
77 let codec = match ver {
78 ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
79 ProtocolVersion::MQTT5 => {
80 MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
81 }
82 };
83
84 *self.io.codec_mut() = codec;
85 Ok(ver)
86 }
87}
88
89pub enum MqttStream<Io> {
91 V3(v3::MqttStream<Io>),
93 V5(v5::MqttStream<Io>),
95}
96
97pub mod v3 {
98
99 use std::net::SocketAddr;
100 use std::num::NonZeroU16;
101 use std::pin::Pin;
102 use std::sync::Arc;
103 use std::task::{Context, Poll};
104 use std::time::Duration;
105
106 use futures::StreamExt;
107 use tokio::io::{AsyncRead, AsyncWrite};
108 use tokio_util::codec::Framed;
109
110 use rmqtt_codec::error::DecodeError;
111 use rmqtt_codec::types::Publish;
112 use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
113 use rmqtt_codec::{MqttCodec, MqttPacket};
114
115 use crate::error::MqttError;
116 use crate::{Builder, Error, Result};
117
118 #[cfg(feature = "tls")]
119 use crate::CertInfo;
120
121 pub struct MqttStream<Io> {
123 pub io: Framed<Io, MqttCodec>,
125 pub remote_addr: SocketAddr,
127 pub cfg: Arc<Builder>,
129 #[cfg(feature = "tls")]
130 pub cert_info: Option<CertInfo>,
132 }
133
134 impl<Io> MqttStream<Io>
159 where
160 Io: AsyncRead + AsyncWrite + Unpin,
161 {
162 #[inline]
164 pub async fn send_disconnect(&mut self) -> Result<()> {
165 self.send(PacketV3::Disconnect).await?;
166 self.flush().await
167 }
168
169 #[inline]
171 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
172 self.send(PacketV3::Publish(publish)).await
173 }
174
175 #[inline]
177 pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
178 self.send(PacketV3::PublishAck { packet_id }).await
179 }
180
181 #[inline]
183 pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
184 self.send(PacketV3::PublishReceived { packet_id }).await
185 }
186
187 #[inline]
189 pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
190 self.send(PacketV3::PublishRelease { packet_id }).await
191 }
192
193 #[inline]
195 pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
196 self.send(PacketV3::PublishComplete { packet_id }).await
197 }
198
199 #[inline]
201 pub async fn send_subscribe_ack(
202 &mut self,
203 packet_id: NonZeroU16,
204 status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
205 ) -> Result<()> {
206 self.send(PacketV3::SubscribeAck { packet_id, status }).await
207 }
208
209 #[inline]
211 pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
212 self.send(PacketV3::UnsubscribeAck { packet_id }).await
213 }
214
215 #[inline]
217 pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
218 self.send(PacketV3::Connect(Box::new(connect))).await
219 }
220
221 #[inline]
223 pub async fn send_connect_ack(
224 &mut self,
225 return_code: ConnectAckReason,
226 session_present: bool,
227 ) -> Result<()> {
228 self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
229 .await
230 }
231
232 #[inline]
234 pub async fn send_ping_request(&mut self) -> Result<()> {
235 self.send(PacketV3::PingRequest {}).await
236 }
237
238 #[inline]
240 pub async fn send_ping_response(&mut self) -> Result<()> {
241 self.send(PacketV3::PingResponse {}).await
242 }
243
244 #[inline]
246 pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
247 super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
248 }
249
250 #[inline]
252 pub async fn flush(&mut self) -> Result<()> {
253 super::flush(&mut self.io, self.cfg.send_timeout).await
254 }
255
256 #[inline]
258 pub async fn close(&mut self) -> Result<()> {
259 super::close(&mut self.io, self.cfg.send_timeout).await
260 }
261
262 #[inline]
264 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
265 match tokio::time::timeout(tm, self.next()).await {
266 Ok(Some(Ok(msg))) => Ok(Some(msg)),
267 Ok(Some(Err(e))) => Err(e),
268 Ok(None) => Ok(None),
269 Err(_) => Err(MqttError::ReadTimeout.into()),
270 }
271 }
272
273 #[inline]
275 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
276 let connect = match self.recv(tm).await {
277 #[allow(unused_mut)]
278 Ok(Some(Packet::Connect(mut connect))) => {
279 #[cfg(feature = "tls")]
280 {
281 if self.cfg.cert_cn_as_username {
282 if let Some(cert) = &self.cert_info {
283 if let Some(cn) = &cert.common_name {
284 connect.username = Some(cn.clone().into());
285 }
286 }
287 }
288 }
289 connect
290 }
291 Err(e) => {
292 return Err(e);
293 }
294 _ => {
295 return Err(MqttError::InvalidProtocol.into());
296 }
297 };
298 Ok(connect)
299 }
300 }
301
302 impl<Io> futures::Stream for MqttStream<Io>
303 where
304 Io: AsyncRead + Unpin,
305 {
306 type Item = Result<rmqtt_codec::v3::Packet>;
307
308 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 let next = Pin::new(&mut self.io).poll_next(cx);
310 Poll::Ready(match futures::ready!(next) {
311 Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
312 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
313 Some(Err(e)) => Some(Err(Error::from(e))),
314 None => None,
315 })
316 }
317 }
318}
319
320pub mod v5 {
321 use std::net::SocketAddr;
322 use std::pin::Pin;
323 use std::sync::Arc;
324 use std::task::{Context, Poll};
325 use std::time::Duration;
326
327 use futures::StreamExt;
328 use tokio::io::{AsyncRead, AsyncWrite};
329 use tokio_util::codec::Framed;
330
331 use rmqtt_codec::error::DecodeError;
332 use rmqtt_codec::types::Publish;
333 use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
334 use rmqtt_codec::{MqttCodec, MqttPacket};
335
336 use crate::error::MqttError;
337 #[cfg(feature = "tls")]
338 use crate::CertInfo;
339 use crate::{Builder, Error, Result};
340
341 pub struct MqttStream<Io> {
343 pub io: Framed<Io, MqttCodec>,
345 pub remote_addr: SocketAddr,
347 pub cfg: Arc<Builder>,
349 #[cfg(feature = "tls")]
350 pub cert_info: Option<CertInfo>,
352 }
353
354 impl<Io> MqttStream<Io>
380 where
381 Io: AsyncRead + AsyncWrite + Unpin,
382 {
383 #[inline]
385 pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
386 self.send(PacketV5::Disconnect(disc)).await?;
387 self.flush().await?;
388 tokio::time::sleep(Duration::from_millis(500)).await;
389 Ok(())
390 }
391
392 #[inline]
394 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
395 self.send(PacketV5::Publish(publish)).await
396 }
397
398 #[inline]
400 pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
401 self.send(PacketV5::PublishAck(ack)).await
402 }
403
404 #[inline]
406 pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
407 self.send(PacketV5::PublishReceived(ack)).await
408 }
409
410 #[inline]
412 pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
413 self.send(PacketV5::PublishRelease(ack2)).await
414 }
415
416 #[inline]
418 pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
419 self.send(PacketV5::PublishComplete(ack2)).await
420 }
421
422 #[inline]
424 pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
425 self.send(PacketV5::SubscribeAck(ack)).await
426 }
427
428 #[inline]
430 pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
431 self.send(PacketV5::UnsubscribeAck(unack)).await
432 }
433
434 #[inline]
436 pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
437 self.send(PacketV5::Connect(Box::new(connect))).await
438 }
439
440 #[inline]
442 pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
443 self.send(PacketV5::ConnectAck(Box::new(ack))).await
444 }
445
446 #[inline]
448 pub async fn send_ping_request(&mut self) -> Result<()> {
449 self.send(PacketV5::PingRequest {}).await
450 }
451
452 #[inline]
454 pub async fn send_ping_response(&mut self) -> Result<()> {
455 self.send(PacketV5::PingResponse {}).await
456 }
457
458 #[inline]
460 pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
461 self.send(PacketV5::Auth(auth)).await
462 }
463
464 #[inline]
466 pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
467 super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
468 }
469
470 #[inline]
472 pub async fn flush(&mut self) -> Result<()> {
473 super::flush(&mut self.io, self.cfg.send_timeout).await
474 }
475
476 #[inline]
478 pub async fn close(&mut self) -> Result<()> {
479 super::close(&mut self.io, self.cfg.send_timeout).await
480 }
481
482 #[inline]
484 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
485 match tokio::time::timeout(tm, self.next()).await {
486 Ok(Some(Ok(msg))) => Ok(Some(msg)),
487 Ok(Some(Err(e))) => Err(e),
488 Ok(None) => Ok(None),
489 Err(_) => Err(MqttError::ReadTimeout.into()),
490 }
491 }
492
493 #[inline]
495 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
496 let connect = match self.recv(tm).await {
497 Ok(Some(Packet::Connect(connect))) => connect,
498 Err(e) => {
499 return Err(e);
500 }
501 _ => {
502 return Err(MqttError::InvalidProtocol.into());
503 }
504 };
505 Ok(connect)
506 }
507 }
508
509 impl<Io> futures::Stream for MqttStream<Io>
510 where
511 Io: AsyncRead + Unpin,
512 {
513 type Item = Result<rmqtt_codec::v5::Packet>;
514
515 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
516 let next = Pin::new(&mut self.io).poll_next(cx);
517 Poll::Ready(match futures::ready!(next) {
518 Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
519 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
520 Some(Err(e)) => Some(Err(Error::from(e))),
521 None => None,
522 })
523 }
524 }
525}
526
527#[inline]
528async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
529where
530 Io: AsyncWrite + Unpin,
531{
532 if send_timeout.is_zero() {
533 io.send(packet).await?;
534 Ok(())
535 } else {
536 match tokio::time::timeout(send_timeout, io.send(packet)).await {
537 Ok(Ok(())) => Ok(()),
538 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
539 Err(_) => Err(MqttError::WriteTimeout),
540 }?;
541 Ok(())
542 }
543}
544
545#[inline]
546async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
547where
548 Io: AsyncWrite + Unpin,
549{
550 if send_timeout.is_zero() {
551 io.flush().await?;
552 Ok(())
553 } else {
554 match tokio::time::timeout(send_timeout, io.flush()).await {
555 Ok(Ok(())) => Ok(()),
556 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
557 Err(_) => Err(MqttError::FlushTimeout),
558 }?;
559 Ok(())
560 }
561}
562
563#[inline]
564async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
565where
566 Io: AsyncWrite + Unpin,
567{
568 if send_timeout.is_zero() {
569 io.close().await?;
570 Ok(())
571 } else {
572 match tokio::time::timeout(send_timeout, io.close()).await {
573 Ok(Ok(())) => Ok(()),
574 Ok(Err(e)) => Err(MqttError::Encode(e)),
575 Err(_) => Err(MqttError::CloseTimeout),
576 }?;
577 Ok(())
578 }
579}