ntex_mqtt/
server.rs

1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_io::{DispatchItem, Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{Middleware2, Service, ServiceCtx, ServiceFactory};
7use ntex_util::future::{Either, join, select};
8use ntex_util::time::{Deadline, Seconds};
9
10use crate::version::{ProtocolVersion, VersionCodec};
11use crate::{MqttServiceConfig, error::HandshakeError, error::MqttError, service};
12
13/// Mqtt Server
14pub struct MqttServer<V3, V5, Err, InitErr> {
15    svc_v3: V3,
16    svc_v5: V5,
17    _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<Err, InitErr>
21    MqttServer<
22        DefaultProtocolServer<Err, InitErr>,
23        DefaultProtocolServer<Err, InitErr>,
24        Err,
25        InitErr,
26    >
27{
28    /// Create mqtt server
29    pub fn new() -> Self {
30        MqttServer {
31            svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
32            svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
33            _t: marker::PhantomData,
34        }
35    }
36}
37
38impl<Err, InitErr> Default
39    for MqttServer<
40        DefaultProtocolServer<Err, InitErr>,
41        DefaultProtocolServer<Err, InitErr>,
42        Err,
43        InitErr,
44    >
45{
46    fn default() -> Self {
47        MqttServer::new()
48    }
49}
50
51impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
52where
53    Err: fmt::Debug,
54    V3: ServiceFactory<
55            IoBoxed,
56            SharedCfg,
57            Response = (),
58            Error = MqttError<Err>,
59            InitError = InitErr,
60        >,
61    V5: ServiceFactory<
62            IoBoxed,
63            SharedCfg,
64            Response = (),
65            Error = MqttError<Err>,
66            InitError = InitErr,
67        >,
68{
69    /// Service to handle v3 protocol
70    pub fn v3<St, H, P, M, Codec>(
71        self,
72        service: service::MqttServer<St, H, P, M, Codec>,
73    ) -> MqttServer<
74        impl ServiceFactory<
75            IoBoxed,
76            SharedCfg,
77            Response = (),
78            Error = MqttError<Err>,
79            InitError = InitErr,
80        >,
81        V5,
82        Err,
83        InitErr,
84    >
85    where
86        St: 'static,
87        H: ServiceFactory<
88                IoBoxed,
89                SharedCfg,
90                Response = (IoBoxed, Codec, St, Seconds),
91                Error = MqttError<Err>,
92                InitError = InitErr,
93            > + 'static,
94        P: ServiceFactory<
95                DispatchItem<Codec>,
96                (SharedCfg, St),
97                Response = Option<<Codec as Encoder>::Item>,
98                Error = MqttError<Err>,
99                InitError = MqttError<Err>,
100            > + 'static,
101        M: Middleware2<P::Service, SharedCfg>,
102        M::Service: Service<
103                DispatchItem<Codec>,
104                Response = Option<<Codec as Encoder>::Item>,
105                Error = MqttError<Err>,
106            > + 'static,
107        Codec: Encoder + Decoder + Clone + 'static,
108    {
109        MqttServer { svc_v3: service, svc_v5: self.svc_v5, _t: marker::PhantomData }
110    }
111
112    /// Service to handle v5 protocol
113    pub fn v5<St, H, P, M, Codec>(
114        self,
115        service: service::MqttServer<St, H, P, M, Codec>,
116    ) -> MqttServer<
117        V3,
118        impl ServiceFactory<
119            IoBoxed,
120            SharedCfg,
121            Response = (),
122            Error = MqttError<Err>,
123            InitError = InitErr,
124        >,
125        Err,
126        InitErr,
127    >
128    where
129        St: 'static,
130        H: ServiceFactory<
131                IoBoxed,
132                SharedCfg,
133                Response = (IoBoxed, Codec, St, Seconds),
134                Error = MqttError<Err>,
135                InitError = InitErr,
136            > + 'static,
137        P: ServiceFactory<
138                DispatchItem<Codec>,
139                (SharedCfg, St),
140                Response = Option<<Codec as Encoder>::Item>,
141                Error = MqttError<Err>,
142                InitError = MqttError<Err>,
143            > + 'static,
144        M: Middleware2<P::Service, SharedCfg>,
145        M::Service: Service<
146                DispatchItem<Codec>,
147                Response = Option<<Codec as Encoder>::Item>,
148                Error = MqttError<Err>,
149            > + 'static,
150        Codec: Encoder + Decoder + Clone + 'static,
151    {
152        MqttServer { svc_v3: self.svc_v3, svc_v5: service, _t: marker::PhantomData }
153    }
154}
155
156impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
157where
158    V3: ServiceFactory<
159            IoBoxed,
160            SharedCfg,
161            Response = (),
162            Error = MqttError<Err>,
163            InitError = InitErr,
164        >,
165    V5: ServiceFactory<
166            IoBoxed,
167            SharedCfg,
168            Response = (),
169            Error = MqttError<Err>,
170            InitError = InitErr,
171        >,
172{
173    async fn create_service(
174        &self,
175        cfg: SharedCfg,
176    ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
177        let (v3, v5) = join(self.svc_v3.create(cfg), self.svc_v5.create(cfg)).await;
178        let v3 = v3?;
179        let v5 = v5?;
180        Ok(MqttServerImpl { handlers: (v3, v5), cfg: cfg.get(), _t: marker::PhantomData })
181    }
182}
183
184impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed, SharedCfg>
185    for MqttServer<V3, V5, Err, InitErr>
186where
187    V3: ServiceFactory<
188            IoBoxed,
189            SharedCfg,
190            Response = (),
191            Error = MqttError<Err>,
192            InitError = InitErr,
193        > + 'static,
194    V5: ServiceFactory<
195            IoBoxed,
196            SharedCfg,
197            Response = (),
198            Error = MqttError<Err>,
199            InitError = InitErr,
200        > + 'static,
201    Err: 'static,
202    InitErr: 'static,
203{
204    type Response = ();
205    type Error = MqttError<Err>;
206    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
207    type InitError = InitErr;
208
209    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
210        self.create_service(cfg).await
211    }
212}
213
214impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>, SharedCfg>
215    for MqttServer<V3, V5, Err, InitErr>
216where
217    F: Filter,
218    V3: ServiceFactory<
219            IoBoxed,
220            SharedCfg,
221            Response = (),
222            Error = MqttError<Err>,
223            InitError = InitErr,
224        > + 'static,
225    V5: ServiceFactory<
226            IoBoxed,
227            SharedCfg,
228            Response = (),
229            Error = MqttError<Err>,
230            InitError = InitErr,
231        > + 'static,
232    Err: 'static,
233    InitErr: 'static,
234{
235    type Response = ();
236    type Error = MqttError<Err>;
237    type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
238    type InitError = InitErr;
239
240    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
241        self.create_service(cfg).await
242    }
243}
244
245/// Mqtt Server
246pub struct MqttServerImpl<V3, V5, Err> {
247    handlers: (V3, V5),
248    cfg: Cfg<MqttServiceConfig>,
249    _t: marker::PhantomData<Err>,
250}
251
252impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
253where
254    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
255    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
256{
257    type Response = ();
258    type Error = MqttError<Err>;
259
260    #[inline]
261    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
262        let (ready1, ready2) =
263            join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
264        ready1?;
265        ready2
266    }
267
268    #[inline]
269    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
270        self.handlers.0.poll(cx)?;
271        self.handlers.1.poll(cx)
272    }
273
274    #[inline]
275    async fn shutdown(&self) {
276        self.handlers.0.shutdown().await;
277        self.handlers.1.shutdown().await;
278    }
279
280    #[inline]
281    async fn call(
282        &self,
283        io: IoBoxed,
284        ctx: ServiceCtx<'_, Self>,
285    ) -> Result<Self::Response, Self::Error> {
286        // try to read Version, buffer may already contain info
287        let res = io
288            .decode(&VersionCodec)
289            .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
290        if let Some(ver) = res {
291            match ver {
292                ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
293                ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
294            }
295        } else {
296            let fut = async {
297                match io.recv(&VersionCodec).await {
298                    Ok(ver) => Ok(ver),
299                    Err(Either::Left(e)) => {
300                        Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
301                    }
302                    Err(Either::Right(e)) => {
303                        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
304                    }
305                }
306            };
307
308            match select(&mut Deadline::new(self.cfg.protocol_version_timeout), fut).await {
309                Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)),
310                Either::Right(Ok(Some(ver))) => match ver {
311                    ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
312                    ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
313                },
314                Either::Right(Ok(None)) => {
315                    Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
316                }
317                Either::Right(Err(e)) => Err(e),
318            }
319        }
320    }
321}
322
323impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
324where
325    F: Filter,
326    V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
327    V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
328{
329    type Response = ();
330    type Error = MqttError<Err>;
331
332    #[inline]
333    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
334        Service::<IoBoxed>::ready(self, ctx).await
335    }
336
337    #[inline]
338    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
339        Service::<IoBoxed>::poll(self, cx)
340    }
341
342    #[inline]
343    async fn shutdown(&self) {
344        Service::<IoBoxed>::shutdown(self).await
345    }
346
347    #[inline]
348    async fn call(
349        &self,
350        io: Io<F>,
351        ctx: ServiceCtx<'_, Self>,
352    ) -> Result<Self::Response, Self::Error> {
353        Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
354    }
355}
356
357pub struct DefaultProtocolServer<Err, InitErr> {
358    ver: ProtocolVersion,
359    _t: marker::PhantomData<(Err, InitErr)>,
360}
361
362impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
363    fn new(ver: ProtocolVersion) -> Self {
364        Self { ver, _t: marker::PhantomData }
365    }
366}
367
368impl<Err, InitErr> ServiceFactory<IoBoxed, SharedCfg> for DefaultProtocolServer<Err, InitErr> {
369    type Response = ();
370    type Error = MqttError<Err>;
371    type Service = DefaultProtocolServer<Err, InitErr>;
372    type InitError = InitErr;
373
374    async fn create(&self, _: SharedCfg) -> Result<Self::Service, Self::InitError> {
375        Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
376    }
377}
378
379impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
380    type Response = ();
381    type Error = MqttError<Err>;
382
383    async fn call(
384        &self,
385        _: IoBoxed,
386        _: ServiceCtx<'_, Self>,
387    ) -> Result<Self::Response, Self::Error> {
388        Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::other(
389            format!("Protocol is not supported: {:?}", self.ver),
390        )))))
391    }
392}