use std::collections::HashMap;
use std::time::{Duration, Instant};
use mechutil::ipc::CommandMessage;
use serde_json::Value;
use tokio::sync::mpsc;
struct PendingRequest {
topic: String,
sent_at: Instant,
}
pub struct CommandClient {
write_tx: mpsc::UnboundedSender<String>,
response_rx: mpsc::UnboundedReceiver<CommandMessage>,
pending: HashMap<u32, PendingRequest>,
responses: HashMap<u32, CommandMessage>,
}
impl CommandClient {
pub fn new(
write_tx: mpsc::UnboundedSender<String>,
response_rx: mpsc::UnboundedReceiver<CommandMessage>,
) -> Self {
Self {
write_tx,
response_rx,
pending: HashMap::new(),
responses: HashMap::new(),
}
}
pub fn send(&mut self, topic: &str, data: Value) -> u32 {
let msg = CommandMessage::request(topic, data);
let transaction_id = msg.transaction_id;
if let Ok(json) = serde_json::to_string(&msg) {
let _ = self.write_tx.send(json);
}
self.pending.insert(transaction_id, PendingRequest {
topic: topic.to_string(),
sent_at: Instant::now(),
});
transaction_id
}
pub fn poll(&mut self) {
while let Ok(msg) = self.response_rx.try_recv() {
let tid = msg.transaction_id;
if self.pending.remove(&tid).is_some() {
self.responses.insert(tid, msg);
}
}
}
pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
self.responses.remove(&transaction_id)
}
pub fn is_pending(&self, transaction_id: u32) -> bool {
self.pending.contains_key(&transaction_id)
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn response_count(&self) -> usize {
self.responses.len()
}
pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
let now = Instant::now();
let stale: Vec<u32> = self.pending.iter()
.filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
.map(|(&tid, _)| tid)
.collect();
for tid in &stale {
if let Some(req) = self.pending.remove(tid) {
log::warn!("Command request {} ('{}') timed out after {:?}",
tid, req.topic, timeout);
}
}
stale
}
}
#[cfg(test)]
mod tests {
use super::*;
use mechutil::ipc::MessageType;
use serde_json::json;
#[test]
fn test_send_pushes_to_channel() {
let (write_tx, mut write_rx) = mpsc::unbounded_channel();
let (_response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("test.command", json!({"key": "value"}));
let msg_json = write_rx.try_recv().expect("should have a message");
let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
assert_eq!(msg.transaction_id, tid);
assert_eq!(msg.topic, "test.command");
assert_eq!(msg.message_type, MessageType::Request);
assert_eq!(msg.data, json!({"key": "value"}));
assert!(client.is_pending(tid));
assert_eq!(client.pending_count(), 1);
}
#[test]
fn test_poll_and_take_response() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("test.command", json!(null));
assert!(client.is_pending(tid));
response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
assert!(client.take_response(tid).is_none());
client.poll();
assert!(!client.is_pending(tid));
assert_eq!(client.response_count(), 1);
let recv = client.take_response(tid).unwrap();
assert_eq!(recv.transaction_id, tid);
assert_eq!(client.response_count(), 0);
}
#[test]
fn test_multi_consumer_isolation() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid_a = client.send("labelit.inspect", json!(null));
let tid_b = client.send("other.status", json!(null));
response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
client.poll();
assert_eq!(client.response_count(), 2);
let resp_a = client.take_response(tid_a).unwrap();
assert_eq!(resp_a.data, json!("a_result"));
let resp_b = client.take_response(tid_b).unwrap();
assert_eq!(resp_b.data, json!("b_result"));
assert!(client.take_response(tid_a).is_none());
assert!(client.take_response(tid_b).is_none());
assert_eq!(client.response_count(), 0);
}
#[test]
fn test_drain_stale() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (_response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("test.command", json!(null));
assert_eq!(client.pending_count(), 1);
let stale = client.drain_stale(Duration::from_secs(0));
assert_eq!(stale, vec![tid]);
assert_eq!(client.pending_count(), 0);
}
#[test]
fn test_drain_stale_keeps_fresh() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (_response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("test.command", json!(null));
let stale = client.drain_stale(Duration::from_secs(3600));
assert!(stale.is_empty());
assert!(client.is_pending(tid));
}
#[test]
fn test_drain_stale_ignores_received() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("test.command", json!(null));
response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
client.poll();
let stale = client.drain_stale(Duration::from_secs(0));
assert!(stale.is_empty());
assert!(client.take_response(tid).is_some());
}
#[test]
fn test_multiple_pending() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid1 = client.send("cmd.first", json!(1));
let tid2 = client.send("cmd.second", json!(2));
let tid3 = client.send("cmd.third", json!(3));
assert_eq!(client.pending_count(), 3);
response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
client.poll();
assert_eq!(client.pending_count(), 2);
assert!(client.is_pending(tid1));
assert!(!client.is_pending(tid2));
assert!(client.is_pending(tid3));
let recv = client.take_response(tid2).unwrap();
assert_eq!(recv.transaction_id, tid2);
}
#[test]
fn test_unsolicited_responses_discarded() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
client.poll();
assert_eq!(client.response_count(), 0);
assert!(client.take_response(99999).is_none());
assert!(client.take_response(99998).is_none());
}
#[test]
fn test_mix_of_solicited_and_unsolicited() {
let (write_tx, _write_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = mpsc::unbounded_channel();
let mut client = CommandClient::new(write_tx, response_rx);
let tid = client.send("real.command", json!(null));
response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
client.poll();
assert_eq!(client.response_count(), 1);
let resp = client.take_response(tid).unwrap();
assert_eq!(resp.data, json!("real_result"));
assert!(client.take_response(77777).is_none());
}
}