ntex_mqtt/v3/
server.rs

1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
6use ntex_util::time::{Seconds, timeout_checked};
7
8use crate::error::{HandshakeError, MqttError, ProtocolError};
9use crate::{MqttServiceConfig, service};
10
11use super::control::{Control, ControlAck};
12use super::default::{DefaultControlService, InFlightService};
13use super::handshake::{Handshake, HandshakeAck};
14use super::shared::{MqttShared, MqttSinkPool};
15use super::{MqttSink, Publish, Session, codec as mqtt, dispatcher::factory};
16
17/// Mqtt v3.1.1 server
18///
19/// `St` - connection state
20/// `H` - handshake service
21/// `C` - service for handling control messages
22/// `P` - service for handling publish
23///
24/// Every mqtt connection is handled in several steps. First step is handshake. Server calls
25/// handshake service with `Handshake` message, during this step service can authenticate connect
26/// packet, it must return instance of connection state `St`.
27///
28/// Handshake service could be expressed as simple function:
29///
30/// ```rust,ignore
31/// use ntex_mqtt::v3::{Handshake, HandshakeAck};
32///
33/// async fn handshake(hnd: Handshake) -> Result<HandshakeAkc<MyState>, MyError> {
34///     Ok(hnd.ack(MyState::new(), false))
35/// }
36/// ```
37///
38/// During next stage, control and publish services get constructed,
39/// both factories receive `Session<St>` state object as an argument. Publish service
40/// handles `Publish` packet. On success, server server sends `PublishAck` packet to
41/// the client, in case of error connection get closed. Control service receives all
42/// other packets, like `Subscribe`, `Unsubscribe` etc. Also control service receives
43/// errors from publish service and connection disconnect.
44pub struct MqttServer<St, H, C, M = Identity> {
45    handshake: H,
46    control: C,
47    middleware: M,
48    pub(super) pool: Rc<MqttSinkPool>,
49    _t: PhantomData<St>,
50}
51
52impl<St, H> MqttServer<St, H, DefaultControlService<St, H::Error>, InFlightService>
53where
54    St: 'static,
55    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
56    H::Error: fmt::Debug,
57{
58    /// Create server factory and provide handshake service
59    pub fn new<F>(handshake: F) -> Self
60    where
61        F: IntoServiceFactory<H, Handshake, SharedCfg>,
62    {
63        MqttServer {
64            handshake: handshake.into_factory(),
65            control: DefaultControlService::default(),
66            middleware: InFlightService,
67            pool: Default::default(),
68            _t: PhantomData,
69        }
70    }
71}
72
73impl<St, H, C, M> MqttServer<St, H, C, M>
74where
75    St: 'static,
76    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
77    C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
78    H::Error: From<C::Error> + From<C::InitError> + fmt::Debug,
79{
80    /// Service to handle control packets
81    ///
82    /// All control packets are processed sequentially, max number of buffered
83    /// control packets is 16.
84    pub fn control<F, Srv>(self, service: F) -> MqttServer<St, H, Srv, M>
85    where
86        F: IntoServiceFactory<Srv, Control<H::Error>, Session<St>>,
87        Srv: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
88        H::Error: From<Srv::Error> + From<Srv::InitError>,
89    {
90        MqttServer {
91            handshake: self.handshake,
92            control: service.into_factory(),
93            middleware: self.middleware,
94            pool: self.pool,
95            _t: PhantomData,
96        }
97    }
98
99    /// Registers middleware, in the form of a middleware component (type),
100    /// that runs during inbound and/or outbound processing in the request
101    /// lifecycle (request -> response), modifying request/response as
102    /// necessary, across all requests managed by the *Server*.
103    ///
104    /// Use middleware when you need to read or modify *every* request or
105    /// response in some way.
106    pub fn middleware<U>(self, mw: U) -> MqttServer<St, H, C, Stack<M, U>> {
107        MqttServer {
108            middleware: Stack::new(self.middleware, mw),
109            handshake: self.handshake,
110            control: self.control,
111            pool: self.pool,
112            _t: PhantomData,
113        }
114    }
115
116    /// Replace middlewares
117    pub fn replace_middlewares<U>(self, mw: U) -> MqttServer<St, H, C, U> {
118        MqttServer {
119            middleware: mw,
120            handshake: self.handshake,
121            control: self.control,
122            pool: self.pool,
123            _t: PhantomData,
124        }
125    }
126
127    /// Set service to handle publish packets and create mqtt server factory
128    pub fn publish<F, Srv>(
129        self,
130        publish: F,
131    ) -> service::MqttServer<
132        Session<St>,
133        impl ServiceFactory<
134            IoBoxed,
135            SharedCfg,
136            Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
137            Error = MqttError<H::Error>,
138            InitError = H::InitError,
139        >,
140        impl ServiceFactory<
141            DispatchItem<Rc<MqttShared>>,
142            (SharedCfg, Session<St>),
143            Response = Option<mqtt::Encoded>,
144            Error = MqttError<H::Error>,
145            InitError = MqttError<H::Error>,
146        >,
147        M,
148        Rc<MqttShared>,
149    >
150    where
151        H::Error: From<C::Error>
152            + From<C::InitError>
153            + From<Srv::Error>
154            + From<Srv::InitError>
155            + fmt::Debug,
156        F: IntoServiceFactory<Srv, Publish, Session<St>>,
157        Srv: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
158        H::Error: From<Srv::Error> + From<Srv::InitError> + fmt::Debug,
159    {
160        service::MqttServer::new(
161            HandshakeFactory {
162                factory: self.handshake,
163                pool: self.pool.clone(),
164                _t: PhantomData,
165            },
166            factory(publish.into_factory(), self.control),
167            self.middleware,
168        )
169    }
170}
171
172struct HandshakeFactory<St, H> {
173    factory: H,
174    pool: Rc<MqttSinkPool>,
175    _t: PhantomData<St>,
176}
177
178impl<St, H> ServiceFactory<IoBoxed, SharedCfg> for HandshakeFactory<St, H>
179where
180    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
181    H::Error: fmt::Debug,
182{
183    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
184    type Error = MqttError<H::Error>;
185
186    type Service = HandshakeService<St, H::Service>;
187    type InitError = H::InitError;
188
189    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
190        Ok(HandshakeService {
191            cfg: cfg.get(),
192            pool: self.pool.clone(),
193            service: self.factory.create(cfg).await?,
194            _t: PhantomData,
195        })
196    }
197}
198
199struct HandshakeService<St, H> {
200    service: H,
201    cfg: Cfg<MqttServiceConfig>,
202    pool: Rc<MqttSinkPool>,
203    _t: PhantomData<St>,
204}
205
206impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
207where
208    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
209    H::Error: fmt::Debug,
210{
211    type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
212    type Error = MqttError<H::Error>;
213
214    ntex_service::forward_ready!(service, MqttError::Service);
215    ntex_service::forward_poll!(service, MqttError::Service);
216    ntex_service::forward_shutdown!(service);
217
218    async fn call(
219        &self,
220        io: IoBoxed,
221        ctx: ServiceCtx<'_, Self>,
222    ) -> Result<Self::Response, Self::Error> {
223        log::trace!("Starting mqtt v3 handshake");
224
225        let codec = mqtt::Codec::default();
226        codec.set_max_size(self.cfg.max_size);
227        codec.set_min_chunk_size(self.cfg.min_chunk_size);
228        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone()));
229
230        // read first packet
231        let packet = timeout_checked(self.cfg.connect_timeout, io.recv(&shared.codec))
232            .await
233            .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
234            .map_err(|err| {
235                log::trace!("Error is received during mqtt handshake: {:?}", err);
236                MqttError::Handshake(HandshakeError::from(err))
237            })?
238            .ok_or_else(|| {
239                log::trace!("Server mqtt is disconnected during handshake");
240                MqttError::Handshake(HandshakeError::Disconnected(None))
241            })?;
242
243        match packet {
244            mqtt::Decoded::Packet(mqtt::Packet::Connect(connect), size) => {
245                // authenticate mqtt connection
246                let ack = ctx
247                    .call(&self.service, Handshake::new(connect, size, io, shared))
248                    .await
249                    .map_err(MqttError::Service)?;
250
251                match ack.session {
252                    Some(session) => {
253                        let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
254                            session_present: ack.session_present,
255                            return_code: mqtt::ConnectAckReason::ConnectionAccepted,
256                        });
257
258                        log::trace!("Sending success handshake ack: {:#?}", pkt);
259
260                        ack.shared.set_cap(ack.max_send.unwrap_or(self.cfg.max_send) as usize);
261                        if let Some(max_packet_size) = ack.max_packet_size {
262                            ack.shared.codec.set_max_size(max_packet_size.get());
263                        }
264                        ack.io.encode(mqtt::Encoded::Packet(pkt), &ack.shared.codec)?;
265                        Ok((
266                            ack.io,
267                            ack.shared.clone(),
268                            Session::new(session, MqttSink::new(ack.shared)),
269                            ack.keepalive,
270                        ))
271                    }
272                    None => {
273                        let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
274                            session_present: false,
275                            return_code: ack.return_code,
276                        });
277
278                        log::trace!("Sending failed handshake ack: {:#?}", pkt);
279                        ack.io.encode(mqtt::Encoded::Packet(pkt), &ack.shared.codec)?;
280                        let _ = ack.io.shutdown().await;
281
282                        Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
283                    }
284                }
285            }
286            mqtt::Decoded::Packet(packet, _) => {
287                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet);
288                Err(MqttError::Handshake(HandshakeError::Protocol(
289                    ProtocolError::unexpected_packet(
290                        packet.packet_type(),
291                        "MQTT-3.1.0-1: Expected CONNECT packet",
292                    ),
293                )))
294            }
295            mqtt::Decoded::Publish(..) => {
296                log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
297                Err(MqttError::Handshake(HandshakeError::Protocol(
298                    ProtocolError::unexpected_packet(
299                        crate::types::packet_type::PUBLISH_START,
300                        "Expected CONNECT packet [MQTT-3.1.0-1]",
301                    ),
302                )))
303            }
304            mqtt::Decoded::PayloadChunk(..) => unreachable!(),
305        }
306    }
307}