1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_io::{DispatchItem, Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{Middleware2, Service, ServiceCtx, ServiceFactory};
7use ntex_util::future::{Either, join, select};
8use ntex_util::time::{Deadline, Seconds};
9
10use crate::version::{ProtocolVersion, VersionCodec};
11use crate::{MqttServiceConfig, error::HandshakeError, error::MqttError, service};
12
13pub struct MqttServer<V3, V5, Err, InitErr> {
15 svc_v3: V3,
16 svc_v5: V5,
17 _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<Err, InitErr>
21 MqttServer<
22 DefaultProtocolServer<Err, InitErr>,
23 DefaultProtocolServer<Err, InitErr>,
24 Err,
25 InitErr,
26 >
27{
28 pub fn new() -> Self {
30 MqttServer {
31 svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
32 svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
33 _t: marker::PhantomData,
34 }
35 }
36}
37
38impl<Err, InitErr> Default
39 for MqttServer<
40 DefaultProtocolServer<Err, InitErr>,
41 DefaultProtocolServer<Err, InitErr>,
42 Err,
43 InitErr,
44 >
45{
46 fn default() -> Self {
47 MqttServer::new()
48 }
49}
50
51impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
52where
53 Err: fmt::Debug,
54 V3: ServiceFactory<
55 IoBoxed,
56 SharedCfg,
57 Response = (),
58 Error = MqttError<Err>,
59 InitError = InitErr,
60 >,
61 V5: ServiceFactory<
62 IoBoxed,
63 SharedCfg,
64 Response = (),
65 Error = MqttError<Err>,
66 InitError = InitErr,
67 >,
68{
69 pub fn v3<St, H, P, M, Codec>(
71 self,
72 service: service::MqttServer<St, H, P, M, Codec>,
73 ) -> MqttServer<
74 impl ServiceFactory<
75 IoBoxed,
76 SharedCfg,
77 Response = (),
78 Error = MqttError<Err>,
79 InitError = InitErr,
80 >,
81 V5,
82 Err,
83 InitErr,
84 >
85 where
86 St: 'static,
87 H: ServiceFactory<
88 IoBoxed,
89 SharedCfg,
90 Response = (IoBoxed, Codec, St, Seconds),
91 Error = MqttError<Err>,
92 InitError = InitErr,
93 > + 'static,
94 P: ServiceFactory<
95 DispatchItem<Codec>,
96 (SharedCfg, St),
97 Response = Option<<Codec as Encoder>::Item>,
98 Error = MqttError<Err>,
99 InitError = MqttError<Err>,
100 > + 'static,
101 M: Middleware2<P::Service, SharedCfg>,
102 M::Service: Service<
103 DispatchItem<Codec>,
104 Response = Option<<Codec as Encoder>::Item>,
105 Error = MqttError<Err>,
106 > + 'static,
107 Codec: Encoder + Decoder + Clone + 'static,
108 {
109 MqttServer { svc_v3: service, svc_v5: self.svc_v5, _t: marker::PhantomData }
110 }
111
112 pub fn v5<St, H, P, M, Codec>(
114 self,
115 service: service::MqttServer<St, H, P, M, Codec>,
116 ) -> MqttServer<
117 V3,
118 impl ServiceFactory<
119 IoBoxed,
120 SharedCfg,
121 Response = (),
122 Error = MqttError<Err>,
123 InitError = InitErr,
124 >,
125 Err,
126 InitErr,
127 >
128 where
129 St: 'static,
130 H: ServiceFactory<
131 IoBoxed,
132 SharedCfg,
133 Response = (IoBoxed, Codec, St, Seconds),
134 Error = MqttError<Err>,
135 InitError = InitErr,
136 > + 'static,
137 P: ServiceFactory<
138 DispatchItem<Codec>,
139 (SharedCfg, St),
140 Response = Option<<Codec as Encoder>::Item>,
141 Error = MqttError<Err>,
142 InitError = MqttError<Err>,
143 > + 'static,
144 M: Middleware2<P::Service, SharedCfg>,
145 M::Service: Service<
146 DispatchItem<Codec>,
147 Response = Option<<Codec as Encoder>::Item>,
148 Error = MqttError<Err>,
149 > + 'static,
150 Codec: Encoder + Decoder + Clone + 'static,
151 {
152 MqttServer { svc_v3: self.svc_v3, svc_v5: service, _t: marker::PhantomData }
153 }
154}
155
156impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
157where
158 V3: ServiceFactory<
159 IoBoxed,
160 SharedCfg,
161 Response = (),
162 Error = MqttError<Err>,
163 InitError = InitErr,
164 >,
165 V5: ServiceFactory<
166 IoBoxed,
167 SharedCfg,
168 Response = (),
169 Error = MqttError<Err>,
170 InitError = InitErr,
171 >,
172{
173 async fn create_service(
174 &self,
175 cfg: SharedCfg,
176 ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
177 let (v3, v5) = join(self.svc_v3.create(cfg), self.svc_v5.create(cfg)).await;
178 let v3 = v3?;
179 let v5 = v5?;
180 Ok(MqttServerImpl { handlers: (v3, v5), cfg: cfg.get(), _t: marker::PhantomData })
181 }
182}
183
184impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed, SharedCfg>
185 for MqttServer<V3, V5, Err, InitErr>
186where
187 V3: ServiceFactory<
188 IoBoxed,
189 SharedCfg,
190 Response = (),
191 Error = MqttError<Err>,
192 InitError = InitErr,
193 > + 'static,
194 V5: ServiceFactory<
195 IoBoxed,
196 SharedCfg,
197 Response = (),
198 Error = MqttError<Err>,
199 InitError = InitErr,
200 > + 'static,
201 Err: 'static,
202 InitErr: 'static,
203{
204 type Response = ();
205 type Error = MqttError<Err>;
206 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
207 type InitError = InitErr;
208
209 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
210 self.create_service(cfg).await
211 }
212}
213
214impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>, SharedCfg>
215 for MqttServer<V3, V5, Err, InitErr>
216where
217 F: Filter,
218 V3: ServiceFactory<
219 IoBoxed,
220 SharedCfg,
221 Response = (),
222 Error = MqttError<Err>,
223 InitError = InitErr,
224 > + 'static,
225 V5: ServiceFactory<
226 IoBoxed,
227 SharedCfg,
228 Response = (),
229 Error = MqttError<Err>,
230 InitError = InitErr,
231 > + 'static,
232 Err: 'static,
233 InitErr: 'static,
234{
235 type Response = ();
236 type Error = MqttError<Err>;
237 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
238 type InitError = InitErr;
239
240 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
241 self.create_service(cfg).await
242 }
243}
244
245pub struct MqttServerImpl<V3, V5, Err> {
247 handlers: (V3, V5),
248 cfg: Cfg<MqttServiceConfig>,
249 _t: marker::PhantomData<Err>,
250}
251
252impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
253where
254 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
255 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
256{
257 type Response = ();
258 type Error = MqttError<Err>;
259
260 #[inline]
261 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
262 let (ready1, ready2) =
263 join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
264 ready1?;
265 ready2
266 }
267
268 #[inline]
269 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
270 self.handlers.0.poll(cx)?;
271 self.handlers.1.poll(cx)
272 }
273
274 #[inline]
275 async fn shutdown(&self) {
276 self.handlers.0.shutdown().await;
277 self.handlers.1.shutdown().await;
278 }
279
280 #[inline]
281 async fn call(
282 &self,
283 io: IoBoxed,
284 ctx: ServiceCtx<'_, Self>,
285 ) -> Result<Self::Response, Self::Error> {
286 let res = io
288 .decode(&VersionCodec)
289 .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
290 if let Some(ver) = res {
291 match ver {
292 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
293 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
294 }
295 } else {
296 let fut = async {
297 match io.recv(&VersionCodec).await {
298 Ok(ver) => Ok(ver),
299 Err(Either::Left(e)) => {
300 Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
301 }
302 Err(Either::Right(e)) => {
303 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
304 }
305 }
306 };
307
308 match select(&mut Deadline::new(self.cfg.protocol_version_timeout), fut).await {
309 Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)),
310 Either::Right(Ok(Some(ver))) => match ver {
311 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
312 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
313 },
314 Either::Right(Ok(None)) => {
315 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
316 }
317 Either::Right(Err(e)) => Err(e),
318 }
319 }
320 }
321}
322
323impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
324where
325 F: Filter,
326 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
327 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
328{
329 type Response = ();
330 type Error = MqttError<Err>;
331
332 #[inline]
333 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
334 Service::<IoBoxed>::ready(self, ctx).await
335 }
336
337 #[inline]
338 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
339 Service::<IoBoxed>::poll(self, cx)
340 }
341
342 #[inline]
343 async fn shutdown(&self) {
344 Service::<IoBoxed>::shutdown(self).await
345 }
346
347 #[inline]
348 async fn call(
349 &self,
350 io: Io<F>,
351 ctx: ServiceCtx<'_, Self>,
352 ) -> Result<Self::Response, Self::Error> {
353 Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
354 }
355}
356
357pub struct DefaultProtocolServer<Err, InitErr> {
358 ver: ProtocolVersion,
359 _t: marker::PhantomData<(Err, InitErr)>,
360}
361
362impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
363 fn new(ver: ProtocolVersion) -> Self {
364 Self { ver, _t: marker::PhantomData }
365 }
366}
367
368impl<Err, InitErr> ServiceFactory<IoBoxed, SharedCfg> for DefaultProtocolServer<Err, InitErr> {
369 type Response = ();
370 type Error = MqttError<Err>;
371 type Service = DefaultProtocolServer<Err, InitErr>;
372 type InitError = InitErr;
373
374 async fn create(&self, _: SharedCfg) -> Result<Self::Service, Self::InitError> {
375 Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
376 }
377}
378
379impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
380 type Response = ();
381 type Error = MqttError<Err>;
382
383 async fn call(
384 &self,
385 _: IoBoxed,
386 _: ServiceCtx<'_, Self>,
387 ) -> Result<Self::Response, Self::Error> {
388 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::other(
389 format!("Protocol is not supported: {:?}", self.ver),
390 )))))
391 }
392}