mqttest 0.2.0

An MQTT server designed for unittesting MQTT clients.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
use crate::{dump::*, mqtt::*, pubsub::*, session::*, Conf, FOREVER};
use futures::{future::{abortable, AbortHandle},
              lock::Mutex,
              prelude::*};
use log::*;
use rand::{seq::SliceRandom, thread_rng};
use std::{collections::HashMap,
          io::{Error, ErrorKind},
          sync::Arc,
          time::{Duration, Instant}};
use tokio::{net::{tcp::{ReadHalf, WriteHalf},
                  TcpStream},
            spawn,
            sync::{mpsc::{channel, Receiver, Sender},
                   oneshot},
            time::delay_until};
use tokio_util::codec::{FramedRead, FramedWrite};


/// Connection id for debug and indexing purposes.
pub type ConnId = usize;

/// Wrapper around `futures::future::AbortHandle` that aborts when dropped.
struct AbortOnDrop(pub AbortHandle);
impl Drop for AbortOnDrop {
    fn drop(&mut self) {
        self.0.abort();
    }
}

/// Allows sending a `Msg` to a `Client`.
#[derive(Clone)]
pub struct Addr(Sender<Msg>, pub(crate) ConnId);
impl Addr {
    /// Send `Msg` to `Addr`.
    pub(crate) async fn send(&self, msg: Msg) {
        if let Err(e) = self.0.clone().send(msg).await {
            warn!("Trying to send to disconnected Addr {:?}", e);
        }
    }

    /// Wait until `Instant` and then send `Msg` to `Addr`.
    async fn send_at_async(addr: Addr, deadline: Instant, msg: Msg) {
        trace!("send_at {:?} {:?} {:?}", deadline, addr, msg);
        delay_until(deadline.into()).await;
        addr.send(msg).await;
    }

    /// Schedule `Msg` to be sent to `Addr` at `Instant`.
    fn send_at(&self, deadline: Instant, msg: Msg) {
        spawn(Self::send_at_async(self.clone(), deadline, msg).map(drop));
    }

    /// Schedule `Msg` to be sent to `Addr` at `Instant`, returning a andle that will abort sending
    /// if dropped.
    #[must_use]
    fn send_at_abort(&self, deadline: Instant, msg: Msg) -> AbortOnDrop {
        let (f, h) = abortable(Self::send_at_async(self.clone(), deadline, msg));
        spawn(f.map(drop));
        AbortOnDrop(h)
    }
}
impl PartialEq for Addr {
    fn eq(&self, other: &Self) -> bool {
        self.1 == other.1
    }
}
impl Eq for Addr {}
impl std::fmt::Debug for Addr {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Addr(_, {})", self.1)
    }
}

#[derive(Debug)]
pub(crate) enum Msg {
    PktIn(Packet),
    PktOut(Packet),
    Publish(QoS, Publish),
    CheckQos,
    Replaced(ConnId, oneshot::Sender<SessionData>),
    Disconnect(String),
}

/// Session data. To be restored at connection, and kept up to date during connection.
///
/// Note that we're not deriving Clone. The only place where a new SessionData should be
/// instanciated is Session::open(). This helps making sure that only one SessionData instance
/// exists for a given Client.name.
#[derive(Debug, Default)]
pub(crate) struct SessionData {
    /// Number of connections seen by this session. If it is 0, this is a brand new session.
    cons: usize,
    /// The last pid generated for this client.
    prev_pid: Option<Pid>,
    /// Topics subscribed by this session. Distinct from the global `Subs` store.
    subs: HashMap<String, QoS>,
    /// Pending Qos1 acks. Unacked packets will be resent after a delay.
    /// TODO: Resend immediately at reconnection too.
    qos1: HashMap<Pid, (Instant, Packet)>,
}
impl SessionData {
    /// Return the next pid and store it in self.
    fn next_pid(&mut self) -> Pid {
        let mut pid = match self.prev_pid {
            Some(p) => p + 1,
            None => Pid::new(),
        };
        while self.qos1.contains_key(&pid) {
            pid = pid + 1;
        }
        self.prev_pid = Some(pid);
        pid
    }
}

