use bytes::{Bytes, BytesMut};
use clasp_core::{codec, Message, SetMessage, SignalType, Value};
use dashmap::DashMap;
use mqttbytes::v4::{
ConnAck, ConnectReturnCode, Packet, PingResp, PubAck, Publish, SubAck, SubscribeReasonCode,
};
use mqttbytes::QoS;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::error::{Result, RouterError};
use crate::session::{Session, SessionId};
use crate::state::RouterState;
use crate::subscription::{Subscription, SubscriptionManager};
use clasp_core::security::{TokenValidator, ValidationResult};
#[cfg(feature = "mqtts")]
use tokio_rustls::TlsAcceptor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MqttServerConfig {
pub bind_addr: String,
#[serde(default = "default_namespace")]
pub namespace: String,
#[serde(default)]
pub require_auth: bool,
#[serde(default)]
pub tls: Option<TlsConfig>,
#[serde(default)]
pub max_clients: usize,
#[serde(default = "default_session_timeout")]
pub session_timeout_secs: u64,
}
fn default_namespace() -> String {
"/mqtt".to_string()
}
fn default_session_timeout() -> u64 {
300
}
impl Default for MqttServerConfig {
fn default() -> Self {
Self {
bind_addr: "0.0.0.0:1883".to_string(),
namespace: "/mqtt".to_string(),
require_auth: false,
tls: None,
max_clients: 0,
session_timeout_secs: 300,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
struct MqttSession {
clasp_session_id: SessionId,
client_id: String,
peer_addr: SocketAddr,
mqtt_subscriptions: HashMap<String, u32>,
next_sub_id: AtomicU32,
last_activity: RwLock<Instant>,
sender: mpsc::Sender<Bytes>,
}
impl MqttSession {
fn new(
clasp_session_id: SessionId,
client_id: String,
peer_addr: SocketAddr,
sender: mpsc::Sender<Bytes>,
) -> Self {
Self {
clasp_session_id,
client_id,
peer_addr,
mqtt_subscriptions: HashMap::new(),
next_sub_id: AtomicU32::new(1),
last_activity: RwLock::new(Instant::now()),
sender,
}
}
fn touch(&self) {
*self.last_activity.write() = Instant::now();
}
fn idle_duration(&self) -> Duration {
self.last_activity.read().elapsed()
}
fn next_subscription_id(&self) -> u32 {
self.next_sub_id.fetch_add(1, Ordering::Relaxed)
}
}
pub struct MqttServerAdapter {
config: MqttServerConfig,
sessions: Arc<DashMap<SessionId, Arc<Session>>>,
subscriptions: Arc<SubscriptionManager>,
state: Arc<RouterState>,
mqtt_sessions: Arc<DashMap<String, Arc<MqttSession>>>,
running: Arc<RwLock<bool>>,
validator: Option<Arc<dyn TokenValidator>>,
#[cfg(feature = "mqtts")]
tls_acceptor: Option<TlsAcceptor>,
}
impl MqttServerAdapter {
pub fn new(
config: MqttServerConfig,
sessions: Arc<DashMap<SessionId, Arc<Session>>>,
subscriptions: Arc<SubscriptionManager>,
state: Arc<RouterState>,
) -> Self {
Self {
config,
sessions,
subscriptions,
state,
mqtt_sessions: Arc::new(DashMap::new()),
running: Arc::new(RwLock::new(false)),
validator: None,
#[cfg(feature = "mqtts")]
tls_acceptor: None,
}
}
pub fn with_validator(mut self, validator: Arc<dyn TokenValidator>) -> Self {
self.validator = Some(validator);
self
}
pub async fn serve(&self) -> Result<()> {
let listener = TcpListener::bind(&self.config.bind_addr)
.await
.map_err(|e| RouterError::Transport(e.into()))?;
info!("MQTT server listening on {}", self.config.bind_addr);
*self.running.write() = true;
self.start_cleanup_task();
while *self.running.read() {
match listener.accept().await {
Ok((stream, peer_addr)) => {
if self.config.max_clients > 0
&& self.mqtt_sessions.len() >= self.config.max_clients
{
warn!(
"Rejecting MQTT connection from {}: max clients reached",
peer_addr
);
continue;
}
info!("MQTT connection from {}", peer_addr);
self.spawn_connection_handler(stream, peer_addr);
}
Err(e) => {
error!("MQTT accept error: {}", e);
}
}
}
Ok(())
}
fn start_cleanup_task(&self) {
let mqtt_sessions = Arc::clone(&self.mqtt_sessions);
let clasp_sessions = Arc::clone(&self.sessions);
let subscriptions = Arc::clone(&self.subscriptions);
let running = Arc::clone(&self.running);
let timeout = Duration::from_secs(self.config.session_timeout_secs);
tokio::spawn(async move {
let check_interval = Duration::from_secs(30);
loop {
tokio::time::sleep(check_interval).await;
if !*running.read() {
break;
}
let timed_out: Vec<String> = mqtt_sessions
.iter()
.filter(|entry| entry.value().idle_duration() > timeout)
.map(|entry| entry.key().clone())
.collect();
for client_id in timed_out {
if let Some((_, mqtt_session)) = mqtt_sessions.remove(&client_id) {
info!(
"MQTT session {} timed out after {:?}",
client_id,
mqtt_session.idle_duration()
);
clasp_sessions.remove(&mqtt_session.clasp_session_id);
subscriptions.remove_session(&mqtt_session.clasp_session_id);
}
}
}
});
}
fn spawn_connection_handler(&self, stream: TcpStream, peer_addr: SocketAddr) {
let config = self.config.clone();
let sessions = Arc::clone(&self.sessions);
let subscriptions = Arc::clone(&self.subscriptions);
let state = Arc::clone(&self.state);
let mqtt_sessions = Arc::clone(&self.mqtt_sessions);
let running = Arc::clone(&self.running);
let validator = self.validator.clone();
tokio::spawn(async move {
if let Err(e) = handle_mqtt_connection(
stream,
peer_addr,
config,
sessions,
subscriptions,
state,
mqtt_sessions,
running,
validator,
)
.await
{
debug!("MQTT connection {} ended: {}", peer_addr, e);
}
});
}
pub fn stop(&self) {
*self.running.write() = false;
}
pub fn client_count(&self) -> usize {
self.mqtt_sessions.len()
}
}
async fn handle_mqtt_connection(
mut stream: TcpStream,
peer_addr: SocketAddr,
config: MqttServerConfig,
clasp_sessions: Arc<DashMap<SessionId, Arc<Session>>>,
subscriptions: Arc<SubscriptionManager>,
state: Arc<RouterState>,
mqtt_sessions: Arc<DashMap<String, Arc<MqttSession>>>,
running: Arc<RwLock<bool>>,
validator: Option<Arc<dyn TokenValidator>>,
) -> Result<()> {
let mut read_buf = BytesMut::with_capacity(4096);
let (tx, mut rx) = mpsc::channel::<Bytes>(100);
let connect = loop {
if !*running.read() {
return Ok(());
}
let n = stream.read_buf(&mut read_buf).await?;
if n == 0 {
return Err(RouterError::Protocol("Connection closed".into()));
}
match mqttbytes::v4::read(&mut read_buf, 65535) {
Ok(Packet::Connect(connect)) => break connect,
Ok(other) => {
warn!("Expected CONNECT, got {:?}", other);
return Err(RouterError::Protocol("Expected CONNECT packet".into()));
}
Err(mqttbytes::Error::InsufficientBytes(_)) => {
continue;
}
Err(e) => {
return Err(RouterError::Protocol(format!("MQTT parse error: {}", e)));
}
}
};
let client_id = connect.client_id.clone();
info!("MQTT CONNECT from {} (client_id: {})", peer_addr, client_id);
if config.require_auth {
let login = match &connect.login {
Some(login) => login,
None => {
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::BadUserNamePassword,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
return Err(RouterError::Auth("Authentication required".into()));
}
};
let token = &login.password;
if let Some(ref validator) = validator {
match validator.validate(token) {
ValidationResult::Valid(_token_info) => {
debug!("MQTT client {} authenticated successfully", client_id);
}
ValidationResult::Invalid(reason) => {
warn!("MQTT auth failed for {}: {}", client_id, reason);
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::BadUserNamePassword,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
return Err(RouterError::Auth(format!("Invalid token: {}", reason)));
}
ValidationResult::NotMyToken => {
warn!(
"MQTT auth failed for {}: unrecognized token format",
client_id
);
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::BadUserNamePassword,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
return Err(RouterError::Auth("Unrecognized token format".into()));
}
ValidationResult::Expired => {
warn!("MQTT auth failed for {}: token expired", client_id);
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::NotAuthorized,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
return Err(RouterError::Auth("Token expired".into()));
}
}
} else {
warn!("MQTT require_auth enabled but no validator configured - rejecting connection");
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::ServiceUnavailable,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
return Err(RouterError::Auth("No token validator configured".into()));
}
}
let mqtt_sender = MqttTransportSender::new(tx.clone());
let clasp_session = Arc::new(Session::new(
Arc::new(mqtt_sender),
format!("mqtt:{}", client_id),
vec!["mqtt".to_string()],
));
let clasp_session_id = clasp_session.id.clone();
clasp_sessions.insert(clasp_session_id.clone(), clasp_session);
let mqtt_session = Arc::new(MqttSession::new(
clasp_session_id.clone(),
client_id.clone(),
peer_addr,
tx,
));
mqtt_sessions.insert(client_id.clone(), Arc::clone(&mqtt_session));
let connack = ConnAck {
session_present: false,
code: ConnectReturnCode::Success,
};
let mut buf = BytesMut::new();
connack.write(&mut buf)?;
stream.write_all(&buf).await?;
info!(
"MQTT session established: {} -> {}",
client_id, clasp_session_id
);
loop {
if !*running.read() {
break;
}
tokio::select! {
result = stream.read_buf(&mut read_buf) => {
match result {
Ok(0) => {
info!("MQTT client {} disconnected", client_id);
break;
}
Ok(_) => {
mqtt_session.touch();
loop {
match mqttbytes::v4::read(&mut read_buf, 65535) {
Ok(packet) => {
if let Err(e) = handle_mqtt_packet(
&packet,
&mqtt_session,
&config,
&subscriptions,
&state,
&clasp_sessions,
&mut stream,
).await {
warn!("Error handling MQTT packet: {}", e);
}
if matches!(packet, Packet::Disconnect) {
info!("MQTT client {} sent DISCONNECT", client_id);
break;
}
}
Err(mqttbytes::Error::InsufficientBytes(_)) => {
break;
}
Err(e) => {
warn!("MQTT parse error: {}", e);
break;
}
}
}
}
Err(e) => {
error!("MQTT read error: {}", e);
break;
}
}
}
Some(data) = rx.recv() => {
if let Err(e) = stream.write_all(&data).await {
error!("MQTT write error: {}", e);
break;
}
}
}
}
mqtt_sessions.remove(&client_id);
clasp_sessions.remove(&clasp_session_id);
subscriptions.remove_session(&clasp_session_id);
info!("MQTT session {} cleaned up", client_id);
Ok(())
}
async fn handle_mqtt_packet(
packet: &Packet,
mqtt_session: &Arc<MqttSession>,
config: &MqttServerConfig,
subscriptions: &Arc<SubscriptionManager>,
state: &Arc<RouterState>,
clasp_sessions: &Arc<DashMap<SessionId, Arc<Session>>>,
stream: &mut TcpStream,
) -> Result<()> {
match packet {
Packet::Subscribe(subscribe) => {
debug!(
"MQTT SUBSCRIBE from {}: {:?}",
mqtt_session.client_id, subscribe.filters
);
let mut return_codes = Vec::new();
for filter in &subscribe.filters {
let topic_filter = &filter.path;
let qos = filter.qos;
let clasp_pattern = mqtt_topic_to_clasp_pattern(&config.namespace, topic_filter);
let sub_id = mqtt_session.next_subscription_id();
match Subscription::new(
sub_id,
mqtt_session.clasp_session_id.clone(),
&clasp_pattern,
vec![], Default::default(),
) {
Ok(subscription) => {
subscriptions.add(subscription);
return_codes.push(SubscribeReasonCode::Success(qos));
debug!(
"MQTT subscription {} -> CLASP pattern {}",
topic_filter, clasp_pattern
);
let snapshot = state.snapshot(&clasp_pattern);
for param in snapshot.params {
let mqtt_topic =
clasp_address_to_mqtt_topic(&config.namespace, ¶m.address);
let payload = value_to_mqtt_payload(¶m.value);
let publish = Publish::new(&mqtt_topic, QoS::AtMostOnce, payload);
let mut buf = BytesMut::new();
publish.write(&mut buf)?;
stream.write_all(&buf).await?;
}
}
Err(e) => {
warn!("Invalid MQTT subscription pattern: {}", e);
return_codes.push(SubscribeReasonCode::Failure);
}
}
}
let suback = SubAck {
pkid: subscribe.pkid,
return_codes,
};
let mut buf = BytesMut::new();
suback.write(&mut buf)?;
stream.write_all(&buf).await?;
}
Packet::Publish(publish) => {
debug!(
"MQTT PUBLISH from {}: {} ({} bytes)",
mqtt_session.client_id,
publish.topic,
publish.payload.len()
);
let clasp_address = mqtt_topic_to_clasp_address(&config.namespace, &publish.topic);
let value = mqtt_payload_to_value(&publish.payload);
let set_msg = SetMessage {
address: clasp_address.clone(),
value: value.clone(),
revision: None,
lock: false,
unlock: false,
ttl: None,
};
if let Ok(revision) = state.apply_set(&set_msg, &mqtt_session.clasp_session_id) {
let subscribers =
subscriptions.find_subscribers(&clasp_address, Some(SignalType::Param));
let mut updated_set = set_msg.clone();
updated_set.revision = Some(revision);
let broadcast_msg = Message::Set(updated_set);
if let Ok(bytes) = codec::encode(&broadcast_msg) {
for sub_session_id in subscribers {
if sub_session_id != mqtt_session.clasp_session_id {
if let Some(sub_session) = clasp_sessions.get(&sub_session_id) {
let _ = sub_session.try_send(bytes.clone());
}
}
}
}
if publish.qos == QoS::AtLeastOnce {
let puback = PubAck { pkid: publish.pkid };
let mut buf = BytesMut::new();
puback.write(&mut buf)?;
stream.write_all(&buf).await?;
}
}
}
Packet::Unsubscribe(unsubscribe) => {
debug!(
"MQTT UNSUBSCRIBE from {}: {:?}",
mqtt_session.client_id, unsubscribe.topics
);
let unsuback = mqttbytes::v4::UnsubAck {
pkid: unsubscribe.pkid,
};
let mut buf = BytesMut::new();
unsuback.write(&mut buf)?;
stream.write_all(&buf).await?;
}
Packet::PingReq => {
debug!("MQTT PINGREQ from {}", mqtt_session.client_id);
let pingresp = PingResp;
let mut buf = BytesMut::new();
pingresp.write(&mut buf)?;
stream.write_all(&buf).await?;
}
Packet::Disconnect => {
info!("MQTT DISCONNECT from {}", mqtt_session.client_id);
}
other => {
debug!("Unhandled MQTT packet: {:?}", other);
}
}
Ok(())
}
fn mqtt_topic_to_clasp_pattern(namespace: &str, topic: &str) -> String {
let clasp_path = topic.replace('+', "*").replace('#', "**").replace('/', "/");
format!("{}/{}", namespace, clasp_path)
}
fn mqtt_topic_to_clasp_address(namespace: &str, topic: &str) -> String {
format!("{}/{}", namespace, topic)
}
fn clasp_address_to_mqtt_topic(namespace: &str, address: &str) -> String {
address
.strip_prefix(namespace)
.unwrap_or(address)
.trim_start_matches('/')
.to_string()
}
fn mqtt_payload_to_value(payload: &[u8]) -> Value {
if let Ok(text) = std::str::from_utf8(payload) {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
return json_to_clasp_value(json);
}
if let Ok(f) = text.parse::<f64>() {
return Value::Float(f);
}
match text {
"true" => return Value::Bool(true),
"false" => return Value::Bool(false),
_ => {}
}
return Value::String(text.to_string());
}
Value::Bytes(payload.to_vec())
}
fn json_to_clasp_value(json: serde_json::Value) -> Value {
match json {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Bool(b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Int(i)
} else if let Some(f) = n.as_f64() {
Value::Float(f)
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::String(s),
serde_json::Value::Array(arr) => {
Value::Array(arr.into_iter().map(json_to_clasp_value).collect())
}
serde_json::Value::Object(obj) => {
let map: HashMap<String, Value> = obj
.into_iter()
.map(|(k, v)| (k, json_to_clasp_value(v)))
.collect();
Value::Map(map)
}
}
}
fn value_to_mqtt_payload(value: &Value) -> Vec<u8> {
match value {
Value::Null => b"null".to_vec(),
Value::Bool(b) => (if *b { "true" } else { "false" }).as_bytes().to_vec(),
Value::Int(i) => i.to_string().into_bytes(),
Value::Float(f) => f.to_string().into_bytes(),
Value::String(s) => s.as_bytes().to_vec(),
Value::Bytes(b) => b.clone(),
Value::Array(_) | Value::Map(_) => {
serde_json::to_vec(value).unwrap_or_else(|_| b"null".to_vec())
}
}
}
struct MqttTransportSender {
tx: mpsc::Sender<Bytes>,
}
impl MqttTransportSender {
fn new(tx: mpsc::Sender<Bytes>) -> Self {
Self { tx }
}
}
#[async_trait::async_trait]
impl clasp_transport::TransportSender for MqttTransportSender {
async fn send(&self, data: Bytes) -> std::result::Result<(), clasp_transport::TransportError> {
if let Ok((msg, _)) = codec::decode(&data) {
if let Some(mqtt_data) = clasp_to_mqtt_publish(&msg) {
self.tx
.send(mqtt_data)
.await
.map_err(|e| clasp_transport::TransportError::SendFailed(e.to_string()))?;
}
}
Ok(())
}
fn try_send(&self, data: Bytes) -> std::result::Result<(), clasp_transport::TransportError> {
if let Ok((msg, _)) = codec::decode(&data) {
if let Some(mqtt_data) = clasp_to_mqtt_publish(&msg) {
self.tx
.try_send(mqtt_data)
.map_err(|e| clasp_transport::TransportError::SendFailed(e.to_string()))?;
}
}
Ok(())
}
fn is_connected(&self) -> bool {
!self.tx.is_closed()
}
async fn close(&self) -> std::result::Result<(), clasp_transport::TransportError> {
Ok(())
}
}
fn clasp_to_mqtt_publish(msg: &Message) -> Option<Bytes> {
let (address, value) = match msg {
Message::Set(set) => (&set.address, &set.value),
Message::Publish(pub_msg) => {
if let Some(val) = &pub_msg.value {
(&pub_msg.address, val)
} else {
return None;
}
}
Message::Snapshot(snapshot) => {
return None;
}
_ => return None,
};
let topic = address
.strip_prefix("/mqtt/")
.unwrap_or(address.trim_start_matches('/'));
let payload = value_to_mqtt_payload(value);
let publish = Publish::new(topic, QoS::AtMostOnce, payload);
let mut buf = BytesMut::new();
if publish.write(&mut buf).is_ok() {
Some(buf.freeze())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topic_to_clasp_pattern() {
assert_eq!(
mqtt_topic_to_clasp_pattern("/mqtt", "sensors/temp"),
"/mqtt/sensors/temp"
);
assert_eq!(
mqtt_topic_to_clasp_pattern("/mqtt", "sensors/+/temp"),
"/mqtt/sensors/*/temp"
);
assert_eq!(
mqtt_topic_to_clasp_pattern("/mqtt", "sensors/#"),
"/mqtt/sensors/**"
);
}
#[test]
fn test_address_to_topic() {
assert_eq!(
clasp_address_to_mqtt_topic("/mqtt", "/mqtt/sensors/temp"),
"sensors/temp"
);
}
#[test]
fn test_payload_parsing() {
let value = mqtt_payload_to_value(b"42.5");
assert!(matches!(value, Value::Float(f) if (f - 42.5).abs() < 0.001));
let value = mqtt_payload_to_value(b"true");
assert!(matches!(value, Value::Bool(true)));
let value = mqtt_payload_to_value(b"{\"temp\": 25}");
assert!(matches!(value, Value::Map(_)));
let value = mqtt_payload_to_value(b"hello world");
assert!(matches!(value, Value::String(s) if s == "hello world"));
}
}