1use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37 topic: String,
38 sent_at: Instant,
39}
40
41pub struct CommandClient {
55 write_tx: mpsc::UnboundedSender<String>,
57 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59 pending: HashMap<u32, PendingRequest>,
61 responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66 pub fn new(
68 write_tx: mpsc::UnboundedSender<String>,
69 response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70 ) -> Self {
71 Self {
72 write_tx,
73 response_rx,
74 pending: HashMap::new(),
75 responses: HashMap::new(),
76 }
77 }
78
79 pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91 let msg = CommandMessage::request(topic, data);
92 let transaction_id = msg.transaction_id;
93
94 if let Ok(json) = serde_json::to_string(&msg) {
95 let _ = self.write_tx.send(json);
96 }
97
98 self.pending.insert(transaction_id, PendingRequest {
99 topic: topic.to_string(),
100 sent_at: Instant::now(),
101 });
102
103 transaction_id
104 }
105
106 pub fn poll(&mut self) {
113 while let Ok(msg) = self.response_rx.try_recv() {
114 let tid = msg.transaction_id;
115 if self.pending.remove(&tid).is_some() {
116 self.responses.insert(tid, msg);
117 }
118 }
120 }
121
122 pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
145 self.responses.remove(&transaction_id)
146 }
147
148 pub fn is_pending(&self, transaction_id: u32) -> bool {
155 self.pending.contains_key(&transaction_id)
156 }
157
158 pub fn pending_count(&self) -> usize {
160 self.pending.len()
161 }
162
163 pub fn response_count(&self) -> usize {
165 self.responses.len()
166 }
167
168 pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
170 let now = Instant::now();
171 let stale: Vec<u32> = self.pending.iter()
172 .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
173 .map(|(&tid, _)| tid)
174 .collect();
175
176 for tid in &stale {
177 if let Some(req) = self.pending.remove(tid) {
178 log::warn!("Command request {} ('{}') timed out after {:?}",
179 tid, req.topic, timeout);
180 }
181 }
182
183 stale
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use mechutil::ipc::MessageType;
191 use serde_json::json;
192
193 #[test]
194 fn test_send_pushes_to_channel() {
195 let (write_tx, mut write_rx) = mpsc::unbounded_channel();
196 let (_response_tx, response_rx) = mpsc::unbounded_channel();
197 let mut client = CommandClient::new(write_tx, response_rx);
198
199 let tid = client.send("test.command", json!({"key": "value"}));
200
201 let msg_json = write_rx.try_recv().expect("should have a message");
203 let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
204 assert_eq!(msg.transaction_id, tid);
205 assert_eq!(msg.topic, "test.command");
206 assert_eq!(msg.message_type, MessageType::Request);
207 assert_eq!(msg.data, json!({"key": "value"}));
208
209 assert!(client.is_pending(tid));
211 assert_eq!(client.pending_count(), 1);
212 }
213
214 #[test]
215 fn test_poll_and_take_response() {
216 let (write_tx, _write_rx) = mpsc::unbounded_channel();
217 let (response_tx, response_rx) = mpsc::unbounded_channel();
218 let mut client = CommandClient::new(write_tx, response_rx);
219
220 let tid = client.send("test.command", json!(null));
221 assert!(client.is_pending(tid));
222
223 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
225
226 assert!(client.take_response(tid).is_none());
228
229 client.poll();
231 assert!(!client.is_pending(tid));
232 assert_eq!(client.response_count(), 1);
233
234 let recv = client.take_response(tid).unwrap();
235 assert_eq!(recv.transaction_id, tid);
236 assert_eq!(client.response_count(), 0);
237 }
238
239 #[test]
240 fn test_multi_consumer_isolation() {
241 let (write_tx, _write_rx) = mpsc::unbounded_channel();
242 let (response_tx, response_rx) = mpsc::unbounded_channel();
243 let mut client = CommandClient::new(write_tx, response_rx);
244
245 let tid_a = client.send("labelit.inspect", json!(null));
247 let tid_b = client.send("other.status", json!(null));
248
249 response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
251 response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
252
253 client.poll();
255 assert_eq!(client.response_count(), 2);
256
257 let resp_a = client.take_response(tid_a).unwrap();
259 assert_eq!(resp_a.data, json!("a_result"));
260
261 let resp_b = client.take_response(tid_b).unwrap();
262 assert_eq!(resp_b.data, json!("b_result"));
263
264 assert!(client.take_response(tid_a).is_none());
266 assert!(client.take_response(tid_b).is_none());
267 assert_eq!(client.response_count(), 0);
268 }
269
270 #[test]
271 fn test_drain_stale() {
272 let (write_tx, _write_rx) = mpsc::unbounded_channel();
273 let (_response_tx, response_rx) = mpsc::unbounded_channel();
274 let mut client = CommandClient::new(write_tx, response_rx);
275
276 let tid = client.send("test.command", json!(null));
277 assert_eq!(client.pending_count(), 1);
278
279 let stale = client.drain_stale(Duration::from_secs(0));
281 assert_eq!(stale, vec![tid]);
282 assert_eq!(client.pending_count(), 0);
283 }
284
285 #[test]
286 fn test_drain_stale_keeps_fresh() {
287 let (write_tx, _write_rx) = mpsc::unbounded_channel();
288 let (_response_tx, response_rx) = mpsc::unbounded_channel();
289 let mut client = CommandClient::new(write_tx, response_rx);
290
291 let tid = client.send("test.command", json!(null));
292
293 let stale = client.drain_stale(Duration::from_secs(3600));
295 assert!(stale.is_empty());
296 assert!(client.is_pending(tid));
297 }
298
299 #[test]
300 fn test_drain_stale_ignores_received() {
301 let (write_tx, _write_rx) = mpsc::unbounded_channel();
302 let (response_tx, response_rx) = mpsc::unbounded_channel();
303 let mut client = CommandClient::new(write_tx, response_rx);
304
305 let tid = client.send("test.command", json!(null));
306
307 response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
309 client.poll();
310
311 let stale = client.drain_stale(Duration::from_secs(0));
313 assert!(stale.is_empty());
314
315 assert!(client.take_response(tid).is_some());
317 }
318
319 #[test]
320 fn test_multiple_pending() {
321 let (write_tx, _write_rx) = mpsc::unbounded_channel();
322 let (response_tx, response_rx) = mpsc::unbounded_channel();
323 let mut client = CommandClient::new(write_tx, response_rx);
324
325 let tid1 = client.send("cmd.first", json!(1));
326 let tid2 = client.send("cmd.second", json!(2));
327 let tid3 = client.send("cmd.third", json!(3));
328 assert_eq!(client.pending_count(), 3);
329
330 response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
332 client.poll();
333
334 assert_eq!(client.pending_count(), 2);
335 assert!(client.is_pending(tid1));
336 assert!(!client.is_pending(tid2));
337 assert!(client.is_pending(tid3));
338
339 let recv = client.take_response(tid2).unwrap();
340 assert_eq!(recv.transaction_id, tid2);
341 }
342
343 #[test]
344 fn test_unsolicited_responses_discarded() {
345 let (write_tx, _write_rx) = mpsc::unbounded_channel();
346 let (response_tx, response_rx) = mpsc::unbounded_channel();
347 let mut client = CommandClient::new(write_tx, response_rx);
348
349 response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
352 response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
353
354 client.poll();
355
356 assert_eq!(client.response_count(), 0);
358 assert!(client.take_response(99999).is_none());
359 assert!(client.take_response(99998).is_none());
360 }
361
362 #[test]
363 fn test_mix_of_solicited_and_unsolicited() {
364 let (write_tx, _write_rx) = mpsc::unbounded_channel();
365 let (response_tx, response_rx) = mpsc::unbounded_channel();
366 let mut client = CommandClient::new(write_tx, response_rx);
367
368 let tid = client.send("real.command", json!(null));
370
371 response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
373 response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
374
375 client.poll();
376
377 assert_eq!(client.response_count(), 1);
379 let resp = client.take_response(tid).unwrap();
380 assert_eq!(resp.data, json!("real_result"));
381 assert!(client.take_response(77777).is_none());
382 }
383}