ntex_mqtt/v5/
server.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed};
4use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
5use ntex_util::time::{timeout_checked, Millis, Seconds};
6
7use crate::error::{HandshakeError, MqttError, ProtocolError};
8use crate::{service, types::QoS, InFlightService};
9
10use super::codec::{self as mqtt, Decoded, Encoded, Packet};
11use super::control::{Control, ControlAck};
12use super::default::{DefaultControlService, DefaultPublishService};
13use super::handshake::{Handshake, HandshakeAck};
14use super::publish::{Publish, PublishAck};
15use super::shared::{MqttShared, MqttSinkPool};
16use super::{dispatcher::factory, MqttSink, Session};
17
18/// Mqtt Server
19pub struct MqttServer<St, C, Cn, P, M = Identity> {
20    handshake: C,
21    srv_control: Cn,
22    srv_publish: P,
23    middleware: M,
24    max_qos: QoS,
25    max_size: u32,
26    max_receive: u16,
27    max_topic_alias: u16,
28    min_chunk_size: u32,
29    handle_qos_after_disconnect: Option<QoS>,
30    connect_timeout: Seconds,
31    config: DispatcherConfig,
32    pub(super) pool: Rc<MqttSinkPool>,
33    _t: PhantomData<St>,
34}
35
36impl<St, C>
37    MqttServer<
38        St,
39        C,
40        DefaultControlService<St, C::Error>,
41        DefaultPublishService<St, C::Error>,
42        InFlightService,
43    >
44where
45    C: ServiceFactory<Handshake, Response = HandshakeAck<St>>,
46    C::Error: fmt::Debug,
47{
48    /// Create server factory and provide handshake service
49    pub fn new<F>(handshake: F) -> Self
50    where
51        F: IntoServiceFactory<C, Handshake>,
52    {
53        let config = DispatcherConfig::default();
54        config.set_disconnect_timeout(Seconds(3));
55
56        MqttServer {
57            config,
58            handshake: handshake.into_factory(),
59            srv_control: DefaultControlService::default(),
60            srv_publish: DefaultPublishService::default(),
61            middleware: InFlightService::new(0, 65535),
62            max_qos: QoS::AtLeastOnce,
63            max_size: 0,
64            max_receive: 15,
65            max_topic_alias: 32,
66            min_chunk_size: 32 * 1024,
67            handle_qos_after_disconnect: None,
68            connect_timeout: Seconds::ZERO,
69            pool: Rc::new(MqttSinkPool::default()),
70            _t: PhantomData,
71        }
72    }
73}
74
75impl<St, C, Cn, P> MqttServer<St, C, Cn, P, InFlightService> {
76    /// Total size of received in-flight messages.
77    ///
78    /// By default total in-flight size is set to 64Kb
79    pub fn max_receive_size(mut self, val: usize) -> Self {
80        self.middleware = self.middleware.max_receive_size(val);
81        self
82    }
83}
84
85impl<St, C, Cn, P, M> MqttServer<St, C, Cn, P, M>
86where
87    St: 'static,
88    C: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
89    C::Error: fmt::Debug,
90    Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
91    P: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
92{
93    /// Set client timeout for first `Connect` frame.
94    ///
95    /// Defines a timeout for reading `Connect` frame. If a client does not transmit
96    /// the entire frame within this time, the connection is terminated with
97    /// Mqtt::Handshake(HandshakeError::Timeout) error.
98    ///
99    /// By default, connect timeout is disabled.
100    pub fn connect_timeout(mut self, timeout: Seconds) -> Self {
101        self.connect_timeout = timeout;
102        self
103    }
104
105    /// Set server connection disconnect timeout.
106    ///
107    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
108    /// within this time, the connection get dropped.
109    ///
110    /// To disable timeout set value to 0.
111    ///
112    /// By default disconnect timeout is set to 3 seconds.
113    pub fn disconnect_timeout(self, val: Seconds) -> Self {
114        self.config.set_disconnect_timeout(val);
115        self
116    }
117
118    /// Set read rate parameters for single frame.
119    ///
120    /// Set read timeout, max timeout and rate for reading payload. If the client
121    /// sends `rate` amount of data within `timeout` period of time, extend timeout by `timeout` seconds.
122    /// But no more than `max_timeout` timeout.
123    ///
124    /// By default frame read rate is disabled.
125    pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
126        self.config.set_frame_read_rate(timeout, max_timeout, rate);
127        self
128    }
129
130    /// Set max inbound frame size.
131    ///
132    /// If max size is set to `0`, size is unlimited.
133    /// By default max size is set to `0`
134    pub fn max_size(mut self, size: u32) -> Self {
135        self.max_size = size;
136        self
137    }
138
139    /// Set `receive max`
140    ///
141    /// Number of in-flight publish packets. By default receive max is set to 15 packets.
142    /// To disable timeout set value to 0.
143    pub fn max_receive(mut self, val: u16) -> Self {
144        self.max_receive = val;
145        self
146    }
147
148    /// Number of topic aliases.
149    ///
150    /// By default value is set to 32
151    pub fn max_topic_alias(mut self, val: u16) -> Self {
152        self.max_topic_alias = val;
153        self
154    }
155
156    /// Set server max qos setting.
157    ///
158    /// By default max qos is not set.
159    pub fn max_qos(mut self, qos: QoS) -> Self {
160        self.max_qos = qos;
161        self
162    }
163
164    /// Set min payload chunk size.
165    ///
166    /// If the minimum size is set to `0`, incoming payload chunks
167    /// will be processed immediately. Otherwise, the codec will
168    /// accumulate chunks until the total size reaches the specified minimum.
169    /// By default min size is set to `0`
170    pub fn min_chunk_size(mut self, size: u32) -> Self {
171        self.min_chunk_size = size;
172        self
173    }
174
175    /// Handle max received QoS messages after client disconnect.
176    ///
177    /// By default, messages received before dispatched to the publish service will be dropped if
178    /// the client disconnect immediately.
179    ///
180    /// If this option is set to `Some(QoS::AtMostOnce)`, only the QoS 0 messages received will
181    /// always be handled by the server's publish service no matter if the client is disconnected
182    /// or not.
183    ///
184    /// If this option is set to `Some(QoS::AtLeastOnce)`, only the QoS 0 and QoS 1 messages
185    /// received will always be handled by the server's publish service no matter if the client
186    /// is disconnected or not. The QoS 2 messages will be dropped if the client is disconnected
187    /// before the server dispatches them to the publish service.
188    ///
189    /// If this option is set to `Some(QoS::ExactlyOnce)`, all the messages received will always
190    /// be handled by the server's publish service no matter if the client is disconnected or not.
191    ///
192    /// By default handle-qos-after-disconnect is set to `None`
193    pub fn handle_qos_after_disconnect(mut self, max_qos: Option<QoS>) -> Self {
194        self.handle_qos_after_disconnect = max_qos;
195        self
196    }
197
198    /// Remove all middlewares
199    pub fn reset_middlewares(self) -> MqttServer<St, C, Cn, P, Identity> {
200        MqttServer {
201            middleware: Identity,
202            config: self.config,
203            handshake: self.handshake,
204            srv_publish: self.srv_publish,
205            srv_control: self.srv_control,
206            max_size: self.max_size,
207            max_receive: self.max_receive,
208            max_topic_alias: self.max_topic_alias,
209            min_chunk_size: self.min_chunk_size,
210            max_qos: self.max_qos,
211            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
212            connect_timeout: self.connect_timeout,
213            pool: self.pool,
214            _t: PhantomData,
215        }
216    }
217
218    /// Registers middleware, in the form of a middleware component (type),
219    /// that runs during inbound and/or outbound processing in the request
220    /// lifecycle (request -> response), modifying request/response as
221    /// necessary, across all requests managed by the *Server*.
222    ///
223    /// Use middleware when you need to read or modify *every* request or
224    /// response in some way.
225    pub fn middleware<U>(self, mw: U) -> MqttServer<St, C, Cn, P, Stack<M, U>> {
226        MqttServer {
227            middleware: Stack::new(self.middleware, mw),
228            config: self.config,
229            handshake: self.handshake,
230            srv_publish: self.srv_publish,
231            srv_control: self.srv_control,
232            max_size: self.max_size,
233            max_receive: self.max_receive,
234            max_topic_alias: self.max_topic_alias,
235            max_qos: self.max_qos,
236            min_chunk_size: self.min_chunk_size,
237            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
238            connect_timeout: self.connect_timeout,
239            pool: self.pool,
240            _t: PhantomData,
241        }
242    }
243
244    /// Service to handle control packets
245    ///
246    /// All control packets are processed sequentially, max number of buffered
247    /// control packets is 16.
248    pub fn control<F, Srv>(self, service: F) -> MqttServer<St, C, Srv, P, M>
249    where
250        F: IntoServiceFactory<Srv, Control<C::Error>, Session<St>>,
251        Srv: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
252        C::Error: From<Srv::Error> + From<Srv::InitError>,
253    {
254        MqttServer {
255            config: self.config,
256            handshake: self.handshake,
257            srv_publish: self.srv_publish,
258            srv_control: service.into_factory(),
259            middleware: self.middleware,
260            max_size: self.max_size,
261            max_receive: self.max_receive,
262            max_topic_alias: self.max_topic_alias,
263            max_qos: self.max_qos,
264            min_chunk_size: self.min_chunk_size,
265            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
266            connect_timeout: self.connect_timeout,
267            pool: self.pool,
268            _t: PhantomData,
269        }
270    }
271
272    /// Set service to handle publish packets and create mqtt server factory
273    pub fn publish<F, Srv>(self, publish: F) -> MqttServer<St, C, Cn, Srv, M>
274    where
275        F: IntoServiceFactory<Srv, Publish, Session<St>>,
276        C::Error: From<Srv::Error> + From<Srv::InitError>,
277        Srv: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
278        Srv::Error: fmt::Debug,
279        PublishAck: TryFrom<Srv::Error, Error = C::Error>,
280    {
281        MqttServer {
282            config: self.config,
283            handshake: self.handshake,
284            srv_publish: publish.into_factory(),
285            srv_control: self.srv_control,
286            middleware: self.middleware,
287            max_size: self.max_size,
288            max_receive: self.max_receive,
289            max_topic_alias: self.max_topic_alias,
290            max_qos: self.max_qos,
291            min_chunk_size: self.min_chunk_size,
292            handle_qos_after_disconnect: self.handle_qos_after_disconnect,
293            connect_timeout: self.connect_timeout,
294            pool: self.pool,
295            _t: PhantomData,
296        }
297    }
298}
299
300impl<St, C, Cn, P, M> MqttServer<St, C, Cn, P, M>
301where
302    St: 'static,
303    C: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
304    C::Error: From<Cn::Error>
305        + From<Cn::InitError>
306        + From<P::Error>
307        + From<P::InitError>
308        + fmt::Debug,
309    Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
310    P: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
311    P::Error: fmt::Debug,
312    PublishAck: TryFrom<P::Error, Error = C::Error>,
313{
314    /// Finish server configuration and create mqtt server factory
315    pub fn finish(
316        self,
317    ) -> service::MqttServer<
318        Session<St>,
319        impl ServiceFactory<
320            IoBoxed,
321            Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
322            Error = MqttError<C::Error>,
323            InitError = C::InitError,
324        >,
325        impl ServiceFactory<
326            DispatchItem<Rc<MqttShared>>,
327            Session<St>,
328            Response = Option<mqtt::Encoded>,
329            Error = MqttError<C::Error>,
330            InitError = MqttError<C::Error>,
331        >,
332        M,
333        Rc<MqttShared>,
334    > {
335        service::MqttServer::new(
336            HandshakeFactory {
337                factory: self.handshake,
338                max_size: self.max_size,
339                max_receive: self.max_receive,
340                max_topic_alias: self.max_topic_alias,
341                max_qos: self.max_qos,
342                min_chunk_size: self.min_chunk_size,
343                connect_timeout: self.connect_timeout.into(),
344                pool: self.pool,
345                _t: PhantomData,
346            },
347            factory(self.srv_publish, self.srv_control, self.handle_qos_after_disconnect),
348            self.middleware,
349            self.config,
350        )
351    }
352}
353
354struct HandshakeFactory<St, H> {
355    factory: H,
356    max_size: u32,
357    max_receive: u16,
358    max_topic_alias: u16,
359    max_qos: QoS,
360    min_chunk_size: u32,
361    connect_timeout: Millis,
362    pool: Rc<MqttSinkPool>,
363    _t: PhantomData<St>,
364}
365
366impl<St, H> ServiceFactory<IoBoxed> for HandshakeFactory<St, H>
367where
368    H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
369    H::Error: fmt::Debug,
370{
371    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
372    type Error = MqttError<H::Error>;
373
374    type Service = HandshakeService<St, H::Service>;
375    type InitError = H::InitError;
376
377    async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
378        Ok(HandshakeService {
379            service: self.factory.create(()).await?,
380            max_size: self.max_size,
381            max_receive: self.max_receive,
382            max_topic_alias: self.max_topic_alias,
383            max_qos: self.max_qos,
384            min_chunk_size: self.min_chunk_size,
385            pool: self.pool.clone(),
386            connect_timeout: self.connect_timeout,
387            _t: PhantomData,
388        })
389    }
390}
391
392struct HandshakeService<St, H> {
393    service: H,
394    max_size: u32,
395    max_receive: u16,
396    max_topic_alias: u16,
397    max_qos: QoS,
398    min_chunk_size: u32,
399    connect_timeout: Millis,
400    pool: Rc<MqttSinkPool>,
401    _t: PhantomData<St>,
402}
403
404impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
405where
406    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
407    H::Error: fmt::Debug,
408{
409    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
410    type Error = MqttError<H::Error>;
411
412    ntex_service::forward_ready!(service, MqttError::Service);
413    ntex_service::forward_shutdown!(service);
414
415    async fn call(
416        &self,
417        io: IoBoxed,
418        ctx: ServiceCtx<'_, Self>,
419    ) -> Result<Self::Response, Self::Error> {
420        log::trace!("Starting mqtt v5 handshake");
421
422        let codec = mqtt::Codec::default();
423        codec.set_max_inbound_size(self.max_size);
424        codec.set_min_chunk_size(self.min_chunk_size);
425
426        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone()));
427        shared.set_max_qos(self.max_qos);
428        shared.set_receive_max(self.max_receive);
429        shared.set_topic_alias_max(self.max_topic_alias);
430
431        // read first packet
432        let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec))
433            .await
434            .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
435            .map_err(|err| {
436                log::trace!("Error is received during mqtt handshake: {:?}", err);
437                MqttError::Handshake(HandshakeError::from(err))
438            })?
439            .ok_or_else(|| {
440                log::trace!("Server mqtt is disconnected during handshake");
441                MqttError::Handshake(HandshakeError::Disconnected(None))
442            })?;
443
444        match packet {
445            Decoded::Packet(Packet::Connect(connect), size) => {
446                // set max outbound (encoder) packet size
447                if let Some(size) = connect.max_packet_size {
448                    shared.codec.set_max_outbound_size(size.get());
449                }
450                let keep_alive = connect.keep_alive;
451                let peer_receive_max =
452                    connect.receive_max.map(|v| v.get()).unwrap_or(16) as usize;
453
454                // authenticate mqtt connection
455                let mut ack = ctx
456                    .call(&self.service, Handshake::new(connect, size, io, shared))
457                    .await
458                    .map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))?;
459
460                match ack.session {
461                    Some(session) => {
462                        log::trace!("Sending: {:#?}", ack.packet);
463                        let shared = ack.shared;
464
465                        shared.set_max_qos(ack.packet.max_qos);
466                        shared.set_receive_max(ack.packet.receive_max.get());
467                        shared.set_topic_alias_max(ack.packet.topic_alias_max);
468                        shared
469                            .codec
470                            .set_max_inbound_size(ack.packet.max_packet_size.unwrap_or(0));
471                        shared.codec.set_retain_available(ack.packet.retain_available);
472                        shared.codec.set_sub_ids_available(
473                            ack.packet.subscription_identifiers_available,
474                        );
475                        if ack.packet.server_keepalive_sec.is_none()
476                            && (keep_alive > ack.keepalive)
477                        {
478                            ack.packet.server_keepalive_sec = Some(ack.keepalive);
479                        }
480                        shared.set_cap(peer_receive_max);
481
482                        ack.io.encode(
483                            Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
484                            &shared.codec,
485                        )?;
486
487                        Ok((
488                            ack.io,
489                            shared.clone(),
490                            Session::new(session, MqttSink::new(shared)),
491                            Seconds(ack.keepalive),
492                        ))
493                    }
494                    None => {
495                        log::trace!("Failed to complete handshake: {:#?}", ack.packet);
496
497                        ack.io.encode(
498                            Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
499                            &ack.shared.codec,
500                        )?;
501                        let _ = ack.io.shutdown().await;
502                        Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
503                    }
504                }
505            }
506            Decoded::Packet(packet, _) => {
507                log::info!(
508                    "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
509                    packet.packet_type()
510                );
511                Err(MqttError::Handshake(HandshakeError::Protocol(
512                    ProtocolError::unexpected_packet(
513                        packet.packet_type(),
514                        "Expected CONNECT packet [MQTT-3.1.0-1]",
515                    ),
516                )))
517            }
518            Decoded::Publish(..) => {
519                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
520                Err(MqttError::Handshake(HandshakeError::Protocol(
521                    ProtocolError::unexpected_packet(
522                        crate::types::packet_type::PUBLISH_START,
523                        "Expected CONNECT packet [MQTT-3.1.0-1]",
524                    ),
525                )))
526            }
527            Decoded::PayloadChunk(..) => unreachable!(),
528        }
529    }
530}