1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
6use ntex_util::time::{Seconds, timeout_checked};
7
8use crate::error::{HandshakeError, MqttError, ProtocolError};
9use crate::{MqttServiceConfig, service};
10
11use super::codec::{self as mqtt, Decoded, Encoded, Packet};
12use super::control::{Control, ControlAck};
13use super::default::{DefaultControlService, InFlightService};
14use super::handshake::{Handshake, HandshakeAck};
15use super::publish::{Publish, PublishAck};
16use super::shared::{MqttShared, MqttSinkPool};
17use super::{MqttSink, Session, dispatcher::factory};
18
19pub struct MqttServer<St, C, Cn, M = Identity> {
21 handshake: C,
22 control: Cn,
23 middleware: M,
24 pub(super) pool: Rc<MqttSinkPool>,
25 _t: PhantomData<St>,
26}
27
28impl<St, C> MqttServer<St, C, DefaultControlService<St, C::Error>, InFlightService>
29where
30 C: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>>,
31 C::Error: fmt::Debug,
32{
33 pub fn new<F>(handshake: F) -> Self
35 where
36 F: IntoServiceFactory<C, Handshake, SharedCfg>,
37 {
38 MqttServer {
39 handshake: handshake.into_factory(),
40 control: DefaultControlService::default(),
41 middleware: InFlightService,
42 pool: Rc::new(MqttSinkPool::default()),
43 _t: PhantomData,
44 }
45 }
46}
47
48impl<St, C, Cn, M> MqttServer<St, C, Cn, M>
49where
50 St: 'static,
51 C: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
52 C::Error: fmt::Debug,
53 Cn: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
54{
55 pub fn reset_middlewares(self) -> MqttServer<St, C, Cn, Identity> {
57 MqttServer {
58 middleware: Identity,
59 control: self.control,
60 handshake: self.handshake,
61 pool: self.pool,
62 _t: PhantomData,
63 }
64 }
65
66 pub fn middleware<U>(self, mw: U) -> MqttServer<St, C, Cn, Stack<M, U>> {
74 MqttServer {
75 middleware: Stack::new(self.middleware, mw),
76 handshake: self.handshake,
77 control: self.control,
78 pool: self.pool,
79 _t: PhantomData,
80 }
81 }
82
83 pub fn control<F, Srv>(self, service: F) -> MqttServer<St, C, Srv, M>
88 where
89 F: IntoServiceFactory<Srv, Control<C::Error>, Session<St>>,
90 Srv: ServiceFactory<Control<C::Error>, Session<St>, Response = ControlAck> + 'static,
91 C::Error: From<Srv::Error> + From<Srv::InitError>,
92 {
93 MqttServer {
94 handshake: self.handshake,
95 control: service.into_factory(),
96 middleware: self.middleware,
97 pool: self.pool,
98 _t: PhantomData,
99 }
100 }
101
102 pub fn publish<F, Srv>(
105 self,
106 publish: F,
107 ) -> service::MqttServer<
108 Session<St>,
109 impl ServiceFactory<
110 IoBoxed,
111 SharedCfg,
112 Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
113 Error = MqttError<C::Error>,
114 InitError = C::InitError,
115 >,
116 impl ServiceFactory<
117 DispatchItem<Rc<MqttShared>>,
118 (SharedCfg, Session<St>),
119 Response = Option<mqtt::Encoded>,
120 Error = MqttError<C::Error>,
121 InitError = MqttError<C::Error>,
122 >,
123 M,
124 Rc<MqttShared>,
125 >
126 where
127 F: IntoServiceFactory<Srv, Publish, Session<St>>,
128 C::Error:
129 From<Cn::Error> + From<Cn::InitError> + From<Srv::Error> + From<Srv::InitError>,
130 Srv: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
131 Srv::Error: fmt::Debug,
132 PublishAck: TryFrom<Srv::Error, Error = C::Error>,
133 {
134 service::MqttServer::new(
135 HandshakeFactory { factory: self.handshake, pool: self.pool, _t: PhantomData },
136 factory(publish.into_factory(), self.control),
137 self.middleware,
138 )
139 }
140}
141
142struct HandshakeFactory<St, H> {
143 factory: H,
144 pool: Rc<MqttSinkPool>,
145 _t: PhantomData<St>,
146}
147
148impl<St, H> ServiceFactory<IoBoxed, SharedCfg> for HandshakeFactory<St, H>
149where
150 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
151 H::Error: fmt::Debug,
152{
153 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
154 type Error = MqttError<H::Error>;
155
156 type Service = HandshakeService<St, H::Service>;
157 type InitError = H::InitError;
158
159 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
160 Ok(HandshakeService {
161 cfg: cfg.get(),
162 service: self.factory.create(cfg).await?,
163 pool: self.pool.clone(),
164 _t: PhantomData,
165 })
166 }
167}
168
169struct HandshakeService<St, H> {
170 service: H,
171 cfg: Cfg<MqttServiceConfig>,
172 pool: Rc<MqttSinkPool>,
173 _t: PhantomData<St>,
174}
175
176impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
177where
178 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
179 H::Error: fmt::Debug,
180{
181 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
182 type Error = MqttError<H::Error>;
183
184 ntex_service::forward_ready!(service, MqttError::Service);
185 ntex_service::forward_poll!(service, MqttError::Service);
186 ntex_service::forward_shutdown!(service);
187
188 async fn call(
189 &self,
190 io: IoBoxed,
191 ctx: ServiceCtx<'_, Self>,
192 ) -> Result<Self::Response, Self::Error> {
193 log::trace!("Starting mqtt v5 handshake");
194
195 let codec = mqtt::Codec::default();
196 codec.set_max_inbound_size(self.cfg.max_size);
197 codec.set_min_chunk_size(self.cfg.min_chunk_size);
198
199 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone()));
200 shared.set_max_qos(self.cfg.max_qos);
201 shared.set_receive_max(self.cfg.max_receive);
202 shared.set_topic_alias_max(self.cfg.max_topic_alias);
203
204 let packet = timeout_checked(self.cfg.connect_timeout, io.recv(&shared.codec))
206 .await
207 .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
208 .map_err(|err| {
209 log::trace!("Error is received during mqtt handshake: {:?}", err);
210 MqttError::Handshake(HandshakeError::from(err))
211 })?
212 .ok_or_else(|| {
213 log::trace!("Server mqtt is disconnected during handshake");
214 MqttError::Handshake(HandshakeError::Disconnected(None))
215 })?;
216
217 match packet {
218 Decoded::Packet(Packet::Connect(connect), size) => {
219 if let Some(size) = connect.max_packet_size {
221 shared.codec.set_max_outbound_size(size.get());
222 }
223 let keep_alive = connect.keep_alive;
224 let peer_receive_max =
225 connect.receive_max.map(|v| v.get()).unwrap_or(16) as usize;
226
227 let mut ack = ctx
229 .call(&self.service, Handshake::new(connect, size, io, shared))
230 .await
231 .map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))?;
232
233 match ack.session {
234 Some(session) => {
235 log::trace!("Sending: {:#?}", ack.packet);
236 let shared = ack.shared;
237
238 shared.set_max_qos(ack.packet.max_qos);
239 shared.set_receive_max(ack.packet.receive_max.get());
240 shared.set_topic_alias_max(ack.packet.topic_alias_max);
241 shared
242 .codec
243 .set_max_inbound_size(ack.packet.max_packet_size.unwrap_or(0));
244 shared.codec.set_retain_available(ack.packet.retain_available);
245 shared.codec.set_sub_ids_available(
246 ack.packet.subscription_identifiers_available,
247 );
248 if ack.packet.server_keepalive_sec.is_none()
249 && (keep_alive > ack.keepalive)
250 {
251 ack.packet.server_keepalive_sec = Some(ack.keepalive);
252 }
253 shared.set_cap(peer_receive_max);
254
255 ack.io.encode(
256 Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
257 &shared.codec,
258 )?;
259
260 Ok((
261 ack.io,
262 shared.clone(),
263 Session::new(session, MqttSink::new(shared)),
264 Seconds(ack.keepalive),
265 ))
266 }
267 None => {
268 log::trace!("Failed to complete handshake: {:#?}", ack.packet);
269
270 ack.io.encode(
271 Encoded::Packet(Packet::ConnectAck(Box::new(ack.packet))),
272 &ack.shared.codec,
273 )?;
274 let _ = ack.io.shutdown().await;
275 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
276 }
277 }
278 }
279 Decoded::Packet(packet, _) => {
280 log::info!(
281 "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
282 packet.packet_type()
283 );
284 Err(MqttError::Handshake(HandshakeError::Protocol(
285 ProtocolError::unexpected_packet(
286 packet.packet_type(),
287 "Expected CONNECT packet [MQTT-3.1.0-1]",
288 ),
289 )))
290 }
291 Decoded::Publish(..) => {
292 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
293 Err(MqttError::Handshake(HandshakeError::Protocol(
294 ProtocolError::unexpected_packet(
295 crate::types::packet_type::PUBLISH_START,
296 "Expected CONNECT packet [MQTT-3.1.0-1]",
297 ),
298 )))
299 }
300 Decoded::PayloadChunk(..) => unreachable!(),
301 }
302 }
303}