use std::sync::Arc;
use std::marker::PhantomData;
use std::collections::HashMap;
use std::io::{Cursor, Result, ErrorKind, Error};
use std::time::Duration;
use futures::future::{FutureExt, LocalBoxFuture};
use mio::Token;
use fnv::FnvBuildHasher;
use httparse::Request;
use mqtt311::{MqttWrite, MqttRead, Protocol, ConnectReturnCode, Packet, Connect,
Connack, QoS, Publish, PacketIdentifier, SubscribeTopic, SubscribeReturnCodes,
Subscribe, Suback, Unsubscribe, TopicPath};
use log::warn;
use pi_atom::Atom;
use tcp::{Socket, SocketEvent, SocketHandle,
tls_connect::TlsSocket,
utils::Ready};
use ws::{connect::WsSocket,
utils::{ChildProtocol, WsFrameType, WsSession}};
use crate::{server::MqttBrokerProtocol,
broker::{Retain, MqttBroker},
session::{MqttConnect, MqttSession, QosZeroSession},
utils::BrokerSession};
#[derive(Clone)]
pub struct WssMqtt311 {
is_strict: bool, protocol_name: String, broker_name: String, qos: QoS, broker: MqttBroker<TlsSocket>, }
unsafe impl Send for WssMqtt311 {}
unsafe impl Sync for WssMqtt311 {}
impl ChildProtocol<TlsSocket> for WssMqtt311 {
fn protocol_name(&self) -> &str {
self.protocol_name.as_str()
}
fn handshake_protocol(&self,
_handle: SocketHandle<TlsSocket>,
_request: &Request,
protocols: &Vec<&str>) -> Result<()> {
if self.is_strict {
if protocols.len() == 0 {
return Err(Error::new(ErrorKind::Interrupted, format!("Handshake failed by strict")));
} else {
if self.protocol_name.as_str() != protocols[0] {
return Err(Error::new(ErrorKind::Interrupted, format!("Invalid handshake protocol")));
}
}
}
Ok(())
}
fn is_strict(&self) -> bool {
self.is_strict
}
fn decode_protocol(&self,
connect: WsSocket<TlsSocket>,
context: &mut WsSession) -> LocalBoxFuture<'static, Result<()>> {
let ws_mqtt = self.clone();
let packet: Vec<u8> = context.pop_msg();
let mut packet_slice = if self.is_strict {
&packet[4..packet.len()]
} else {
packet.as_slice()
};
match packet_slice.read_packet() {
Err(e) => {
async move {
Err(Error::new(ErrorKind::Other, format!("Websocket decode child protocol failed, protocol: {:?}, reason: {:?}", ws_mqtt.protocol_name, e)))
}.boxed_local()
},
Ok(packet) => {
let future = async move {
match packet {
Packet::Connect(packet) => {
accept(ws_mqtt, connect, packet).await
},
Packet::Publish(packet) => {
publish(ws_mqtt, connect, packet).await
},
Packet::Subscribe(packet) => {
subscribe(ws_mqtt, connect, packet).await
},
Packet::Unsubscribe(packet) => {
unsubscribe(ws_mqtt, connect, packet).await
},
Packet::Pingreq => {
ping_req(connect)
},
Packet::Disconnect => {
disconnect(connect)
},
_ => {
Ok(())
},
}
};
future.boxed_local()
},
}
}
fn close_protocol(&self,
connect: WsSocket<TlsSocket>,
mut context: WsSession,
reason: Result<()>) -> LocalBoxFuture<'static, ()> {
match context.get_context_mut().remove::<BrokerSession>() {
Err(e) => {
warn!("Free Context Failed of Mqtt Close, token: {:?}, remote: {:?}, local: {:?}, reason: {:?}",
connect.get_token(),
connect.get_remote(),
connect.get_local(),
e);
},
Ok(opt) => {
if let Some(context) = opt {
let client_id = context.get_client_id();
if let Some(session) = self.broker.get_session(client_id) {
if session.is_clean() {
self.broker.unsubscribe_all(&session); self.broker.remove_session(client_id); }
session.set_accept(false);
if let Some(listener) = self.broker.get_listener() {
return listener
.0
.closed(MqttBrokerProtocol::WssMqtt311(Arc::new(self.clone())),
session,
context,
reason);
}
}
}
},
}
async move {}.boxed_local()
}
fn protocol_timeout(&self,
connect: WsSocket<TlsSocket>,
context: &mut WsSession,
mut event: SocketEvent) -> LocalBoxFuture<'static, Result<()>> {
async move {
if let Err(e) = send_packet(&connect, &Packet::Disconnect) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by disconnect, reason: {:?}",
e)));
}
warn!("Mqtt Session Timeout, token: {:?}, local: {:?}, remote: {:?}, client_id: {:?}",
connect.get_token(),
connect.get_local(),
connect.get_remote(),
event.remove::<String>());
Ok(())
}.boxed_local()
}
}
fn update_timeout(connect: &WsSocket<TlsSocket>,
client_id: String,
keep_alive: u16) {
let mut event = SocketEvent::empty();
event.set::<String>(client_id);
connect.set_timeout(keep_alive as usize * 1500, event);
}
fn send_packet(connect: &WsSocket<TlsSocket>,
packet: &Packet) -> Result<()> {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
if let Err(e) = buf.write_packet(packet) {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt send failed, reason: {:?}",
e)));
}
connect.send(WssMqtt311::WS_MSG_TYPE, buf.into_inner())
}
pub fn broadcast_packet(connects: &[WsSocket<TlsSocket>],
packet: &Packet) -> Result<()> {
if connects.len() == 0 {
return Ok(());
}
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
if let Err(e) = buf.write_packet(packet) {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt send failed, reason: {:?}",
e)));
}
for connect in connects {
if connect.is_closed() {
continue;
}
return WsSocket::<TlsSocket>::broadcast(connects,
WssMqtt311::WS_MSG_TYPE,
buf.into_inner());
}
Ok(())
}
async fn accept(protocol: WssMqtt311,
connect: WsSocket<TlsSocket>,
packet: Connect) -> Result<()> {
let clean_session = packet.clean_session;
let client_id = packet.client_id.clone();
let is_exist_session = protocol.is_exist(&client_id);
if is_exist_session {
if let Some(session) = protocol.broker.get_session(&client_id) {
if session.is_accepted() {
return Err(Error::new(ErrorKind::AlreadyExists,
"Mqtt connect failed, reason: connect already exist"));
}
}
}
let mut session_present = true;
if clean_session {
session_present = false;
} else if !clean_session && !is_exist_session {
session_present = false;
}
let level = packet.protocol.level();
if (level != WssMqtt311::MQTT31) && (level != WssMqtt311::MQTT311) {
let ack = Connack {
session_present,
code: ConnectReturnCode::RefusedProtocolVersion,
};
if let Err(e) = send_packet(&connect, &Packet::Connack(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by connect, reason: {:?}",
e)));
}
return Err(Error::new(ErrorKind::ConnectionAborted,
"Mqtt connect failed, reason: invalid protocol version"));
}
if let Some(mut handle) = connect.get_session() {
if let Some(ws_session) = handle.as_mut() {
ws_session
.get_context_mut()
.set::<BrokerSession>(
BrokerSession::new(
client_id.clone(),
packet.keep_alive,
packet.clean_session,
packet.username.clone(),
packet.password.clone()));
}
}
let mut code = ConnectReturnCode::Accepted; let mqtt_connect = reset_session(&protocol, connect.clone(), client_id, packet);
if let Some(listener) = protocol.broker.get_listener() {
if let Err(e) = listener
.0
.connected(MqttBrokerProtocol::WssMqtt311(Arc::new(protocol)),
mqtt_connect).await {
code = ConnectReturnCode::NotAuthorized
}
}
let ack = Connack {
session_present,
code,
};
if let Err(e) = send_packet(&connect, &Packet::Connack(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by connect, reason: {:?}",
e)));
}
Ok(())
}
fn reset_session(protocol: &WssMqtt311,
connect: WsSocket<TlsSocket>,
client_id: String,
packet: Connect) -> Arc<QosZeroSession<TlsSocket>> {
let mut session = QosZeroSession::with_client_id(client_id.clone());
session.bind_connect(connect.clone());
session.set_accept(true);
session.set_clean(packet.clean_session);
if let Some(w) = &packet.last_will {
session.set_will(w.topic.clone(), w.message.clone(), w.qos.to_u8(), w.retain);
}
session.set_user_pwd(packet.username, packet.password);
session.set_keep_alive(packet.keep_alive);
update_timeout(&connect, packet.client_id, packet.keep_alive);
protocol.broker.insert_session(client_id, session)
}
#[inline(always)]
fn get_client_context(connect: &WsSocket<TlsSocket>) -> Result<(String, u16)> {
if let Some(handle) = connect
.get_session()
.unwrap()
.as_ref()
.get_context()
.get::<BrokerSession>() {
let h = handle.as_ref();
Ok((h.get_client_id().clone(), h.get_keep_alive()))
} else {
Err(Error::new(ErrorKind::ConnectionRefused,
"Mqtt subscribe failed, reason: invalid connect"))
}
}
async fn publish(protocol: WssMqtt311,
connect: WsSocket<TlsSocket>,
mut packet: Publish) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let qos = packet.qos.to_u8();
if qos > protocol.get_qos().to_u8() {
return Err(Error::new(ErrorKind::InvalidData,
"Mqtt publish failed, reason: invalid qos"));
}
let topic_path = if let Ok(path) = TopicPath::from_str(packet.topic_name.as_str()) {
path
} else {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt publish failed, topic: {:?}, reason: parse topic error",
packet.topic_name)));
};
if topic_path.wildcards || topic_path.path.is_empty() {
return Err(Error::new(ErrorKind::InvalidData,
"Mqtt publish failed, reason: invalid topic"));
}
packet.dup = false;
let mut retain = None;
if packet.retain {
retain = Some(packet.clone());
}
if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
let is_passive_receive = session.is_passive_receive();
let result = service
.0
.publish(MqttBrokerProtocol::WssMqtt311(Arc::new(protocol)),
session.clone(),
packet.topic_name.clone(),
packet.payload).await;
if is_passive_receive {
if let Some(hibernate) = session.hibernate(Ready::Writable) {
return hibernate.await;
} else {
return Ok(());
}
}
return result;
}
}
let mut is_public = true;
if let Some(sessions) = protocol
.broker
.subscribed(is_public, &topic_path.path, qos, retain) {
let mut connects: Vec<WsSocket<TlsSocket>> = Vec::with_capacity(sessions.len());
for session in sessions {
if let Some(connect) = session.get_connect() {
connects.push(connect.clone());
}
};
if let Err(e) = broadcast_packet(&connects[..], &Packet::Publish(packet)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt broadcast failed by publish, reason: {:?}",
e)));
}
}
Ok(())
}
async fn subscribe(protocol: WssMqtt311,
connect: WsSocket<TlsSocket>,
packet: Subscribe) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let pkid = packet.pkid;
let topics = packet.topics;
let mut retains: Option<Retain> = None;
let mut return_codes = Vec::with_capacity(topics.len()); if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
let mut sub_topics = Vec::with_capacity(topics.len());
for topic in &topics {
sub_topics.push((topic.topic_path.clone(), topic.qos.to_u8()));
}
if let Err(_) = service
.0
.subscribe(MqttBrokerProtocol::WssMqtt311(
Arc::new(protocol.clone())),
session,
sub_topics).await {
for _ in topics {
return_codes.push(SubscribeReturnCodes::Failure);
}
} else {
let qos = protocol.get_qos();
for _ in topics {
return_codes.push(SubscribeReturnCodes::Success(qos));
}
}
}
} else {
if let Some(session) = protocol.broker.get_session(&client_id) {
let qos = protocol.get_qos();
let qos_val = qos.to_u8();
for topic in topics {
if qos_val < topic.qos.to_u8() {
return_codes.push(SubscribeReturnCodes::Failure);
} else {
return_codes.push(SubscribeReturnCodes::Success(qos));
}
retains = protocol
.broker
.subscribe(session.clone(),
topic.topic_path);
}
}
}
let ack = Suback {
pkid,
return_codes,
};
if let Err(e) = send_packet(&connect, &Packet::Suback(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by subscribe, reason: {:?}",
e)));
}
if let Some(r) = retains {
match r {
Retain::Single(p) => {
send_packet(&connect, &Packet::Publish(p));
},
Retain::Mutil(ps) => {
for p in ps {
send_packet(&connect, &Packet::Publish(p));
}
},
}
}
Ok(())
}
async fn unsubscribe(protocol: WssMqtt311,
connect: WsSocket<TlsSocket>,
packet: Unsubscribe) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let pkid = packet.pkid;
let topics = packet.topics;
if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
if let Err(e) = service
.0
.unsubscribe(MqttBrokerProtocol::WssMqtt311(Arc::new(protocol)),
session,
topics.clone()).await {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt unsubscribe failed, reason: {:?}",
e)));
}
}
} else {
if let Some(session) = protocol.broker.get_session(&client_id) {
for topic in topics {
protocol.broker.unsubscribe(&session, topic);
}
}
}
if let Err(e) = send_packet(&connect, &Packet::Unsuback(pkid)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by unsubscribe, reason: {:?}",
e)));
}
Ok(())
}
fn ping_req(connect: WsSocket<TlsSocket>) -> Result<()> {
match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok((client_id, keep_alive)) => {
update_timeout(&connect, client_id.clone(), keep_alive);
}
}
if let Err(e) = send_packet(&connect, &Packet::Pingresp) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by ping, reason: {:?}",
e)));
}
Ok(())
}
fn disconnect(connect: WsSocket<TlsSocket>) -> Result<()> {
connect.close(Ok(()))
}
impl WssMqtt311 {
pub const WS_MSG_TYPE: WsFrameType = WsFrameType::Binary; pub const MQTT31: u8 = 0x3; pub const MQTT311: u8 = 0x4; pub const MAX_QOS: u8 = 0;
pub fn with_name(protocol_name: &str,
broker_name: &str,
qos: u8,
is_strict: bool) -> Self {
let broker = MqttBroker::new();
broker.startup_expire_unsubscribed_topic_loop(Duration::from_millis(60000));
WssMqtt311 {
is_strict,
protocol_name: protocol_name.to_lowercase(),
broker_name: broker_name.to_string(),
qos: QoS::from_u8(qos).unwrap(),
broker,
}
}
#[inline(always)]
pub fn is_exist(&self, client_id: &str) -> bool {
self.broker.is_exist_session(&client_id.to_string())
}
pub fn get_broker_name(&self) -> &str {
self.broker_name.as_str()
}
#[inline(always)]
pub fn get_qos(&self) -> QoS {
self.qos
}
pub fn get_broker(&self) -> &MqttBroker<TlsSocket> {
&self.broker
}
}