use crate::{dump::*, mqtt::*, pubsub::*, session::*, Conf, FOREVER};
use futures::{future::{abortable, AbortHandle},
lock::Mutex,
prelude::*};
use log::*;
use rand::{seq::SliceRandom, thread_rng};
use std::{collections::HashMap,
io::{Error, ErrorKind},
sync::Arc,
time::{Duration, Instant}};
use tokio::{net::{tcp::{ReadHalf, WriteHalf},
TcpStream},
spawn,
sync::{mpsc::{channel, Receiver, Sender},
oneshot},
time::delay_until};
use tokio_util::codec::{FramedRead, FramedWrite};
pub type ConnId = usize;
struct AbortOnDrop(pub AbortHandle);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
#[derive(Clone)]
pub struct Addr(Sender<Msg>, pub(crate) ConnId);
impl Addr {
pub(crate) async fn send(&self, msg: Msg) {
if let Err(e) = self.0.clone().send(msg).await {
warn!("Trying to send to disconnected Addr {:?}", e);
}
}
async fn send_at_async(addr: Addr, deadline: Instant, msg: Msg) {
trace!("send_at {:?} {:?} {:?}", deadline, addr, msg);
delay_until(deadline.into()).await;
addr.send(msg).await;
}
fn send_at(&self, deadline: Instant, msg: Msg) {
spawn(Self::send_at_async(self.clone(), deadline, msg).map(drop));
}
#[must_use]
fn send_at_abort(&self, deadline: Instant, msg: Msg) -> AbortOnDrop {
let (f, h) = abortable(Self::send_at_async(self.clone(), deadline, msg));
spawn(f.map(drop));
AbortOnDrop(h)
}
}
impl PartialEq for Addr {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}
impl Eq for Addr {}
impl std::fmt::Debug for Addr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Addr(_, {})", self.1)
}
}
#[derive(Debug)]
pub(crate) enum Msg {
PktIn(Packet),
PktOut(Packet),
Publish(QoS, Publish),
CheckQos,
Replaced(ConnId, oneshot::Sender<SessionData>),
Disconnect(String),
}
#[derive(Debug, Default)]
pub(crate) struct SessionData {
cons: usize,
prev_pid: Option<Pid>,
subs: HashMap<String, QoS>,
qos1: HashMap<Pid, (Instant, Packet)>,
}
impl SessionData {
fn next_pid(&mut self) -> Pid {
let mut pid = match self.prev_pid {
Some(p) => p + 1,
None => Pid::new(),
};
while self.qos1.contains_key(&pid) {
pid = pid + 1;
}
self.prev_pid = Some(pid);
pid
}
}
pub(crate) struct Client<'s> {
pub id: ConnId,
pub name: String,
pub addr: Addr,
conn: bool,
writer: FramedWrite<WriteHalf<'s>, Codec>,
dumps: Dump,
ack_timeouts_conf: (Option<Duration>, Option<Duration>),
ack_timeout: Duration,
ack_delay: Duration,
strict: bool,
idprefix: String,
userpass: Option<String>,
subs: Arc<Mutex<Subs>>,
sessions: Arc<Mutex<Sessions>>,
session: Option<SessionData>,
pub sess_expire: Option<Duration>,
qos1_check: Option<AbortOnDrop>,
max_pkt: usize,
max_pkt_delay: Option<Duration>,
count_pkt: usize,
}
impl Client<'_> {
pub async fn start(id: usize,
mut socket: TcpStream,
subs: Arc<Mutex<Subs>>,
sessions: Arc<Mutex<Sessions>>,
dumps: Dump,
conf: Conf) {
info!("C{}: Connection from {:?}", id, socket);
let (read, write) = socket.split();
let (sx, rx) = channel::<Msg>(10);
let max_pkt = conf.max_pkt[id as usize % conf.max_pkt.len()].unwrap_or(std::usize::MAX);
let sess_expire = conf.sess_expire[id as usize % conf.sess_expire.len()];
let mut client = Client { id,
name: String::from(""),
addr: Addr(sx.clone(), id),
conn: false,
writer: FramedWrite::new(write, Codec(id)),
dumps,
ack_timeouts_conf: conf.ack_timeouts,
ack_timeout: conf.ack_timeouts.0.unwrap_or(FOREVER),
ack_delay: conf.ack_delay,
strict: conf.strict,
idprefix: conf.idprefix.clone(),
userpass: conf.userpass.clone(),
subs,
sessions,
session: None,
sess_expire,
qos1_check: None,
max_pkt,
max_pkt_delay: conf.max_pkt_delay,
count_pkt: 0 };
if let Some(m) = conf.max_time[id as usize % conf.max_time.len()] {
client.addr.send_at(Instant::now() + m, Msg::Disconnect(format!("max time {:?}", m)))
}
for s in conf.dump_files {
let s = s.replace("{c}", &format!("{}", id));
match client.dumps.register(&s) {
Ok(_) => debug!("C{}: Dump to {}", id, s),
Err(e) => error!("C{}: Cannot dump to {}: {}", id, s, e),
}
}
let f1 = Self::handle_net(read, sx, client.id);
let f2 = Self::handle_msgs(&mut client, rx);
let res = futures::select!(r = f1.fuse() => r, r = f2.fuse() => r);
warn!("C{}: Terminating: {:?}", id, res);
let mut subs = client.subs.lock().await;
subs.del_all(&client);
if let Some(sess) = client.session.take() {
let mut sm = client.sessions.lock().await;
sm.close(&client, sess);
}
}
async fn handle_net(read: ReadHalf<'_>,
mut sx: Sender<Msg>,
id: ConnId)
-> Result<&'static str, Error> {
let mut frame = FramedRead::new(read, Codec(id));
while let Some(pkt) = frame.next().await {
sx.send(Msg::PktIn(pkt?))
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("while sending to self: {}", e)))?;
}
Ok("Connection closed")
}
async fn handle_msgs(client: &mut Client<'_>,
mut receiver: Receiver<Msg>)
-> Result<&'static str, Error> {
while let Some(msg) = receiver.next().await {
match msg {
Msg::PktIn(p) => client.handle_pkt_in(p).await?,
Msg::PktOut(p) => client.handle_pkt_out(p).await?,
Msg::Publish(q, p) => client.handle_publish(q, p).await?,
Msg::CheckQos => client.handle_check_qos(Instant::now()).await?,
Msg::Replaced(i, c) => client.handle_replaced(i, c)?,
Msg::Disconnect(r) => client.handle_disconnect(r)?,
}
}
Ok("No more messages")
}
async fn handle_pkt_in(&mut self, pkt: Packet) -> Result<(), Error> {
info!("C{}: receive Packet::{:?}", self.id, pkt);
self.dumps.dump(self.id, &self.name, "C", &pkt).await;
self.count_pkt += 1;
match (pkt, self.conn) {
(Packet::Connect(c), false) => {
self.conn = true;
self.ack_timeout = match c.protocol {
Protocol::MQTT311 => self.ack_timeouts_conf.0.unwrap_or(FOREVER),
Protocol::MQIsdp => self.ack_timeouts_conf.0.unwrap_or(FOREVER),
};
self.name = c.client_id.clone();
if let Err((code, desc)) = self.check_credentials(&c) {
self.addr.send(Msg::PktOut(connack(false, code))).await;
return Err(Error::new(ErrorKind::ConnectionAborted, desc));
}
let mut sm = self.sessions.lock().await;
let mut sess = sm.open(&self, c.clean_session).await;
let isold = sess.cons > 0;
debug!("C{}: loaded {} session {:?}",
self.id,
if isold { "old" } else { "new" },
sess);
sess.cons += 1;
let mut subs = self.subs.lock().await;
for (topic, qos) in sess.subs.iter() {
subs.add(&topic, *qos, self.id, self.addr.clone());
}
self.session = Some(sess);
self.addr.send(Msg::CheckQos).await;
self.addr.send(Msg::PktOut(connack(isold, ConnectReturnCode::Accepted))).await;
},
(Packet::Disconnect, true) => {
self.conn = false;
return Err(Error::new(ErrorKind::ConnectionAborted, "Disconnect"));
},
(Packet::Pingreq, true) => self.addr.send(Msg::PktOut(pingresp())).await,
(Packet::Puback(pid), true) => {
let sess = self.session.as_mut().expect("unwrap session");
if sess.qos1.remove(&pid).is_none() {
return Err(Error::new(ErrorKind::InvalidData,
format!("Puback {:?} unexpected", pid)));
}
},
(Packet::Publish(p), true) => {
if let Some(subs) = self.subs.lock().await.get(&p.topic_name) {
for s in subs.values() {
s.addr.send(Msg::Publish(s.qos, p.clone())).await;
}
}
match p.qospid {
QosPid::AtMostOnce => (),
QosPid::AtLeastOnce(pid) => {
let d = Instant::now() + self.ack_delay;
self.addr.send_at(d, Msg::PktOut(puback(pid)));
},
QosPid::ExactlyOnce(_) => panic!("ExactlyOnce not supported yet"),
}
},
(Packet::Subscribe(Subscribe { pid, topics }), true) => {
let mut subs = self.subs.lock().await;
let sess = self.session.as_mut().expect("unwrap session");
let mut codes = Vec::new();
for SubscribeTopic { topic_path, qos } in topics {
assert_ne!(QoS::ExactlyOnce, qos, "ExactlyOnce not supported yet");
subs.add(&topic_path, qos, self.id, self.addr.clone());
sess.subs.insert(topic_path.clone(), qos);
codes.push(SubscribeReturnCodes::Success(qos));
}
let d = Instant::now() + self.ack_delay;
self.addr.send_at(d, Msg::PktOut(suback(pid, codes)));
},
(other, _) => {
return Err(Error::new(ErrorKind::InvalidData, format!("Unhandled {:?}", other)))
},
}
if self.count_pkt >= self.max_pkt {
let reason = format!("max packets {:?} {:?}", self.max_pkt, self.max_pkt_delay);
match self.max_pkt_delay {
Some(d) => self.addr.send_at(Instant::now() + d, Msg::Disconnect(reason)),
None => self.addr.send(Msg::Disconnect(reason)).await,
}
}
Ok(())
}
async fn handle_pkt_out(&mut self, pkt: Packet) -> Result<(), Error> {
info!("C{}: send Packet::{:?}", self.id, pkt);
self.dumps.dump(self.id, &self.name, "S", &pkt).await;
self.writer.send(pkt).await?;
self.writer.flush().await.map_err(|e| e.into())
}
async fn handle_publish(&mut self, qos: QoS, p: Publish) -> Result<(), Error> {
let sess = self.session.as_mut().expect("unwrap session");
let qospid = match qos {
QoS::AtMostOnce => QosPid::AtMostOnce,
QoS::AtLeastOnce => QosPid::AtLeastOnce(sess.next_pid()),
QoS::ExactlyOnce => panic!("ExactlyOnce not supported yet"),
};
let pkt = publish(false, qospid, false, p.topic_name, p.payload);
if let QosPid::AtLeastOnce(pid) = qospid {
let deadline = Instant::now() + self.ack_timeout;
debug!("C{}: waiting for {:?} + {:?}@{:?}", self.id, sess.qos1, pid, deadline);
let prev = sess.qos1.insert(pid, (deadline, pkt.clone()));
assert!(prev.is_none(), "C{}: Server error: reusing {:?} {:?}", self.id, pid, prev);
if self.qos1_check.is_none() && self.ack_timeout < FOREVER {
self.qos1_check = Some(self.addr.send_at_abort(deadline, Msg::CheckQos));
}
}
self.handle_pkt_out(pkt).await
}
async fn handle_check_qos(&mut self, reftime: Instant) -> Result<(), Error> {
let sess = self.session.as_mut().expect("unwrap session");
trace!("C{}: check Qos acks {:?}", self.id, sess.qos1);
let id = self.id;
let addr = self.addr.clone();
for (pid, (deadline, pkt)) in sess.qos1.iter() {
if *deadline > reftime {
warn!("C{}: Timeout receiving ack {:?}, resending packet", id, pid);
addr.send(Msg::PktOut(pkt.clone())).await
}
}
sess.qos1.retain(|_pid, (deadline, _pkt)| *deadline <= reftime);
if let Some(deadline) = sess.qos1.values().map(|(d, _)| d).min() {
self.qos1_check = Some(self.addr.send_at_abort(*deadline, Msg::CheckQos));
}
Ok(())
}
fn handle_replaced(&mut self,
conn: ConnId,
chan: oneshot::Sender<SessionData>)
-> Result<(), Error> {
info!("C{}: replaced by connection {}", self.id, conn);
self.conn = false;
chan.send(self.session.take().unwrap()).unwrap_or_else(|_| {
trace!("C{}: C{} didn't wait for the session",
self.id,
conn)
});
Err(Error::new(ErrorKind::ConnectionReset, "Replaced"))
}
fn handle_disconnect(&mut self, reason: String) -> Result<(), Error> {
info!("C{}: Disconnect by server: {:?}", self.id, reason);
self.conn = false;
Err(Error::new(ErrorKind::ConnectionReset, reason))
}
fn check_credentials(&mut self,
con: &Connect)
-> Result<(), (ConnectReturnCode, &'static str)> {
let allow = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
if self.name.len() > 23 || !self.name.chars().all(|c| allow.contains(c)) {
if self.strict {
return Err((ConnectReturnCode::RefusedIdentifierRejected,
"Client_id too long or bad charset [MQTT-3.1.3-8]"));
}
warn!("C{}: Servers MAY reject {:?} [MQTT-3.1.3-5/MQTT-3.1.3-6]", self.id, self.name);
}
if self.name.is_empty() {
if !con.clean_session {
return Err((ConnectReturnCode::RefusedIdentifierRejected,
"Empty client_id with session [MQTT-3.1.3-8]"));
}
let mut rng = thread_rng();
for _ in 0..20 {
self.name.push(*allow.as_bytes().choose(&mut rng).unwrap() as char);
}
info!("C{}: Unamed client, assigned random name {:?}", self.id, self.name);
}
if con.password.is_some() && con.username.is_none() {
return Err((ConnectReturnCode::BadUsernamePassword,
"Password without a username [MQTT-3.1.2-22]"));
}
if let Some(ref req_up) = self.userpass {
let con_up = format!("{}:{:?}",
con.username.as_ref().unwrap_or(&String::new()),
con.password.as_ref().unwrap_or(&Vec::new()));
if &con_up != req_up {
return Err((ConnectReturnCode::BadUsernamePassword,
"Bad username/password [MQTT-3.1.3.4/3.1.3.5]"));
}
}
if !self.name.starts_with(&self.idprefix) {
return Err((ConnectReturnCode::NotAuthorized, "Not Authorised [MQTT-5.4.2]"));
}
Ok(())
}
}