Skip to main content

ntex_mqtt/
server.rs

1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_dispatcher::DispatchItem;
5use ntex_io::{Filter, Io, IoBoxed};
6use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
7use ntex_util::future::{Either, join, select};
8use ntex_util::time::{Deadline, Seconds};
9
10use crate::version::{ProtocolVersion, VersionCodec};
11use crate::{MqttServiceConfig, error::HandshakeError, error::MqttError, service};
12
13/// Mqtt Server
14pub struct MqttServer<V3, V5, Err, InitErr> {
15    svc_v3: V3,
16    svc_v5: V5,
17    _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<V3, V5, Err, InitErr> fmt::Debug for MqttServer<V3, V5, Err, InitErr> {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_struct("MqttServer").finish()
23    }
24}
25
26impl<Err, InitErr>
27    MqttServer<
28        DefaultProtocolServer<Err, InitErr>,
29        DefaultProtocolServer<Err, InitErr>,
30        Err,
31        InitErr,
32    >
33{
34    /// Create mqtt server
35    pub fn new() -> Self {
36        MqttServer {
37            svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
38            svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
39            _t: marker::PhantomData,
40        }
41    }
42}
43
44impl<Err, InitErr> Default
45    for MqttServer<
46        DefaultProtocolServer<Err, InitErr>,
47        DefaultProtocolServer<Err, InitErr>,
48        Err,
49        InitErr,
50    >
51{
52    fn default() -> Self {
53        MqttServer::new()
54    }
55}
56
57impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
58where
59    Err: fmt::Debug,
60    V3: ServiceFactory<
61            IoBoxed,
62            SharedCfg,
63            Response = (),
64            Error = MqttError<Err>,
65            InitError = InitErr,
66        >,
67    V5: ServiceFactory<
68            IoBoxed,
69            SharedCfg,
70            Response = (),
71            Error = MqttError<Err>,
72            InitError = InitErr,
73        >,
74{
75    /// Service to handle v3 protocol
76    pub fn v3<St, H, P, M, Codec>(
77        self,
78        service: service::MqttServer<St, H, P, M, Codec>,
79    ) -> MqttServer<
80        impl ServiceFactory<
81            IoBoxed,
82            SharedCfg,
83            Response = (),
84            Error = MqttError<Err>,
85            InitError = InitErr,
86        >,
87        V5,
88        Err,
89        InitErr,
90    >
91    where
92        St: 'static,
93        H: ServiceFactory<
94                IoBoxed,
95                SharedCfg,
96                Response = (IoBoxed, Codec, St, Seconds),
97                Error = MqttError<Err>,
98                InitError = InitErr,
99            > + 'static,
100        P: ServiceFactory<
101                DispatchItem<Codec>,
102                (SharedCfg, St),
103                Response = Option<<Codec as Encoder>::Item>,
104                Error = MqttError<Err>,
105                InitError = MqttError<Err>,
106            > + 'static,
107        M: Middleware<P::Service, SharedCfg>,
108        M::Service: Service<
109                DispatchItem<Codec>,
110                Response = Option<<Codec as Encoder>::Item>,
111                Error = MqttError<Err>,
112            > + 'static,
113        Codec: Encoder + Decoder + Clone + 'static,
114    {
115        MqttServer { svc_v3: service, svc_v5: self.svc_v5, _t: marker::PhantomData }
116    }
117
118    /// Service to handle v5 protocol
119    pub fn v5<St, H, P, M, Codec>(
120        self,
121        service: service::MqttServer<St, H, P, M, Codec>,
122    ) -> MqttServer<
123        V3,
124        impl ServiceFactory<
125            IoBoxed,
126            SharedCfg,
127            Response = (),
128            Error = MqttError<Err>,
129            InitError = InitErr,
130        >,
131        Err,
132        InitErr,
133    >
134    where
135        St: 'static,
136        H: ServiceFactory<
137                IoBoxed,
138                SharedCfg,
139                Response = (IoBoxed, Codec, St, Seconds),
140                Error = MqttError<Err>,
141                InitError = InitErr,
142            > + 'static,
143        P: ServiceFactory<
144                DispatchItem<Codec>,
145                (SharedCfg, St),
146                Response = Option<<Codec as Encoder>::Item>,
147                Error = MqttError<Err>,
148                InitError = MqttError<Err>,
149            > + 'static,
150        M: Middleware<P::Service, SharedCfg>,
151        M::Service: Service<
152                DispatchItem<Codec>,
153                Response = Option<<Codec as Encoder>::Item>,
154                Error = MqttError<Err>,
155            > + 'static,
156        Codec: Encoder + Decoder + Clone + 'static,
157    {
158        MqttServer { svc_v3: self.svc_v3, svc_v5: service, _t: marker::PhantomData }
159    }
160}
161
162impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
163where
164    V3: ServiceFactory<
165            IoBoxed,
166            SharedCfg,
167            Response = (),
168            Error = MqttError<Err>,
169            InitError = InitErr,
170        >,
171    V5: ServiceFactory<
172            IoBoxed,
173            SharedCfg,
174            Response = (),
175            Error = MqttError<Err>,
176            InitError = InitErr,
177        >,
178{
179    async fn create_service(
180        &self,
181        cfg: SharedCfg,
182    ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
183        let (v3, v5) =
184            join(self.svc_v3.create(cfg.clone()), self.svc_v5.create(cfg.clone())).await;
185        let v3 = v3?;
186        let v5 = v5?;
187        Ok(MqttServerImpl { handlers: (v3, v5), cfg: cfg.get(), _t: marker::PhantomData })
188    }
189}
190
191impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed, SharedCfg>
192    for MqttServer<V3, V5, Err, InitErr>
193where
194    V3: ServiceFactory<
195            IoBoxed,
196            SharedCfg,
197            Response = (),
198            Error = MqttError<Err>,
199            InitError = InitErr,
200        > + 'static,
201    V5: ServiceFactory<
202            IoBoxed,
203            SharedCfg,
204            Response = (),
205            Error = MqttError<Err>,
206            InitError = InitErr,
207        > + 'static,
208    Err: 'static,
209    InitErr: 'static,
210{
211    type Response = ();
212    type Error = MqttError<Err>;
213    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
214    type InitError = InitErr;
215
216    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
217        self.create_service(cfg).await
218    }
219}
220
221impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>, SharedCfg>
222    for MqttServer<V3, V5, Err, InitErr>
223where
224    F: Filter,
225    V3: ServiceFactory<
226            IoBoxed,
227            SharedCfg,
228            Response = (),
229            Error = MqttError<Err>,
230            InitError = InitErr,
231        > + 'static,
232    V5: ServiceFactory<
233            IoBoxed,
234            SharedCfg,
235            Response = (),
236            Error = MqttError<Err>,
237            InitError = InitErr,
238        > + 'static,
239    Err: 'static,
240    InitErr: 'static,
241{
242    type Response = ();
243    type Error = MqttError<Err>;
244    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
245    type InitError = InitErr;
246
247    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
248        self.create_service(cfg).await
249    }
250}
251
252/// Mqtt Server
253pub struct MqttServerImpl<V3, V5, Err> {
254    handlers: (V3, V5),
255    cfg: Cfg<MqttServiceConfig>,
256    _t: marker::PhantomData<Err>,
257}
258
259impl<V3, V5, Err> fmt::Debug for MqttServerImpl<V3, V5, Err> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        f.debug_struct("MqttServerImpl").finish()
262    }
263}
264
265impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
266where
267    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
268    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
269{
270    type Response = ();
271    type Error = MqttError<Err>;
272
273    #[inline]
274    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
275        let (ready1, ready2) =
276            join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
277        ready1?;
278        ready2
279    }
280
281    #[inline]
282    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
283        self.handlers.0.poll(cx)?;
284        self.handlers.1.poll(cx)
285    }
286
287    #[inline]
288    async fn shutdown(&self) {
289        self.handlers.0.shutdown().await;
290        self.handlers.1.shutdown().await;
291    }
292
293    #[inline]
294    async fn call(
295        &self,
296        io: IoBoxed,
297        ctx: ServiceCtx<'_, Self>,
298    ) -> Result<Self::Response, Self::Error> {
299        // try to read Version, buffer may already contain info
300        let res = io
301            .decode(&VersionCodec)
302            .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
303        if let Some(ver) = res {
304            match ver {
305                ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
306                ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
307            }
308        } else {
309            let fut = async {
310                match io.recv(&VersionCodec).await {
311                    Ok(ver) => Ok(ver),
312                    Err(Either::Left(e)) => {
313                        Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
314                    }
315                    Err(Either::Right(e)) => {
316                        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
317                    }
318                }
319            };
320
321            match select(&mut Deadline::new(self.cfg.protocol_version_timeout), fut).await {
322                Either::Left(()) => Err(MqttError::Handshake(HandshakeError::Timeout)),
323                Either::Right(Ok(Some(ver))) => match ver {
324                    ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
325                    ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
326                },
327                Either::Right(Ok(None)) => {
328                    Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
329                }
330                Either::Right(Err(e)) => Err(e),
331            }
332        }
333    }
334}
335
336impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
337where
338    F: Filter,
339    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
340    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
341{
342    type Response = ();
343    type Error = MqttError<Err>;
344
345    #[inline]
346    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
347        Service::<IoBoxed>::ready(self, ctx).await
348    }
349
350    #[inline]
351    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
352        Service::<IoBoxed>::poll(self, cx)
353    }
354
355    #[inline]
356    async fn shutdown(&self) {
357        Service::<IoBoxed>::shutdown(self).await;
358    }
359
360    #[inline]
361    async fn call(
362        &self,
363        io: Io<F>,
364        ctx: ServiceCtx<'_, Self>,
365    ) -> Result<Self::Response, Self::Error> {
366        Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
367    }
368}
369
370pub struct DefaultProtocolServer<Err, InitErr> {
371    ver: ProtocolVersion,
372    _t: marker::PhantomData<(Err, InitErr)>,
373}
374
375impl<Err, InitErr> fmt::Debug for DefaultProtocolServer<Err, InitErr> {
376    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377        f.debug_struct("DefaultProtocolServer").field("ver", &self.ver).finish()
378    }
379}
380
381impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
382    fn new(ver: ProtocolVersion) -> Self {
383        Self { ver, _t: marker::PhantomData }
384    }
385}
386
387impl<Err, InitErr> ServiceFactory<IoBoxed, SharedCfg> for DefaultProtocolServer<Err, InitErr> {
388    type Response = ();
389    type Error = MqttError<Err>;
390    type Service = DefaultProtocolServer<Err, InitErr>;
391    type InitError = InitErr;
392
393    async fn create(&self, _: SharedCfg) -> Result<Self::Service, Self::InitError> {
394        Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
395    }
396}
397
398impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
399    type Response = ();
400    type Error = MqttError<Err>;
401
402    async fn call(
403        &self,
404        _: IoBoxed,
405        _: ServiceCtx<'_, Self>,
406    ) -> Result<Self::Response, Self::Error> {
407        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::other(
408            format!("Protocol is not supported: {:?}", self.ver),
409        )))))
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_debug() {
419        // Use the default constructor which fills in all type params automatically
420        let server = <MqttServer<_, _, (), ()>>::default();
421        assert!(format!("{server:?}").contains("MqttServer"));
422    }
423}