actix_mqtt/
server.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::rc::Rc;
4use std::time::Duration;
5
6use actix_codec::{AsyncRead, AsyncWrite};
7use actix_ioframe as ioframe;
8use actix_service::{apply, apply_fn, boxed, fn_factory, pipeline_factory, unit_config};
9use actix_service::{IntoServiceFactory, Service, ServiceFactory};
10use actix_utils::timeout::{Timeout, TimeoutError};
11use futures::{FutureExt, SinkExt, StreamExt};
12use mqtt_codec as mqtt;
13
14use crate::cell::Cell;
15use crate::connect::{Connect, ConnectAck};
16use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
17use crate::dispatcher::{dispatcher, MqttState};
18use crate::error::MqttError;
19use crate::publish::Publish;
20use crate::sink::MqttSink;
21use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
22
23/// Mqtt Server
24pub struct MqttServer<Io, St, C: ServiceFactory, U> {
25    connect: C,
26    subscribe: boxed::BoxServiceFactory<
27        St,
28        Subscribe<St>,
29        SubscribeResult,
30        MqttError<C::Error>,
31        MqttError<C::Error>,
32    >,
33    unsubscribe: boxed::BoxServiceFactory<
34        St,
35        Unsubscribe<St>,
36        (),
37        MqttError<C::Error>,
38        MqttError<C::Error>,
39    >,
40    disconnect: U,
41    max_size: usize,
42    inflight: usize,
43    handshake_timeout: u64,
44    _t: PhantomData<(Io, St)>,
45}
46
47fn default_disconnect<St>(_: St, _: bool) {}
48
49impl<Io, St, C> MqttServer<Io, St, C, ()>
50where
51    St: 'static,
52    C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
53        + 'static,
54{
55    /// Create server factory and provide connect service
56    pub fn new<F>(connect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
57    where
58        F: IntoServiceFactory<C>,
59    {
60        MqttServer {
61            connect: connect.into_factory(),
62            subscribe: boxed::factory(
63                pipeline_factory(SubsNotImplemented::default())
64                    .map_err(MqttError::Service)
65                    .map_init_err(MqttError::Service),
66            ),
67            unsubscribe: boxed::factory(
68                pipeline_factory(UnsubsNotImplemented::default())
69                    .map_err(MqttError::Service)
70                    .map_init_err(MqttError::Service),
71            ),
72            max_size: 0,
73            inflight: 15,
74            disconnect: default_disconnect,
75            handshake_timeout: 0,
76            _t: PhantomData,
77        }
78    }
79}
80
81impl<Io, St, C, U> MqttServer<Io, St, C, U>
82where
83    St: Clone + 'static,
84    U: Fn(St, bool) + 'static,
85    C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>
86        + 'static,
87{
88    /// Set handshake timeout in millis.
89    ///
90    /// Handshake includes `connect` packet and response `connect-ack`.
91    /// By default handshake timeuot is disabled.
92    pub fn handshake_timeout(mut self, timeout: u64) -> Self {
93        self.handshake_timeout = timeout;
94        self
95    }
96
97    /// Set max inbound frame size.
98    ///
99    /// If max size is set to `0`, size is unlimited.
100    /// By default max size is set to `0`
101    pub fn max_size(mut self, size: usize) -> Self {
102        self.max_size = size;
103        self
104    }
105
106    /// Number of in-flight concurrent messages.
107    ///
108    /// in-flight is set to 15 messages
109    pub fn inflight(mut self, val: usize) -> Self {
110        self.inflight = val;
111        self
112    }
113
114    /// Service to execute for subscribe packet
115    pub fn subscribe<F, Srv>(mut self, subscribe: F) -> Self
116    where
117        F: IntoServiceFactory<Srv>,
118        Srv: ServiceFactory<Config = St, Request = Subscribe<St>, Response = SubscribeResult>
119            + 'static,
120        C::Error: From<Srv::Error> + From<Srv::InitError>,
121    {
122        self.subscribe = boxed::factory(
123            subscribe
124                .into_factory()
125                .map_err(|e| MqttError::Service(e.into()))
126                .map_init_err(|e| MqttError::Service(e.into())),
127        );
128        self
129    }
130
131    /// Service to execute for unsubscribe packet
132    pub fn unsubscribe<F, Srv>(mut self, unsubscribe: F) -> Self
133    where
134        F: IntoServiceFactory<Srv>,
135        Srv: ServiceFactory<Config = St, Request = Unsubscribe<St>, Response = ()> + 'static,
136        C::Error: From<Srv::Error> + From<Srv::InitError>,
137    {
138        self.unsubscribe = boxed::factory(
139            unsubscribe
140                .into_factory()
141                .map_err(|e| MqttError::Service(e.into()))
142                .map_init_err(|e| MqttError::Service(e.into())),
143        );
144        self
145    }
146
147    /// Callback to execute on disconnect
148    ///
149    /// Second parameter indicates error occured during disconnect.
150    pub fn disconnect<F, Out>(self, disconnect: F) -> MqttServer<Io, St, C, impl Fn(St, bool)>
151    where
152        F: Fn(St, bool) -> Out,
153        Out: Future + 'static,
154    {
155        MqttServer {
156            connect: self.connect,
157            subscribe: self.subscribe,
158            unsubscribe: self.unsubscribe,
159            max_size: self.max_size,
160            inflight: self.inflight,
161            handshake_timeout: self.handshake_timeout,
162            disconnect: move |st: St, err| {
163                let fut = disconnect(st, err);
164                actix_rt::spawn(fut.map(|_| ()));
165            },
166            _t: PhantomData,
167        }
168    }
169
170    /// Set service to execute for publish packet and create service factory
171    pub fn finish<F, P>(
172        self,
173        publish: F,
174    ) -> impl ServiceFactory<Config = (), Request = Io, Response = (), Error = MqttError<C::Error>>
175    where
176        Io: AsyncRead + AsyncWrite + 'static,
177        F: IntoServiceFactory<P>,
178        P: ServiceFactory<Config = St, Request = Publish<St>, Response = ()> + 'static,
179        C::Error: From<P::Error> + From<P::InitError>,
180    {
181        let connect = self.connect;
182        let max_size = self.max_size;
183        let handshake_timeout = self.handshake_timeout;
184        let disconnect = self.disconnect;
185        let publish = boxed::factory(
186            publish
187                .into_factory()
188                .map_err(|e| MqttError::Service(e.into()))
189                .map_init_err(|e| MqttError::Service(e.into())),
190        );
191
192        unit_config(
193            ioframe::Builder::new()
194                .factory(connect_service_factory(
195                    connect,
196                    max_size,
197                    self.inflight,
198                    handshake_timeout,
199                ))
200                .disconnect(move |cfg, err| disconnect(cfg.session().clone(), err))
201                .finish(dispatcher(
202                    publish,
203                    Rc::new(self.subscribe),
204                    Rc::new(self.unsubscribe),
205                ))
206                .map_err(|e| match e {
207                    ioframe::ServiceError::Service(e) => e,
208                    ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
209                    ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
210                }),
211        )
212    }
213}
214
215fn connect_service_factory<Io, St, C>(
216    factory: C,
217    max_size: usize,
218    inflight: usize,
219    handshake_timeout: u64,
220) -> impl ServiceFactory<
221    Config = (),
222    Request = ioframe::Connect<Io, mqtt::Codec>,
223    Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>,
224    Error = MqttError<C::Error>,
225>
226where
227    Io: AsyncRead + AsyncWrite,
228    C: ServiceFactory<Config = (), Request = Connect<Io>, Response = ConnectAck<Io, St>>,
229{
230    apply(
231        Timeout::new(Duration::from_millis(handshake_timeout)),
232        fn_factory(move || {
233            let fut = factory.new_service(());
234
235            async move {
236                let service = Cell::new(fut.await?);
237
238                Ok::<_, C::InitError>(apply_fn(
239                    service.map_err(MqttError::Service),
240                    move |conn: ioframe::Connect<Io, mqtt::Codec>, service| {
241                        let mut srv = service.clone();
242                        let mut framed = conn.codec(mqtt::Codec::new().max_size(max_size));
243
244                        async move {
245                            // read first packet
246                            let packet = framed
247                                .next()
248                                .await
249                                .ok_or(MqttError::Disconnected)
250                                .and_then(|res| res.map_err(|e| MqttError::Protocol(e)))?;
251
252                            match packet {
253                                mqtt::Packet::Connect(connect) => {
254                                    let sink = MqttSink::new(framed.sink().clone());
255
256                                    // authenticate mqtt connection
257                                    let mut ack = srv
258                                        .call(Connect::new(
259                                            connect,
260                                            framed,
261                                            sink.clone(),
262                                            inflight,
263                                        ))
264                                        .await?;
265
266                                    match ack.session {
267                                        Some(session) => {
268                                            log::trace!(
269                                                "Sending: {:#?}",
270                                                mqtt::Packet::ConnectAck {
271                                                    session_present: ack.session_present,
272                                                    return_code:
273                                                        mqtt::ConnectCode::ConnectionAccepted,
274                                                }
275                                            );
276                                            ack.io
277                                                .send(mqtt::Packet::ConnectAck {
278                                                    session_present: ack.session_present,
279                                                    return_code:
280                                                        mqtt::ConnectCode::ConnectionAccepted,
281                                                })
282                                                .await?;
283
284                                            Ok(ack.io.state(MqttState::new(
285                                                session,
286                                                sink,
287                                                ack.keep_alive,
288                                                ack.inflight,
289                                            )))
290                                        }
291                                        None => {
292                                            log::trace!(
293                                                "Sending: {:#?}",
294                                                mqtt::Packet::ConnectAck {
295                                                    session_present: false,
296                                                    return_code: ack.return_code,
297                                                }
298                                            );
299
300                                            ack.io
301                                                .send(mqtt::Packet::ConnectAck {
302                                                    session_present: false,
303                                                    return_code: ack.return_code,
304                                                })
305                                                .await?;
306                                            Err(MqttError::Disconnected)
307                                        }
308                                    }
309                                }
310                                packet => {
311                                    log::info!(
312                                        "MQTT-3.1.0-1: Expected CONNECT packet, received {}",
313                                        packet.packet_type()
314                                    );
315                                    Err(MqttError::Unexpected(
316                                        packet,
317                                        "MQTT-3.1.0-1: Expected CONNECT packet",
318                                    ))
319                                }
320                            }
321                        }
322                    },
323                ))
324            }
325        }),
326    )
327    .map_err(|e| match e {
328        TimeoutError::Service(e) => e,
329        TimeoutError::Timeout => MqttError::HandshakeTimeout,
330    })
331}