/// The `Client` struct follows the actor model. It's owned by one `Future`, that receives `Msg`s
/// and handles them, mutating the struct.
pub(crate) struct Client<'s> {
    pub id: ConnId,
    pub name: String,
    pub addr: Addr,
    /// Is the MQTT connection fully established ?
    // FIXME: there's more than two states.
    conn: bool,
    /// Write `Packet`s there, they'll get encoded and sent over the TcpStream.
    writer: FramedWrite<WriteHalf<'s>, Codec>,
    /// Dump targets.
    dumps: Dump,
    /// Protocol-specific ack-timeout config.
    ack_timeouts_conf: (Option<Duration>, Option<Duration>),
    /// Pending acks will timeout after that duration.
    ack_timeout: Duration,
    /// Wait before acking publish and subscribe packets
    // TODO: should be a ring buffer.
    ack_delay: Duration,
    /// Wether to allow or reject optional behaviours.
    strict: bool,
    /// Client_id must start with this.
    idprefix: String,
    /// Username:password must match this (if `userpass.is_some()`).
    userpass: Option<String>,
    /// Shared list of all the client subscriptions.
    subs: Arc<Mutex<Subs>>,
    /// Shared list of all the client sessions.
    sessions: Arc<Mutex<Sessions>>,
    /// Client session.
    session: Option<SessionData>,
    /// Override session expiry time.
    pub sess_expire: Option<Duration>,
    /// Handle to future sending the next Msg::CheckQos.
    qos1_check: Option<AbortOnDrop>,
    /// Disconnect after that many received packets.
    max_pkt: usize,
    /// Delay before max_pkt disconnection.
    max_pkt_delay: Option<Duration>,
    /// Count received packets.
    count_pkt: usize,
}
impl Client<'_> {
    /// Start a new `Client` on given socket, using a `Future` (to be executed by the caller) to
    /// represent the whole connection.
    pub async fn start(id: usize,
                       mut socket: TcpStream,
                       subs: Arc<Mutex<Subs>>,
                       sessions: Arc<Mutex<Sessions>>,
                       dumps: Dump,
                       conf: Conf) {
        info!("C{}: Connection from {:?}", id, socket);
        let (read, write) = socket.split();
        let (sx, rx) = channel::<Msg>(10);
        let max_pkt = conf.max_pkt[id as usize % conf.max_pkt.len()].unwrap_or(std::usize::MAX);
        let sess_expire = conf.sess_expire[id as usize % conf.sess_expire.len()];
        let mut client = Client { id,
                                  name: String::from(""),
                                  addr: Addr(sx.clone(), id),
                                  conn: false,
                                  writer: FramedWrite::new(write, Codec(id)),
                                  dumps,
                                  ack_timeouts_conf: conf.ack_timeouts,
                                  ack_timeout: conf.ack_timeouts.0.unwrap_or(FOREVER),
                                  ack_delay: conf.ack_delay,
                                  strict: conf.strict,
                                  idprefix: conf.idprefix.clone(),
                                  userpass: conf.userpass.clone(),
                                  subs,
                                  sessions,
                                  session: None,
                                  sess_expire,
                                  qos1_check: None,
                                  max_pkt,
                                  max_pkt_delay: conf.max_pkt_delay,
                                  count_pkt: 0 };

        // Setup disconnect timer.
        if let Some(m) = conf.max_time[id as usize % conf.max_time.len()] {
            client.addr.send_at(Instant::now() + m, Msg::Disconnect(format!("max time {:?}", m)))
        }

        // Initialize json dump target.
        for s in conf.dump_files {
            let s = s.replace("{c}", &format!("{}", id));
            match client.dumps.register(&s) {
                Ok(_) => debug!("C{}: Dump to {}", id, s),
                Err(e) => error!("C{}: Cannot dump to {}: {}", id, s, e),
            }
        }

        // Handle the Tcp and Msg streams concurrently.
        let f1 = Self::handle_net(read, sx, client.id);
        let f2 = Self::handle_msgs(&mut client, rx);
        let res = futures::select!(r = f1.fuse() => r, r = f2.fuse() => r);

        // One of the stream ended, cleanup.
        warn!("C{}: Terminating: {:?}", id, res);
        let mut subs = client.subs.lock().await;
        subs.del_all(&client);
        if let Some(sess) = client.session.take() {
            let mut sm = client.sessions.lock().await;
            sm.close(&client, sess);
        }
    }

    /// Frame bytes from the socket as Decode MQTT packets, and forwards them as `Msg`s.
    async fn handle_net(read: ReadHalf<'_>,
                        mut sx: Sender<Msg>,
                        id: ConnId)
                        -> Result<&'static str, Error> {
        let mut frame = FramedRead::new(read, Codec(id));
        while let Some(pkt) = frame.next().await {
            sx.send(Msg::PktIn(pkt?))
              .await
              .map_err(|e| Error::new(ErrorKind::Other, format!("while sending to self: {}", e)))?;
        }
        Ok("Connection closed")
    }

    /// Handle `Msg`s. This is `Client`'s main event loop.
    async fn handle_msgs(client: &mut Client<'_>,
                         mut receiver: Receiver<Msg>)
                         -> Result<&'static str, Error> {
        while let Some(msg) = receiver.next().await {
            match msg {
                Msg::PktIn(p) => client.handle_pkt_in(p).await?,
                Msg::PktOut(p) => client.handle_pkt_out(p).await?,
                Msg::Publish(q, p) => client.handle_publish(q, p).await?,
                Msg::CheckQos => client.handle_check_qos(Instant::now()).await?,
                Msg::Replaced(i, c) => client.handle_replaced(i, c)?,
                Msg::Disconnect(r) => client.handle_disconnect(r)?,
            }
        }
        Ok("No more messages")
    }

    /// Receive packets from client.
    async fn handle_pkt_in(&mut self, pkt: Packet) -> Result<(), Error> {
        info!("C{}: receive Packet::{:?}", self.id, pkt);
        self.dumps.dump(self.id, &self.name, "C", &pkt).await;
        self.count_pkt += 1;
        match (pkt, self.conn) {
            // Connection
            (Packet::Connect(c), false) => {
                self.conn = true;
                self.ack_timeout = match c.protocol {
                    Protocol::MQTT311 => self.ack_timeouts_conf.0.unwrap_or(FOREVER),
                    Protocol::MQIsdp => self.ack_timeouts_conf.0.unwrap_or(FOREVER),
                };
                // Set and check client name
                self.name = c.client_id.clone();
                if let Err((code, desc)) = self.check_credentials(&c) {
                    self.addr.send(Msg::PktOut(connack(false, code))).await;
                    return Err(Error::new(ErrorKind::ConnectionAborted, desc));
                }
                // Load session
                let mut sm = self.sessions.lock().await;
                let mut sess = sm.open(&self, c.clean_session).await;
                let isold = sess.cons > 0;
                debug!("C{}: loaded {} session {:?}",
                       self.id,
                       if isold { "old" } else { "new" },
                       sess);
                sess.cons += 1;
                let mut subs = self.subs.lock().await;
                for (topic, qos) in sess.subs.iter() {
                    subs.add(&topic, *qos, self.id, self.addr.clone());
                }
                self.session = Some(sess);
                // Handle QoS
                self.addr.send(Msg::CheckQos).await;
                // Send connack
                self.addr.send(Msg::PktOut(connack(isold, ConnectReturnCode::Accepted))).await;
            },
            // FIXME: Use our own error type, and let this one log as INFO rather than ERROR
            (Packet::Disconnect, true) => {
                self.conn = false;
                return Err(Error::new(ErrorKind::ConnectionAborted, "Disconnect"));
            },
            // Ping request
            (Packet::Pingreq, true) => self.addr.send(Msg::PktOut(pingresp())).await,
            // Puback: cancel the resend timer if the ack was expected, die otherwise.
            (Packet::Puback(pid), true) => {
                let sess = self.session.as_mut().expect("unwrap session");
                if sess.qos1.remove(&pid).is_none() {
                    return Err(Error::new(ErrorKind::InvalidData,
                                          format!("Puback {:?} unexpected", pid)));
                }
            },
            // Publish
            (Packet::Publish(p), true) => {
                if let Some(subs) = self.subs.lock().await.get(&p.topic_name) {
                    for s in subs.values() {
                        s.addr.send(Msg::Publish(s.qos, p.clone())).await;
                    }
                }
                match p.qospid {
                    QosPid::AtMostOnce => (),
                    QosPid::AtLeastOnce(pid) => {
                        let d = Instant::now() + self.ack_delay;
                        self.addr.send_at(d, Msg::PktOut(puback(pid)));
                    },
                    QosPid::ExactlyOnce(_) => panic!("ExactlyOnce not supported yet"),
                }
            },
            // Subscription request
            (Packet::Subscribe(Subscribe { pid, topics }), true) => {
                let mut subs = self.subs.lock().await;
                let sess = self.session.as_mut().expect("unwrap session");
                let mut codes = Vec::new();
                for SubscribeTopic { topic_path, qos } in topics {
                    assert_ne!(QoS::ExactlyOnce, qos, "ExactlyOnce not supported yet");
                    subs.add(&topic_path, qos, self.id, self.addr.clone());
                    sess.subs.insert(topic_path.clone(), qos);
                    codes.push(SubscribeReturnCodes::Success(qos));
                }
                let d = Instant::now() + self.ack_delay;
                self.addr.send_at(d, Msg::PktOut(suback(pid, codes)));
            },
            (other, _) => {
                return Err(Error::new(ErrorKind::InvalidData, format!("Unhandled {:?}", other)))
            },
        }
        if self.count_pkt >= self.max_pkt {
            let reason = format!("max packets {:?} {:?}", self.max_pkt, self.max_pkt_delay);
            match self.max_pkt_delay {
                Some(d) => self.addr.send_at(Instant::now() + d, Msg::Disconnect(reason)),
                None => self.addr.send(Msg::Disconnect(reason)).await,
            }
        }
        Ok(())
    }

    /// Send packets to client.
    async fn handle_pkt_out(&mut self, pkt: Packet) -> Result<(), Error> {
        info!("C{}: send Packet::{:?}", self.id, pkt);
        self.dumps.dump(self.id, &self.name, "S", &pkt).await;
        self.writer.send(pkt).await?;
        self.writer.flush().await.map_err(|e| e.into())
    }

    async fn handle_publish(&mut self, qos: QoS, p: Publish) -> Result<(), Error> {
        let sess = self.session.as_mut().expect("unwrap session");
        let qospid = match qos {
            QoS::AtMostOnce => QosPid::AtMostOnce,
            QoS::AtLeastOnce => QosPid::AtLeastOnce(sess.next_pid()),
            QoS::ExactlyOnce => panic!("ExactlyOnce not supported yet"),
        };
        let pkt = publish(false, qospid, false, p.topic_name, p.payload);
        if let QosPid::AtLeastOnce(pid) = qospid {
            // Publish with QoS 1, remember the pid so that we can accept the ack later. If the
            let deadline = Instant::now() + self.ack_timeout;
            debug!("C{}: waiting for {:?} + {:?}@{:?}", self.id, sess.qos1, pid, deadline);

            // Remember the details so we can aceept the ack or resend the pkt.
            let prev = sess.qos1.insert(pid, (deadline, pkt.clone()));
            assert!(prev.is_none(), "C{}: Server error: reusing {:?} {:?}", self.id, pid, prev);

            // Schedule the next check for timedout acks.
            if self.qos1_check.is_none() && self.ack_timeout < FOREVER {
                self.qos1_check = Some(self.addr.send_at_abort(deadline, Msg::CheckQos));
            }
        }
        self.handle_pkt_out(pkt).await
    }

    /// Go trhough self.session.qos1 and resend any timedout packets.
    async fn handle_check_qos(&mut self, reftime: Instant) -> Result<(), Error> {
        let sess = self.session.as_mut().expect("unwrap session");
        trace!("C{}: check Qos acks {:?}", self.id, sess.qos1);
        let id = self.id;
        let addr = self.addr.clone();
        // FIXME: Should be able to just move pkt.
        for (pid, (deadline, pkt)) in sess.qos1.iter() {
            if *deadline > reftime {
                warn!("C{}: Timeout receiving ack {:?}, resending packet", id, pid);
                addr.send(Msg::PktOut(pkt.clone())).await
            }
        }
        sess.qos1.retain(|_pid, (deadline, _pkt)| *deadline <= reftime);
        if let Some(deadline) = sess.qos1.values().map(|(d, _)| d).min() {
            self.qos1_check = Some(self.addr.send_at_abort(*deadline, Msg::CheckQos));
        }
        Ok(())
    }

    fn handle_replaced(&mut self,
                       conn: ConnId,
                       chan: oneshot::Sender<SessionData>)
                       -> Result<(), Error> {
        info!("C{}: replaced by connection {}", self.id, conn);
        self.conn = false;
        chan.send(self.session.take().unwrap()).unwrap_or_else(|_| {
                                                   trace!("C{}: C{} didn't wait for the session",
                                                          self.id,
                                                          conn)
                                               });
        Err(Error::new(ErrorKind::ConnectionReset, "Replaced"))
    }

    fn handle_disconnect(&mut self, reason: String) -> Result<(), Error> {
        info!("C{}: Disconnect by server: {:?}", self.id, reason);
        self.conn = false;
        Err(Error::new(ErrorKind::ConnectionReset, reason))
    }

    /// Check client identifier.
    /// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031
    fn check_credentials(&mut self,
                         con: &Connect)
                         -> Result<(), (ConnectReturnCode, &'static str)> {
        let allow = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
        if self.name.len() > 23 || !self.name.chars().all(|c| allow.contains(c)) {
            if self.strict {
                return Err((ConnectReturnCode::RefusedIdentifierRejected,
                            "Client_id too long or bad charset [MQTT-3.1.3-8]"));
            }
            warn!("C{}: Servers MAY reject {:?} [MQTT-3.1.3-5/MQTT-3.1.3-6]", self.id, self.name);
        }
        if self.name.is_empty() {
            if !con.clean_session {
                return Err((ConnectReturnCode::RefusedIdentifierRejected,
                            "Empty client_id with session [MQTT-3.1.3-8]"));
            }
            let mut rng = thread_rng();
            for _ in 0..20 {
                self.name.push(*allow.as_bytes().choose(&mut rng).unwrap() as char);
            }
            info!("C{}: Unamed client, assigned random name {:?}", self.id, self.name);
        }
        if con.password.is_some() && con.username.is_none() {
            return Err((ConnectReturnCode::BadUsernamePassword,
                        "Password without a username [MQTT-3.1.2-22]"));
        }
        if let Some(ref req_up) = self.userpass {
            let con_up = format!("{}:{:?}",
                                 con.username.as_ref().unwrap_or(&String::new()),
                                 con.password.as_ref().unwrap_or(&Vec::new()));
            if &con_up != req_up {
                return Err((ConnectReturnCode::BadUsernamePassword,
                            "Bad username/password [MQTT-3.1.3.4/3.1.3.5]"));
            }
        }
        if !self.name.starts_with(&self.idprefix) {
            return Err((ConnectReturnCode::NotAuthorized, "Not Authorised [MQTT-5.4.2]"));
        }
        Ok(())
    }
}