use std::collections::HashSet;
use std::time::Duration;
use serde_json::Value;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::error::ProtocolError;
#[derive(Debug, Clone, Default)]
pub enum ResponseSpec {
#[default]
Single,
Multiple {
expected_topics: Vec<String>,
timeout: Duration,
},
}
impl ResponseSpec {
#[must_use]
pub const fn single() -> Self {
Self::Single
}
#[must_use]
pub fn multiple(expected_topics: Vec<String>, timeout: Duration) -> Self {
Self::Multiple {
expected_topics,
timeout,
}
}
#[must_use]
pub fn status_all(timeout: Duration) -> Self {
Self::Multiple {
expected_topics: vec![
"STATUS".to_string(),
"STATUS1".to_string(),
"STATUS2".to_string(),
"STATUS3".to_string(),
"STATUS4".to_string(),
"STATUS5".to_string(),
"STATUS6".to_string(),
"STATUS7".to_string(),
"STATUS11".to_string(),
],
timeout,
}
}
#[must_use]
pub const fn is_multiple(&self) -> bool {
matches!(self, Self::Multiple { .. })
}
}
#[derive(Debug, Clone)]
pub(crate) struct MqttMessage {
pub(crate) topic_suffix: String,
pub(crate) payload: String,
}
impl MqttMessage {
#[must_use]
pub(crate) fn new(topic_suffix: String, payload: String) -> Self {
Self {
topic_suffix,
payload,
}
}
}
struct ResponseCollector {
expected: HashSet<String>,
collected: Vec<(String, Value)>,
deadline: Instant,
}
impl ResponseCollector {
#[must_use]
fn new(spec: &ResponseSpec) -> Self {
match spec {
ResponseSpec::Single => {
panic!("ResponseCollector should not be used for single responses")
}
ResponseSpec::Multiple {
expected_topics,
timeout,
} => Self {
expected: expected_topics.iter().cloned().collect(),
collected: Vec::with_capacity(expected_topics.len()),
deadline: Instant::now() + *timeout,
},
}
}
#[must_use]
fn remaining_time(&self) -> Duration {
self.deadline.saturating_duration_since(Instant::now())
}
#[must_use]
fn is_timed_out(&self) -> bool {
Instant::now() >= self.deadline
}
#[must_use]
fn is_complete(&self) -> bool {
self.expected.is_empty()
}
fn process_message(&mut self, msg: &MqttMessage) -> bool {
if self.expected.remove(&msg.topic_suffix)
&& let Ok(value) = serde_json::from_str::<Value>(&msg.payload)
{
self.collected.push((msg.topic_suffix.clone(), value));
return true;
}
false
}
#[must_use]
fn merge_responses(self) -> String {
let mut merged = serde_json::Map::new();
for (_, value) in self.collected {
if let Value::Object(obj) = value {
for (key, val) in obj {
merged.insert(key, val);
}
}
}
Value::Object(merged).to_string()
}
#[must_use]
fn pending_count(&self) -> usize {
self.expected.len()
}
#[must_use]
fn collected_count(&self) -> usize {
self.collected.len()
}
}
pub(crate) async fn collect_responses(
rx: &mut mpsc::Receiver<MqttMessage>,
spec: &ResponseSpec,
single_timeout: Duration,
) -> Result<String, ProtocolError> {
match spec {
ResponseSpec::Single => {
#[allow(clippy::cast_possible_truncation)]
let timeout_ms = single_timeout.as_millis() as u64;
let msg = tokio::time::timeout(single_timeout, rx.recv())
.await
.map_err(|_| ProtocolError::Timeout(timeout_ms))?
.ok_or_else(|| {
ProtocolError::ConnectionFailed("Response channel closed".to_string())
})?;
Ok(msg.payload)
}
ResponseSpec::Multiple { .. } => {
let mut collector = ResponseCollector::new(spec);
while !collector.is_complete() && !collector.is_timed_out() {
let remaining = collector.remaining_time();
match tokio::time::timeout(remaining, rx.recv()).await {
Ok(Some(msg)) => {
tracing::trace!(
topic = %msg.topic_suffix,
collected = collector.collected_count(),
pending = collector.pending_count(),
"Received response message"
);
collector.process_message(&msg);
}
Ok(None) => {
break;
}
Err(_) => {
tracing::debug!(
collected = collector.collected_count(),
pending = collector.pending_count(),
"Response collection timed out"
);
break;
}
}
}
if collector.collected_count() == 0 {
#[allow(clippy::cast_possible_truncation)]
let timeout_ms = match spec {
ResponseSpec::Multiple { timeout, .. } => timeout.as_millis() as u64,
ResponseSpec::Single => single_timeout.as_millis() as u64,
};
return Err(ProtocolError::Timeout(timeout_ms));
}
Ok(collector.merge_responses())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn response_spec_default_is_single() {
let spec = ResponseSpec::default();
assert!(!spec.is_multiple());
}
#[test]
fn response_spec_multiple() {
let spec = ResponseSpec::multiple(
vec!["STATUS".to_string(), "STATUS1".to_string()],
Duration::from_secs(5),
);
assert!(spec.is_multiple());
}
#[test]
fn response_spec_status_all() {
let spec = ResponseSpec::status_all(Duration::from_secs(5));
if let ResponseSpec::Multiple {
expected_topics, ..
} = spec
{
assert!(expected_topics.contains(&"STATUS".to_string()));
assert!(expected_topics.contains(&"STATUS11".to_string()));
assert!(!expected_topics.contains(&"STATUS10".to_string())); } else {
panic!("Expected Multiple variant");
}
}
#[test]
fn collector_processes_expected_message() {
let spec = ResponseSpec::multiple(
vec!["STATUS".to_string(), "STATUS1".to_string()],
Duration::from_secs(5),
);
let mut collector = ResponseCollector::new(&spec);
let msg = MqttMessage::new(
"STATUS".to_string(),
r#"{"Status":{"Topic":"test"}}"#.to_string(),
);
assert!(collector.process_message(&msg));
assert_eq!(collector.collected_count(), 1);
assert_eq!(collector.pending_count(), 1);
}
#[test]
fn collector_ignores_unexpected_message() {
let spec = ResponseSpec::multiple(vec!["STATUS".to_string()], Duration::from_secs(5));
let mut collector = ResponseCollector::new(&spec);
let msg = MqttMessage::new("RESULT".to_string(), r#"{"POWER":"ON"}"#.to_string());
assert!(!collector.process_message(&msg));
assert_eq!(collector.collected_count(), 0);
}
#[test]
fn collector_merges_responses() {
let spec = ResponseSpec::multiple(
vec!["STATUS".to_string(), "STATUS11".to_string()],
Duration::from_secs(5),
);
let mut collector = ResponseCollector::new(&spec);
let msg1 = MqttMessage::new(
"STATUS".to_string(),
r#"{"Status":{"Topic":"test"}}"#.to_string(),
);
let msg2 = MqttMessage::new(
"STATUS11".to_string(),
r#"{"StatusSTS":{"UptimeSec":12345}}"#.to_string(),
);
collector.process_message(&msg1);
collector.process_message(&msg2);
let merged = collector.merge_responses();
let value: Value = serde_json::from_str(&merged).unwrap();
assert!(value.get("Status").is_some());
assert!(value.get("StatusSTS").is_some());
}
#[test]
fn collector_is_complete_when_all_collected() {
let spec = ResponseSpec::multiple(vec!["STATUS".to_string()], Duration::from_secs(5));
let mut collector = ResponseCollector::new(&spec);
assert!(!collector.is_complete());
let msg = MqttMessage::new("STATUS".to_string(), r#"{"Status":{}}"#.to_string());
collector.process_message(&msg);
assert!(collector.is_complete());
}
#[tokio::test]
async fn collect_single_response() {
let (tx, mut rx) = mpsc::channel(10);
tx.send(MqttMessage::new(
"RESULT".to_string(),
r#"{"POWER":"ON"}"#.to_string(),
))
.await
.unwrap();
let result = collect_responses(&mut rx, &ResponseSpec::Single, Duration::from_secs(1))
.await
.unwrap();
assert_eq!(result, r#"{"POWER":"ON"}"#);
}
#[tokio::test]
async fn collect_multiple_responses() {
let (tx, mut rx) = mpsc::channel(10);
let spec = ResponseSpec::multiple(
vec!["STATUS".to_string(), "STATUS11".to_string()],
Duration::from_secs(5),
);
tokio::spawn(async move {
tx.send(MqttMessage::new(
"STATUS".to_string(),
r#"{"Status":{"Topic":"test"}}"#.to_string(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
tx.send(MqttMessage::new(
"STATUS11".to_string(),
r#"{"StatusSTS":{"UptimeSec":100}}"#.to_string(),
))
.await
.unwrap();
});
let result = collect_responses(&mut rx, &spec, Duration::from_secs(1))
.await
.unwrap();
let value: Value = serde_json::from_str(&result).unwrap();
assert!(value.get("Status").is_some());
assert!(value.get("StatusSTS").is_some());
}
#[tokio::test]
async fn collect_partial_on_timeout() {
let (tx, mut rx) = mpsc::channel(10);
let spec = ResponseSpec::multiple(
vec!["STATUS".to_string(), "STATUS11".to_string()],
Duration::from_millis(100),
);
tx.send(MqttMessage::new(
"STATUS".to_string(),
r#"{"Status":{"Topic":"test"}}"#.to_string(),
))
.await
.unwrap();
let result = collect_responses(&mut rx, &spec, Duration::from_secs(1))
.await
.unwrap();
let value: Value = serde_json::from_str(&result).unwrap();
assert!(value.get("Status").is_some());
assert!(value.get("StatusSTS").is_none());
}
}