use std::sync::Arc;
use std::io::{Cursor, Result, ErrorKind, Error};
use std::time::Duration;
use futures::future::{FutureExt, LocalBoxFuture};
use mqtt311::{MqttWrite, MqttRead, ConnectReturnCode, Packet, Connect,
Connack, QoS, Publish, SubscribeReturnCodes,
Subscribe, Suback, Unsubscribe, TopicPath, Error as MqttError};
use quinn_proto::{StreamId, Dir};
use bytes::Buf;
use log::error;
use quic::{AsyncService, SocketHandle, SocketEvent,
connect::QuicSocket,
utils::QuicSocketReady};
use crate::{server::MqttBrokerProtocol,
quic_broker::{Retain, MqttBroker},
quic_session::{MqttConnect, MqttSession, QosZeroSession},
utils::QuicBrokerSession};
#[derive(Clone)]
pub struct QuicMqtt311 {
broker_name: String, qos: QoS, broker: MqttBroker, }
unsafe impl Send for QuicMqtt311 {}
unsafe impl Sync for QuicMqtt311 {}
impl AsyncService for QuicMqtt311 {
fn handle_connected(&self,
handle: SocketHandle,
result: Result<()>) -> LocalBoxFuture<'static, ()> {
async move {
if let Err(e) = result {
error!("Connect failed for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
e);
return;
}
handle.set_ready(handle.get_main_stream_id().unwrap().clone(),
QuicSocketReady::Readable); }.boxed_local()
}
fn handle_opened_expanding_stream(&self,
_handle: SocketHandle,
_stream_id: StreamId,
_stream_type: Dir,
_result: Result<()>) -> LocalBoxFuture<'static, ()> {
async move {
}.boxed_local()
}
fn handle_readed(&self,
handle: SocketHandle,
stream_id: StreamId,
result: Result<usize>) -> LocalBoxFuture<'static, ()> {
let quic_mqtt = self.clone();
async move {
if let Err(e) = result {
error!("Read failed for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, stream_id: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
stream_id,
e);
return;
}
loop {
let mut ready_len = 0;
let remaining = if let Some(len) = handle.read_buffer_remaining(handle.get_main_stream_id().unwrap()) {
len
} else {
return;
};
if remaining == 0 {
ready_len = match handle.read_ready(handle.get_main_stream_id().unwrap(), 0) {
Err(len) => len,
Ok(value) => {
value.await
},
};
if ready_len == 0 {
return;
}
}
let packet = if let Some(buf) = handle.get_read_buffer(handle.get_main_stream_id().unwrap()).as_ref().unwrap().lock().as_mut() {
let mut bin: &[u8] = buf.chunk();
match bin.read_packet_with_len() {
Err(MqttError::PayloadSizeIncorrect) | Err(MqttError::PayloadRequired) | Err(MqttError::MalformedRemainingLength) => {
return;
},
Err(e) => {
error!("Parse packet failed for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, stream_id: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
stream_id,
e);
handle.close(1000, Err(Error::new(ErrorKind::InvalidInput, format!("{:?}", e))));
return;
},
Ok((readed_packet, readed_len)) => {
drop(bin);
let _ = buf.copy_to_bytes(readed_len);
readed_packet
},
}
} else {
return;
};
let result = match packet {
Packet::Connect(packet) => {
accept(quic_mqtt.clone(), handle.clone(), packet).await
},
Packet::Publish(packet) => {
publish(quic_mqtt.clone(), handle.clone(), packet).await
},
Packet::Subscribe(packet) => {
subscribe(quic_mqtt.clone(), handle.clone(), packet).await
},
Packet::Unsubscribe(packet) => {
unsubscribe(quic_mqtt.clone(), handle.clone(), packet).await
},
Packet::Pingreq => {
ping_req(handle.clone())
},
Packet::Disconnect => {
disconnect(handle.clone())
},
_ => {
return;
},
};
if let Err(e) = result {
error!("Handle packet failed for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, stream_id: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
stream_id,
e);
handle.close(1000, Err(Error::new(ErrorKind::InvalidData, format!("{:?}", e))));
return;
}
}
}.boxed_local()
}
fn handle_writed(&self,
_handle: SocketHandle,
_stream_id: StreamId,
_result: Result<()>) -> LocalBoxFuture<'static, ()> {
async move {
}.boxed_local()
}
fn handle_closed(&self,
handle: SocketHandle,
stream_id: Option<StreamId>,
code: u32,
result: Result<()>) -> LocalBoxFuture<'static, ()> {
match handle.remove_session::<QuicBrokerSession>() {
Err(e) => {
error!("Free context failed for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
e);
},
Ok(opt) => {
if let Some(context) = opt {
let client_id = context.get_client_id();
if let Some(session) = self.broker.get_session(client_id) {
if session.is_clean() {
self.broker.unsubscribe_all(&session); self.broker.remove_session(client_id); }
session.set_accept(false);
if let Some(listener) = self.broker.get_listener() {
return listener
.0
.closed(MqttBrokerProtocol::QuicMqtt311(Arc::new(self.clone())),
session,
context,
result);
}
}
}
},
}
async move {
}.boxed_local()
}
fn handle_timeouted(&self,
handle: SocketHandle,
result: Result<SocketEvent>) -> LocalBoxFuture<'static, ()> {
if let Err(e) = send_packet(&handle, &Packet::Disconnect) {
error!("Mqtt send closed failed for mqtt by quic, uid: {:?}, local: {:?}, remote: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_local(),
handle.get_remote(),
e);
}
match result {
Err(e) => {
error!("Mqtt session timeout for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, reason: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
e);
},
Ok(mut event) => {
error!("Mqtt session timeout for mqtt by quic, uid: {:?}, remote: {:?}, local: {:?}, client_id: {:?}",
handle.get_uid(),
handle.get_remote(),
handle.get_local(),
event.remove::<String>());
},
}
async move {
}.boxed_local()
}
}
fn update_timeout(connect: &SocketHandle,
client_id: String,
keep_alive: u16) {
let mut event = SocketEvent::empty();
event.set::<String>(client_id);
connect.set_timeout(keep_alive as usize * 1500, event);
}
fn send_packet(connect: &SocketHandle,
packet: &Packet) -> Result<()> {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
if let Err(e) = buf.write_packet(packet) {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt send failed, reason: {:?}",
e)));
}
connect.write_ready(connect.get_main_stream_id().unwrap().clone(),
buf.into_inner())
}
pub fn broadcast_packet(connects: &[SocketHandle],
packet: &Packet) -> Result<()> {
if connects.len() == 0 {
return Ok(());
}
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
if let Err(e) = buf.write_packet(packet) {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt send failed, reason: {:?}",
e)));
}
for connect in connects {
if connect.is_closed() {
continue;
}
return SocketHandle::broadcast(connects,
buf.into_inner());
}
Ok(())
}
async fn accept(protocol: QuicMqtt311,
connect: SocketHandle,
packet: Connect) -> Result<()> {
let clean_session = packet.clean_session;
let client_id = packet.client_id.clone();
let is_exist_session = protocol.is_exist(&client_id);
if is_exist_session {
if let Some(session) = protocol.broker.get_session(&client_id) {
if session.is_accepted() {
return Err(Error::new(ErrorKind::AlreadyExists,
"Mqtt connect failed, reason: connect already exist"));
}
}
}
let mut session_present = true;
if clean_session {
session_present = false;
} else if !clean_session && !is_exist_session {
session_present = false;
}
let level = packet.protocol.level();
if (level != QuicMqtt311::MQTT31) && (level != QuicMqtt311::MQTT311) {
let ack = Connack {
session_present,
code: ConnectReturnCode::RefusedProtocolVersion,
};
if let Err(e) = send_packet(&connect, &Packet::Connack(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by connect, reason: {:?}",
e)));
}
return Err(Error::new(ErrorKind::ConnectionAborted,
"Mqtt connect failed, reason: invalid protocol version"));
}
connect.set_session(QuicBrokerSession::new(
client_id.clone(),
packet.keep_alive,
packet.clean_session,
packet.username.clone(),
packet.password.clone())
);
let mut code = ConnectReturnCode::Accepted; let mqtt_connect = reset_session(&protocol,
connect.clone(),
client_id,
packet);
if let Some(listener) = protocol.broker.get_listener() {
if let Err(_e) = listener
.0
.connected(MqttBrokerProtocol::QuicMqtt311(Arc::new(protocol)),
mqtt_connect).await {
code = ConnectReturnCode::NotAuthorized
}
}
let ack = Connack {
session_present,
code,
};
if let Err(e) = send_packet(&connect, &Packet::Connack(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by connect, reason: {:?}",
e)));
}
Ok(())
}
fn reset_session(protocol: &QuicMqtt311,
connect: SocketHandle,
client_id: String,
packet: Connect) -> Arc<QosZeroSession> {
let mut session = QosZeroSession::with_client_id(client_id.clone());
session.bind_connect(connect.clone());
session.set_accept(true);
session.set_clean(packet.clean_session);
if let Some(w) = &packet.last_will {
session.set_will(w.topic.clone(), w.message.clone(), w.qos.to_u8(), w.retain);
}
session.set_user_pwd(packet.username, packet.password);
session.set_keep_alive(packet.keep_alive);
update_timeout(&connect, packet.client_id, packet.keep_alive);
protocol.broker.insert_session(client_id, session)
}
#[inline(always)]
fn get_client_context(connect: &SocketHandle) -> Result<(String, u16)> {
if let Some(handle) = connect.get_session::<QuicBrokerSession>() {
let h = handle.as_ref();
Ok((h.get_client_id().clone(), h.get_keep_alive()))
} else {
Err(Error::new(ErrorKind::ConnectionRefused,
"Mqtt subscribe failed, reason: invalid connect"))
}
}
async fn publish(protocol: QuicMqtt311,
connect: SocketHandle,
mut packet: Publish) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let qos = packet.qos.to_u8();
if qos > protocol.get_qos().to_u8() {
return Err(Error::new(ErrorKind::InvalidData,
"Mqtt publish failed, reason: invalid qos"));
}
let topic_path = if let Ok(path) = TopicPath::from_str(packet.topic_name.as_str()) {
path
} else {
return Err(Error::new(ErrorKind::InvalidData,
format!("Mqtt publish failed, topic: {:?}, reason: parse topic error",
packet.topic_name)));
};
if topic_path.wildcards || topic_path.path.is_empty() {
return Err(Error::new(ErrorKind::InvalidData,
"Mqtt publish failed, reason: invalid topic"));
}
packet.dup = false;
let mut retain = None;
if packet.retain {
retain = Some(packet.clone());
}
if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
let is_passive_receive = session.is_passive_receive();
let result = service
.0
.publish(MqttBrokerProtocol::QuicMqtt311(Arc::new(protocol)),
session.clone(),
packet.topic_name.clone(),
packet.payload).await;
if is_passive_receive {
if let Some(hibernate) = connect
.hibernate(connect.get_main_stream_id().unwrap().clone(),
QuicSocketReady::Writable) {
return hibernate.await;
} else {
return Ok(());
}
}
return result;
}
}
let is_public = true;
if let Some(sessions) = protocol.broker.subscribed(is_public,
&topic_path.path,
qos,
retain) {
let mut connects: Vec<SocketHandle> = Vec::with_capacity(sessions.len());
for session in sessions {
if let Some(connect) = session.get_connect() {
connects.push(connect.clone());
}
};
if let Err(e) = broadcast_packet(&connects[..],
&Packet::Publish(packet)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt broadcast failed by publish, reason: {:?}",
e)));
}
}
Ok(())
}
async fn subscribe(protocol: QuicMqtt311,
connect: SocketHandle,
packet: Subscribe) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let pkid = packet.pkid;
let topics = packet.topics;
let mut retains: Option<Retain> = None;
let mut return_codes = Vec::with_capacity(topics.len()); if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
let mut sub_topics = Vec::with_capacity(topics.len());
for topic in &topics {
sub_topics.push((topic.topic_path.clone(), topic.qos.to_u8()));
}
if let Err(_) = service
.0
.subscribe(MqttBrokerProtocol::QuicMqtt311(Arc::new(protocol.clone())),
session,
sub_topics).await {
for _ in topics {
return_codes.push(SubscribeReturnCodes::Failure);
}
} else {
let qos = protocol.get_qos();
for _ in topics {
return_codes.push(SubscribeReturnCodes::Success(qos));
}
}
}
} else {
if let Some(session) = protocol.broker.get_session(&client_id) {
let qos = protocol.get_qos();
let qos_val = qos.to_u8();
for topic in topics {
if qos_val < topic.qos.to_u8() {
return_codes.push(SubscribeReturnCodes::Failure);
} else {
return_codes.push(SubscribeReturnCodes::Success(qos));
}
retains = protocol.broker.subscribe(session.clone(), topic.topic_path);
}
}
}
let ack = Suback {
pkid,
return_codes,
};
if let Err(e) = send_packet(&connect, &Packet::Suback(ack)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by subscribe, reason: {:?}",
e)));
}
if let Some(r) = retains {
match r {
Retain::Single(p) => {
let _ = send_packet(&connect, &Packet::Publish(p));
},
Retain::Mutil(ps) => {
for p in ps {
let _ = send_packet(&connect, &Packet::Publish(p));
}
},
}
}
Ok(())
}
async fn unsubscribe(protocol: QuicMqtt311,
connect: SocketHandle,
packet: Unsubscribe) -> Result<()> {
let (client_id, keep_alive) = match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok(r) => {
r
}
};
update_timeout(&connect, client_id.clone(), keep_alive);
let pkid = packet.pkid;
let topics = packet.topics;
if let Some(service) = protocol.broker.get_service() {
if let Some(session) = protocol.broker.get_session(&client_id) {
if let Err(e) = service.0.unsubscribe(MqttBrokerProtocol::QuicMqtt311(Arc::new(protocol)),
session,
topics.clone()).await {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt unsubscribe failed, reason: {:?}",
e)));
}
}
} else {
if let Some(session) = protocol.broker.get_session(&client_id) {
for topic in topics {
protocol.broker.unsubscribe(&session, topic);
}
}
}
if let Err(e) = send_packet(&connect, &Packet::Unsuback(pkid)) {
return Err(Error::new(ErrorKind::BrokenPipe,
format!("Mqtt send failed by unsubscribe, reason: {:?}",
e)));
}
Ok(())
}
fn ping_req(connect: SocketHandle) -> Result<()> {
match get_client_context(&connect) {
Err(e) => {
return Err(e);
}
Ok((client_id, keep_alive)) => {
update_timeout(&connect, client_id.clone(), keep_alive);
}
}
if let Err(e) = send_packet(&connect, &Packet::Pingresp) {
return Err(Error::new(ErrorKind::BrokenPipe, format!("Mqtt send failed by ping, reason: {:?}", e)));
}
Ok(())
}
fn disconnect(connect: SocketHandle) -> Result<()> {
connect.close(0, Ok(()))
}
impl QuicMqtt311 {
pub const MQTT31: u8 = 0x3; pub const MQTT311: u8 = 0x4; pub const MAX_QOS: u8 = 0;
pub fn with_name(broker_name: &str,
qos: u8) -> Self {
let broker = MqttBroker::new();
broker.startup_expire_unsubscribed_topic_loop(Duration::from_millis(60000));
QuicMqtt311 {
broker_name: broker_name.to_string(),
qos: QoS::from_u8(qos).unwrap(),
broker,
}
}
#[inline(always)]
pub fn is_exist(&self, client_id: &str) -> bool {
self.broker.is_exist_session(&client_id.to_string())
}
pub fn get_broker_name(&self) -> &str {
self.broker_name.as_str()
}
#[inline(always)]
pub fn get_qos(&self) -> QoS {
self.qos
}
pub fn get_broker(&self) -> &MqttBroker {
&self.broker
}
}