1use std::{fmt, marker::PhantomData, rc::Rc};
2
3use ntex_io::{DispatchItem, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack};
6use ntex_util::time::{Seconds, timeout_checked};
7
8use crate::error::{HandshakeError, MqttError, ProtocolError};
9use crate::{MqttServiceConfig, service};
10
11use super::control::{Control, ControlAck};
12use super::default::{DefaultControlService, InFlightService};
13use super::handshake::{Handshake, HandshakeAck};
14use super::shared::{MqttShared, MqttSinkPool};
15use super::{MqttSink, Publish, Session, codec as mqtt, dispatcher::factory};
16
17pub struct MqttServer<St, H, C, M = Identity> {
45 handshake: H,
46 control: C,
47 middleware: M,
48 pub(super) pool: Rc<MqttSinkPool>,
49 _t: PhantomData<St>,
50}
51
52impl<St, H> MqttServer<St, H, DefaultControlService<St, H::Error>, InFlightService>
53where
54 St: 'static,
55 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
56 H::Error: fmt::Debug,
57{
58 pub fn new<F>(handshake: F) -> Self
60 where
61 F: IntoServiceFactory<H, Handshake, SharedCfg>,
62 {
63 MqttServer {
64 handshake: handshake.into_factory(),
65 control: DefaultControlService::default(),
66 middleware: InFlightService,
67 pool: Default::default(),
68 _t: PhantomData,
69 }
70 }
71}
72
73impl<St, H, C, M> MqttServer<St, H, C, M>
74where
75 St: 'static,
76 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
77 C: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
78 H::Error: From<C::Error> + From<C::InitError> + fmt::Debug,
79{
80 pub fn control<F, Srv>(self, service: F) -> MqttServer<St, H, Srv, M>
85 where
86 F: IntoServiceFactory<Srv, Control<H::Error>, Session<St>>,
87 Srv: ServiceFactory<Control<H::Error>, Session<St>, Response = ControlAck> + 'static,
88 H::Error: From<Srv::Error> + From<Srv::InitError>,
89 {
90 MqttServer {
91 handshake: self.handshake,
92 control: service.into_factory(),
93 middleware: self.middleware,
94 pool: self.pool,
95 _t: PhantomData,
96 }
97 }
98
99 pub fn middleware<U>(self, mw: U) -> MqttServer<St, H, C, Stack<M, U>> {
107 MqttServer {
108 middleware: Stack::new(self.middleware, mw),
109 handshake: self.handshake,
110 control: self.control,
111 pool: self.pool,
112 _t: PhantomData,
113 }
114 }
115
116 pub fn replace_middlewares<U>(self, mw: U) -> MqttServer<St, H, C, U> {
118 MqttServer {
119 middleware: mw,
120 handshake: self.handshake,
121 control: self.control,
122 pool: self.pool,
123 _t: PhantomData,
124 }
125 }
126
127 pub fn publish<F, Srv>(
129 self,
130 publish: F,
131 ) -> service::MqttServer<
132 Session<St>,
133 impl ServiceFactory<
134 IoBoxed,
135 SharedCfg,
136 Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds),
137 Error = MqttError<H::Error>,
138 InitError = H::InitError,
139 >,
140 impl ServiceFactory<
141 DispatchItem<Rc<MqttShared>>,
142 (SharedCfg, Session<St>),
143 Response = Option<mqtt::Encoded>,
144 Error = MqttError<H::Error>,
145 InitError = MqttError<H::Error>,
146 >,
147 M,
148 Rc<MqttShared>,
149 >
150 where
151 H::Error: From<C::Error>
152 + From<C::InitError>
153 + From<Srv::Error>
154 + From<Srv::InitError>
155 + fmt::Debug,
156 F: IntoServiceFactory<Srv, Publish, Session<St>>,
157 Srv: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
158 H::Error: From<Srv::Error> + From<Srv::InitError> + fmt::Debug,
159 {
160 service::MqttServer::new(
161 HandshakeFactory {
162 factory: self.handshake,
163 pool: self.pool.clone(),
164 _t: PhantomData,
165 },
166 factory(publish.into_factory(), self.control),
167 self.middleware,
168 )
169 }
170}
171
172struct HandshakeFactory<St, H> {
173 factory: H,
174 pool: Rc<MqttSinkPool>,
175 _t: PhantomData<St>,
176}
177
178impl<St, H> ServiceFactory<IoBoxed, SharedCfg> for HandshakeFactory<St, H>
179where
180 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
181 H::Error: fmt::Debug,
182{
183 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
184 type Error = MqttError<H::Error>;
185
186 type Service = HandshakeService<St, H::Service>;
187 type InitError = H::InitError;
188
189 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
190 Ok(HandshakeService {
191 cfg: cfg.get(),
192 pool: self.pool.clone(),
193 service: self.factory.create(cfg).await?,
194 _t: PhantomData,
195 })
196 }
197}
198
199struct HandshakeService<St, H> {
200 service: H,
201 cfg: Cfg<MqttServiceConfig>,
202 pool: Rc<MqttSinkPool>,
203 _t: PhantomData<St>,
204}
205
206impl<St, H> Service<IoBoxed> for HandshakeService<St, H>
207where
208 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
209 H::Error: fmt::Debug,
210{
211 type Response = (IoBoxed, Rc<MqttShared>, Session<St>, Seconds);
212 type Error = MqttError<H::Error>;
213
214 ntex_service::forward_ready!(service, MqttError::Service);
215 ntex_service::forward_poll!(service, MqttError::Service);
216 ntex_service::forward_shutdown!(service);
217
218 async fn call(
219 &self,
220 io: IoBoxed,
221 ctx: ServiceCtx<'_, Self>,
222 ) -> Result<Self::Response, Self::Error> {
223 log::trace!("Starting mqtt v3 handshake");
224
225 let codec = mqtt::Codec::default();
226 codec.set_max_size(self.cfg.max_size);
227 codec.set_min_chunk_size(self.cfg.min_chunk_size);
228 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone()));
229
230 let packet = timeout_checked(self.cfg.connect_timeout, io.recv(&shared.codec))
232 .await
233 .map_err(|_| MqttError::Handshake(HandshakeError::Timeout))?
234 .map_err(|err| {
235 log::trace!("Error is received during mqtt handshake: {:?}", err);
236 MqttError::Handshake(HandshakeError::from(err))
237 })?
238 .ok_or_else(|| {
239 log::trace!("Server mqtt is disconnected during handshake");
240 MqttError::Handshake(HandshakeError::Disconnected(None))
241 })?;
242
243 match packet {
244 mqtt::Decoded::Packet(mqtt::Packet::Connect(connect), size) => {
245 let ack = ctx
247 .call(&self.service, Handshake::new(connect, size, io, shared))
248 .await
249 .map_err(MqttError::Service)?;
250
251 match ack.session {
252 Some(session) => {
253 let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
254 session_present: ack.session_present,
255 return_code: mqtt::ConnectAckReason::ConnectionAccepted,
256 });
257
258 log::trace!("Sending success handshake ack: {:#?}", pkt);
259
260 ack.shared.set_cap(ack.max_send.unwrap_or(self.cfg.max_send) as usize);
261 if let Some(max_packet_size) = ack.max_packet_size {
262 ack.shared.codec.set_max_size(max_packet_size.get());
263 }
264 ack.io.encode(mqtt::Encoded::Packet(pkt), &ack.shared.codec)?;
265 Ok((
266 ack.io,
267 ack.shared.clone(),
268 Session::new(session, MqttSink::new(ack.shared)),
269 ack.keepalive,
270 ))
271 }
272 None => {
273 let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck {
274 session_present: false,
275 return_code: ack.return_code,
276 });
277
278 log::trace!("Sending failed handshake ack: {:#?}", pkt);
279 ack.io.encode(mqtt::Encoded::Packet(pkt), &ack.shared.codec)?;
280 let _ = ack.io.shutdown().await;
281
282 Err(MqttError::Handshake(HandshakeError::Disconnected(None)))
283 }
284 }
285 }
286 mqtt::Decoded::Packet(packet, _) => {
287 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet);
288 Err(MqttError::Handshake(HandshakeError::Protocol(
289 ProtocolError::unexpected_packet(
290 packet.packet_type(),
291 "MQTT-3.1.0-1: Expected CONNECT packet",
292 ),
293 )))
294 }
295 mqtt::Decoded::Publish(..) => {
296 log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received PUBLISH");
297 Err(MqttError::Handshake(HandshakeError::Protocol(
298 ProtocolError::unexpected_packet(
299 crate::types::packet_type::PUBLISH_START,
300 "Expected CONNECT packet [MQTT-3.1.0-1]",
301 ),
302 )))
303 }
304 mqtt::Decoded::PayloadChunk(..) => unreachable!(),
305 }
306 }
307}