use crate::broker::auth::EnhancedAuthStatus;
use crate::broker::storage::{ClientSession, DynamicStorage, StorageBackend};
use crate::error::{MqttError, Result};
use crate::packet::auth::AuthPacket;
use crate::packet::connack::ConnAckPacket;
use crate::packet::connect::ConnectPacket;
use crate::packet::Packet;
use crate::protocol::v5::reason_codes::ReasonCode;
use crate::time::Duration;
use crate::transport::PacketIo;
use crate::types::ProtocolVersion;
use crate::QoS;
use std::sync::Arc;
use tracing::{debug, info, trace, warn};
use crate::broker::router::RoutableMessage;
use super::{AuthState, ClientHandler, PendingConnect};
enum AuthOutcome {
Authenticated(Box<ConnectPacket>),
ConnectDeferred,
Failed(MqttError),
}
impl ClientHandler {
pub(super) async fn validate_protocol_version(&mut self, protocol_version: u8) -> Result<()> {
match protocol_version {
4 | 5 => {
self.protocol_version = protocol_version;
debug!(
protocol_version,
addr = %self.client_addr,
"Client using MQTT v{}",
if protocol_version == 5 { "5.0" } else { "3.1.1" }
);
Ok(())
}
_ => {
info!(
protocol_version,
addr = %self.client_addr,
"Rejecting connection: unsupported protocol version"
);
let connack = ConnAckPacket::new(false, ReasonCode::UnsupportedProtocolVersion);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
Err(MqttError::UnsupportedProtocolVersion)
}
}
}
pub(super) async fn handle_connect(&mut self, mut connect: ConnectPacket) -> Result<()> {
debug!(
client_id = %connect.client_id,
addr = %self.client_addr,
protocol_version = connect.protocol_version,
clean_start = connect.clean_start,
keep_alive = connect.keep_alive,
"Processing CONNECT packet"
);
self.validate_protocol_version(connect.protocol_version)
.await?;
if let Some(redirect) = self.check_load_balancer_redirect(&connect).await? {
return redirect;
}
self.request_problem_information = connect
.properties
.get_request_problem_information()
.unwrap_or(true);
self.request_response_information = connect
.properties
.get_request_response_information()
.unwrap_or(false);
let assigned_client_id = Self::assign_client_id_if_empty(&mut connect);
self.validate_client_id(&connect).await?;
let connect = match self
.handle_authentication(connect, assigned_client_id.clone())
.await?
{
AuthOutcome::Authenticated(connect) => *connect,
AuthOutcome::ConnectDeferred => return Ok(()),
AuthOutcome::Failed(err) => return Err(err),
};
self.validate_will_capabilities(&connect).await?;
self.client_id = Some(connect.client_id.clone());
self.keep_alive = Duration::from_secs(u64::from(connect.keep_alive));
self.client_receive_maximum = connect.properties.get_receive_maximum().unwrap_or(65535);
self.client_max_packet_size = connect.properties.get_maximum_packet_size();
debug!(
client_id = %connect.client_id,
receive_maximum = self.client_receive_maximum,
max_packet_size = ?self.client_max_packet_size,
"Client connection limits"
);
#[cfg(feature = "opentelemetry")]
let session_present = {
use tracing::Instrument;
let span = tracing::info_span!(
"mqtt.session",
mqtt.client_id = %connect.client_id,
);
self.handle_session(&connect).instrument(span).await?
};
#[cfg(not(feature = "opentelemetry"))]
let session_present = self.handle_session(&connect).await?;
let mut connack = if self.protocol_version == 4 {
ConnAckPacket::new_v311(session_present, ReasonCode::Success)
} else {
ConnAckPacket::new(session_present, ReasonCode::Success)
};
if self.protocol_version == 5 {
self.build_connack_properties(&mut connack, assigned_client_id.as_ref());
}
debug!(
client_id = %connect.client_id,
session_present = session_present,
assigned_client_id = ?assigned_client_id,
"Sending CONNACK"
);
trace!("CONNACK properties: {:?}", connack.properties);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
debug!("CONNACK sent successfully");
if session_present {
self.deliver_queued_messages(&connect.client_id).await?;
self.resend_inflight_messages().await?;
}
Ok(())
}
async fn check_load_balancer_redirect(
&mut self,
connect: &ConnectPacket,
) -> Result<Option<Result<()>>> {
let Some(ref lb) = self.config.load_balancer else {
return Ok(None);
};
let generated;
let client_id = if connect.client_id.is_empty() {
use std::sync::atomic::{AtomicU64, Ordering};
static LB_COUNTER: AtomicU64 = AtomicU64::new(0);
generated = format!("auto-{}", LB_COUNTER.fetch_add(1, Ordering::Relaxed));
&generated
} else {
&connect.client_id
};
if let Some(backend) = lb.select_backend(client_id) {
let backend = backend.to_string();
info!(
client_id = %client_id,
backend = %backend,
addr = %self.client_addr,
"Redirecting client to backend"
);
let connack = ConnAckPacket::new(false, ReasonCode::UseAnotherServer)
.with_server_reference(backend);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Ok(Some(Err(MqttError::UseAnotherServer)));
}
Ok(None)
}
fn assign_client_id_if_empty(connect: &mut ConnectPacket) -> Option<String> {
if connect.client_id.is_empty() {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
let generated_id = format!("auto-{}", COUNTER.fetch_add(1, Ordering::SeqCst));
debug!("Generated client ID '{}' for empty client ID", generated_id);
connect.client_id.clone_from(&generated_id);
Some(generated_id)
} else {
None
}
}
async fn validate_client_id(&mut self, connect: &ConnectPacket) -> Result<()> {
if !crate::is_path_safe_client_id(&connect.client_id) {
warn!(
client_id = %connect.client_id,
addr = %self.client_addr,
"Rejecting connection: invalid client identifier"
);
let connack = ConnAckPacket::new(false, ReasonCode::ClientIdentifierNotValid);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Err(MqttError::InvalidClientId(connect.client_id.clone()));
}
if connect.client_id.starts_with("cert:") && self.transport.client_cert_info().is_none() {
warn!(
client_id = %connect.client_id,
transport = self.transport.transport_type(),
addr = %self.client_addr,
"Rejecting cert: client ID on connection without verified client certificate"
);
let connack = ConnAckPacket::new(false, ReasonCode::NotAuthorized);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Err(MqttError::AuthenticationFailed);
}
Ok(())
}
async fn handle_authentication(
&mut self,
connect: ConnectPacket,
assigned_client_id: Option<String>,
) -> Result<AuthOutcome> {
let auth_method_prop = connect.properties.get_authentication_method();
let auth_data_prop = connect.properties.get_authentication_data();
if let Some(method) = auth_method_prop {
self.auth_method = Some(method.clone());
if self.auth_provider.supports_enhanced_auth() {
self.client_id = Some(connect.client_id.clone());
let result = self
.auth_provider
.authenticate_enhanced(method, auth_data_prop, &connect.client_id)
.await?;
match result.status {
EnhancedAuthStatus::Success => {
self.auth_state = AuthState::Completed;
self.user_id = result.user_id;
}
EnhancedAuthStatus::Continue => {
self.auth_state = AuthState::InProgress;
self.keep_alive = Duration::from_secs(u64::from(connect.keep_alive));
let auth_packet = AuthPacket::continue_authentication(
result.auth_method,
result.auth_data,
)?;
self.transport
.write_packet(Packet::Auth(auth_packet))
.await?;
self.pending_connect = Some(PendingConnect {
connect,
assigned_client_id,
});
return Ok(AuthOutcome::ConnectDeferred);
}
EnhancedAuthStatus::Failed => {
let mut connack = self.new_connack(false, result.reason_code);
if self.protocol_version == 5 && self.request_problem_information {
if let Some(reason) = result.reason_string {
connack.properties.set_reason_string(reason);
}
}
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Ok(AuthOutcome::Failed(MqttError::AuthenticationFailed));
}
}
} else {
let mut connack = self.new_connack(false, ReasonCode::BadAuthenticationMethod);
if self.protocol_version == 5 && self.request_problem_information {
connack.properties.set_reason_string(
"Server does not support enhanced authentication".to_string(),
);
}
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Ok(AuthOutcome::Failed(MqttError::AuthenticationFailed));
}
} else {
let auth_result = self
.auth_provider
.authenticate(&connect, self.client_addr)
.await?;
if !auth_result.authenticated {
debug!(
client_id = %connect.client_id,
reason = ?auth_result.reason_code,
"Authentication failed"
);
let mut connack = self.new_connack(false, auth_result.reason_code);
if self.protocol_version == 5 && self.request_problem_information {
connack
.properties
.set_reason_string("Authentication failed".to_string());
}
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Ok(AuthOutcome::Failed(MqttError::AuthenticationFailed));
}
self.user_id = auth_result.user_id;
}
Ok(AuthOutcome::Authenticated(Box::new(connect)))
}
fn new_connack(&self, session_present: bool, reason_code: ReasonCode) -> ConnAckPacket {
if self.protocol_version == 4 {
ConnAckPacket::new_v311(session_present, reason_code)
} else {
ConnAckPacket::new(session_present, reason_code)
}
}
async fn validate_will_capabilities(&mut self, connect: &ConnectPacket) -> Result<()> {
if let Some(ref will) = connect.will {
if (will.qos as u8) > self.config.maximum_qos {
info!(
client_id = %connect.client_id,
will_qos = will.qos as u8,
maximum_qos = self.config.maximum_qos,
"Rejecting connection: Will QoS exceeds server maximum"
);
let connack = self.new_connack(false, ReasonCode::QoSNotSupported);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Err(MqttError::ProtocolError(
"Will QoS exceeds server maximum".into(),
));
}
if will.retain && !self.config.retain_available {
info!(
client_id = %connect.client_id,
"Rejecting connection: Will Retain not supported"
);
let connack = self.new_connack(false, ReasonCode::RetainNotSupported);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Err(MqttError::ProtocolError("Retain not supported".into()));
}
}
Ok(())
}
fn build_connack_properties(
&self,
connack: &mut ConnAckPacket,
assigned_client_id: Option<&String>,
) {
if let Some(assigned_id) = assigned_client_id {
debug!("Setting assigned client ID in CONNACK: {}", assigned_id);
connack
.properties
.set_assigned_client_identifier(assigned_id.clone());
}
connack
.properties
.set_topic_alias_maximum(self.config.topic_alias_maximum);
connack
.properties
.set_retain_available(self.config.retain_available);
connack.properties.set_maximum_packet_size(
u32::try_from(self.config.max_packet_size).unwrap_or(u32::MAX),
);
connack
.properties
.set_wildcard_subscription_available(self.config.wildcard_subscription_available);
connack
.properties
.set_subscription_identifier_available(self.config.subscription_identifier_available);
connack
.properties
.set_shared_subscription_available(self.config.shared_subscription_available);
if self.config.maximum_qos < 2 {
connack.properties.set_maximum_qos(self.config.maximum_qos);
}
if let Some(recv_max) = self.config.server_receive_maximum {
connack.properties.set_receive_maximum(recv_max);
}
if let Some(keep_alive) = self.config.server_keep_alive {
connack
.properties
.set_server_keep_alive(u16::try_from(keep_alive.as_secs()).unwrap_or(u16::MAX));
}
if self.request_response_information {
if let Some(ref response_info) = self.config.response_information {
connack
.properties
.set_response_information(response_info.clone());
}
}
}
pub(super) async fn handle_session(&mut self, connect: &ConnectPacket) -> Result<bool> {
let mut session_present = false;
if let Some(storage) = self.storage.clone() {
let existing_session = storage.get_session(&connect.client_id).await?;
if connect.clean_start || existing_session.is_none() {
self.create_new_session(connect, &storage).await?;
} else if let Some(session) = existing_session {
session_present = true;
self.restore_existing_session(connect, session, &storage)
.await?;
}
}
Ok(session_present)
}
async fn create_new_session(
&mut self,
connect: &ConnectPacket,
storage: &Arc<DynamicStorage>,
) -> Result<()> {
let session_expiry = connect.properties.get_session_expiry_interval();
let will_message = connect.will.clone();
if let Some(ref will) = will_message {
debug!(
"Will message present with delay: {:?}",
will.properties.will_delay_interval
);
}
let mut session = ClientSession::new_with_will(
connect.client_id.clone(),
true,
session_expiry,
will_message,
);
session.receive_maximum = self.client_receive_maximum;
session.user_id.clone_from(&self.user_id);
debug!(
"Created new session with will_delay_interval: {:?}",
session.will_delay_interval
);
storage.store_session(session.clone()).await?;
storage
.remove_all_inflight_messages(&connect.client_id)
.await?;
self.session = Some(session);
Ok(())
}
async fn restore_existing_session(
&mut self,
connect: &ConnectPacket,
mut session: ClientSession,
storage: &Arc<DynamicStorage>,
) -> Result<()> {
if session.user_id.as_deref() != self.user_id.as_deref() {
warn!(
client_id = %connect.client_id,
session_user = ?session.user_id,
current_user = ?self.user_id,
"Session user mismatch, rejecting connection"
);
let connack = ConnAckPacket::new(false, ReasonCode::NotAuthorized);
self.transport
.write_packet(Packet::ConnAck(connack))
.await?;
return Err(MqttError::AuthenticationFailed);
}
let mut unauthorized_filters = Vec::new();
for (topic_filter, stored) in &session.subscriptions {
let authorized = self
.auth_provider
.authorize_subscribe(&connect.client_id, self.user_id.as_deref(), topic_filter)
.await;
if !authorized {
warn!(
client_id = %connect.client_id,
topic_filter = %topic_filter,
"Dropping subscription on session restore: no longer authorized"
);
unauthorized_filters.push(topic_filter.clone());
continue;
}
self.router
.subscribe(
connect.client_id.clone(),
topic_filter.clone(),
stored.qos,
stored.subscription_id,
stored.no_local,
stored.retain_as_published,
stored.retain_handling,
ProtocolVersion::try_from(self.protocol_version).unwrap_or_default(),
stored.change_only,
None,
)
.await?;
}
for filter in &unauthorized_filters {
session.subscriptions.remove(filter);
}
self.router
.load_change_only_state(&connect.client_id, session.change_only_state.clone())
.await;
session.will_message.clone_from(&connect.will);
session.will_delay_interval = connect
.will
.as_ref()
.and_then(|w| w.properties.will_delay_interval);
session.receive_maximum = self.client_receive_maximum;
session.user_id.clone_from(&self.user_id);
session.touch();
storage.store_session(session.clone()).await?;
self.session = Some(session);
Ok(())
}
pub(super) async fn deliver_queued_messages(&mut self, client_id: &str) -> Result<()> {
let queued_messages = if let Some(ref storage) = self.storage {
let messages = storage.get_queued_messages(client_id).await?;
storage.remove_queued_messages(client_id).await?;
messages
} else {
Vec::new()
};
if !queued_messages.is_empty() {
info!(
"Delivering {} queued messages to {}",
queued_messages.len(),
client_id
);
for msg in queued_messages {
let mut publish = msg.to_publish_packet();
if publish.qos != QoS::AtMostOnce && publish.packet_id.is_none() {
publish.packet_id = Some(self.next_packet_id());
}
let routable = RoutableMessage {
publish,
target_flow: None,
};
if let Err(e) = self.publish_tx.try_send(routable) {
warn!("Failed to deliver queued message to {}: {:?}", client_id, e);
}
}
}
Ok(())
}
}