1use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37 topic: String,
38 sent_at: Instant,
39}
40
41pub struct CommandClient {
55 write_tx: mpsc::UnboundedSender<String>,
57 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59 pending: HashMap<u32, PendingRequest>,
61 responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66 pub fn new(
68 write_tx: mpsc::UnboundedSender<String>,
69 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70 ) -> Self {
71 Self {
72 write_tx,
73 response_rx,
74 pending: HashMap::new(),
75 responses: HashMap::new(),
76 }
77 }
78
79 pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91 let msg = CommandMessage::request(topic, data);
92 self.send_message(msg)
93 }
94
95 pub fn send_message(&mut self, msg: CommandMessage) -> u32 {
100 let transaction_id = msg.transaction_id;
101
102 if let Ok(json) = serde_json::to_string(&msg) {
103 let _ = self.write_tx.send(json);
104 }
105
106 self.pending.insert(transaction_id, PendingRequest {
107 topic: msg.topic.clone(),
108 sent_at: Instant::now(),
109 });
110
111 transaction_id
112 }
113
114 pub fn poll(&mut self) {
121 while let Ok(msg) = self.response_rx.try_recv() {
122 let tid = msg.transaction_id;
123 if self.pending.remove(&tid).is_some() {
124 self.responses.insert(tid, msg);
126 }
127 }
129 }
130
131 pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
154 self.responses.remove(&transaction_id)
155 }
156
157 pub fn is_pending(&self, transaction_id: u32) -> bool {
164 self.pending.contains_key(&transaction_id)
165 }
166
167 pub fn pending_count(&self) -> usize {
169 self.pending.len()
170 }
171
172 pub fn response_count(&self) -> usize {
174 self.responses.len()
175 }
176
177 pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
179 let now = Instant::now();
180 let stale: Vec<u32> = self.pending.iter()
181 .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
182 .map(|(&tid, _)| tid)
183 .collect();
184
185 for tid in &stale {
186 if let Some(req) = self.pending.remove(tid) {
187 log::warn!("Command request {} ('{}') timed out after {:?}",
188 tid, req.topic, timeout);
189 }
190 }
191
192 stale
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use mechutil::ipc::MessageType;
200 use serde_json::json;
201
202 #[test]
203 fn test_send_pushes_to_channel() {
204 let (write_tx, mut write_rx) = mpsc::unbounded_channel();
205 let (_response_tx, response_rx) = mpsc::unbounded_channel();
206 let mut client = CommandClient::new(write_tx, response_rx);
207
208 let tid = client.send("test.command", json!({"key": "value"}));
209
210 let msg_json = write_rx.try_recv().expect("should have a message");
212 let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
213 assert_eq!(msg.transaction_id, tid);
214 assert_eq!(msg.topic, "test.command");
215 assert_eq!(msg.message_type, MessageType::Request);
216 assert_eq!(msg.data, json!({"key": "value"}));
217
218 assert!(client.is_pending(tid));
220 assert_eq!(client.pending_count(), 1);
221 }
222
223 #[test]
224 fn test_poll_and_take_response() {
225 let (write_tx, _write_rx) = mpsc::unbounded_channel();
226 let (response_tx, response_rx) = mpsc::unbounded_channel();
227 let mut client = CommandClient::new(write_tx, response_rx);
228
229 let tid = client.send("test.command", json!(null));
230 assert!(client.is_pending(tid));
231
232 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
234
235 assert!(client.take_response(tid).is_none());
237
238 client.poll();
240 assert!(!client.is_pending(tid));
241 assert_eq!(client.response_count(), 1);
242
243 let recv = client.take_response(tid).unwrap();
244 assert_eq!(recv.transaction_id, tid);
245 assert_eq!(client.response_count(), 0);
246 }
247
248 #[test]
249 fn test_multi_consumer_isolation() {
250 let (write_tx, _write_rx) = mpsc::unbounded_channel();
251 let (response_tx, response_rx) = mpsc::unbounded_channel();
252 let mut client = CommandClient::new(write_tx, response_rx);
253
254 let tid_a = client.send("labelit.inspect", json!(null));
256 let tid_b = client.send("other.status", json!(null));
257
258 response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
260 response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
261
262 client.poll();
264 assert_eq!(client.response_count(), 2);
265
266 let resp_a = client.take_response(tid_a).unwrap();
268 assert_eq!(resp_a.data, json!("a_result"));
269
270 let resp_b = client.take_response(tid_b).unwrap();
271 assert_eq!(resp_b.data, json!("b_result"));
272
273 assert!(client.take_response(tid_a).is_none());
275 assert!(client.take_response(tid_b).is_none());
276 assert_eq!(client.response_count(), 0);
277 }
278
279 #[test]
280 fn test_drain_stale() {
281 let (write_tx, _write_rx) = mpsc::unbounded_channel();
282 let (_response_tx, response_rx) = mpsc::unbounded_channel();
283 let mut client = CommandClient::new(write_tx, response_rx);
284
285 let tid = client.send("test.command", json!(null));
286 assert_eq!(client.pending_count(), 1);
287
288 let stale = client.drain_stale(Duration::from_secs(0));
290 assert_eq!(stale, vec![tid]);
291 assert_eq!(client.pending_count(), 0);
292 }
293
294 #[test]
295 fn test_drain_stale_keeps_fresh() {
296 let (write_tx, _write_rx) = mpsc::unbounded_channel();
297 let (_response_tx, response_rx) = mpsc::unbounded_channel();
298 let mut client = CommandClient::new(write_tx, response_rx);
299
300 let tid = client.send("test.command", json!(null));
301
302 let stale = client.drain_stale(Duration::from_secs(3600));
304 assert!(stale.is_empty());
305 assert!(client.is_pending(tid));
306 }
307
308 #[test]
309 fn test_drain_stale_ignores_received() {
310 let (write_tx, _write_rx) = mpsc::unbounded_channel();
311 let (response_tx, response_rx) = mpsc::unbounded_channel();
312 let mut client = CommandClient::new(write_tx, response_rx);
313
314 let tid = client.send("test.command", json!(null));
315
316 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
318 client.poll();
319
320 let stale = client.drain_stale(Duration::from_secs(0));
322 assert!(stale.is_empty());
323
324 assert!(client.take_response(tid).is_some());
326 }
327
328 #[test]
329 fn test_multiple_pending() {
330 let (write_tx, _write_rx) = mpsc::unbounded_channel();
331 let (response_tx, response_rx) = mpsc::unbounded_channel();
332 let mut client = CommandClient::new(write_tx, response_rx);
333
334 let tid1 = client.send("cmd.first", json!(1));
335 let tid2 = client.send("cmd.second", json!(2));
336 let tid3 = client.send("cmd.third", json!(3));
337 assert_eq!(client.pending_count(), 3);
338
339 response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
341 client.poll();
342
343 assert_eq!(client.pending_count(), 2);
344 assert!(client.is_pending(tid1));
345 assert!(!client.is_pending(tid2));
346 assert!(client.is_pending(tid3));
347
348 let recv = client.take_response(tid2).unwrap();
349 assert_eq!(recv.transaction_id, tid2);
350 }
351
352 #[test]
353 fn test_unsolicited_responses_discarded() {
354 let (write_tx, _write_rx) = mpsc::unbounded_channel();
355 let (response_tx, response_rx) = mpsc::unbounded_channel();
356 let mut client = CommandClient::new(write_tx, response_rx);
357
358 response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
361 response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
362
363 client.poll();
364
365 assert_eq!(client.response_count(), 0);
367 assert!(client.take_response(99999).is_none());
368 assert!(client.take_response(99998).is_none());
369 }
370
371 #[test]
372 fn test_mix_of_solicited_and_unsolicited() {
373 let (write_tx, _write_rx) = mpsc::unbounded_channel();
374 let (response_tx, response_rx) = mpsc::unbounded_channel();
375 let mut client = CommandClient::new(write_tx, response_rx);
376
377 let tid = client.send("real.command", json!(null));
379
380 response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
382 response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
383
384 client.poll();
385
386 assert_eq!(client.response_count(), 1);
388 let resp = client.take_response(tid).unwrap();
389 assert_eq!(resp.data, json!("real_result"));
390 assert!(client.take_response(77777).is_none());
391 }
392}