1use std::future::Future;
2use std::marker::PhantomData;
3use std::rc::Rc;
4use std::time::Duration;
5
6use actix_codec::{AsyncRead, AsyncWrite};
7use actix_ioframe as ioframe;
8use actix_service::{apply, apply_fn, boxed, fn_factory, pipeline_factory, unit_config};
9use actix_service::{IntoServiceFactory, Service, ServiceFactory};
10use actix_utils::timeout::{Timeout, TimeoutError};
11use futures::{FutureExt, SinkExt, StreamExt};
12use mqtt_codec as mqtt;
13
14use crate::cell::Cell;
15use crate::connect::{Connect, ConnectAck};
16use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
17use crate::dispatcher::{dispatcher, MqttState};
18use crate::error::MqttError;
19use crate::publish::Publish;
20use crate::sink::MqttSink;
21use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
22
23pub struct MqttServer<Io, St, C: ServiceFactory, U> {
25 connect: C,
26 subscribe: boxed::BoxServiceFactory<
27 St,
28 Subscribe<St>,
29 SubscribeResult,
30 MqttError<C::Error>,
31 MqttError<C::Error>,
32 >,
33 unsubscribe: boxed::BoxServiceFactory<
34 St,
35 Unsubscribe<St>,
36 (),
37 MqttError<C::Error>,
38 MqttError<C::Error>,
39 >,
40 disconnect: U,
41 max_size: usize,
42 inflight: usize,
43 handshake_timeout: u64,
44 _t: PhantomData<(Io, St)>,
45}
46
47fn default_disconnect<St>(_: St, _: bool) {}
48
49impl<Io, St, C> MqttServer<Io, St, C, ()>
50where
51 St: 'static,
52 C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
53 + 'static,
54{
55 pub fn new<F>(connect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
57 where
58 F: IntoServiceFactory<C>,
59 {
60 MqttServer {
61 connect: connect.into_factory(),
62 subscribe: boxed::factory(
63 pipeline_factory(SubsNotImplemented::default())
64 .map_err(MqttError::Service)
65 .map_init_err(MqttError::Service),
66 ),
67 unsubscribe: boxed::factory(
68 pipeline_factory(UnsubsNotImplemented::default())
69 .map_err(MqttError::Service)
70 .map_init_err(MqttError::Service),
71 ),
72 max_size: 0,
73 inflight: 15,
74 disconnect: default_disconnect,
75 handshake_timeout: 0,
76 _t: PhantomData,
77 }
78 }
79}
80
81impl<Io, St, C, U> MqttServer<Io, St, C, U>
82where
83 St: Clone + 'static,
84 U: Fn(St, bool) + 'static,
85 C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
86 + 'static,
87{
88 pub fn handshake_timeout(mut self, timeout: u64) -> Self {
93 self.handshake_timeout = timeout;
94 self
95 }
96
97 pub fn max_size(mut self, size: usize) -> Self {
102 self.max_size = size;
103 self
104 }
105
106 pub fn inflight(mut self, val: usize) -> Self {
110 self.inflight = val;
111 self
112 }
113
114 pub fn subscribe<F, Srv>(mut self, subscribe: F) -> Self
116 where
117 F: IntoServiceFactory<Srv>,
118 Srv: ServiceFactory<Config = St, Request = Subscribe<St>, Response = SubscribeResult>
119 + 'static,
120 C::Error: From<Srv::Error> + From<Srv::InitError>,
121 {
122 self.subscribe = boxed::factory(
123 subscribe
124 .into_factory()
125 .map_err(|e| MqttError::Service(e.into()))
126 .map_init_err(|e| MqttError::Service(e.into())),
127 );
128 self
129 }
130
131 pub fn unsubscribe<F, Srv>(mut self, unsubscribe: F) -> Self
133 where
134 F: IntoServiceFactory<Srv>,
135 Srv: ServiceFactory<Config = St, Request = Unsubscribe<St>, Response = ()> + 'static,
136 C::Error: From<Srv::Error> + From<Srv::InitError>,
137 {
138 self.unsubscribe = boxed::factory(
139 unsubscribe
140 .into_factory()
141 .map_err(|e| MqttError::Service(e.into()))
142 .map_init_err(|e| MqttError::Service(e.into())),
143 );
144 self
145 }
146
147 pub fn disconnect<F, Out>(self, disconnect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
151 where
152 F: Fn(St, bool) -> Out,
153 Out: Future + 'static,
154 {
155 MqttServer {
156 connect: self.connect,
157 subscribe: self.subscribe,
158 unsubscribe: self.unsubscribe,
159 max_size: self.max_size,
160 inflight: self.inflight,
161 handshake_timeout: self.handshake_timeout,
162 disconnect: move |st: St, err| {
163 let fut = disconnect(st, err);
164 actix_rt::spawn(fut.map(|_| ()));
165 },
166 _t: PhantomData,
167 }
168 }
169
170 pub fn finish<F, P>(
172 self,
173 publish: F,
174 ) -> impl ServiceFactory<Config = (), Request = Io, Response = (), Error = MqttError<C::Error>>
175 where
176 Io: AsyncRead + AsyncWrite + 'static,
177 F: IntoServiceFactory<P>,
178 P: ServiceFactory<Config = St, Request = Publish<St>, Response = ()> + 'static,
179 C::Error: From<P::Error> + From<P::InitError>,
180 {
181 let connect = self.connect;
182 let max_size = self.max_size;
183 let handshake_timeout = self.handshake_timeout;
184 let disconnect = self.disconnect;
185 let publish = boxed::factory(
186 publish
187 .into_factory()
188 .map_err(|e| MqttError::Service(e.into()))
189 .map_init_err(|e| MqttError::Service(e.into())),
190 );
191
192 unit_config(
193 ioframe::Builder::new()
194 .factory(connect_service_factory(
195 connect,
196 max_size,
197 self.inflight,
198 handshake_timeout,
199 ))
200 .disconnect(move |cfg, err| disconnect(cfg.session().clone(), err))
201 .finish(dispatcher(
202 publish,
203 Rc::new(self.subscribe),
204 Rc::new(self.unsubscribe),
205 ))
206 .map_err(|e| match e {
207 ioframe::ServiceError::Service(e) => e,
208 ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
209 ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
210 }),
211 )
212 }
213}
214
215fn connect_service_factory<Io, St, C>(
216 factory: C,
217 max_size: usize,
218 inflight: usize,
219 handshake_timeout: u64,
220) -> impl ServiceFactory<
221 Config = (),
222 Request = ioframe::Connect<Io, mqtt::Codec>,
223 Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>,
224 Error = MqttError<C::Error>,
225>
226where
227 Io: AsyncRead + AsyncWrite,
228 C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>,
229{
230 apply(
231 Timeout::new(Duration::from_millis(handshake_timeout)),
232 fn_factory(move || {
233 let fut = factory.new_service(());
234
235 async move {
236 let service = Cell::new(fut.await?);
237
238 Ok::<_, C::InitError>(apply_fn(
239 service.map_err(MqttError::Service),
240 move |conn: ioframe::Connect<Io, mqtt::Codec>, service| {
241 let mut srv = service.clone();
242 let mut framed = conn.codec(mqtt::Codec::new().max_size(max_size));
243
244 async move {
245 let packet = framed
247 .next()
248 .await
249 .ok_or(MqttError::Disconnected)
250 .and_then(|res| res.map_err(|e| MqttError::Protocol(e)))?;
251
252 match packet {
253 mqtt::Packet::Connect(connect) => {
254 let sink = MqttSink::new(framed.sink().clone());
255
256 let mut ack = srv
258 .call(Connect::new(
259 connect,
260 framed,
261 sink.clone(),
262 inflight,
263 ))
264 .await?;
265
266 match ack.session {
267 Some(session) => {
268 log::trace!(
269 "Sending: {:#?}",
270 mqtt::Packet::ConnectAck {
271 session_present: ack.session_present,
272 return_code:
273 mqtt::ConnectCode::ConnectionAccepted,
274 }
275 );
276 ack.io
277 .send(mqtt::Packet::ConnectAck {
278 session_present: ack.session_present,
279 return_code:
280 mqtt::ConnectCode::ConnectionAccepted,
281 })
282 .await?;
283
284 Ok(ack.io.state(MqttState::new(
285 session,
286 sink,
287 ack.keep_alive,
288 ack.inflight,
289 )))
290 }
291 None => {
292 log::trace!(
293 "Sending: {:#?}",
294 mqtt::Packet::ConnectAck {
295 session_present: false,
296 return_code: ack.return_code,
297 }
298 );
299
300 ack.io
301 .send(mqtt::Packet::ConnectAck {
302 session_present: false,
303 return_code: ack.return_code,
304 })
305 .await?;
306 Err(MqttError::Disconnected)
307 }
308 }
309 }
310 packet => {
311 log::info!(
312 "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
313 packet.packet_type()
314 );
315 Err(MqttError::Unexpected(
316 packet,
317 "MQTT-3.1.0-1: Expected CONNECT packet",
318 ))
319 }
320 }
321 }
322 },
323 ))
324 }
325 }),
326 )
327 .map_err(|e| match e {
328 TimeoutError::Service(e) => e,
329 TimeoutError::Timeout => MqttError::HandshakeTimeout,
330 })
331}