1use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37 topic: String,
38 sent_at: Instant,
39}
40
41pub struct CommandClient {
55 write_tx: mpsc::UnboundedSender<String>,
57 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59 pending: HashMap<u32, PendingRequest>,
61 responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66 pub fn new(
68 write_tx: mpsc::UnboundedSender<String>,
69 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70 ) -> Self {
71 Self {
72 write_tx,
73 response_rx,
74 pending: HashMap::new(),
75 responses: HashMap::new(),
76 }
77 }
78
79 pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91 let msg = CommandMessage::request(topic, data);
92 self.send_message(msg)
93 }
94
95 pub fn send_message(&mut self, msg: CommandMessage) -> u32 {
100 let transaction_id = msg.transaction_id;
101
102 if let Ok(json) = serde_json::to_string(&msg) {
103 let _ = self.write_tx.send(json);
104 }
105
106 self.pending.insert(transaction_id, PendingRequest {
107 topic: msg.topic.clone(),
108 sent_at: Instant::now(),
109 });
110
111 transaction_id
112 }
113
114 pub fn poll(&mut self) {
121 while let Ok(msg) = self.response_rx.try_recv() {
122 let tid = msg.transaction_id;
123 if self.pending.remove(&tid).is_some() {
124 self.responses.insert(tid, msg);
125 }
126 }
128 }
129
130 pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
153 self.responses.remove(&transaction_id)
154 }
155
156 pub fn is_pending(&self, transaction_id: u32) -> bool {
163 self.pending.contains_key(&transaction_id)
164 }
165
166 pub fn pending_count(&self) -> usize {
168 self.pending.len()
169 }
170
171 pub fn response_count(&self) -> usize {
173 self.responses.len()
174 }
175
176 pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
178 let now = Instant::now();
179 let stale: Vec<u32> = self.pending.iter()
180 .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
181 .map(|(&tid, _)| tid)
182 .collect();
183
184 for tid in &stale {
185 if let Some(req) = self.pending.remove(tid) {
186 log::warn!("Command request {} ('{}') timed out after {:?}",
187 tid, req.topic, timeout);
188 }
189 }
190
191 stale
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use mechutil::ipc::MessageType;
199 use serde_json::json;
200
201 #[test]
202 fn test_send_pushes_to_channel() {
203 let (write_tx, mut write_rx) = mpsc::unbounded_channel();
204 let (_response_tx, response_rx) = mpsc::unbounded_channel();
205 let mut client = CommandClient::new(write_tx, response_rx);
206
207 let tid = client.send("test.command", json!({"key": "value"}));
208
209 let msg_json = write_rx.try_recv().expect("should have a message");
211 let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
212 assert_eq!(msg.transaction_id, tid);
213 assert_eq!(msg.topic, "test.command");
214 assert_eq!(msg.message_type, MessageType::Request);
215 assert_eq!(msg.data, json!({"key": "value"}));
216
217 assert!(client.is_pending(tid));
219 assert_eq!(client.pending_count(), 1);
220 }
221
222 #[test]
223 fn test_poll_and_take_response() {
224 let (write_tx, _write_rx) = mpsc::unbounded_channel();
225 let (response_tx, response_rx) = mpsc::unbounded_channel();
226 let mut client = CommandClient::new(write_tx, response_rx);
227
228 let tid = client.send("test.command", json!(null));
229 assert!(client.is_pending(tid));
230
231 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
233
234 assert!(client.take_response(tid).is_none());
236
237 client.poll();
239 assert!(!client.is_pending(tid));
240 assert_eq!(client.response_count(), 1);
241
242 let recv = client.take_response(tid).unwrap();
243 assert_eq!(recv.transaction_id, tid);
244 assert_eq!(client.response_count(), 0);
245 }
246
247 #[test]
248 fn test_multi_consumer_isolation() {
249 let (write_tx, _write_rx) = mpsc::unbounded_channel();
250 let (response_tx, response_rx) = mpsc::unbounded_channel();
251 let mut client = CommandClient::new(write_tx, response_rx);
252
253 let tid_a = client.send("labelit.inspect", json!(null));
255 let tid_b = client.send("other.status", json!(null));
256
257 response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
259 response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
260
261 client.poll();
263 assert_eq!(client.response_count(), 2);
264
265 let resp_a = client.take_response(tid_a).unwrap();
267 assert_eq!(resp_a.data, json!("a_result"));
268
269 let resp_b = client.take_response(tid_b).unwrap();
270 assert_eq!(resp_b.data, json!("b_result"));
271
272 assert!(client.take_response(tid_a).is_none());
274 assert!(client.take_response(tid_b).is_none());
275 assert_eq!(client.response_count(), 0);
276 }
277
278 #[test]
279 fn test_drain_stale() {
280 let (write_tx, _write_rx) = mpsc::unbounded_channel();
281 let (_response_tx, response_rx) = mpsc::unbounded_channel();
282 let mut client = CommandClient::new(write_tx, response_rx);
283
284 let tid = client.send("test.command", json!(null));
285 assert_eq!(client.pending_count(), 1);
286
287 let stale = client.drain_stale(Duration::from_secs(0));
289 assert_eq!(stale, vec![tid]);
290 assert_eq!(client.pending_count(), 0);
291 }
292
293 #[test]
294 fn test_drain_stale_keeps_fresh() {
295 let (write_tx, _write_rx) = mpsc::unbounded_channel();
296 let (_response_tx, response_rx) = mpsc::unbounded_channel();
297 let mut client = CommandClient::new(write_tx, response_rx);
298
299 let tid = client.send("test.command", json!(null));
300
301 let stale = client.drain_stale(Duration::from_secs(3600));
303 assert!(stale.is_empty());
304 assert!(client.is_pending(tid));
305 }
306
307 #[test]
308 fn test_drain_stale_ignores_received() {
309 let (write_tx, _write_rx) = mpsc::unbounded_channel();
310 let (response_tx, response_rx) = mpsc::unbounded_channel();
311 let mut client = CommandClient::new(write_tx, response_rx);
312
313 let tid = client.send("test.command", json!(null));
314
315 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
317 client.poll();
318
319 let stale = client.drain_stale(Duration::from_secs(0));
321 assert!(stale.is_empty());
322
323 assert!(client.take_response(tid).is_some());
325 }
326
327 #[test]
328 fn test_multiple_pending() {
329 let (write_tx, _write_rx) = mpsc::unbounded_channel();
330 let (response_tx, response_rx) = mpsc::unbounded_channel();
331 let mut client = CommandClient::new(write_tx, response_rx);
332
333 let tid1 = client.send("cmd.first", json!(1));
334 let tid2 = client.send("cmd.second", json!(2));
335 let tid3 = client.send("cmd.third", json!(3));
336 assert_eq!(client.pending_count(), 3);
337
338 response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
340 client.poll();
341
342 assert_eq!(client.pending_count(), 2);
343 assert!(client.is_pending(tid1));
344 assert!(!client.is_pending(tid2));
345 assert!(client.is_pending(tid3));
346
347 let recv = client.take_response(tid2).unwrap();
348 assert_eq!(recv.transaction_id, tid2);
349 }
350
351 #[test]
352 fn test_unsolicited_responses_discarded() {
353 let (write_tx, _write_rx) = mpsc::unbounded_channel();
354 let (response_tx, response_rx) = mpsc::unbounded_channel();
355 let mut client = CommandClient::new(write_tx, response_rx);
356
357 response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
360 response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
361
362 client.poll();
363
364 assert_eq!(client.response_count(), 0);
366 assert!(client.take_response(99999).is_none());
367 assert!(client.take_response(99998).is_none());
368 }
369
370 #[test]
371 fn test_mix_of_solicited_and_unsolicited() {
372 let (write_tx, _write_rx) = mpsc::unbounded_channel();
373 let (response_tx, response_rx) = mpsc::unbounded_channel();
374 let mut client = CommandClient::new(write_tx, response_rx);
375
376 let tid = client.send("real.command", json!(null));
378
379 response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
381 response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
382
383 client.poll();
384
385 assert_eq!(client.response_count(), 1);
387 let resp = client.take_response(tid).unwrap();
388 assert_eq!(resp.data, json!("real_result"));
389 assert!(client.take_response(77777).is_none());
390 }
391}