1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::rc::Rc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use actix_codec::{AsyncRead, AsyncWrite};
8use actix_ioframe as ioframe;
9use actix_service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory};
10use bytes::Bytes;
11use bytestring::ByteString;
12use futures::future::{FutureExt, LocalBoxFuture};
13use futures::{Sink, SinkExt, Stream, StreamExt};
14use mqtt_codec as mqtt;
15
16use crate::cell::Cell;
17use crate::default::{SubsNotImplemented, UnsubsNotImplemented};
18use crate::dispatcher::{dispatcher, MqttState};
19use crate::error::MqttError;
20use crate::publish::Publish;
21use crate::sink::MqttSink;
22use crate::subs::{Subscribe, SubscribeResult, Unsubscribe};
23
24#[derive(Clone)]
26pub struct Client<Io, St> {
27 client_id: ByteString,
28 clean_session: bool,
29 protocol: mqtt::Protocol,
30 keep_alive: u16,
31 last_will: Option<mqtt::LastWill>,
32 username: Option<ByteString>,
33 password: Option<Bytes>,
34 inflight: usize,
35 _t: PhantomData<(Io, St)>,
36}
37
38impl<Io, St> Client<Io, St>
39where
40 St: 'static,
41{
42 pub fn new(client_id: ByteString) -> Self {
44 Client {
45 client_id,
46 clean_session: true,
47 protocol: mqtt::Protocol::default(),
48 keep_alive: 30,
49 last_will: None,
50 username: None,
51 password: None,
52 inflight: 15,
53 _t: PhantomData,
54 }
55 }
56
57 pub fn protocol(mut self, val: mqtt::Protocol) -> Self {
59 self.protocol = val;
60 self
61 }
62
63 pub fn clean_session(mut self, val: bool) -> Self {
65 self.clean_session = val;
66 self
67 }
68
69 pub fn keep_alive(mut self, val: u16) -> Self {
73 self.keep_alive = val;
74 self
75 }
76
77 pub fn last_will(mut self, val: mqtt::LastWill) -> Self {
81 self.last_will = Some(val);
82 self
83 }
84
85 pub fn username(mut self, val: ByteString) -> Self {
87 self.username = Some(val);
88 self
89 }
90
91 pub fn password(mut self, val: Bytes) -> Self {
93 self.password = Some(val);
94 self
95 }
96
97 pub fn inflight(mut self, val: usize) -> Self {
101 self.inflight = val;
102 self
103 }
104
105 pub fn state<C, F>(self, state: F) -> ServiceBuilder<Io, St, C>
109 where
110 F: IntoService<C>,
111 Io: AsyncRead + AsyncWrite,
112 C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>>,
113 C::Error: 'static,
114 {
115 ServiceBuilder {
116 state: Cell::new(state.into_service()),
117 packet: mqtt::Connect {
118 client_id: self.client_id,
119 clean_session: self.clean_session,
120 protocol: self.protocol,
121 keep_alive: self.keep_alive,
122 last_will: self.last_will,
123 username: self.username,
124 password: self.password,
125 },
126 subscribe: Rc::new(boxed::factory(SubsNotImplemented::default())),
127 unsubscribe: Rc::new(boxed::factory(UnsubsNotImplemented::default())),
128 disconnect: None,
129 keep_alive: self.keep_alive.into(),
130 inflight: self.inflight,
131 _t: PhantomData,
132 }
133 }
134}
135
136pub struct ServiceBuilder<Io, St, C: Service> {
137 state: Cell<C>,
138 packet: mqtt::Connect,
139 subscribe: Rc<
140 boxed::BoxServiceFactory<
141 St,
142 Subscribe<St>,
143 SubscribeResult,
144 MqttError<C::Error>,
145 MqttError<C::Error>,
146 >,
147 >,
148 unsubscribe: Rc<
149 boxed::BoxServiceFactory<
150 St,
151 Unsubscribe<St>,
152 (),
153 MqttError<C::Error>,
154 MqttError<C::Error>,
155 >,
156 >,
157 disconnect: Option<Cell<boxed::BoxService<St, (), MqttError<C::Error>>>>,
158 keep_alive: u64,
159 inflight: usize,
160
161 _t: PhantomData<(Io, St, C)>,
162}
163
164impl<Io, St, C> ServiceBuilder<Io, St, C>
165where
166 St: Clone + 'static,
167 Io: AsyncRead + AsyncWrite + 'static,
168 C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
169 C::Error: 'static,
170{
171 pub fn disconnect<UF, U>(mut self, srv: UF) -> Self
173 where
174 UF: IntoService<U>,
175 U: Service<Request = St, Response = (), Error = C::Error> + 'static,
176 {
177 self.disconnect = Some(Cell::new(boxed::service(
178 srv.into_service().map_err(MqttError::Service),
179 )));
180 self
181 }
182
183 pub fn finish<F, T>(
184 self,
185 service: F,
186 ) -> impl Service<Request = Io, Response = (), Error = MqttError<C::Error>>
187 where
188 F: IntoServiceFactory<T>,
189 T: ServiceFactory<
190 Config = St,
191 Request = Publish<St>,
192 Response = (),
193 Error = C::Error,
194 InitError = C::Error,
195 > + 'static,
196 {
197 ioframe::Builder::new()
198 .service(ConnectService {
199 connect: self.state,
200 packet: self.packet,
201 keep_alive: self.keep_alive,
202 inflight: self.inflight,
203 _t: PhantomData,
204 })
205 .finish(dispatcher(
206 service
207 .into_factory()
208 .map_err(MqttError::Service)
209 .map_init_err(MqttError::Service),
210 self.subscribe,
211 self.unsubscribe,
212 ))
213 .map_err(|e| match e {
214 ioframe::ServiceError::Service(e) => e,
215 ioframe::ServiceError::Encoder(e) => MqttError::Protocol(e),
216 ioframe::ServiceError::Decoder(e) => MqttError::Protocol(e),
217 })
218 }
219}
220
221struct ConnectService<Io, St, C> {
222 connect: Cell<C>,
223 packet: mqtt::Connect,
224 keep_alive: u64,
225 inflight: usize,
226 _t: PhantomData<(Io, St)>,
227}
228
229impl<Io, St, C> Service for ConnectService<Io, St, C>
230where
231 St: 'static,
232 Io: AsyncRead + AsyncWrite + 'static,
233 C: Service<Request = ConnectAck<Io>, Response = ConnectAckResult<Io, St>> + 'static,
234 C::Error: 'static,
235{
236 type Request = ioframe::Connect<Io, mqtt::Codec>;
237 type Response = ioframe::ConnectResult<Io, MqttState<St>, mqtt::Codec>;
238 type Error = MqttError<C::Error>;
239 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
240
241 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
242 self.connect
243 .get_mut()
244 .poll_ready(cx)
245 .map_err(MqttError::Service)
246 }
247
248 fn call(&mut self, req: Self::Request) -> Self::Future {
249 let mut srv = self.connect.clone();
250 let packet = self.packet.clone();
251 let keep_alive = Duration::from_secs(self.keep_alive as u64);
252 let inflight = self.inflight;
253
254 async move {
256 let mut framed = req.codec(mqtt::Codec::new());
257 framed
258 .send(mqtt::Packet::Connect(packet))
259 .await
260 .map_err(MqttError::Protocol)?;
261
262 let packet = framed
263 .next()
264 .await
265 .ok_or(MqttError::Disconnected)
266 .and_then(|res| res.map_err(MqttError::Protocol))?;
267
268 match packet {
269 mqtt::Packet::ConnectAck {
270 session_present,
271 return_code,
272 } => {
273 let sink = MqttSink::new(framed.sink().clone());
274 let ack = ConnectAck {
275 sink,
276 session_present,
277 return_code,
278 keep_alive,
279 inflight,
280 io: framed,
281 };
282 Ok(srv
283 .get_mut()
284 .call(ack)
285 .await
286 .map_err(MqttError::Service)
287 .map(|ack| ack.io.state(ack.state))?)
288 }
289 p => Err(MqttError::Unexpected(p, "Expected CONNECT-ACK packet")),
290 }
291 }
292 .boxed_local()
293 }
294}
295
296pub struct ConnectAck<Io> {
297 io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
298 sink: MqttSink,
299 session_present: bool,
300 return_code: mqtt::ConnectCode,
301 keep_alive: Duration,
302 inflight: usize,
303}
304
305impl<Io> ConnectAck<Io> {
306 #[inline]
307 pub fn session_present(&self) -> bool {
309 self.session_present
310 }
311
312 #[inline]
313 pub fn return_code(&self) -> mqtt::ConnectCode {
315 self.return_code
316 }
317
318 #[inline]
319 pub fn sink(&self) -> &MqttSink {
321 &self.sink
322 }
323
324 #[inline]
325 pub fn state<St>(self, state: St) -> ConnectAckResult<Io, St> {
327 ConnectAckResult {
328 io: self.io,
329 state: MqttState::new(state, self.sink, self.keep_alive, self.inflight),
330 }
331 }
332}
333
334impl<Io> Stream for ConnectAck<Io>
335where
336 Io: AsyncRead + AsyncWrite + Unpin,
337{
338 type Item = Result<mqtt::Packet, mqtt::ParseError>;
339
340 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
341 Pin::new(&mut self.io).poll_next(cx)
342 }
343}
344
345impl<Io> Sink<mqtt::Packet> for ConnectAck<Io>
346where
347 Io: AsyncRead + AsyncWrite + Unpin,
348{
349 type Error = mqtt::ParseError;
350
351 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
352 Pin::new(&mut self.io).poll_ready(cx)
353 }
354
355 fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
356 Pin::new(&mut self.io).start_send(item)
357 }
358
359 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
360 Pin::new(&mut self.io).poll_flush(cx)
361 }
362
363 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
364 Pin::new(&mut self.io).poll_close(cx)
365 }
366}
367
368#[pin_project::pin_project]
369pub struct ConnectAckResult<Io, St> {
370 state: MqttState<St>,
371 io: ioframe::ConnectResult<Io, (), mqtt::Codec>,
372}
373
374impl<Io, St> Stream for ConnectAckResult<Io, St>
375where
376 Io: AsyncRead + AsyncWrite + Unpin,
377{
378 type Item = Result<mqtt::Packet, mqtt::ParseError>;
379
380 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
381 Pin::new(&mut self.io).poll_next(cx)
382 }
383}
384
385impl<Io, St> Sink<mqtt::Packet> for ConnectAckResult<Io, St>
386where
387 Io: AsyncRead + AsyncWrite + Unpin,
388{
389 type Error = mqtt::ParseError;
390
391 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
392 Pin::new(&mut self.io).poll_ready(cx)
393 }
394
395 fn start_send(mut self: Pin<&mut Self>, item: mqtt::Packet) -> Result<(), Self::Error> {
396 Pin::new(&mut self.io).start_send(item)
397 }
398
399 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
400 Pin::new(&mut self.io).poll_flush(cx)
401 }
402
403 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
404 Pin::new(&mut self.io).poll_close(cx)
405 }
406}