1use std::{fmt, io, marker, task::Context};
2
3use ntex_codec::{Decoder, Encoder};
4use ntex_dispatcher::DispatchItem;
5use ntex_io::{Filter, Io, IoBoxed};
6use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
7use ntex_util::future::{Either, join, select};
8use ntex_util::time::{Deadline, Seconds};
9
10use crate::version::{ProtocolVersion, VersionCodec};
11use crate::{MqttServiceConfig, error::HandshakeError, error::MqttError, service};
12
13pub struct MqttServer<V3, V5, Err, InitErr> {
15 svc_v3: V3,
16 svc_v5: V5,
17 _t: marker::PhantomData<(Err, InitErr)>,
18}
19
20impl<V3, V5, Err, InitErr> fmt::Debug for MqttServer<V3, V5, Err, InitErr> {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 f.debug_struct("MqttServer").finish()
23 }
24}
25
26impl<Err, InitErr>
27 MqttServer<
28 DefaultProtocolServer<Err, InitErr>,
29 DefaultProtocolServer<Err, InitErr>,
30 Err,
31 InitErr,
32 >
33{
34 pub fn new() -> Self {
36 MqttServer {
37 svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3),
38 svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5),
39 _t: marker::PhantomData,
40 }
41 }
42}
43
44impl<Err, InitErr> Default
45 for MqttServer<
46 DefaultProtocolServer<Err, InitErr>,
47 DefaultProtocolServer<Err, InitErr>,
48 Err,
49 InitErr,
50 >
51{
52 fn default() -> Self {
53 MqttServer::new()
54 }
55}
56
57impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
58where
59 Err: fmt::Debug,
60 V3: ServiceFactory<
61 IoBoxed,
62 SharedCfg,
63 Response = (),
64 Error = MqttError<Err>,
65 InitError = InitErr,
66 >,
67 V5: ServiceFactory<
68 IoBoxed,
69 SharedCfg,
70 Response = (),
71 Error = MqttError<Err>,
72 InitError = InitErr,
73 >,
74{
75 pub fn v3<St, H, P, M, Codec>(
77 self,
78 service: service::MqttServer<St, H, P, M, Codec>,
79 ) -> MqttServer<
80 impl ServiceFactory<
81 IoBoxed,
82 SharedCfg,
83 Response = (),
84 Error = MqttError<Err>,
85 InitError = InitErr,
86 >,
87 V5,
88 Err,
89 InitErr,
90 >
91 where
92 St: 'static,
93 H: ServiceFactory<
94 IoBoxed,
95 SharedCfg,
96 Response = (IoBoxed, Codec, St, Seconds),
97 Error = MqttError<Err>,
98 InitError = InitErr,
99 > + 'static,
100 P: ServiceFactory<
101 DispatchItem<Codec>,
102 (SharedCfg, St),
103 Response = Option<<Codec as Encoder>::Item>,
104 Error = MqttError<Err>,
105 InitError = MqttError<Err>,
106 > + 'static,
107 M: Middleware<P::Service, SharedCfg>,
108 M::Service: Service<
109 DispatchItem<Codec>,
110 Response = Option<<Codec as Encoder>::Item>,
111 Error = MqttError<Err>,
112 > + 'static,
113 Codec: Encoder + Decoder + Clone + 'static,
114 {
115 MqttServer { svc_v3: service, svc_v5: self.svc_v5, _t: marker::PhantomData }
116 }
117
118 pub fn v5<St, H, P, M, Codec>(
120 self,
121 service: service::MqttServer<St, H, P, M, Codec>,
122 ) -> MqttServer<
123 V3,
124 impl ServiceFactory<
125 IoBoxed,
126 SharedCfg,
127 Response = (),
128 Error = MqttError<Err>,
129 InitError = InitErr,
130 >,
131 Err,
132 InitErr,
133 >
134 where
135 St: 'static,
136 H: ServiceFactory<
137 IoBoxed,
138 SharedCfg,
139 Response = (IoBoxed, Codec, St, Seconds),
140 Error = MqttError<Err>,
141 InitError = InitErr,
142 > + 'static,
143 P: ServiceFactory<
144 DispatchItem<Codec>,
145 (SharedCfg, St),
146 Response = Option<<Codec as Encoder>::Item>,
147 Error = MqttError<Err>,
148 InitError = MqttError<Err>,
149 > + 'static,
150 M: Middleware<P::Service, SharedCfg>,
151 M::Service: Service<
152 DispatchItem<Codec>,
153 Response = Option<<Codec as Encoder>::Item>,
154 Error = MqttError<Err>,
155 > + 'static,
156 Codec: Encoder + Decoder + Clone + 'static,
157 {
158 MqttServer { svc_v3: self.svc_v3, svc_v5: service, _t: marker::PhantomData }
159 }
160}
161
162impl<V3, V5, Err, InitErr> MqttServer<V3, V5, Err, InitErr>
163where
164 V3: ServiceFactory<
165 IoBoxed,
166 SharedCfg,
167 Response = (),
168 Error = MqttError<Err>,
169 InitError = InitErr,
170 >,
171 V5: ServiceFactory<
172 IoBoxed,
173 SharedCfg,
174 Response = (),
175 Error = MqttError<Err>,
176 InitError = InitErr,
177 >,
178{
179 async fn create_service(
180 &self,
181 cfg: SharedCfg,
182 ) -> Result<MqttServerImpl<V3::Service, V5::Service, Err>, InitErr> {
183 let (v3, v5) =
184 join(self.svc_v3.create(cfg.clone()), self.svc_v5.create(cfg.clone())).await;
185 let v3 = v3?;
186 let v5 = v5?;
187 Ok(MqttServerImpl { handlers: (v3, v5), cfg: cfg.get(), _t: marker::PhantomData })
188 }
189}
190
191impl<V3, V5, Err, InitErr> ServiceFactory<IoBoxed, SharedCfg>
192 for MqttServer<V3, V5, Err, InitErr>
193where
194 V3: ServiceFactory<
195 IoBoxed,
196 SharedCfg,
197 Response = (),
198 Error = MqttError<Err>,
199 InitError = InitErr,
200 > + 'static,
201 V5: ServiceFactory<
202 IoBoxed,
203 SharedCfg,
204 Response = (),
205 Error = MqttError<Err>,
206 InitError = InitErr,
207 > + 'static,
208 Err: 'static,
209 InitErr: 'static,
210{
211 type Response = ();
212 type Error = MqttError<Err>;
213 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
214 type InitError = InitErr;
215
216 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
217 self.create_service(cfg).await
218 }
219}
220
221impl<F, V3, V5, Err, InitErr> ServiceFactory<Io<F>, SharedCfg>
222 for MqttServer<V3, V5, Err, InitErr>
223where
224 F: Filter,
225 V3: ServiceFactory<
226 IoBoxed,
227 SharedCfg,
228 Response = (),
229 Error = MqttError<Err>,
230 InitError = InitErr,
231 > + 'static,
232 V5: ServiceFactory<
233 IoBoxed,
234 SharedCfg,
235 Response = (),
236 Error = MqttError<Err>,
237 InitError = InitErr,
238 > + 'static,
239 Err: 'static,
240 InitErr: 'static,
241{
242 type Response = ();
243 type Error = MqttError<Err>;
244 type Service = MqttServerImpl<V3::Service, V5::Service, Err>;
245 type InitError = InitErr;
246
247 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
248 self.create_service(cfg).await
249 }
250}
251
252pub struct MqttServerImpl<V3, V5, Err> {
254 handlers: (V3, V5),
255 cfg: Cfg<MqttServiceConfig>,
256 _t: marker::PhantomData<Err>,
257}
258
259impl<V3, V5, Err> fmt::Debug for MqttServerImpl<V3, V5, Err> {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_struct("MqttServerImpl").finish()
262 }
263}
264
265impl<V3, V5, Err> Service<IoBoxed> for MqttServerImpl<V3, V5, Err>
266where
267 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
268 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
269{
270 type Response = ();
271 type Error = MqttError<Err>;
272
273 #[inline]
274 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
275 let (ready1, ready2) =
276 join(ctx.ready(&self.handlers.0), ctx.ready(&self.handlers.1)).await;
277 ready1?;
278 ready2
279 }
280
281 #[inline]
282 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
283 self.handlers.0.poll(cx)?;
284 self.handlers.1.poll(cx)
285 }
286
287 #[inline]
288 async fn shutdown(&self) {
289 self.handlers.0.shutdown().await;
290 self.handlers.1.shutdown().await;
291 }
292
293 #[inline]
294 async fn call(
295 &self,
296 io: IoBoxed,
297 ctx: ServiceCtx<'_, Self>,
298 ) -> Result<Self::Response, Self::Error> {
299 let res = io
301 .decode(&VersionCodec)
302 .map_err(|e| MqttError::Handshake(HandshakeError::Protocol(e.into())))?;
303 if let Some(ver) = res {
304 match ver {
305 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
306 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
307 }
308 } else {
309 let fut = async {
310 match io.recv(&VersionCodec).await {
311 Ok(ver) => Ok(ver),
312 Err(Either::Left(e)) => {
313 Err(MqttError::Handshake(HandshakeError::Protocol(e.into())))
314 }
315 Err(Either::Right(e)) => {
316 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(e))))
317 }
318 }
319 };
320
321 match select(&mut Deadline::new(self.cfg.protocol_version_timeout), fut).await {
322 Either::Left(()) => Err(MqttError::Handshake(HandshakeError::Timeout)),
323 Either::Right(Ok(Some(ver))) => match ver {
324 ProtocolVersion::MQTT3 => ctx.call(&self.handlers.0, io).await,
325 ProtocolVersion::MQTT5 => ctx.call(&self.handlers.1, io).await,
326 },
327 Either::Right(Ok(None)) => {
328 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
329 }
330 Either::Right(Err(e)) => Err(e),
331 }
332 }
333 }
334}
335
336impl<F, V3, V5, Err> Service<Io<F>> for MqttServerImpl<V3, V5, Err>
337where
338 F: Filter,
339 V3: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
340 V5: Service<IoBoxed, Response = (), Error = MqttError<Err>>,
341{
342 type Response = ();
343 type Error = MqttError<Err>;
344
345 #[inline]
346 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
347 Service::<IoBoxed>::ready(self, ctx).await
348 }
349
350 #[inline]
351 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
352 Service::<IoBoxed>::poll(self, cx)
353 }
354
355 #[inline]
356 async fn shutdown(&self) {
357 Service::<IoBoxed>::shutdown(self).await;
358 }
359
360 #[inline]
361 async fn call(
362 &self,
363 io: Io<F>,
364 ctx: ServiceCtx<'_, Self>,
365 ) -> Result<Self::Response, Self::Error> {
366 Service::<IoBoxed>::call(self, IoBoxed::from(io), ctx).await
367 }
368}
369
370pub struct DefaultProtocolServer<Err, InitErr> {
371 ver: ProtocolVersion,
372 _t: marker::PhantomData<(Err, InitErr)>,
373}
374
375impl<Err, InitErr> fmt::Debug for DefaultProtocolServer<Err, InitErr> {
376 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377 f.debug_struct("DefaultProtocolServer").field("ver", &self.ver).finish()
378 }
379}
380
381impl<Err, InitErr> DefaultProtocolServer<Err, InitErr> {
382 fn new(ver: ProtocolVersion) -> Self {
383 Self { ver, _t: marker::PhantomData }
384 }
385}
386
387impl<Err, InitErr> ServiceFactory<IoBoxed, SharedCfg> for DefaultProtocolServer<Err, InitErr> {
388 type Response = ();
389 type Error = MqttError<Err>;
390 type Service = DefaultProtocolServer<Err, InitErr>;
391 type InitError = InitErr;
392
393 async fn create(&self, _: SharedCfg) -> Result<Self::Service, Self::InitError> {
394 Ok(DefaultProtocolServer { ver: self.ver, _t: marker::PhantomData })
395 }
396}
397
398impl<Err, InitErr> Service<IoBoxed> for DefaultProtocolServer<Err, InitErr> {
399 type Response = ();
400 type Error = MqttError<Err>;
401
402 async fn call(
403 &self,
404 _: IoBoxed,
405 _: ServiceCtx<'_, Self>,
406 ) -> Result<Self::Response, Self::Error> {
407 Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::other(
408 format!("Protocol is not supported: {:?}", self.ver),
409 )))))
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_debug() {
419 let server = <MqttServer<_, _, (), ()>>::default();
421 assert!(format!("{server:?}").contains("MqttServer"));
422 }
423}