actix_mqtt/
client.rs

1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::rc::Rc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use actix_codec::{AsyncRead, AsyncWrite};
8use actix_ioframe as ioframe;
9use actix_service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory};
10use bytes::Bytes;
11use bytestring::ByteString;
12use futures::future::{FutureExt, LocalBoxFuture};
13use futures::{Sink, SinkExt, Stream, StreamExt};
14use mqtt_codec as mqtt;
15
16use crate::cell::Cell;
17use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
18use crate::dispatcher::{dispatcher, MqttState};
19use crate::error::MqttError;
20use crate::publish::Publish;
21use crate::sink::MqttSink;
22use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
23
24/// Mqtt client
25#[derive(Clone)]
26pub struct Client<Io, St> {
27    client_id: ByteString,
28    clean_session: bool,
29    protocol: mqtt::Protocol,
30    keep_alive: u16,
31    last_will: Option<mqtt::LastWill>,
32    username: Option<ByteString>,
33    password: Option<Bytes>,
34    inflight: usize,
35    _t: PhantomData<(Io, St)>,
36}
37
38impl<Io, St> Client<Io, St>
39where
40    St: 'static,
41{
42    /// Create new client and provide client id
43    pub fn new(client_id: ByteString) -> Self {
44        Client {
45            client_id,
46            clean_session: true,
47            protocol: mqtt::Protocol::default(),
48            keep_alive: 30,
49            last_will: None,
50            username: None,
51            password: None,
52            inflight: 15,
53            _t: PhantomData,
54        }
55    }
56
57    /// Mqtt protocol version
58    pub fn protocol(mut self, val: mqtt::Protocol) -> Self {
59        self.protocol = val;
60        self
61    }
62
63    /// The handling of the Session state.
64    pub fn clean_session(mut self, val: bool) -> Self {
65        self.clean_session = val;
66        self
67    }
68
69    /// A time interval measured in seconds.
70    ///
71    /// keep-alive is set to 30 seconds by default.
72    pub fn keep_alive(mut self, val: u16) -> Self {
73        self.keep_alive = val;
74        self
75    }
76
77    /// Will Message be stored on the Server and associated with the Network Connection.
78    ///
79    /// by default last will value is not set
80    pub fn last_will(mut self, val: mqtt::LastWill) -> Self {
81        self.last_will = Some(val);
82        self
83    }
84
85    /// Username can be used by the Server for authentication and authorization.
86    pub fn username(mut self, val: ByteString) -> Self {
87        self.username = Some(val);
88        self
89    }
90
91    /// Password can be used by the Server for authentication and authorization.
92    pub fn password(mut self, val: Bytes) -> Self {
93        self.password = Some(val);
94        self
95    }
96
97    /// Number of in-flight concurrent messages.
98    ///
99    /// in-flight is set to 15 messages
100    pub fn inflight(mut self, val: usize) -> Self {
101        self.inflight = val;
102        self
103    }
104
105    /// Set state service
106    ///
107    /// State service verifies connect ack packet and construct connection state.
108    pub fn state<C, F>(self, state: F) -> ServiceBuilder<Io, St, C>
109    where
110        F: IntoService<C>,
111        Io: AsyncRead + AsyncWrite,
112        C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>>,
113        C::Error: 'static,
114    {
115        ServiceBuilder {
116            state: Cell::new(state.into_service()),
117            packet: mqtt::Connect {
118                client_id: self.client_id,
119                clean_session: self.clean_session,
120                protocol: self.protocol,
121                keep_alive: self.keep_alive,
122                last_will: self.last_will,
123                username: self.username,
124                password: self.password,
125            },
126            subscribe: Rc::new(boxed::factory(SubsNotImplemented::default())),
127            unsubscribe: Rc::new(boxed::factory(UnsubsNotImplemented::default())),
128            disconnect: None,
129            keep_alive: self.keep_alive.into(),
130            inflight: self.inflight,
131            _t: PhantomData,
132        }
133    }
134}
135
136pub struct ServiceBuilder<Io, St, C: Service> {
137    state: Cell<C>,
138    packet: mqtt::Connect,
139    subscribe: Rc<
140        boxed::BoxServiceFactory<
141            St,
142            Subscribe<St>,
143            SubscribeResult,
144            MqttError<C::Error>,
145            MqttError<C::Error>,
146        >,
147    >,
148    unsubscribe: Rc<
149        boxed::BoxServiceFactory<
150            St,
151            Unsubscribe<St>,
152            (),
153            MqttError<C::Error>,
154            MqttError<C::Error>,
155        >,
156    >,
157    disconnect: Option<Cell<boxed::BoxService<St, (), MqttError<C::Error>>>>,
158    keep_alive: u64,
159    inflight: usize,
160
161    _t: PhantomData<(Io, St, C)>,
162}
163
164impl<Io, St, C> ServiceBuilder<Io, St, C>
165where
166    St: Clone + 'static,
167    Io: AsyncRead + AsyncWrite + 'static,
168    C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
169    C::Error: 'static,
170{
171    /// Service to execute on disconnect
172    pub fn disconnect<UF, U>(mut self, srv: UF) -> Self
173    where
174        UF: IntoService<U>,
175        U: Service<Request = St, Response = (), Error = C::Error> + 'static,
176    {
177        self.disconnect = Some(Cell::new(boxed::service(
178            srv.into_service().map_err(MqttError::Service),
179        )));
180        self
181    }
182
183    pub fn finish<F, T>(
184        self,
185        service: F,
186    ) -> impl Service<Request = Io, Response = (), Error = MqttError<C::Error>>
187    where
188        F: IntoServiceFactory<T>,
189        T: ServiceFactory<
190                Config = St,
191                Request = Publish<St>,
192                Response = (),
193                Error = C::Error,
194                InitError = C::Error,
195            > + 'static,
196    {
197        ioframe::Builder::new()
198            .service(ConnectService {
199                connect: self.state,
200                packet: self.packet,
201                keep_alive: self.keep_alive,
202                inflight: self.inflight,
203                _t: PhantomData,
204            })
205            .finish(dispatcher(
206                service
207                    .into_factory()
208                    .map_err(MqttError::Service)
209                    .map_init_err(MqttError::Service),
210                self.subscribe,
211                self.unsubscribe,
212            ))
213            .map_err(|e| match e {
214                ioframe::ServiceError::Service(e) => e,
215                ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
216                ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
217            })
218    }
219}
220
221struct ConnectService<Io, St, C> {
222    connect: Cell<C>,
223    packet: mqtt::Connect,
224    keep_alive: u64,
225    inflight: usize,
226    _t: PhantomData<(Io, St)>,
227}
228
229impl<Io, St, C> Service for ConnectService<Io, St, C>
230where
231    St: 'static,
232    Io: AsyncRead + AsyncWrite + 'static,
233    C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
234    C::Error: 'static,
235{
236    type Request = ioframe::Connect<Io, mqtt::Codec>;
237    type Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>;
238    type Error = MqttError<C::Error>;
239    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
240
241    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
242        self.connect
243            .get_mut()
244            .poll_ready(cx)
245            .map_err(MqttError::Service)
246    }
247
248    fn call(&mut self, req: Self::Request) -> Self::Future {
249        let mut srv = self.connect.clone();
250        let packet = self.packet.clone();
251        let keep_alive = Duration::from_secs(self.keep_alive as u64);
252        let inflight = self.inflight;
253
254        // send Connect packet
255        async move {
256            let mut framed = req.codec(mqtt::Codec::new());
257            framed
258                .send(mqtt::Packet::Connect(packet))
259                .await
260                .map_err(MqttError::Protocol)?;
261
262            let packet = framed
263                .next()
264                .await
265                .ok_or(MqttError::Disconnected)
266                .and_then(|res| res.map_err(MqttError::Protocol))?;
267
268            match packet {
269                mqtt::Packet::ConnectAck {
270                    session_present,
271                    return_code,
272                } => {
273                    let sink = MqttSink::new(framed.sink().clone());
274                    let ack = ConnectAck {
275                        sink,
276                        session_present,
277                        return_code,
278                        keep_alive,
279                        inflight,
280                        io: framed,
281                    };
282                    Ok(srv
283                        .get_mut()
284                        .call(ack)
285                        .await
286                        .map_err(MqttError::Service)
287                        .map(|ack| ack.io.state(ack.state))?)
288                }
289                p => Err(MqttError::Unexpected(p, "Expected CONNECT-ACK packet")),
290            }
291        }
292        .boxed_local()
293    }
294}
295
296pub struct ConnectAck<Io> {
297    io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
298    sink: MqttSink,
299    session_present: bool,
300    return_code: mqtt::ConnectCode,
301    keep_alive: Duration,
302    inflight: usize,
303}
304
305impl<Io> ConnectAck<Io> {
306    #[inline]
307    /// Indicates whether there is already stored Session state
308    pub fn session_present(&self) -> bool {
309        self.session_present
310    }
311
312    #[inline]
313    /// Connect return code
314    pub fn return_code(&self) -> mqtt::ConnectCode {
315        self.return_code
316    }
317
318    #[inline]
319    /// Mqtt client sink object
320    pub fn sink(&self) -> &MqttSink {
321        &self.sink
322    }
323
324    #[inline]
325    /// Set connection state and create result object
326    pub fn state<St>(self, state: St) -> ConnectAckResult<Io, St> {
327        ConnectAckResult {
328            io: self.io,
329            state: MqttState::new(state, self.sink, self.keep_alive, self.inflight),
330        }
331    }
332}
333
334impl<Io> Stream for ConnectAck<Io>
335where
336    Io: AsyncRead + AsyncWrite + Unpin,
337{
338    type Item = Result<mqtt::Packet, mqtt::ParseError>;
339
340    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
341        Pin::new(&mut self.io).poll_next(cx)
342    }
343}
344
345impl<Io> Sink<mqtt::Packet> for ConnectAck<Io>
346where
347    Io: AsyncRead + AsyncWrite + Unpin,
348{
349    type Error = mqtt::ParseError;
350
351    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
352        Pin::new(&mut self.io).poll_ready(cx)
353    }
354
355    fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
356        Pin::new(&mut self.io).start_send(item)
357    }
358
359    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
360        Pin::new(&mut self.io).poll_flush(cx)
361    }
362
363    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
364        Pin::new(&mut self.io).poll_close(cx)
365    }
366}
367
368#[pin_project::pin_project]
369pub struct ConnectAckResult<Io, St> {
370    state: MqttState<St>,
371    io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
372}
373
374impl<Io, St> Stream for ConnectAckResult<Io, St>
375where
376    Io: AsyncRead + AsyncWrite + Unpin,
377{
378    type Item = Result<mqtt::Packet, mqtt::ParseError>;
379
380    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
381        Pin::new(&mut self.io).poll_next(cx)
382    }
383}
384
385impl<Io, St> Sink<mqtt::Packet> for ConnectAckResult<Io, St>
386where
387    Io: AsyncRead + AsyncWrite + Unpin,
388{
389    type Error = mqtt::ParseError;
390
391    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
392        Pin::new(&mut self.io).poll_ready(cx)
393    }
394
395    fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
396        Pin::new(&mut self.io).start_send(item)
397    }
398
399    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
400        Pin::new(&mut self.io).poll_flush(cx)
401    }
402
403    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
404        Pin::new(&mut self.io).poll_close(cx)
405    }
406}