ntex_mqtt/v5/
server.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
6use ntex_util::time::{Seconds, timeout_checked};
7
8use crate::error::{HandshakeError, MqttError, ProtocolError};
9use crate::{MqttServiceConfig, service};
10
11use super::codec::{self as mqtt, Decoded, Encoded, Packet};
12use super::control::{Control, ControlAck};
13use super::default::{DefaultControlService, InFlightService};
14use super::handshake::{Handshake, HandshakeAck};
15use super::publish::{Publish, PublishAck};
16use super::shared::{MqttShared, MqttSinkPool};
17use super::{MqttSink, Session, dispatcher::factory};
18
19/// Mqtt Server
20pub struct MqttServer<St, C, Cn, M = Identity> {
21    handshake: C,
22    control: Cn,
23    middleware: M,
24    pub(super) pool: Rc<MqttSinkPool>,
25    _t: PhantomData<St>,
26}
27
28impl<St, C> MqttServer<St, C, DefaultControlService<St, C::Error>, InFlightService>
29where
30    C: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>>,
31    C::Error: fmt::Debug,
32{
33    /// Create server factory and provide handshake service
34    pub fn new<F>(handshake: F) -> Self
35    where
36        F: IntoServiceFactory<C, Handshake, SharedCfg>,
37    {
38        MqttServer {
39            handshake: handshake.into_factory(),
40            control: DefaultControlService::default(),
41            middleware: InFlightService,
42            pool: Rc::new(MqttSinkPool::default()),
43            _t: PhantomData,
44        }
45    }
46}
47
48impl<St, C, Cn, M> MqttServer<St, C, Cn, M>
49where
50    St: 'static,
51    C: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
52    C::Error: fmt::Debug,
53    Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
54{
55    /// Remove all middlewares
56    pub fn reset_middlewares(self) -> MqttServer<St, C, Cn, Identity> {
57        MqttServer {
58            middleware: Identity,
59            control: self.control,
60            handshake: self.handshake,
61            pool: self.pool,
62            _t: PhantomData,
63        }
64    }
65
66    /// Registers middleware, in the form of a middleware component (type),
67    /// that runs during inbound and/or outbound processing in the request
68    /// lifecycle (request -> response), modifying request/response as
69    /// necessary, across all requests managed by the *Server*.
70    ///
71    /// Use middleware when you need to read or modify *every* request or
72    /// response in some way.
73    pub fn middleware<U>(self, mw: U) -> MqttServer<St, C, Cn, Stack<M, U>> {
74        MqttServer {
75            middleware: Stack::new(self.middleware, mw),
76            handshake: self.handshake,
77            control: self.control,
78            pool: self.pool,
79            _t: PhantomData,
80        }
81    }
82
83    /// Service to handle control packets
84    ///
85    /// All control packets are processed sequentially, max number of buffered
86    /// control packets is 16.
87    pub fn control<F, Srv>(self, service: F) -> MqttServer<St, C, Srv, M>
88    where
89        F: IntoServiceFactory<Srv, Control<C::Error>, Session<St>>,
90        Srv: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
91        C::Error: From<Srv::Error> + From<Srv::InitError>,
92    {
93        MqttServer {
94            handshake: self.handshake,
95            control: service.into_factory(),
96            middleware: self.middleware,
97            pool: self.pool,
98            _t: PhantomData,
99        }
100    }
101
102    /// Set service to handle publish packets and create mqtt server factory
103    /// and create mqtt server factory
104    pub fn publish<F, Srv>(
105        self,
106        publish: F,
107    ) -> service::MqttServer<
108        Session<St>,
109        impl ServiceFactory<
110            IoBoxed,
111            SharedCfg,
112            Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
113            Error = MqttError<C::Error>,
114            InitError = C::InitError,
115        >,
116        impl ServiceFactory<
117            DispatchItem<Rc<MqttShared>>,
118            (SharedCfg, Session<St>),
119            Response = Option<mqtt::Encoded>,
120            Error = MqttError<C::Error>,
121            InitError = MqttError<C::Error>,
122        >,
123        M,
124        Rc<MqttShared>,
125    >
126    where
127        F: IntoServiceFactory<Srv, Publish, Session<St>>,
128        C::Error:
129            From<Cn::Error> + From<Cn::InitError> + From<Srv::Error> + From<Srv::InitError>,
130        Srv: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
131        Srv::Error: fmt::Debug,
132        PublishAck: TryFrom<Srv::Error, Error = C::Error>,
133    {
134        service::MqttServer::new(
135            HandshakeFactory { factory: self.handshake, pool: self.pool, _t: PhantomData },
136            factory(publish.into_factory(), self.control),
137            self.middleware,
138        )
139    }
140}
141
142struct HandshakeFactory<St, H> {
143    factory: H,
144    pool: Rc<MqttSinkPool>,
145    _t: PhantomData<St>,
146}
147
148impl<St, H> ServiceFactory<IoBoxed, SharedCfg> for HandshakeFactory<St, H>
149where
150    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
151    H::Error: fmt::Debug,
152{
153    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
154    type Error = MqttError<H::Error>;
155
156    type Service = HandshakeService<St, H::Service>;
157    type InitError = H::InitError;
158
159    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
160        Ok(HandshakeService {
161            cfg: cfg.get(),
162            service: self.factory.create(cfg).await?,
163            pool: self.pool.clone(),
164            _t: PhantomData,
165        })
166    }
167}
168
169struct HandshakeService<St, H> {
170    service: H,
171    cfg: Cfg<MqttServiceConfig>,
172    pool: Rc<MqttSinkPool>,
173    _t: PhantomData<St>,
174}
175
176impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
177where
178    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
179    H::Error: fmt::Debug,
180{
181    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
182    type Error = MqttError<H::Error>;
183
184    ntex_service::forward_ready!(service, MqttError::Service);
185    ntex_service::forward_poll!(service, MqttError::Service);
186    ntex_service::forward_shutdown!(service);
187
188    async fn call(
189        &self,
190        io: IoBoxed,
191        ctx: ServiceCtx<'_, Self>,
192    ) -> Result<Self::Response, Self::Error> {
193        log::trace!("Starting mqtt v5 handshake");
194
195        let codec = mqtt::Codec::default();
196        codec.set_max_inbound_size(self.cfg.max_size);
197        codec.set_min_chunk_size(self.cfg.min_chunk_size);
198
199        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone()));
200        shared.set_max_qos(self.cfg.max_qos);
201        shared.set_receive_max(self.cfg.max_receive);
202        shared.set_topic_alias_max(self.cfg.max_topic_alias);
203
204        // read first packet
205        let packet = timeout_checked(self.cfg.connect_timeout, io.recv(&shared.codec))
206            .await
207            .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
208            .map_err(|err| {
209                log::trace!("Error is received during mqtt handshake: {:?}", err);
210                MqttError::Handshake(HandshakeError::from(err))
211            })?
212            .ok_or_else(|| {
213                log::trace!("Server mqtt is disconnected during handshake");
214                MqttError::Handshake(HandshakeError::Disconnected(None))
215            })?;
216
217        match packet {
218            Decoded::Packet(Packet::Connect(connect), size) => {
219                // set max outbound (encoder) packet size
220                if let Some(size) = connect.max_packet_size {
221                    shared.codec.set_max_outbound_size(size.get());
222                }
223                let keep_alive = connect.keep_alive;
224                let peer_receive_max =
225                    connect.receive_max.map(|v| v.get()).unwrap_or(16) as usize;
226
227                // authenticate mqtt connection
228                let mut ack = ctx
229                    .call(&self.service, Handshake::new(connect, size, io, shared))
230                    .await
231                    .map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))?;
232
233                match ack.session {
234                    Some(session) => {
235                        log::trace!("Sending: {:#?}", ack.packet);
236                        let shared = ack.shared;
237
238                        shared.set_max_qos(ack.packet.max_qos);
239                        shared.set_receive_max(ack.packet.receive_max.get());
240                        shared.set_topic_alias_max(ack.packet.topic_alias_max);
241                        shared
242                            .codec
243                            .set_max_inbound_size(ack.packet.max_packet_size.unwrap_or(0));
244                        shared.codec.set_retain_available(ack.packet.retain_available);
245                        shared.codec.set_sub_ids_available(
246                            ack.packet.subscription_identifiers_available,
247                        );
248                        if ack.packet.server_keepalive_sec.is_none()
249                            && (keep_alive > ack.keepalive)
250                        {
251                            ack.packet.server_keepalive_sec = Some(ack.keepalive);
252                        }
253                        shared.set_cap(peer_receive_max);
254
255                        ack.io.encode(
256                            Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
257                            &shared.codec,
258                        )?;
259
260                        Ok((
261                            ack.io,
262                            shared.clone(),
263                            Session::new(session, MqttSink::new(shared)),
264                            Seconds(ack.keepalive),
265                        ))
266                    }
267                    None => {
268                        log::trace!("Failed to complete handshake: {:#?}", ack.packet);
269
270                        ack.io.encode(
271                            Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
272                            &ack.shared.codec,
273                        )?;
274                        let _ = ack.io.shutdown().await;
275                        Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
276                    }
277                }
278            }
279            Decoded::Packet(packet, _) => {
280                log::info!(
281                    "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
282                    packet.packet_type()
283                );
284                Err(MqttError::Handshake(HandshakeError::Protocol(
285                    ProtocolError::unexpected_packet(
286                        packet.packet_type(),
287                        "Expected CONNECT packet [MQTT-3.1.0-1]",
288                    ),
289                )))
290            }
291            Decoded::Publish(..) => {
292                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
293                Err(MqttError::Handshake(HandshakeError::Protocol(
294                    ProtocolError::unexpected_packet(
295                        crate::types::packet_type::PUBLISH_START,
296                        "Expected CONNECT packet [MQTT-3.1.0-1]",
297                    ),
298                )))
299            }
300            Decoded::PayloadChunk(..) => unreachable!(),
301        }
302    }
303}