use crate::callback::CallbackManager;
use crate::error::{MqttError, Result};
use crate::packet::publish::PublishPacket;
use crate::packet::Packet;
use crate::protocol::v5::properties::Properties;
use crate::session::SessionState;
use crate::transport::PacketIo;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::{interval, Duration};
pub async fn packet_reader_task(
transport: Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
session: Arc<RwLock<SessionState>>,
callback_manager: Arc<CallbackManager>,
protocol_version: u8,
) {
loop {
match transport.lock().await.read_packet(protocol_version).await {
Ok(packet) => {
if let Err(e) =
handle_incoming_packet(packet, &transport, &session, &callback_manager).await
{
tracing::error!(error = %e, "Error handling packet");
break;
}
}
Err(e) => {
tracing::error!(error = %e, "Error reading packet");
break;
}
}
}
}
pub async fn keepalive_task(
transport: Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
keepalive_interval: Duration,
) {
let mut interval = interval(keepalive_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval.tick().await;
loop {
interval.tick().await;
if let Err(e) = transport.lock().await.write_packet(Packet::PingReq).await {
tracing::error!(error = %e, "Error sending PINGREQ");
break;
}
}
}
pub async fn handle_incoming_packet(
packet: Packet,
transport: &Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
session: &Arc<RwLock<SessionState>>,
callback_manager: &Arc<CallbackManager>,
) -> Result<()> {
match packet {
Packet::Publish(publish) => {
handle_publish(publish, transport, session, callback_manager).await
}
Packet::PubAck(puback) => handle_puback(puback.packet_id, session).await,
Packet::PubRec(pubrec) => handle_pubrec(pubrec.packet_id, transport, session).await,
Packet::PubRel(pubrel) => handle_pubrel(pubrel.packet_id, transport, session).await,
Packet::PubComp(pubcomp) => handle_pubcomp(pubcomp.packet_id, session).await,
Packet::PingResp => {
Ok(())
}
Packet::Disconnect(disconnect) => {
tracing::info!(reason_code = ?disconnect.reason_code, "Server sent DISCONNECT");
Err(MqttError::ConnectionError(
"Server disconnected".to_string(),
))
}
_ => {
Ok(())
}
}
}
async fn handle_publish(
publish: PublishPacket,
transport: &Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
session: &Arc<RwLock<SessionState>>,
callback_manager: &Arc<CallbackManager>,
) -> Result<()> {
match publish.qos {
crate::QoS::AtMostOnce => {
}
crate::QoS::AtLeastOnce => {
if let Some(packet_id) = publish.packet_id {
let puback = crate::packet::puback::PubAckPacket {
packet_id,
reason_code: crate::protocol::v5::reason_codes::ReasonCode::Success,
properties: Properties::default(),
};
transport
.lock()
.await
.write_packet(Packet::PubAck(puback))
.await?;
}
}
crate::QoS::ExactlyOnce => {
if let Some(packet_id) = publish.packet_id {
let pubrec = crate::packet::pubrec::PubRecPacket {
packet_id,
reason_code: crate::protocol::v5::reason_codes::ReasonCode::Success,
properties: Properties::default(),
};
transport
.lock()
.await
.write_packet(Packet::PubRec(pubrec))
.await?;
session.write().await.store_pubrec(packet_id).await;
}
}
}
route_message(&publish, callback_manager);
Ok(())
}
fn route_message(publish: &PublishPacket, callback_manager: &Arc<CallbackManager>) {
let _ = callback_manager.dispatch(publish);
}
async fn handle_puback(packet_id: u16, session: &Arc<RwLock<SessionState>>) -> Result<()> {
session.write().await.complete_publish(packet_id).await;
Ok(())
}
async fn handle_pubrec(
packet_id: u16,
transport: &Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
session: &Arc<RwLock<SessionState>>,
) -> Result<()> {
let pubrel = crate::packet::pubrel::PubRelPacket {
packet_id,
reason_code: crate::protocol::v5::reason_codes::ReasonCode::Success,
properties: Properties::default(),
};
transport
.lock()
.await
.write_packet(Packet::PubRel(pubrel))
.await?;
session.write().await.store_pubrel(packet_id).await;
Ok(())
}
async fn handle_pubrel(
packet_id: u16,
transport: &Arc<tokio::sync::Mutex<crate::transport::TransportType>>,
session: &Arc<RwLock<SessionState>>,
) -> Result<()> {
let pubcomp = crate::packet::pubcomp::PubCompPacket {
packet_id,
reason_code: crate::protocol::v5::reason_codes::ReasonCode::Success,
properties: Properties::default(),
};
transport
.lock()
.await
.write_packet(Packet::PubComp(pubcomp))
.await?;
session.write().await.complete_pubrec(packet_id).await;
Ok(())
}
async fn handle_pubcomp(packet_id: u16, session: &Arc<RwLock<SessionState>>) -> Result<()> {
session.write().await.complete_pubrel(packet_id).await;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::disconnect::DisconnectPacket;
use crate::protocol::v5::properties::Properties;
use crate::protocol::v5::reason_codes::ReasonCode;
use crate::session::SessionConfig;
use crate::test_utils::*;
use crate::transport::mock::{MockBehavior, MockTransport};
use crate::transport::TransportType;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::time::timeout;
fn create_test_session() -> Arc<RwLock<SessionState>> {
Arc::new(RwLock::new(SessionState::new(
"test-client".to_string(),
SessionConfig::default(),
true,
)))
}
#[tokio::test]
async fn test_packet_reader_task_handles_packets() {
let transport = MockTransport::new();
transport
.inject_packet(encode_packet(&Packet::PingResp).unwrap())
.await;
transport
.set_behavior(MockBehavior {
fail_read: false,
read_delay_ms: 10,
..Default::default()
})
.await;
let transport = Arc::new(tokio::sync::Mutex::new(TransportType::Tcp(
crate::transport::tcp::TcpTransport::from_addr(std::net::SocketAddr::from((
[127, 0, 0, 1],
1883,
))),
)));
let session = create_test_session();
let callback_manager = Arc::new(CallbackManager::new());
assert!(Arc::strong_count(&transport) >= 1);
assert!(Arc::strong_count(&session) >= 1);
assert!(Arc::strong_count(&callback_manager) >= 1);
let result = timeout(Duration::from_millis(100), async {
Ok::<(), MqttError>(())
})
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_handle_publish_qos0() {
let transport = Arc::new(tokio::sync::Mutex::new(TransportType::Tcp(
crate::transport::tcp::TcpTransport::from_addr(std::net::SocketAddr::from((
[127, 0, 0, 1],
1883,
))),
)));
let session = create_test_session();
let callback_manager = Arc::new(CallbackManager::new());
let publish = PublishPacket {
topic_name: "test/topic".to_string(),
payload: b"test payload".to_vec().into(),
qos: crate::QoS::AtMostOnce,
retain: false,
dup: false,
packet_id: None,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
let result = handle_publish(publish, &transport, &session, &callback_manager).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_handle_publish_qos1() {
let mock_transport = MockTransport::new();
assert_eq!(mock_transport.get_written_data().await.len(), 0);
let publish = PublishPacket {
topic_name: "test/topic".to_string(),
payload: b"test payload".to_vec().into(),
qos: crate::QoS::AtLeastOnce,
retain: false,
dup: false,
packet_id: Some(123),
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
assert_eq!(publish.packet_id, Some(123));
}
#[tokio::test]
async fn test_handle_publish_qos2() {
let publish = PublishPacket {
topic_name: "test/topic".to_string(),
payload: b"test payload".to_vec().into(),
qos: crate::QoS::ExactlyOnce,
retain: false,
dup: false,
packet_id: Some(456),
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
assert_eq!(publish.packet_id, Some(456));
}
#[tokio::test]
async fn test_handle_puback() {
let session = create_test_session();
let publish = PublishPacket {
topic_name: "test".to_string(),
payload: vec![].into(),
qos: crate::QoS::AtLeastOnce,
retain: false,
dup: false,
packet_id: Some(100),
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
session
.write()
.await
.store_unacked_publish(publish)
.await
.unwrap();
let result = handle_puback(100, &session).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_handle_disconnect() {
let transport = Arc::new(tokio::sync::Mutex::new(TransportType::Tcp(
crate::transport::tcp::TcpTransport::from_addr(std::net::SocketAddr::from((
[127, 0, 0, 1],
1883,
))),
)));
let session = create_test_session();
let callback_manager = Arc::new(CallbackManager::new());
let disconnect = Packet::Disconnect(DisconnectPacket {
reason_code: ReasonCode::UnspecifiedError,
properties: Properties::default(),
});
let result =
handle_incoming_packet(disconnect, &transport, &session, &callback_manager).await;
assert!(result.is_err());
assert!(matches!(result, Err(MqttError::ConnectionError(_))));
}
#[tokio::test]
async fn test_keepalive_task() {
let transport = Arc::new(tokio::sync::Mutex::new(TransportType::Tcp(
crate::transport::tcp::TcpTransport::from_addr(std::net::SocketAddr::from((
[127, 0, 0, 1],
1883,
))),
)));
assert!(Arc::strong_count(&transport) >= 1);
let keepalive_interval = Duration::from_millis(100);
let mut interval = tokio::time::interval(keepalive_interval);
interval.tick().await; let start = tokio::time::Instant::now();
interval.tick().await; let elapsed = start.elapsed();
assert!(elapsed >= keepalive_interval.saturating_sub(Duration::from_millis(10)));
}
#[tokio::test]
async fn test_route_message_with_callbacks() {
let callback_manager = Arc::new(CallbackManager::new());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
callback_manager
.register(
"test/+",
Arc::new(move |_msg: PublishPacket| {
counter_clone.fetch_add(1, Ordering::SeqCst);
}),
)
.unwrap();
let publish = PublishPacket {
topic_name: "test/data".to_string(),
payload: b"hello".to_vec().into(),
qos: crate::QoS::AtMostOnce,
retain: false,
dup: false,
packet_id: None,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
};
route_message(&publish, &callback_manager);
tokio::task::yield_now().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_qos2_flow() {
let session = create_test_session();
let packet_id = 789;
session.write().await.store_pubrel(packet_id).await;
let result = handle_pubcomp(packet_id, &session).await;
assert!(result.is_ok());
}
}