use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt;
use itertools::Itertools;
use crate::context::ServerContext;
use crate::net::MqttStream;
use crate::net::{Listener, ListenerType, Result};
use crate::types::ListenerId;
use crate::{v3, v5};
pub struct MqttServerBuilder {
scx: ServerContext,
listeners: Vec<(ListenerId, Listener)>,
}
impl MqttServerBuilder {
fn new(scx: ServerContext) -> Self {
Self { scx, listeners: Vec::default() }
}
pub fn listener(self, listen: Listener) -> Self {
let unique_id = listen.cfg.laddr.port();
if 0 == unique_id {
log::warn!(
"As the listener port is dynamically assigned, it is advisable to use `listener_by_id(mut self, listen: Listener, unique_id: u16)` and explicitly provide a unique_id."
);
}
self.listener_by_id(listen, unique_id)
}
pub fn listener_by_id(mut self, listen: Listener, unique_id: ListenerId) -> Self {
match self.scx.listen_cfgs.entry(unique_id) {
dashmap::mapref::entry::Entry::Occupied(entry) => {
panic!("unique_id already exists: {}", entry.key());
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(listen.cfg.clone());
}
}
self.listeners.push((unique_id, listen));
self
}
pub fn build(self) -> MqttServer {
MqttServer { inner: Arc::new(MqttServerInner { scx: self.scx, listeners: self.listeners }) }
}
}
#[derive(Clone)]
pub struct MqttServer {
inner: Arc<MqttServerInner>,
}
pub struct MqttServerInner {
scx: ServerContext,
listeners: Vec<(ListenerId, Listener)>,
}
impl Deref for MqttServer {
type Target = MqttServerInner;
#[inline]
fn deref(&self) -> &Self::Target {
self.inner.as_ref()
}
}
impl MqttServer {
#[allow(clippy::new_ret_no_self)]
pub fn new(scx: ServerContext) -> MqttServerBuilder {
MqttServerBuilder::new(scx)
}
pub fn start(self) {
tokio::spawn(async move {
if let Err(e) = self.run().await {
log::error!("Failed to start the MQTT server! {e}");
std::process::exit(1);
}
});
}
pub async fn run(self) -> Result<()> {
self.scx.extends.hook_mgr().before_startup().await;
futures::future::join_all(
self.listeners
.iter()
.map(|(lid, l)| match l.typ {
ListenerType::TCP => listen_tcp(self.scx.clone(), l, *lid).boxed(),
#[cfg(feature = "tls")]
ListenerType::TLS => listen_tls(self.scx.clone(), l, *lid).boxed(),
#[cfg(feature = "ws")]
ListenerType::WS => listen_ws(self.scx.clone(), l, *lid).boxed(),
#[cfg(feature = "tls")]
#[cfg(feature = "ws")]
ListenerType::WSS => listen_wss(self.scx.clone(), l, *lid).boxed(),
#[cfg(feature = "quic")]
ListenerType::QUIC => listen_quic(self.scx.clone(), l, *lid).boxed(),
})
.collect_vec(),
)
.await;
Ok(())
}
}
async fn listen_tcp(scx: ServerContext, l: &Listener, lid: ListenerId) {
loop {
match l.accept().await {
Ok(accept) => {
let scx = scx.clone();
tokio::spawn(async move {
log::debug!("TCP connection from {}", accept.remote_addr);
let stream = match accept.tcp() {
Ok(s) => s,
Err(e) => {
log::warn!("TCP accept error: {e:?}");
return;
}
};
match stream.mqtt().await {
Ok(MqttStream::V3(s)) => {
if let Err(e) = v3::process(scx.clone(), s, lid).await {
log::info!("MQTTv3 processing error: {e:?}");
}
}
Ok(MqttStream::V5(s)) => {
if let Err(e) = v5::process(scx.clone(), s, lid).await {
log::info!("MQTTv5 processing error: {e:?}");
}
}
Err(e) => {
log::info!("MQTT version detection failed: {e:?}");
}
}
});
}
Err(e) => {
log::info!("TCP listener error: {e:?}");
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
}
#[cfg(feature = "tls")]
async fn listen_tls(scx: ServerContext, l: &Listener, lid: ListenerId) {
loop {
match l.accept().await {
Ok(accept) => {
let scx = scx.clone();
tokio::spawn(async move {
log::debug!("TLS connection from {}", accept.remote_addr);
let stream = match accept.tls().await {
Ok(s) => s,
Err(e) => {
log::warn!("TLS accept error: {e:?}");
return;
}
};
match stream.mqtt().await {
Ok(MqttStream::V3(s)) => {
if let Err(e) = v3::process(scx.clone(), s, lid).await {
log::info!("MQTTv3/TLS processing error: {e:?}");
}
}
Ok(MqttStream::V5(s)) => {
if let Err(e) = v5::process(scx.clone(), s, lid).await {
log::info!("MQTTv5/TLS processing error: {e:?}");
}
}
Err(e) => {
log::info!("MQTT/TLS version detection failed: {e:?}");
}
}
});
}
Err(e) => {
log::info!("TLS listener error: {e:?}");
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
}
#[cfg(feature = "ws")]
async fn listen_ws(scx: ServerContext, l: &Listener, lid: ListenerId) {
loop {
match l.accept().await {
Ok(accept) => {
let scx = scx.clone();
tokio::spawn(async move {
log::debug!("WebSocket connection from {}", accept.remote_addr);
let stream = match accept.ws().await {
Ok(s) => s,
Err(e) => {
log::warn!("WebSocket accept error: {e:?}");
return;
}
};
match stream.mqtt().await {
Ok(MqttStream::V3(s)) => {
if let Err(e) = v3::process(scx.clone(), s, lid).await {
log::info!("MQTTv3/WS processing error: {e:?}");
}
}
Ok(MqttStream::V5(s)) => {
if let Err(e) = v5::process(scx.clone(), s, lid).await {
log::info!("MQTTv5/WS processing error: {e:?}");
}
}
Err(e) => {
log::info!("MQTT/WS version detection failed: {e:?}");
}
}
});
}
Err(e) => {
log::info!("WebSocket listener error: {e:?}");
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
}
#[cfg(all(feature = "tls", feature = "ws"))]
async fn listen_wss(scx: ServerContext, l: &Listener, lid: ListenerId) {
loop {
match l.accept().await {
Ok(accept) => {
let scx = scx.clone();
tokio::spawn(async move {
log::debug!("WSS connection from {}", accept.remote_addr);
let stream = match accept.wss().await {
Ok(s) => s,
Err(e) => {
log::warn!("WSS accept error: {e:?}");
return;
}
};
match stream.mqtt().await {
Ok(MqttStream::V3(s)) => {
if let Err(e) = v3::process(scx.clone(), s, lid).await {
log::info!("MQTTv3/WSS processing error: {e:?}");
}
}
Ok(MqttStream::V5(s)) => {
if let Err(e) = v5::process(scx.clone(), s, lid).await {
log::info!("MQTTv5/WSS processing error: {e:?}");
}
}
Err(e) => {
log::info!("MQTT/WSS version detection failed: {e:?}");
}
}
});
}
Err(e) => {
log::info!("WSS listener error: {e:?}");
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
}
#[cfg(feature = "quic")]
async fn listen_quic(scx: ServerContext, l: &Listener, lid: ListenerId) {
loop {
match l.accept_quic().await {
Ok(accept) => {
let scx = scx.clone();
tokio::spawn(async move {
log::debug!("QUIC connection from {}", accept.remote_addr);
let stream = match accept.quic().await {
Ok(s) => s,
Err(e) => {
log::warn!("QUIC accept error: {e:?}");
return;
}
};
match stream.mqtt().await {
Ok(MqttStream::V3(s)) => {
if let Err(e) = v3::process(scx.clone(), s, lid).await {
log::info!("MQTTv3/QUIC processing error: {e:?}");
}
}
Ok(MqttStream::V5(s)) => {
if let Err(e) = v5::process(scx.clone(), s, lid).await {
log::info!("MQTTv5/QUIC processing error: {e:?}");
}
}
Err(e) => {
log::info!("MQTT/QUIC version detection failed: {e:?}");
}
}
});
}
Err(e) => {
log::info!("QUIC listener error: {e:?}");
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
}