use std::{io, sync::Arc};
use futures::{SinkExt as _, StreamExt as _};
use mqtt_codec_kit::v4::packet::{
DisconnectPacket, MqttDecoder, MqttEncoder, PingrespPacket, VariablePacket, VariablePacketError,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc,
};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
use crate::{
protocols::v4::publish::deliver_publish,
server::state::GlobalState,
types::{error::Error, outgoing::Outgoing},
};
use super::{
connect::{handle_connect, handle_disconnect, handle_offline},
publish::{handle_puback, handle_pubcomp, handle_publish, handle_pubrec, handle_pubrel},
subscribe::{handle_subscribe, handle_unsubscribe},
};
async fn read_from_client<T, D>(mut reader: FramedRead<T, D>, msg_tx: mpsc::Sender<VariablePacket>)
where
T: AsyncRead + Unpin,
D: Decoder<Item = VariablePacket, Error = VariablePacketError>,
{
loop {
match reader.next().await {
None => {
log::info!("client closed");
break;
}
Some(Err(e)) => {
log::warn!("read from client: {}", e);
break;
}
Some(Ok(packet)) => {
log::debug!("read from client: {:?}", packet);
if let Err(err) = msg_tx.send(packet).await {
log::error!("receiver closed: {}", err);
break;
}
}
}
}
}
async fn write_to_client<T, E>(
mut writer: FramedWrite<T, E>,
mut msg_rx: mpsc::Receiver<VariablePacket>,
global: Arc<GlobalState>,
) -> Result<(), Error>
where
T: AsyncWrite + Unpin,
E: Encoder<VariablePacket, Error = io::Error>,
{
let packet = match msg_rx.recv().await {
Some(VariablePacket::ConnectPacket(packet)) => packet,
_ => {
log::debug!("first packet is not CONNECT packet");
return Err(Error::InvalidConnectPacket);
}
};
let (mut session, mut outgoing_rx) =
match handle_connect(&mut writer, &packet, global.clone()).await {
Ok(r) => r,
Err(err) => {
log::error!(
"handle client#{} connect: {err}",
packet.client_identifier()
);
return Err(err);
}
};
let mut take_over = true;
loop {
tokio::select! {
packet = msg_rx.recv() => {
match packet {
Some(packet) => {
session.renew_last_packet_at();
let resp = match packet {
VariablePacket::PingreqPacket(_packet) => PingrespPacket::new().into(),
VariablePacket::PublishPacket(packet) => {
match handle_publish(&mut session, &packet, global.clone()).await {
Ok(Some(resp)) => resp,
Ok(None) => continue,
Err(err) => {
log::error!("handle publish message failed: {}", err);
break;
}
}
}
VariablePacket::PubrelPacket(packet) => {
handle_pubrel(&mut session, global.clone(), packet.packet_identifier())
.await
.into()
}
VariablePacket::PubackPacket(packet) => {
handle_puback(&mut session, packet.packet_identifier());
continue;
}
VariablePacket::PubrecPacket(packet) => {
handle_pubrec(&mut session, packet.packet_identifier()).into()
}
VariablePacket::SubscribePacket(packet) => {
let packets = handle_subscribe(&mut session, &packet, &global);
for packet in packets {
if let Err(err) = writer.send(packet).await {
log::error!(
"write subscribe ack to client#{} : {}",
&session.client_identifier(),
err
);
break;
}
}
continue;
}
VariablePacket::PubcompPacket(packet) => {
handle_pubcomp(&mut session, packet.packet_identifier());
continue;
}
VariablePacket::UnsubscribePacket(packet) => {
handle_unsubscribe(&mut session, &packet, &global).into()
}
VariablePacket::DisconnectPacket(_packet) => {
handle_disconnect(&mut session).await;
break;
}
_ => {
log::debug!("unsupported packet: {:?}", packet);
break;
}
};
if let Err(err) = writer.send(resp).await {
log::error!("write to client#{} : {}", &session.client_identifier(), err);
break;
}
}
None => {
log::warn!("incoming receive channel closed");
break;
}
}
}
outgoing = outgoing_rx.recv() => {
match outgoing {
Some(outgoing) => {
let resp = match outgoing {
Outgoing::Publish(qos, msg) => deliver_publish(&mut session, qos, msg).into(),
Outgoing::Online(sender) => {
global.remove_client(session.client_id(), session.subscribes().keys());
if let Err(err) = sender.send((&mut session).into()).await {
log::debug!(
"client#{} send session state : {}",
session.client_identifier(),
err,
);
}
if let Err(err) = writer.send(DisconnectPacket::new().into()).await {
log::debug!(
"client#{} write disconnect packet : {}",
session.client_identifier(),
err,
);
}
take_over = false;
break;
}
Outgoing::Kick(reason) => {
log::info!(
"client#{} kicked out: {}",
session.client_identifier(),
reason
);
global.remove_client(session.client_id(), session.subscribes().keys());
if let Err(err) = writer.send(DisconnectPacket::new().into()).await {
log::error!(
"send client#{} disconnect: {}",
session.client_identifier(),
err
);
}
take_over = false;
break;
}
};
if let Err(err) = writer.send(resp).await {
log::error!(
"write outgoing to client#{} : {}",
&session.client_identifier(),
err
);
break;
}
}
None => {
log::warn!("outgoing receive channel closed");
break;
}
}
}
}
}
tokio::spawn(handle_offline(
session,
outgoing_rx,
global.clone(),
take_over,
));
Ok(())
}
pub async fn read_write_loop<R, W>(reader: R, writer: W, global: Arc<GlobalState>)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let frame_reader = FramedRead::new(reader, MqttDecoder::new());
let frame_writer = FramedWrite::new(writer, MqttEncoder::new());
let (msg_tx, msg_rx) = mpsc::channel(8);
let mut read_task = tokio::spawn(async move {
read_from_client(frame_reader, msg_tx).await;
});
let mut write_task = tokio::spawn(async move {
if let Err(err) = write_to_client(frame_writer, msg_rx, global.clone()).await {
log::error!("write to client: {err}");
}
});
if tokio::try_join!(&mut read_task, &mut write_task).is_err() {
log::error!("read_task/write_task terminated");
read_task.abort();
};
}