1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed};
4use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
5use ntex_util::time::{timeout_checked, Millis, Seconds};
6
7use crate::error::{HandshakeError, MqttError, ProtocolError};
8use crate::{service, types::QoS, InFlightService};
9
10use super::codec::{self as mqtt, Decoded, Encoded, Packet};
11use super::control::{Control, ControlAck};
12use super::default::{DefaultControlService, DefaultPublishService};
13use super::handshake::{Handshake, HandshakeAck};
14use super::publish::{Publish, PublishAck};
15use super::shared::{MqttShared, MqttSinkPool};
16use super::{dispatcher::factory, MqttSink, Session};
17
18pub struct MqttServer<St, C, Cn, P, M = Identity> {
20 handshake: C,
21 srv_control: Cn,
22 srv_publish: P,
23 middleware: M,
24 max_qos: QoS,
25 max_size: u32,
26 max_receive: u16,
27 max_topic_alias: u16,
28 min_chunk_size: u32,
29 handle_qos_after_disconnect: Option<QoS>,
30 connect_timeout: Seconds,
31 config: DispatcherConfig,
32 pub(super) pool: Rc<MqttSinkPool>,
33 _t: PhantomData<St>,
34}
35
36impl<St, C>
37 MqttServer<
38 St,
39 C,
40 DefaultControlService<St, C::Error>,
41 DefaultPublishService<St, C::Error>,
42 InFlightService,
43 >
44where
45 C: ServiceFactory<Handshake, Response = HandshakeAck<St>>,
46 C::Error: fmt::Debug,
47{
48 pub fn new<F>(handshake: F) -> Self
50 where
51 F: IntoServiceFactory<C, Handshake>,
52 {
53 let config = DispatcherConfig::default();
54 config.set_disconnect_timeout(Seconds(3));
55
56 MqttServer {
57 config,
58 handshake: handshake.into_factory(),
59 srv_control: DefaultControlService::default(),
60 srv_publish: DefaultPublishService::default(),
61 middleware: InFlightService::new(0, 65535),
62 max_qos: QoS::AtLeastOnce,
63 max_size: 0,
64 max_receive: 15,
65 max_topic_alias: 32,
66 min_chunk_size: 32 * 1024,
67 handle_qos_after_disconnect: None,
68 connect_timeout: Seconds::ZERO,
69 pool: Rc::new(MqttSinkPool::default()),
70 _t: PhantomData,
71 }
72 }
73}
74
75impl<St, C, Cn, P> MqttServer<St, C, Cn, P, InFlightService> {
76 pub fn max_receive_size(mut self, val: usize) -> Self {
80 self.middleware = self.middleware.max_receive_size(val);
81 self
82 }
83}
84
85impl<St, C, Cn, P, M> MqttServer<St, C, Cn, P, M>
86where
87 St: 'static,
88 C: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
89 C::Error: fmt::Debug,
90 Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
91 P: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
92{
93 pub fn connect_timeout(mut self, timeout: Seconds) -> Self {
101 self.connect_timeout = timeout;
102 self
103 }
104
105 pub fn disconnect_timeout(self, val: Seconds) -> Self {
114 self.config.set_disconnect_timeout(val);
115 self
116 }
117
118 pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
126 self.config.set_frame_read_rate(timeout, max_timeout, rate);
127 self
128 }
129
130 pub fn max_size(mut self, size: u32) -> Self {
135 self.max_size = size;
136 self
137 }
138
139 pub fn max_receive(mut self, val: u16) -> Self {
144 self.max_receive = val;
145 self
146 }
147
148 pub fn max_topic_alias(mut self, val: u16) -> Self {
152 self.max_topic_alias = val;
153 self
154 }
155
156 pub fn max_qos(mut self, qos: QoS) -> Self {
160 self.max_qos = qos;
161 self
162 }
163
164 pub fn min_chunk_size(mut self, size: u32) -> Self {
171 self.min_chunk_size = size;
172 self
173 }
174
175 pub fn handle_qos_after_disconnect(mut self, max_qos: Option<QoS>) -> Self {
194 self.handle_qos_after_disconnect = max_qos;
195 self
196 }
197
198 pub fn reset_middlewares(self) -> MqttServer<St, C, Cn, P, Identity> {
200 MqttServer {
201 middleware: Identity,
202 config: self.config,
203 handshake: self.handshake,
204 srv_publish: self.srv_publish,
205 srv_control: self.srv_control,
206 max_size: self.max_size,
207 max_receive: self.max_receive,
208 max_topic_alias: self.max_topic_alias,
209 min_chunk_size: self.min_chunk_size,
210 max_qos: self.max_qos,
211 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
212 connect_timeout: self.connect_timeout,
213 pool: self.pool,
214 _t: PhantomData,
215 }
216 }
217
218 pub fn middleware<U>(self, mw: U) -> MqttServer<St, C, Cn, P, Stack<M, U>> {
226 MqttServer {
227 middleware: Stack::new(self.middleware, mw),
228 config: self.config,
229 handshake: self.handshake,
230 srv_publish: self.srv_publish,
231 srv_control: self.srv_control,
232 max_size: self.max_size,
233 max_receive: self.max_receive,
234 max_topic_alias: self.max_topic_alias,
235 max_qos: self.max_qos,
236 min_chunk_size: self.min_chunk_size,
237 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
238 connect_timeout: self.connect_timeout,
239 pool: self.pool,
240 _t: PhantomData,
241 }
242 }
243
244 pub fn control<F, Srv>(self, service: F) -> MqttServer<St, C, Srv, P, M>
249 where
250 F: IntoServiceFactory<Srv, Control<C::Error>, Session<St>>,
251 Srv: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
252 C::Error: From<Srv::Error> + From<Srv::InitError>,
253 {
254 MqttServer {
255 config: self.config,
256 handshake: self.handshake,
257 srv_publish: self.srv_publish,
258 srv_control: service.into_factory(),
259 middleware: self.middleware,
260 max_size: self.max_size,
261 max_receive: self.max_receive,
262 max_topic_alias: self.max_topic_alias,
263 max_qos: self.max_qos,
264 min_chunk_size: self.min_chunk_size,
265 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
266 connect_timeout: self.connect_timeout,
267 pool: self.pool,
268 _t: PhantomData,
269 }
270 }
271
272 pub fn publish<F, Srv>(self, publish: F) -> MqttServer<St, C, Cn, Srv, M>
274 where
275 F: IntoServiceFactory<Srv, Publish, Session<St>>,
276 C::Error: From<Srv::Error> + From<Srv::InitError>,
277 Srv: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
278 Srv::Error: fmt::Debug,
279 PublishAck: TryFrom<Srv::Error, Error = C::Error>,
280 {
281 MqttServer {
282 config: self.config,
283 handshake: self.handshake,
284 srv_publish: publish.into_factory(),
285 srv_control: self.srv_control,
286 middleware: self.middleware,
287 max_size: self.max_size,
288 max_receive: self.max_receive,
289 max_topic_alias: self.max_topic_alias,
290 max_qos: self.max_qos,
291 min_chunk_size: self.min_chunk_size,
292 handle_qos_after_disconnect: self.handle_qos_after_disconnect,
293 connect_timeout: self.connect_timeout,
294 pool: self.pool,
295 _t: PhantomData,
296 }
297 }
298}
299
300impl<St, C, Cn, P, M> MqttServer<St, C, Cn, P, M>
301where
302 St: 'static,
303 C: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
304 C::Error: From<Cn::Error>
305 + From<Cn::InitError>
306 + From<P::Error>
307 + From<P::InitError>
308 + fmt::Debug,
309 Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
310 P: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
311 P::Error: fmt::Debug,
312 PublishAck: TryFrom<P::Error, Error = C::Error>,
313{
314 pub fn finish(
316 self,
317 ) -> service::MqttServer<
318 Session<St>,
319 impl ServiceFactory<
320 IoBoxed,
321 Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
322 Error = MqttError<C::Error>,
323 InitError = C::InitError,
324 >,
325 impl ServiceFactory<
326 DispatchItem<Rc<MqttShared>>,
327 Session<St>,
328 Response = Option<mqtt::Encoded>,
329 Error = MqttError<C::Error>,
330 InitError = MqttError<C::Error>,
331 >,
332 M,
333 Rc<MqttShared>,
334 > {
335 service::MqttServer::new(
336 HandshakeFactory {
337 factory: self.handshake,
338 max_size: self.max_size,
339 max_receive: self.max_receive,
340 max_topic_alias: self.max_topic_alias,
341 max_qos: self.max_qos,
342 min_chunk_size: self.min_chunk_size,
343 connect_timeout: self.connect_timeout.into(),
344 pool: self.pool,
345 _t: PhantomData,
346 },
347 factory(self.srv_publish, self.srv_control, self.handle_qos_after_disconnect),
348 self.middleware,
349 self.config,
350 )
351 }
352}
353
354struct HandshakeFactory<St, H> {
355 factory: H,
356 max_size: u32,
357 max_receive: u16,
358 max_topic_alias: u16,
359 max_qos: QoS,
360 min_chunk_size: u32,
361 connect_timeout: Millis,
362 pool: Rc<MqttSinkPool>,
363 _t: PhantomData<St>,
364}
365
366impl<St, H> ServiceFactory<IoBoxed> for HandshakeFactory<St, H>
367where
368 H: ServiceFactory<Handshake, Response = HandshakeAck<St>> + 'static,
369 H::Error: fmt::Debug,
370{
371 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
372 type Error = MqttError<H::Error>;
373
374 type Service = HandshakeService<St, H::Service>;
375 type InitError = H::InitError;
376
377 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
378 Ok(HandshakeService {
379 service: self.factory.create(()).await?,
380 max_size: self.max_size,
381 max_receive: self.max_receive,
382 max_topic_alias: self.max_topic_alias,
383 max_qos: self.max_qos,
384 min_chunk_size: self.min_chunk_size,
385 pool: self.pool.clone(),
386 connect_timeout: self.connect_timeout,
387 _t: PhantomData,
388 })
389 }
390}
391
392struct HandshakeService<St, H> {
393 service: H,
394 max_size: u32,
395 max_receive: u16,
396 max_topic_alias: u16,
397 max_qos: QoS,
398 min_chunk_size: u32,
399 connect_timeout: Millis,
400 pool: Rc<MqttSinkPool>,
401 _t: PhantomData<St>,
402}
403
404impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
405where
406 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
407 H::Error: fmt::Debug,
408{
409 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
410 type Error = MqttError<H::Error>;
411
412 ntex_service::forward_ready!(service, MqttError::Service);
413 ntex_service::forward_shutdown!(service);
414
415 async fn call(
416 &self,
417 io: IoBoxed,
418 ctx: ServiceCtx<'_, Self>,
419 ) -> Result<Self::Response, Self::Error> {
420 log::trace!("Starting mqtt v5 handshake");
421
422 let codec = mqtt::Codec::default();
423 codec.set_max_inbound_size(self.max_size);
424 codec.set_min_chunk_size(self.min_chunk_size);
425
426 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone()));
427 shared.set_max_qos(self.max_qos);
428 shared.set_receive_max(self.max_receive);
429 shared.set_topic_alias_max(self.max_topic_alias);
430
431 let packet = timeout_checked(self.connect_timeout, io.recv(&shared.codec))
433 .await
434 .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
435 .map_err(|err| {
436 log::trace!("Error is received during mqtt handshake: {:?}", err);
437 MqttError::Handshake(HandshakeError::from(err))
438 })?
439 .ok_or_else(|| {
440 log::trace!("Server mqtt is disconnected during handshake");
441 MqttError::Handshake(HandshakeError::Disconnected(None))
442 })?;
443
444 match packet {
445 Decoded::Packet(Packet::Connect(connect), size) => {
446 if let Some(size) = connect.max_packet_size {
448 shared.codec.set_max_outbound_size(size.get());
449 }
450 let keep_alive = connect.keep_alive;
451 let peer_receive_max =
452 connect.receive_max.map(|v| v.get()).unwrap_or(16) as usize;
453
454 let mut ack = ctx
456 .call(&self.service, Handshake::new(connect, size, io, shared))
457 .await
458 .map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))?;
459
460 match ack.session {
461 Some(session) => {
462 log::trace!("Sending: {:#?}", ack.packet);
463 let shared = ack.shared;
464
465 shared.set_max_qos(ack.packet.max_qos);
466 shared.set_receive_max(ack.packet.receive_max.get());
467 shared.set_topic_alias_max(ack.packet.topic_alias_max);
468 shared
469 .codec
470 .set_max_inbound_size(ack.packet.max_packet_size.unwrap_or(0));
471 shared.codec.set_retain_available(ack.packet.retain_available);
472 shared.codec.set_sub_ids_available(
473 ack.packet.subscription_identifiers_available,
474 );
475 if ack.packet.server_keepalive_sec.is_none()
476 && (keep_alive > ack.keepalive)
477 {
478 ack.packet.server_keepalive_sec = Some(ack.keepalive);
479 }
480 shared.set_cap(peer_receive_max);
481
482 ack.io.encode(
483 Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
484 &shared.codec,
485 )?;
486
487 Ok((
488 ack.io,
489 shared.clone(),
490 Session::new(session, MqttSink::new(shared)),
491 Seconds(ack.keepalive),
492 ))
493 }
494 None => {
495 log::trace!("Failed to complete handshake: {:#?}", ack.packet);
496
497 ack.io.encode(
498 Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
499 &ack.shared.codec,
500 )?;
501 let _ = ack.io.shutdown().await;
502 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
503 }
504 }
505 }
506 Decoded::Packet(packet, _) => {
507 log::info!(
508 "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
509 packet.packet_type()
510 );
511 Err(MqttError::Handshake(HandshakeError::Protocol(
512 ProtocolError::unexpected_packet(
513 packet.packet_type(),
514 "Expected CONNECT packet [MQTT-3.1.0-1]",
515 ),
516 )))
517 }
518 Decoded::Publish(..) => {
519 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
520 Err(MqttError::Handshake(HandshakeError::Protocol(
521 ProtocolError::unexpected_packet(
522 crate::types::packet_type::PUBLISH_START,
523 "Expected CONNECT packet [MQTT-3.1.0-1]",
524 ),
525 )))
526 }
527 Decoded::PayloadChunk(..) => unreachable!(),
528 }
529 }
530}