1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::anyhow;
6use futures::SinkExt;
7use futures::StreamExt;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10
11use rmqtt_codec::error::{DecodeError, SendPacketError};
12use rmqtt_codec::v3::Codec as CodecV3;
13use rmqtt_codec::v5::Codec as CodecV5;
14use rmqtt_codec::version::{ProtocolVersion, VersionCodec};
15use rmqtt_codec::{MqttCodec, MqttPacket};
16
17use crate::error::MqttError;
18use crate::CertInfo;
19use crate::{Builder, Result};
20
21pub struct Dispatcher<Io> {
25 pub(crate) io: Framed<Io, MqttCodec>,
27 pub remote_addr: SocketAddr,
29 pub cfg: Arc<Builder>,
31
32 pub cert_info: Option<CertInfo>,
33}
34
35impl<Io> Dispatcher<Io>
36where
37 Io: AsyncRead + AsyncWrite + Unpin,
38{
39 pub(crate) fn new(
41 io: Io,
42 remote_addr: SocketAddr,
43 cert_info: Option<CertInfo>,
44 cfg: Arc<Builder>,
45 ) -> Self {
46 Dispatcher { io: Framed::new(io, MqttCodec::Version(VersionCodec)), remote_addr, cfg, cert_info }
47 }
48
49 #[inline]
51 pub async fn mqtt(mut self) -> Result<MqttStream<Io>> {
52 Ok(match self.probe_version().await? {
53 ProtocolVersion::MQTT3 => MqttStream::V3(v3::MqttStream {
54 io: self.io,
55 remote_addr: self.remote_addr,
56 cfg: self.cfg,
57 #[cfg(feature = "tls")]
58 cert_info: self.cert_info,
59 }),
60 ProtocolVersion::MQTT5 => MqttStream::V5(v5::MqttStream {
61 io: self.io,
62 remote_addr: self.remote_addr,
63 cfg: self.cfg,
64 #[cfg(feature = "tls")]
65 cert_info: self.cert_info,
66 }),
67 })
68 }
69
70 #[inline]
72 async fn probe_version(&mut self) -> Result<ProtocolVersion> {
73 let Some(Ok((MqttPacket::Version(ver), _))) = self.io.next().await else {
74 return Err(anyhow!(DecodeError::InvalidProtocol));
75 };
76
77 let codec = match ver {
78 ProtocolVersion::MQTT3 => MqttCodec::V3(CodecV3::new(self.cfg.max_packet_size)),
79 ProtocolVersion::MQTT5 => {
80 MqttCodec::V5(CodecV5::new(self.cfg.max_packet_size, self.cfg.max_packet_size))
81 }
82 };
83
84 *self.io.codec_mut() = codec;
85 Ok(ver)
86 }
87}
88
89pub enum MqttStream<Io> {
91 V3(v3::MqttStream<Io>),
93 V5(v5::MqttStream<Io>),
95}
96
97pub mod v3 {
98
99 use std::net::SocketAddr;
100 use std::num::NonZeroU16;
101 use std::pin::Pin;
102 use std::sync::Arc;
103 use std::task::{Context, Poll};
104 use std::time::Duration;
105
106 use futures::StreamExt;
107 use tokio::io::{AsyncRead, AsyncWrite};
108 use tokio_util::codec::Framed;
109
110 use rmqtt_codec::error::DecodeError;
111 use rmqtt_codec::types::Publish;
112 use rmqtt_codec::v3::{Connect, ConnectAckReason, Packet as PacketV3, Packet};
113 use rmqtt_codec::{MqttCodec, MqttPacket};
114
115 use crate::error::MqttError;
116 use crate::{Builder, Error, Result};
117
118 #[cfg(feature = "tls")]
119 use crate::CertInfo;
120
121 pub struct MqttStream<Io> {
123 pub io: Framed<Io, MqttCodec>,
125 pub remote_addr: SocketAddr,
127 pub cfg: Arc<Builder>,
129 #[cfg(feature = "tls")]
130 pub cert_info: Option<CertInfo>,
132 }
133
134 impl<Io> MqttStream<Io>
159 where
160 Io: AsyncRead + AsyncWrite + Unpin,
161 {
162 #[inline]
164 pub async fn send_disconnect(&mut self) -> Result<()> {
165 self.send(PacketV3::Disconnect).await?;
166 self.flush().await
167 }
168
169 #[inline]
171 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
172 self.send(PacketV3::Publish(publish)).await
173 }
174
175 #[inline]
177 pub async fn send_publish_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
178 self.send(PacketV3::PublishAck { packet_id }).await
179 }
180
181 #[inline]
183 pub async fn send_publish_received(&mut self, packet_id: NonZeroU16) -> Result<()> {
184 self.send(PacketV3::PublishReceived { packet_id }).await
185 }
186
187 #[inline]
189 pub async fn send_publish_release(&mut self, packet_id: NonZeroU16) -> Result<()> {
190 self.send(PacketV3::PublishRelease { packet_id }).await
191 }
192
193 #[inline]
195 pub async fn send_publish_complete(&mut self, packet_id: NonZeroU16) -> Result<()> {
196 self.send(PacketV3::PublishComplete { packet_id }).await
197 }
198
199 #[inline]
201 pub async fn send_subscribe_ack(
202 &mut self,
203 packet_id: NonZeroU16,
204 status: Vec<rmqtt_codec::v3::SubscribeReturnCode>,
205 ) -> Result<()> {
206 self.send(PacketV3::SubscribeAck { packet_id, status }).await
207 }
208
209 #[inline]
211 pub async fn send_unsubscribe_ack(&mut self, packet_id: NonZeroU16) -> Result<()> {
212 self.send(PacketV3::UnsubscribeAck { packet_id }).await
213 }
214
215 #[inline]
217 pub async fn send_connect(&mut self, connect: rmqtt_codec::v3::Connect) -> Result<()> {
218 self.send(PacketV3::Connect(Box::new(connect))).await
219 }
220
221 #[inline]
223 pub async fn send_connect_ack(
224 &mut self,
225 return_code: ConnectAckReason,
226 session_present: bool,
227 ) -> Result<()> {
228 self.send(PacketV3::ConnectAck(rmqtt_codec::v3::ConnectAck { session_present, return_code }))
229 .await
230 }
231
232 #[inline]
234 pub async fn send_ping_request(&mut self) -> Result<()> {
235 self.send(PacketV3::PingRequest {}).await
236 }
237
238 #[inline]
240 pub async fn send_ping_response(&mut self) -> Result<()> {
241 self.send(PacketV3::PingResponse {}).await
242 }
243
244 #[inline]
246 pub async fn send(&mut self, packet: rmqtt_codec::v3::Packet) -> Result<()> {
247 super::send(&mut self.io, MqttPacket::V3(packet), self.cfg.send_timeout).await
248 }
249
250 #[inline]
252 pub async fn flush(&mut self) -> Result<()> {
253 super::flush(&mut self.io, self.cfg.send_timeout).await
254 }
255
256 #[inline]
258 pub async fn close(&mut self) -> Result<()> {
259 super::close(&mut self.io, self.cfg.send_timeout).await
260 }
261
262 #[inline]
264 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v3::Packet>> {
265 match tokio::time::timeout(tm, self.next()).await {
266 Ok(Some(Ok(msg))) => Ok(Some(msg)),
267 Ok(Some(Err(e))) => Err(e),
268 Ok(None) => Ok(None),
269 Err(_) => Err(MqttError::ReadTimeout.into()),
270 }
271 }
272
273 #[inline]
275 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
276 let connect = match self.recv(tm).await {
277 Ok(Some(Packet::Connect(mut connect))) => {
278 #[cfg(feature = "tls")]
279 {
280 if self.cfg.cert_cn_as_username {
281 if let Some(cert) = &self.cert_info {
282 if let Some(cn) = &cert.common_name {
283 connect.username = Some(cn.clone().into());
284 }
285 }
286 }
287 }
288 connect
289 }
290 Err(e) => {
291 return Err(e);
292 }
293 _ => {
294 return Err(MqttError::InvalidProtocol.into());
295 }
296 };
297 Ok(connect)
298 }
299 }
300
301 impl<Io> futures::Stream for MqttStream<Io>
302 where
303 Io: AsyncRead + Unpin,
304 {
305 type Item = Result<rmqtt_codec::v3::Packet>;
306
307 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
308 let next = Pin::new(&mut self.io).poll_next(cx);
309 Poll::Ready(match futures::ready!(next) {
310 Some(Ok((MqttPacket::V3(packet), _))) => Some(Ok(packet)),
311 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
312 Some(Err(e)) => Some(Err(Error::from(e))),
313 None => None,
314 })
315 }
316 }
317}
318
319pub mod v5 {
320 use std::net::SocketAddr;
321 use std::pin::Pin;
322 use std::sync::Arc;
323 use std::task::{Context, Poll};
324 use std::time::Duration;
325
326 use futures::StreamExt;
327 use tokio::io::{AsyncRead, AsyncWrite};
328 use tokio_util::codec::Framed;
329
330 use rmqtt_codec::error::DecodeError;
331 use rmqtt_codec::types::Publish;
332 use rmqtt_codec::v5::{Auth, Connect, Disconnect, Packet as PacketV5, Packet};
333 use rmqtt_codec::{MqttCodec, MqttPacket};
334
335 use crate::error::MqttError;
336 use crate::{Builder, CertInfo, Error, Result};
337
338 pub struct MqttStream<Io> {
340 pub io: Framed<Io, MqttCodec>,
342 pub remote_addr: SocketAddr,
344 pub cfg: Arc<Builder>,
346 #[cfg(feature = "tls")]
347 pub cert_info: Option<CertInfo>,
349 }
350
351 impl<Io> MqttStream<Io>
377 where
378 Io: AsyncRead + AsyncWrite + Unpin,
379 {
380 #[inline]
382 pub async fn send_disconnect(&mut self, disc: Disconnect) -> Result<()> {
383 self.send(PacketV5::Disconnect(disc)).await?;
384 self.flush().await?;
385 tokio::time::sleep(Duration::from_millis(500)).await;
386 Ok(())
387 }
388
389 #[inline]
391 pub async fn send_publish(&mut self, publish: Box<Publish>) -> Result<()> {
392 self.send(PacketV5::Publish(publish)).await
393 }
394
395 #[inline]
397 pub async fn send_publish_ack(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
398 self.send(PacketV5::PublishAck(ack)).await
399 }
400
401 #[inline]
403 pub async fn send_publish_received(&mut self, ack: rmqtt_codec::v5::PublishAck) -> Result<()> {
404 self.send(PacketV5::PublishReceived(ack)).await
405 }
406
407 #[inline]
409 pub async fn send_publish_release(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
410 self.send(PacketV5::PublishRelease(ack2)).await
411 }
412
413 #[inline]
415 pub async fn send_publish_complete(&mut self, ack2: rmqtt_codec::v5::PublishAck2) -> Result<()> {
416 self.send(PacketV5::PublishComplete(ack2)).await
417 }
418
419 #[inline]
421 pub async fn send_subscribe_ack(&mut self, ack: rmqtt_codec::v5::SubscribeAck) -> Result<()> {
422 self.send(PacketV5::SubscribeAck(ack)).await
423 }
424
425 #[inline]
427 pub async fn send_unsubscribe_ack(&mut self, unack: rmqtt_codec::v5::UnsubscribeAck) -> Result<()> {
428 self.send(PacketV5::UnsubscribeAck(unack)).await
429 }
430
431 #[inline]
433 pub async fn send_connect(&mut self, connect: rmqtt_codec::v5::Connect) -> Result<()> {
434 self.send(PacketV5::Connect(Box::new(connect))).await
435 }
436
437 #[inline]
439 pub async fn send_connect_ack(&mut self, ack: rmqtt_codec::v5::ConnectAck) -> Result<()> {
440 self.send(PacketV5::ConnectAck(Box::new(ack))).await
441 }
442
443 #[inline]
445 pub async fn send_ping_request(&mut self) -> Result<()> {
446 self.send(PacketV5::PingRequest {}).await
447 }
448
449 #[inline]
451 pub async fn send_ping_response(&mut self) -> Result<()> {
452 self.send(PacketV5::PingResponse {}).await
453 }
454
455 #[inline]
457 pub async fn send_auth(&mut self, auth: Auth) -> Result<()> {
458 self.send(PacketV5::Auth(auth)).await
459 }
460
461 #[inline]
463 pub async fn send(&mut self, packet: rmqtt_codec::v5::Packet) -> Result<()> {
464 super::send(&mut self.io, MqttPacket::V5(packet), self.cfg.send_timeout).await
465 }
466
467 #[inline]
469 pub async fn flush(&mut self) -> Result<()> {
470 super::flush(&mut self.io, self.cfg.send_timeout).await
471 }
472
473 #[inline]
475 pub async fn close(&mut self) -> Result<()> {
476 super::close(&mut self.io, self.cfg.send_timeout).await
477 }
478
479 #[inline]
481 pub async fn recv(&mut self, tm: Duration) -> Result<Option<rmqtt_codec::v5::Packet>> {
482 match tokio::time::timeout(tm, self.next()).await {
483 Ok(Some(Ok(msg))) => Ok(Some(msg)),
484 Ok(Some(Err(e))) => Err(e),
485 Ok(None) => Ok(None),
486 Err(_) => Err(MqttError::ReadTimeout.into()),
487 }
488 }
489
490 #[inline]
492 pub async fn recv_connect(&mut self, tm: Duration) -> Result<Box<Connect>> {
493 let connect = match self.recv(tm).await {
494 Ok(Some(Packet::Connect(connect))) => connect,
495 Err(e) => {
496 return Err(e);
497 }
498 _ => {
499 return Err(MqttError::InvalidProtocol.into());
500 }
501 };
502 Ok(connect)
503 }
504 }
505
506 impl<Io> futures::Stream for MqttStream<Io>
507 where
508 Io: AsyncRead + Unpin,
509 {
510 type Item = Result<rmqtt_codec::v5::Packet>;
511
512 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
513 let next = Pin::new(&mut self.io).poll_next(cx);
514 Poll::Ready(match futures::ready!(next) {
515 Some(Ok((MqttPacket::V5(packet), _))) => Some(Ok(packet)),
516 Some(Ok(_)) => Some(Err(MqttError::Decode(DecodeError::MalformedPacket).into())),
517 Some(Err(e)) => Some(Err(Error::from(e))),
518 None => None,
519 })
520 }
521 }
522}
523
524#[inline]
525async fn send<Io>(io: &mut Framed<Io, MqttCodec>, packet: MqttPacket, send_timeout: Duration) -> Result<()>
526where
527 Io: AsyncWrite + Unpin,
528{
529 if send_timeout.is_zero() {
530 io.send(packet).await?;
531 Ok(())
532 } else {
533 match tokio::time::timeout(send_timeout, io.send(packet)).await {
534 Ok(Ok(())) => Ok(()),
535 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
536 Err(_) => Err(MqttError::WriteTimeout),
537 }?;
538 Ok(())
539 }
540}
541
542#[inline]
543async fn flush<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
544where
545 Io: AsyncWrite + Unpin,
546{
547 if send_timeout.is_zero() {
548 io.flush().await?;
549 Ok(())
550 } else {
551 match tokio::time::timeout(send_timeout, io.flush()).await {
552 Ok(Ok(())) => Ok(()),
553 Ok(Err(e)) => Err(MqttError::SendPacket(SendPacketError::Encode(e))),
554 Err(_) => Err(MqttError::FlushTimeout),
555 }?;
556 Ok(())
557 }
558}
559
560#[inline]
561async fn close<Io>(io: &mut Framed<Io, MqttCodec>, send_timeout: Duration) -> Result<()>
562where
563 Io: AsyncWrite + Unpin,
564{
565 if send_timeout.is_zero() {
566 io.close().await?;
567 Ok(())
568 } else {
569 match tokio::time::timeout(send_timeout, io.close()).await {
570 Ok(Ok(())) => Ok(()),
571 Ok(Err(e)) => Err(MqttError::Encode(e)),
572 Err(_) => Err(MqttError::CloseTimeout),
573 }?;
574 Ok(())
575 }
576}