Skip to main content

autocore_std/
command_client.rs

1//! Client for sending IPC commands to external modules via WebSocket.
2//!
3//! `CommandClient` allows control programs to send [`CommandMessage`] requests
4//! (e.g., `labelit.translate_check`) to any external module through the existing
5//! WebSocket connection and poll for responses non-blockingly from `process_tick`.
6//!
7//! # Multi-consumer pattern
8//!
9//! The framework passes a `&mut CommandClient` to the control program each cycle
10//! via [`TickContext`](crate::TickContext). The framework calls [`poll()`](CommandClient::poll)
11//! before `process_tick`, so incoming responses are already buffered. Each
12//! subsystem sends requests via [`send()`](CommandClient::send) and retrieves
13//! its own responses by transaction ID via
14//! [`take_response()`](CommandClient::take_response).
15//!
16//! ```ignore
17//! fn process_tick(&mut self, ctx: &mut TickContext<Self::Memory>) {
18//!     // poll() is already called by the framework before process_tick
19//!
20//!     // Each subsystem checks for its own responses by transaction_id
21//!     self.labelit.tick(ctx.client);
22//!     self.other_camera.tick(ctx.client);
23//!
24//!     // Clean up stale requests
25//!     ctx.client.drain_stale(Duration::from_secs(10));
26//! }
27//! ```
28
29use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37    topic: String,
38    sent_at: Instant,
39}
40
41/// A non-blocking client for sending IPC commands to external modules.
42///
43/// `CommandClient` is constructed by [`ControlRunner`](crate::ControlRunner) during
44/// startup and passed to the control program each cycle via
45/// [`TickContext::client`](crate::TickContext::client).
46///
47/// All methods are non-blocking and safe to call from `process_tick`.
48///
49/// Multiple subsystems (state machines, modules) can share a single `CommandClient`.
50/// Each subsystem calls [`send()`](Self::send) to issue requests and
51/// [`take_response()`](Self::take_response) to claim its own responses by
52/// `transaction_id`. The framework calls [`poll()`](Self::poll) once per tick
53/// before `process_tick`, so incoming messages are already buffered.
54pub struct CommandClient {
55    /// Channel to send serialized messages for the WS write task.
56    write_tx: mpsc::UnboundedSender<String>,
57    /// Channel to receive response CommandMessages from the WS read task.
58    response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59    /// Track pending requests by transaction_id for timeout/diagnostics.
60    pending: HashMap<u32, PendingRequest>,
61    /// Buffered responses keyed by transaction_id, ready for consumers to claim.
62    responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66    /// Create a new `CommandClient` from channels created by `ControlRunner`.
67    pub fn new(
68        write_tx: mpsc::UnboundedSender<String>,
69        response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70    ) -> Self {
71        Self {
72            write_tx,
73            response_rx,
74            pending: HashMap::new(),
75            responses: HashMap::new(),
76        }
77    }
78
79    /// Send a command request to an external module.
80    ///
81    /// Creates a [`CommandMessage::request`] with the given topic and data,
82    /// serializes it, and pushes it into the WebSocket write channel.
83    ///
84    /// Returns the `transaction_id` which can be used to match the response.
85    ///
86    /// # Arguments
87    ///
88    /// * `topic` - Fully-qualified topic name (e.g., `"labelit.translate_check"`)
89    /// * `data` - JSON payload for the request
90    pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91        let msg = CommandMessage::request(topic, data);
92        self.send_message(msg)
93    }
94
95    /// Send an arbitrary `CommandMessage`.
96    /// 
97    /// Useful for sending specific `MessageType` requests like `Read` or `Write`.
98    /// Returns the `transaction_id` which can be used to match the response.
99    pub fn send_message(&mut self, msg: CommandMessage) -> u32 {
100        let transaction_id = msg.transaction_id;
101
102        if let Ok(json) = serde_json::to_string(&msg) {
103            let _ = self.write_tx.send(json);
104        }
105
106        self.pending.insert(transaction_id, PendingRequest {
107            topic: msg.topic.clone(),
108            sent_at: Instant::now(),
109        });
110
111        transaction_id
112    }
113
114    /// Drain all available responses from the WebSocket channel into the
115    /// internal buffer.
116    ///
117    /// Call this **once per tick** at the top of `process_tick`, before any
118    /// subsystem calls [`take_response()`](Self::take_response). This ensures
119    /// every subsystem sees responses that arrived since the last cycle.
120    pub fn poll(&mut self) {
121        while let Ok(msg) = self.response_rx.try_recv() {
122            let tid = msg.transaction_id;
123            if self.pending.remove(&tid).is_some() {
124                self.responses.insert(tid, msg);
125            }
126            // else: unsolicited response (e.g. gm.write from control loop), discard
127        }
128    }
129
130    /// Take a response for a specific `transaction_id` from the buffer.
131    ///
132    /// Returns `Some(response)` if a response with that ID has been received,
133    /// or `None` if it hasn't arrived yet. The response is removed from the
134    /// buffer on retrieval.
135    ///
136    /// This is the recommended way for subsystems to retrieve their responses,
137    /// since each subsystem only claims its own `transaction_id` and cannot
138    /// accidentally consume another subsystem's response.
139    ///
140    /// # Example
141    ///
142    /// ```ignore
143    /// // In a subsystem's tick method:
144    /// if let Some(response) = client.take_response(self.my_tid) {
145    ///     if response.success {
146    ///         // handle success
147    ///     } else {
148    ///         // handle error
149    ///     }
150    /// }
151    /// ```
152    pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
153        self.responses.remove(&transaction_id)
154    }
155
156    /// Check if a request is still awaiting a response.
157    ///
158    /// Returns `true` if the request has been sent but no response has arrived
159    /// yet. Returns `false` if the response is already buffered (even if not
160    /// yet claimed via [`take_response()`](Self::take_response)) or if the
161    /// transaction ID is unknown.
162    pub fn is_pending(&self, transaction_id: u32) -> bool {
163        self.pending.contains_key(&transaction_id)
164    }
165
166    /// Number of outstanding requests (sent but no response received yet).
167    pub fn pending_count(&self) -> usize {
168        self.pending.len()
169    }
170
171    /// Number of responses buffered and ready to be claimed.
172    pub fn response_count(&self) -> usize {
173        self.responses.len()
174    }
175
176    /// Remove and return transaction IDs that have been pending longer than `timeout`.
177    pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
178        let now = Instant::now();
179        let stale: Vec<u32> = self.pending.iter()
180            .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
181            .map(|(&tid, _)| tid)
182            .collect();
183
184        for tid in &stale {
185            if let Some(req) = self.pending.remove(tid) {
186                log::warn!("Command request {} ('{}') timed out after {:?}",
187                    tid, req.topic, timeout);
188            }
189        }
190
191        stale
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use mechutil::ipc::MessageType;
199    use serde_json::json;
200
201    #[test]
202    fn test_send_pushes_to_channel() {
203        let (write_tx, mut write_rx) = mpsc::unbounded_channel();
204        let (_response_tx, response_rx) = mpsc::unbounded_channel();
205        let mut client = CommandClient::new(write_tx, response_rx);
206
207        let tid = client.send("test.command", json!({"key": "value"}));
208
209        // Should have pushed a message to the write channel
210        let msg_json = write_rx.try_recv().expect("should have a message");
211        let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
212        assert_eq!(msg.transaction_id, tid);
213        assert_eq!(msg.topic, "test.command");
214        assert_eq!(msg.message_type, MessageType::Request);
215        assert_eq!(msg.data, json!({"key": "value"}));
216
217        // Should be tracked as pending
218        assert!(client.is_pending(tid));
219        assert_eq!(client.pending_count(), 1);
220    }
221
222    #[test]
223    fn test_poll_and_take_response() {
224        let (write_tx, _write_rx) = mpsc::unbounded_channel();
225        let (response_tx, response_rx) = mpsc::unbounded_channel();
226        let mut client = CommandClient::new(write_tx, response_rx);
227
228        let tid = client.send("test.command", json!(null));
229        assert!(client.is_pending(tid));
230
231        // Simulate response arriving
232        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
233
234        // Before poll, take_response finds nothing
235        assert!(client.take_response(tid).is_none());
236
237        // After poll, take_response returns the response
238        client.poll();
239        assert!(!client.is_pending(tid));
240        assert_eq!(client.response_count(), 1);
241
242        let recv = client.take_response(tid).unwrap();
243        assert_eq!(recv.transaction_id, tid);
244        assert_eq!(client.response_count(), 0);
245    }
246
247    #[test]
248    fn test_multi_consumer_isolation() {
249        let (write_tx, _write_rx) = mpsc::unbounded_channel();
250        let (response_tx, response_rx) = mpsc::unbounded_channel();
251        let mut client = CommandClient::new(write_tx, response_rx);
252
253        // Two subsystems send requests
254        let tid_a = client.send("labelit.inspect", json!(null));
255        let tid_b = client.send("other.status", json!(null));
256
257        // Both responses arrive
258        response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
259        response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
260
261        // Single poll drains both
262        client.poll();
263        assert_eq!(client.response_count(), 2);
264
265        // Each subsystem claims only its own response
266        let resp_a = client.take_response(tid_a).unwrap();
267        assert_eq!(resp_a.data, json!("a_result"));
268
269        let resp_b = client.take_response(tid_b).unwrap();
270        assert_eq!(resp_b.data, json!("b_result"));
271
272        // Neither response is available again
273        assert!(client.take_response(tid_a).is_none());
274        assert!(client.take_response(tid_b).is_none());
275        assert_eq!(client.response_count(), 0);
276    }
277
278    #[test]
279    fn test_drain_stale() {
280        let (write_tx, _write_rx) = mpsc::unbounded_channel();
281        let (_response_tx, response_rx) = mpsc::unbounded_channel();
282        let mut client = CommandClient::new(write_tx, response_rx);
283
284        let tid = client.send("test.command", json!(null));
285        assert_eq!(client.pending_count(), 1);
286
287        // With a zero timeout, the request should be immediately stale
288        let stale = client.drain_stale(Duration::from_secs(0));
289        assert_eq!(stale, vec![tid]);
290        assert_eq!(client.pending_count(), 0);
291    }
292
293    #[test]
294    fn test_drain_stale_keeps_fresh() {
295        let (write_tx, _write_rx) = mpsc::unbounded_channel();
296        let (_response_tx, response_rx) = mpsc::unbounded_channel();
297        let mut client = CommandClient::new(write_tx, response_rx);
298
299        let tid = client.send("test.command", json!(null));
300
301        // With a long timeout, nothing should be stale
302        let stale = client.drain_stale(Duration::from_secs(3600));
303        assert!(stale.is_empty());
304        assert!(client.is_pending(tid));
305    }
306
307    #[test]
308    fn test_drain_stale_ignores_received() {
309        let (write_tx, _write_rx) = mpsc::unbounded_channel();
310        let (response_tx, response_rx) = mpsc::unbounded_channel();
311        let mut client = CommandClient::new(write_tx, response_rx);
312
313        let tid = client.send("test.command", json!(null));
314
315        // Response arrives before we drain
316        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
317        client.poll();
318
319        // drain_stale should not report it since it's no longer pending
320        let stale = client.drain_stale(Duration::from_secs(0));
321        assert!(stale.is_empty());
322
323        // But it's still in the response buffer
324        assert!(client.take_response(tid).is_some());
325    }
326
327    #[test]
328    fn test_multiple_pending() {
329        let (write_tx, _write_rx) = mpsc::unbounded_channel();
330        let (response_tx, response_rx) = mpsc::unbounded_channel();
331        let mut client = CommandClient::new(write_tx, response_rx);
332
333        let tid1 = client.send("cmd.first", json!(1));
334        let tid2 = client.send("cmd.second", json!(2));
335        let tid3 = client.send("cmd.third", json!(3));
336        assert_eq!(client.pending_count(), 3);
337
338        // Respond to the second one
339        response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
340        client.poll();
341
342        assert_eq!(client.pending_count(), 2);
343        assert!(client.is_pending(tid1));
344        assert!(!client.is_pending(tid2));
345        assert!(client.is_pending(tid3));
346
347        let recv = client.take_response(tid2).unwrap();
348        assert_eq!(recv.transaction_id, tid2);
349    }
350
351    #[test]
352    fn test_unsolicited_responses_discarded() {
353        let (write_tx, _write_rx) = mpsc::unbounded_channel();
354        let (response_tx, response_rx) = mpsc::unbounded_channel();
355        let mut client = CommandClient::new(write_tx, response_rx);
356
357        // Simulate responses arriving for transaction IDs that were never
358        // registered via send() (e.g. gm.write sent directly via ws_write_tx).
359        response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
360        response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
361
362        client.poll();
363
364        // Unsolicited responses must NOT accumulate in the responses HashMap
365        assert_eq!(client.response_count(), 0);
366        assert!(client.take_response(99999).is_none());
367        assert!(client.take_response(99998).is_none());
368    }
369
370    #[test]
371    fn test_mix_of_solicited_and_unsolicited() {
372        let (write_tx, _write_rx) = mpsc::unbounded_channel();
373        let (response_tx, response_rx) = mpsc::unbounded_channel();
374        let mut client = CommandClient::new(write_tx, response_rx);
375
376        // One real request
377        let tid = client.send("real.command", json!(null));
378
379        // One unsolicited + one solicited response
380        response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
381        response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
382
383        client.poll();
384
385        // Only the solicited response should be buffered
386        assert_eq!(client.response_count(), 1);
387        let resp = client.take_response(tid).unwrap();
388        assert_eq!(resp.data, json!("real_result"));
389        assert!(client.take_response(77777).is_none());
390    }
391}