use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use crate::client::MqttClientTrait;
use crate::error::{MqttError, Result};
use crate::types::{
ConnectOptions, ConnectResult, Message, MessageProperties, PublishOptions, PublishResult,
SubscribeOptions,
};
use crate::QoS;
type SubscriptionCallback = Box<dyn Fn(Message) + Send + Sync>;
#[derive(Clone)]
pub struct MockMqttClient {
state: Arc<MockState>,
}
struct MockState {
connected: AtomicBool,
client_id: RwLock<String>,
packet_id_counter: AtomicU16,
queue_on_disconnect: AtomicBool,
calls: Mutex<Vec<MockCall>>,
responses: RwLock<MockResponses>,
subscriptions: RwLock<HashMap<String, SubscriptionCallback>>,
}
#[derive(Debug, Clone)]
pub enum MockCall {
Connect {
address: String,
},
ConnectWithOptions {
address: String,
options: Box<ConnectOptions>,
},
Disconnect,
Publish {
topic: String,
payload: Vec<u8>,
},
PublishWithOptions {
topic: String,
payload: Vec<u8>,
options: PublishOptions,
},
Subscribe {
topic: String,
},
SubscribeWithOptions {
topic: String,
options: SubscribeOptions,
},
Unsubscribe {
topic: String,
},
SetQueueOnDisconnect {
enabled: bool,
},
}
#[derive(Debug, Default)]
pub struct MockResponses {
pub connect_response: Option<Result<()>>,
pub connect_with_options_response: Option<Result<ConnectResult>>,
pub disconnect_response: Option<Result<()>>,
pub publish_response: Option<Result<PublishResult>>,
pub subscribe_response: Option<Result<(u16, QoS)>>,
pub unsubscribe_response: Option<Result<()>>,
}
impl MockMqttClient {
pub fn new(client_id: impl Into<String>) -> Self {
Self {
state: Arc::new(MockState {
connected: AtomicBool::new(false),
client_id: RwLock::new(client_id.into()),
packet_id_counter: AtomicU16::new(0),
queue_on_disconnect: AtomicBool::new(false),
calls: Mutex::new(Vec::new()),
responses: RwLock::new(MockResponses::default()),
subscriptions: RwLock::new(HashMap::new()),
}),
}
}
pub fn set_connected(&self, connected: bool) {
self.state.connected.store(connected, Ordering::SeqCst);
}
pub async fn get_calls(&self) -> Vec<MockCall> {
self.state.calls.lock().await.clone()
}
pub async fn clear_calls(&self) {
self.state.calls.lock().await.clear();
}
pub async fn set_connect_response(&self, response: Result<()>) {
self.state.responses.write().await.connect_response = Some(response);
}
pub async fn set_connect_with_options_response(&self, response: Result<ConnectResult>) {
self.state
.responses
.write()
.await
.connect_with_options_response = Some(response);
}
pub async fn set_disconnect_response(&self, response: Result<()>) {
self.state.responses.write().await.disconnect_response = Some(response);
}
pub async fn set_publish_response(&self, response: Result<PublishResult>) {
self.state.responses.write().await.publish_response = Some(response);
}
pub async fn set_subscribe_response(&self, response: Result<(u16, QoS)>) {
self.state.responses.write().await.subscribe_response = Some(response);
}
pub async fn set_unsubscribe_response(&self, response: Result<()>) {
self.state.responses.write().await.unsubscribe_response = Some(response);
}
pub async fn simulate_message(&self, topic: &str, payload: Vec<u8>, qos: QoS) -> Result<()> {
let subscriptions = self.state.subscriptions.read().await;
for (topic_filter, callback) in subscriptions.iter() {
if Self::topic_matches(topic_filter, topic) {
let message = Message {
topic: topic.to_string(),
payload,
qos,
retain: false,
properties: MessageProperties::default(),
stream_id: None,
};
callback(message);
return Ok(());
}
}
Err(MqttError::ProtocolError(format!(
"No subscription found for topic: {topic}"
)))
}
fn topic_matches(filter: &str, topic: &str) -> bool {
if filter == topic {
return true;
}
if filter.contains('+') || filter.contains('#') {
if filter == "#" {
return true;
}
if let Some(prefix) = filter.strip_suffix("/#") {
return topic.starts_with(prefix);
}
if filter.contains('+') {
let filter_parts: Vec<&str> = filter.split('/').collect();
let topic_parts: Vec<&str> = topic.split('/').collect();
if filter_parts.len() != topic_parts.len() {
return false;
}
for (f, t) in filter_parts.iter().zip(topic_parts.iter()) {
if *f != "+" && f != t {
return false;
}
}
return true;
}
}
false
}
async fn record_call(&self, call: MockCall) {
self.state.calls.lock().await.push(call);
}
fn next_packet_id(&self) -> u16 {
self.state.packet_id_counter.fetch_add(1, Ordering::SeqCst) + 1
}
}
#[allow(clippy::manual_async_fn)]
impl MqttClientTrait for MockMqttClient {
fn is_connected(&self) -> impl Future<Output = bool> + Send + '_ {
async move { self.state.connected.load(Ordering::SeqCst) }
}
fn client_id(&self) -> impl Future<Output = String> + Send + '_ {
async move { self.state.client_id.read().await.clone() }
}
fn connect<'a>(&'a self, address: &'a str) -> impl Future<Output = Result<()>> + Send + 'a {
async move {
self.record_call(MockCall::Connect {
address: address.to_string(),
})
.await;
let responses = self.state.responses.read().await;
if let Some(response) = &responses.connect_response {
let result = response.clone();
drop(responses);
if result.is_ok() {
self.set_connected(true);
}
result
} else {
self.set_connected(true);
Ok(())
}
}
}
fn connect_with_options<'a>(
&'a self,
address: &'a str,
options: ConnectOptions,
) -> impl Future<Output = Result<ConnectResult>> + Send + 'a {
async move {
self.record_call(MockCall::ConnectWithOptions {
address: address.to_string(),
options: Box::new(options.clone()),
})
.await;
let responses = self.state.responses.read().await;
if let Some(response) = &responses.connect_with_options_response {
let result = response.clone();
drop(responses);
if result.is_ok() {
self.set_connected(true);
}
result
} else {
self.set_connected(true);
Ok(ConnectResult {
session_present: false,
})
}
}
}
fn disconnect(&self) -> impl Future<Output = Result<()>> + Send + '_ {
async move {
self.record_call(MockCall::Disconnect).await;
let responses = self.state.responses.read().await;
if let Some(response) = &responses.disconnect_response {
let result = response.clone();
drop(responses);
if result.is_ok() {
self.set_connected(false);
}
result
} else {
self.set_connected(false);
Ok(())
}
}
}
fn publish<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move {
let topic_str = topic.into();
let payload_vec = payload.into();
self.record_call(MockCall::Publish {
topic: topic_str,
payload: payload_vec,
})
.await;
let responses = self.state.responses.read().await;
if let Some(response) = &responses.publish_response {
response.clone()
} else {
Ok(PublishResult::QoS0)
}
}
}
fn publish_qos<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
qos: QoS,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move {
let options = PublishOptions {
qos,
..Default::default()
};
self.publish_with_options(topic, payload, options).await
}
}
fn publish_with_options<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
options: PublishOptions,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move {
let topic_str = topic.into();
let payload_vec = payload.into();
self.record_call(MockCall::PublishWithOptions {
topic: topic_str,
payload: payload_vec,
options: options.clone(),
})
.await;
let responses = self.state.responses.read().await;
if let Some(response) = &responses.publish_response {
response.clone()
} else {
match options.qos {
QoS::AtMostOnce => Ok(PublishResult::QoS0),
QoS::AtLeastOnce | QoS::ExactlyOnce => {
let packet_id = self.next_packet_id();
Ok(PublishResult::QoS1Or2 { packet_id })
}
}
}
}
}
fn subscribe<'a, F>(
&'a self,
topic_filter: impl Into<String> + Send + 'a,
callback: F,
) -> impl Future<Output = Result<(u16, QoS)>> + Send + 'a
where
F: Fn(Message) + Send + Sync + 'static,
{
async move {
let topic_str = topic_filter.into();
self.record_call(MockCall::Subscribe {
topic: topic_str.clone(),
})
.await;
self.state
.subscriptions
.write()
.await
.insert(topic_str, Box::new(callback));
let responses = self.state.responses.read().await;
if let Some(response) = &responses.subscribe_response {
response.clone()
} else {
let packet_id = self.next_packet_id();
Ok((packet_id, QoS::AtMostOnce))
}
}
}
fn subscribe_with_options<'a, F>(
&'a self,
topic_filter: impl Into<String> + Send + 'a,
options: SubscribeOptions,
callback: F,
) -> impl Future<Output = Result<(u16, QoS)>> + Send + 'a
where
F: Fn(Message) + Send + Sync + 'static,
{
async move {
let topic_str = topic_filter.into();
self.record_call(MockCall::SubscribeWithOptions {
topic: topic_str.clone(),
options: options.clone(),
})
.await;
self.state
.subscriptions
.write()
.await
.insert(topic_str, Box::new(callback));
let responses = self.state.responses.read().await;
if let Some(response) = &responses.subscribe_response {
response.clone()
} else {
let packet_id = self.next_packet_id();
Ok((packet_id, options.qos))
}
}
}
fn unsubscribe<'a>(
&'a self,
topic_filter: impl Into<String> + Send + 'a,
) -> impl Future<Output = Result<()>> + Send + 'a {
async move {
let topic_str = topic_filter.into();
self.record_call(MockCall::Unsubscribe {
topic: topic_str.clone(),
})
.await;
self.state.subscriptions.write().await.remove(&topic_str);
let responses = self.state.responses.read().await;
if let Some(response) = &responses.unsubscribe_response {
response.clone()
} else {
Ok(())
}
}
}
fn subscribe_many<'a, F>(
&'a self,
topics: Vec<(&'a str, QoS)>,
callback: F,
) -> impl Future<Output = Result<Vec<(u16, QoS)>>> + Send + 'a
where
F: Fn(Message) + Send + Sync + 'static + Clone,
{
async move {
let mut results = Vec::new();
for (topic, qos) in topics {
let opts = SubscribeOptions {
qos,
..Default::default()
};
let result = self
.subscribe_with_options(topic, opts, callback.clone())
.await?;
results.push(result);
}
Ok(results)
}
}
fn unsubscribe_many<'a>(
&'a self,
topics: Vec<&'a str>,
) -> impl Future<Output = Result<Vec<(String, Result<()>)>>> + Send + 'a {
async move {
let mut results = Vec::with_capacity(topics.len());
for topic in topics {
let topic_string = topic.to_string();
let result = self.unsubscribe(topic).await;
results.push((topic_string, result));
}
Ok(results)
}
}
fn publish_retain<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move {
let opts = PublishOptions {
retain: true,
..Default::default()
};
self.publish_with_options(topic, payload, opts).await
}
}
fn publish_qos0<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move { self.publish_qos(topic, payload, QoS::AtMostOnce).await }
}
fn publish_qos1<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move { self.publish_qos(topic, payload, QoS::AtLeastOnce).await }
}
fn publish_qos2<'a>(
&'a self,
topic: impl Into<String> + Send + 'a,
payload: impl Into<Vec<u8>> + Send + 'a,
) -> impl Future<Output = Result<PublishResult>> + Send + 'a {
async move { self.publish_qos(topic, payload, QoS::ExactlyOnce).await }
}
fn is_queue_on_disconnect(&self) -> impl Future<Output = bool> + Send + '_ {
async move { self.state.queue_on_disconnect.load(Ordering::SeqCst) }
}
fn set_queue_on_disconnect(&self, enabled: bool) -> impl Future<Output = ()> + Send + '_ {
async move {
self.record_call(MockCall::SetQueueOnDisconnect { enabled })
.await;
self.state
.queue_on_disconnect
.store(enabled, Ordering::SeqCst);
}
}
}