#![deny(missing_docs)]
pub mod handler;
mod message;
mod state;
mod subscriptions;
use std::{sync::Arc, time::Duration};
use dashmap::DashMap;
use mqtt_format::v3::{
connect_return::MConnectReturnCode,
packet::{
MConnack, MConnect, MDisconnect, MPacket, MPingreq, MPingresp, MPuback, MPubcomp, MPublish,
MPubrec, MPubrel, MSuback, MSubscribe,
},
qos::MQualityOfService,
strings::MString,
subscription_acks::MSubscriptionAcks,
will::MLastWill,
};
use tokio::{
io::{AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
net::{TcpListener, ToSocketAddrs},
sync::broadcast::Sender as BroadcastSender,
sync::Mutex,
};
use tracing::{debug, error, info, trace, warn};
use crate::{error::MqttError, mqtt_stream::MqttStream, PacketIOError};
use subscriptions::{ClientInformation, SubscriptionManager};
use self::{
handler::{
AllowAllLogins, AllowAllSubscriptions, LoginError, LoginHandler, SubscriptionHandler,
},
message::MqttMessage,
state::ClientState,
subscriptions::TopicFilter,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ClientId(String);
impl ClientId {
#[allow(dead_code)]
pub(crate) fn new(id: String) -> Self {
ClientId(id)
}
pub fn get(&self) -> &str {
&self.0
}
}
impl<'message> TryFrom<MString<'message>> for ClientId {
type Error = ClientError;
fn try_from(ms: MString<'message>) -> Result<Self, Self::Error> {
Ok(ClientId(ms.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("An error occured during the handling of a packet")]
Packet(#[from] PacketIOError),
#[error("An authentication was rejected")]
Authentication(#[from] LoginError),
}
#[derive(Debug)]
pub(crate) struct ClientConnection {
reader: Mutex<ReadHalf<MqttStream>>,
writer: Mutex<WriteHalf<MqttStream>>,
}
#[derive(Debug)]
enum ClientSource {
UnsecuredTcp(TcpListener),
#[allow(dead_code)]
Duplex(tokio::sync::mpsc::Receiver<DuplexStream>),
}
impl ClientSource {
async fn accept(&mut self) -> Result<MqttStream, MqttError> {
Ok({
match self {
ClientSource::UnsecuredTcp(listener) => listener
.accept()
.await
.map(|tpl| tpl.0)
.map(MqttStream::UnsecuredTcp)?,
ClientSource::Duplex(recv) => recv
.recv()
.await
.map(MqttStream::MemoryDuplex)
.ok_or(MqttError::DuplexSourceClosed)?,
}
})
}
}
pub struct MqttServer<LoginH, SubH> {
inner: Arc<InnerServer<LoginH, SubH>>,
}
impl<LoginH, SubH> Clone for MqttServer<LoginH, SubH> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
struct InnerServer<LoginH, SubH> {
clients: Arc<DashMap<ClientId, ClientState>>,
client_source: Mutex<ClientSource>,
auth_handler: LoginH,
extra_listeners: BroadcastSender<MqttMessage>,
subscription_manager: Arc<SubscriptionManager<SubH>>,
}
impl MqttServer<AllowAllLogins, AllowAllSubscriptions> {
pub async fn serve_v3_unsecured_tcp<Addr: ToSocketAddrs>(
addr: Addr,
) -> Result<Self, MqttError> {
let bind = TcpListener::bind(addr).await?;
let (extra_listeners, _) = tokio::sync::broadcast::channel(50);
Ok(MqttServer {
inner: Arc::new(InnerServer {
clients: Arc::new(DashMap::new()),
client_source: Mutex::new(ClientSource::UnsecuredTcp(bind)),
auth_handler: AllowAllLogins,
extra_listeners,
subscription_manager: Arc::new(SubscriptionManager::new()),
}),
})
}
}
impl<LH: LoginHandler, SH: SubscriptionHandler> MqttServer<LH, SH> {
pub fn with_login_handler<NLH: LoginHandler>(
self,
new_login_handler: NLH,
) -> MqttServer<NLH, SH> {
let inner = Arc::try_unwrap(self.inner)
.unwrap_or_else(|_| panic!("Called after started listening"));
MqttServer {
inner: Arc::new(InnerServer {
clients: inner.clients,
client_source: inner.client_source,
auth_handler: new_login_handler,
extra_listeners: inner.extra_listeners,
subscription_manager: inner.subscription_manager,
}),
}
}
pub fn with_subscription_handler<NSH: SubscriptionHandler>(
self,
new_subscription_handler: NSH,
) -> MqttServer<LH, NSH> {
let inner = Arc::try_unwrap(self.inner)
.unwrap_or_else(|_| panic!("Called after started listening"));
MqttServer {
inner: Arc::new(InnerServer {
clients: inner.clients,
client_source: inner.client_source,
auth_handler: inner.auth_handler,
extra_listeners: inner.extra_listeners,
subscription_manager: Arc::new({
let manager = Arc::try_unwrap(inner.subscription_manager);
manager
.unwrap_or_else(|_| panic!("Called after started listening"))
.with_subscription_handler(new_subscription_handler)
}),
}),
}
}
pub async fn accept_new_clients(&self) -> Result<(), MqttError> {
let mut client_source = self
.inner
.client_source
.try_lock()
.map_err(|_| MqttError::AlreadyListening)?;
loop {
let server: MqttServer<LH, SH> = self.clone();
let client = client_source.accept().await?;
tokio::spawn(async move {
if let Err(client_error) = server.accept_client(client).await {
tracing::error!("Client error: {}", client_error)
}
});
}
}
pub fn subscribe_to_message<
Fut: std::future::Future<Output = ()>,
CB: FnMut(MqttMessage) -> Fut + 'static,
>(
&self,
topic_paths: Vec<String>,
mut callback: CB,
) -> impl std::future::Future<Output = Result<(), MqttError>> + 'static {
let mut listener = self.inner.extra_listeners.subscribe();
let topics = topic_paths
.into_iter()
.map(TopicFilter::parse_from)
.collect::<Vec<_>>();
async move {
loop {
let message = listener.recv().await;
match message {
Ok(message) => {
if topics.iter().any(|topic| {
let msg_topic = TopicFilter::parse_from(message.topic().to_string());
let mut i = 0;
loop {
match (topic.get(i), msg_topic.get(i)) {
(None, None) => break true,
(None, Some(_)) => break false,
(Some(_), None) => break false,
(Some(TopicFilter::MultiWildcard), Some(_)) => break true,
(Some(TopicFilter::SingleWildcard), Some(_)) => (),
(Some(left), Some(right)) => {
if left != right {
break false;
}
}
}
i += 1;
}
}) {
callback(message).await;
}
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
warn!("Subscriber lagged by {count} values")
}
}
}
Ok(())
}
}
async fn accept_client(&self, mut client: MqttStream) -> Result<(), ClientError> {
async fn send_connack(
session_present: bool,
connect_return_code: MConnectReturnCode,
client: &mut MqttStream,
) -> Result<(), ClientError> {
let conn_ack = MConnack {
session_present,
connect_return_code,
};
crate::write_packet(client, conn_ack).await?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn connect_client<'message, LH: LoginHandler, SubH: SubscriptionHandler>(
server: &MqttServer<LH, SubH>,
mut client: MqttStream,
_protocol_name: MString<'message>,
_protocol_level: u8,
clean_session: bool,
will: Option<MLastWill<'message>>,
username: Option<MString<'message>>,
password: Option<&'message [u8]>,
keep_alive: u16,
client_id: MString<'message>,
) -> Result<(), ClientError> {
let empty_client_id = client_id.is_empty();
let client_id = ClientId::try_from(client_id)?;
if empty_client_id && !clean_session {
if let Err(e) =
send_connack(false, MConnectReturnCode::IdentifierRejected, &mut client).await
{
debug!("Client could not shut down cleanly: {e}");
}
return Err(ClientError::Packet(PacketIOError::InvalidParsedPacket));
}
let session_present = if clean_session {
let _ = server.inner.clients.remove(&client_id);
false
} else {
server.inner.clients.contains_key(&client_id)
};
let client_id = Arc::new(client_id);
if let Err(err) = server
.inner
.auth_handler
.allow_login(client_id.clone(), username.as_deref(), password)
.await
{
send_connack(session_present, err.as_rejection_code(), &mut client).await?;
return Err(ClientError::Authentication(err));
}
send_connack(session_present, MConnectReturnCode::Accepted, &mut client).await?;
debug!(?client_id, "Accepted new connection");
let (client_reader, client_writer) = tokio::io::split(client);
let client_connection = Arc::new(ClientConnection {
reader: Mutex::new(client_reader),
writer: Mutex::new(client_writer),
});
{
let client_state = server
.inner
.clients
.entry((*client_id).clone())
.or_insert_with(ClientState::default);
client_state
.set_new_connection(client_connection.clone())
.await;
}
let mut last_will: Option<MqttMessage> = will
.as_ref()
.map(|will| MqttMessage::from_last_will(will, client_id.clone()));
let published_packets = server.inner.subscription_manager.clone();
let (published_packets_send, mut published_packets_rec) =
tokio::sync::mpsc::unbounded_channel::<MqttMessage>();
let send_loop = {
let publisher_client_id = client_id.clone();
let clients = server.inner.clients.clone();
tokio::spawn(async move {
loop {
match published_packets_rec.recv().await {
Some(packet) => {
if packet.author_id() == &*publisher_client_id {
trace!(?packet, "Skipping sending message to oneself");
continue;
}
let Some(client_state) = clients.get(&publisher_client_id) else {
debug!(?publisher_client_id, "Associated state no longer exists");
break;
};
client_state.send_message(packet).await;
}
None => {
debug!(
?publisher_client_id,
"No more senders, stopping sending cycle"
);
break;
}
}
}
})
};
let read_loop = {
let keep_alive = keep_alive;
let subscription_manager = server.inner.subscription_manager.clone();
let client_id = client_id.clone();
let clients = server.inner.clients.clone();
let extra_listener = server.inner.extra_listeners.clone();
tokio::spawn(async move {
let client_id = client_id;
let client_connection = client_connection;
let mut reader = client_connection.reader.lock().await;
let keep_alive_duration = Duration::from_secs((keep_alive as u64 * 150) / 100);
let subscription_manager = subscription_manager;
loop {
let packet = tokio::select! {
packet = crate::read_one_packet(&mut *reader) => {
match packet {
Ok(packet) => packet,
Err(e) => {
debug!("Could not read the next client packet: {e}");
break;
}
}
},
_timeout = tokio::time::sleep(keep_alive_duration) => {
debug!("Client timed out");
break;
}
};
match packet.get_packet() {
MPacket::Publish(MPublish {
dup: _,
qos,
retain,
topic_name,
id,
payload,
}) => {
let message = MqttMessage::new(
client_id.clone(),
payload.to_vec(),
topic_name.to_string(),
*retain,
*qos,
);
let _ = extra_listener.send(message.clone());
subscription_manager.route_message(message).await;
if *qos == MQualityOfService::AtLeastOnce {
let packet = MPuback { id: id.unwrap() };
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
}
if *qos == MQualityOfService::ExactlyOnce {
let Some(client_state) = clients.get(&client_id) else {
debug!(?client_id, "Associated state no longer exists");
break;
};
if let Err(_err) =
client_state.save_qos_exactly_once(id.unwrap())
{
debug!("Encountered an error while handling a PUBACK");
break;
}
let packet = MPubrec { id: id.unwrap() };
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
}
}
MPacket::Puback(ack @ MPuback { id }) => {
trace!(?client_id, ?ack, "Received puback");
let Some(client_state) = clients.get(&client_id) else {
debug!(?client_id, "Associated state no longer exists");
break;
};
if let Err(_err) = client_state.receive_puback(*id) {
debug!("Encountered an error while handling a PUBACK");
break;
}
}
MPacket::Pubrec(ack @ MPubrec { id }) => {
trace!(?client_id, ?ack, "Received pubrec");
let Some(client_state) = clients.get(&client_id) else {
debug!(?client_id, "Associated state no longer exists");
break;
};
if let Err(_err) = client_state.receive_pubrec(*id) {
debug!("Encountered an error while handling a PUBACK");
break;
}
trace!(?client_id, "Received PUBREC, responding with PUBREL");
let packet = MPubrel { id: *id };
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
trace!("Done responding to PUBREC with PUBREL");
}
MPacket::Pubrel(ack @ MPubrel { id }) => {
trace!(?client_id, ?ack, "Received pubrel");
let Some(client_state) = clients.get(&client_id) else {
debug!(?client_id, "Associated state no longer exists");
break;
};
if let Err(_err) = client_state.receive_pubrel(*id) {
debug!("Encountered an error while handling a PUBREL");
break;
}
let packet = MPubcomp { id: *id };
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
trace!("Done responding to PUBREL with PUBCOMP");
}
MPacket::Pubcomp(ack @ MPubcomp { id }) => {
trace!(?client_id, ?ack, "Received pubcomp");
let Some(client_state) = clients.get(&client_id) else {
debug!(?client_id, "Associated state no longer exists");
break;
};
if let Err(_err) = client_state.receive_pubcomp(*id) {
debug!("Encountered an error while handling a PUBCOMP");
break;
}
}
MPacket::Disconnect(MDisconnect) => {
last_will.take();
debug!("Client disconnected gracefully");
break;
}
MPacket::Subscribe(MSubscribe { id, subscriptions }) => {
let subscription_acks = subscription_manager
.subscribe(
Arc::new(ClientInformation {
client_id: client_id.clone(),
client_sender: published_packets_send.clone(),
}),
*subscriptions,
)
.await;
trace!(?client_id, "Received SUBSCRIBE, responding with SUBACK");
let packet = MSuback {
id: *id,
subscription_acks: MSubscriptionAcks {
acks: &subscription_acks,
},
};
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
}
MPacket::Pingreq(MPingreq) => {
trace!(
?client_id,
"Received ping request, responding with ping response"
);
let packet = MPingresp;
let mut writer = client_connection.writer.lock().await;
crate::write_packet(&mut *writer, packet).await?;
}
packet => info!("Received packet: {packet:?}, not handling it"),
}
}
if let Some(will) = last_will {
debug!(?will, "Sending out will");
let _ = published_packets.route_message(will);
}
if let Err(e) = client_connection.writer.lock().await.shutdown().await {
debug!("Client could not shut down cleanly: {e}");
}
Ok::<(), ClientError>(())
})
};
let (send_err, read_err) = tokio::join!(send_loop, read_loop);
match send_err {
Ok(_) => (),
Err(join_error) => error!(
"Send loop of client {} had an unexpected error: {join_error}",
&client_id.0
),
}
match read_err {
Ok(_) => (),
Err(join_error) => error!(
"Read loop of client {} had an unexpected error: {join_error}",
&client_id.0
),
}
Ok(())
}
trace!("Accepting new client");
let packet = crate::read_one_packet(&mut client).await?;
if let MPacket::Connect(MConnect {
client_id,
clean_session,
protocol_name,
protocol_level,
will,
username,
password,
keep_alive,
}) = packet.get_packet()
{
trace!(?client_id, "Connecting client");
connect_client(
self,
client,
*protocol_name,
*protocol_level,
*clean_session,
*will,
*username,
*password,
*keep_alive,
*client_id,
)
.await?;
} else {
if let Err(e) = client.shutdown().await {
debug!("Client could not shut down cleanly: {e}");
}
}
Ok(())
}
}