Skip to main content

autocore_std/
command_client.rs

1//! Client for sending IPC commands to external modules via WebSocket.
2//!
3//! `CommandClient` allows control programs to send [`CommandMessage`] requests
4//! (e.g., `labelit.translate_check`) to any external module through the existing
5//! WebSocket connection and poll for responses non-blockingly from `process_tick`.
6//!
7//! # Multi-consumer pattern
8//!
9//! The framework passes a `&mut CommandClient` to the control program each cycle
10//! via [`TickContext`](crate::TickContext). The framework calls [`poll()`](CommandClient::poll)
11//! before `process_tick`, so incoming responses are already buffered. Each
12//! subsystem sends requests via [`send()`](CommandClient::send) and retrieves
13//! its own responses by transaction ID via
14//! [`take_response()`](CommandClient::take_response).
15//!
16//! ```ignore
17//! fn process_tick(&mut self, ctx: &mut TickContext<Self::Memory>) {
18//!     // poll() is already called by the framework before process_tick
19//!
20//!     // Each subsystem checks for its own responses by transaction_id
21//!     self.labelit.tick(ctx.client);
22//!     self.other_camera.tick(ctx.client);
23//!
24//!     // Clean up stale requests
25//!     ctx.client.drain_stale(Duration::from_secs(10));
26//! }
27//! ```
28
29use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37    topic: String,
38    sent_at: Instant,
39}
40
41/// A non-blocking client for sending IPC commands to external modules.
42///
43/// `CommandClient` is constructed by [`ControlRunner`](crate::ControlRunner) during
44/// startup and passed to the control program each cycle via
45/// [`TickContext::client`](crate::TickContext::client).
46///
47/// All methods are non-blocking and safe to call from `process_tick`.
48///
49/// Multiple subsystems (state machines, modules) can share a single `CommandClient`.
50/// Each subsystem calls [`send()`](Self::send) to issue requests and
51/// [`take_response()`](Self::take_response) to claim its own responses by
52/// `transaction_id`. The framework calls [`poll()`](Self::poll) once per tick
53/// before `process_tick`, so incoming messages are already buffered.
54pub struct CommandClient {
55    /// Channel to send serialized messages for the WS write task.
56    write_tx: mpsc::UnboundedSender<String>,
57    /// Channel to receive response CommandMessages from the WS read task.
58    response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59    /// Track pending requests by transaction_id for timeout/diagnostics.
60    pending: HashMap<u32, PendingRequest>,
61    /// Buffered responses keyed by transaction_id, ready for consumers to claim.
62    responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66    /// Create a new `CommandClient` from channels created by `ControlRunner`.
67    pub fn new(
68        write_tx: mpsc::UnboundedSender<String>,
69        response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70    ) -> Self {
71        Self {
72            write_tx,
73            response_rx,
74            pending: HashMap::new(),
75            responses: HashMap::new(),
76        }
77    }
78
79    /// Send a command request to an external module.
80    ///
81    /// Creates a [`CommandMessage::request`] with the given topic and data,
82    /// serializes it, and pushes it into the WebSocket write channel.
83    ///
84    /// Returns the `transaction_id` which can be used to match the response.
85    ///
86    /// # Arguments
87    ///
88    /// * `topic` - Fully-qualified topic name (e.g., `"labelit.translate_check"`)
89    /// * `data` - JSON payload for the request
90    pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91        let msg = CommandMessage::request(topic, data);
92        let transaction_id = msg.transaction_id;
93
94        if let Ok(json) = serde_json::to_string(&msg) {
95            let _ = self.write_tx.send(json);
96        }
97
98        self.pending.insert(transaction_id, PendingRequest {
99            topic: topic.to_string(),
100            sent_at: Instant::now(),
101        });
102
103        transaction_id
104    }
105
106    /// Drain all available responses from the WebSocket channel into the
107    /// internal buffer.
108    ///
109    /// Call this **once per tick** at the top of `process_tick`, before any
110    /// subsystem calls [`take_response()`](Self::take_response). This ensures
111    /// every subsystem sees responses that arrived since the last cycle.
112    pub fn poll(&mut self) {
113        while let Ok(msg) = self.response_rx.try_recv() {
114            let tid = msg.transaction_id;
115            if self.pending.remove(&tid).is_some() {
116                self.responses.insert(tid, msg);
117            }
118            // else: unsolicited response (e.g. gm.write from control loop), discard
119        }
120    }
121
122    /// Take a response for a specific `transaction_id` from the buffer.
123    ///
124    /// Returns `Some(response)` if a response with that ID has been received,
125    /// or `None` if it hasn't arrived yet. The response is removed from the
126    /// buffer on retrieval.
127    ///
128    /// This is the recommended way for subsystems to retrieve their responses,
129    /// since each subsystem only claims its own `transaction_id` and cannot
130    /// accidentally consume another subsystem's response.
131    ///
132    /// # Example
133    ///
134    /// ```ignore
135    /// // In a subsystem's tick method:
136    /// if let Some(response) = client.take_response(self.my_tid) {
137    ///     if response.success {
138    ///         // handle success
139    ///     } else {
140    ///         // handle error
141    ///     }
142    /// }
143    /// ```
144    pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
145        self.responses.remove(&transaction_id)
146    }
147
148    /// Check if a request is still awaiting a response.
149    ///
150    /// Returns `true` if the request has been sent but no response has arrived
151    /// yet. Returns `false` if the response is already buffered (even if not
152    /// yet claimed via [`take_response()`](Self::take_response)) or if the
153    /// transaction ID is unknown.
154    pub fn is_pending(&self, transaction_id: u32) -> bool {
155        self.pending.contains_key(&transaction_id)
156    }
157
158    /// Number of outstanding requests (sent but no response received yet).
159    pub fn pending_count(&self) -> usize {
160        self.pending.len()
161    }
162
163    /// Number of responses buffered and ready to be claimed.
164    pub fn response_count(&self) -> usize {
165        self.responses.len()
166    }
167
168    /// Remove and return transaction IDs that have been pending longer than `timeout`.
169    pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
170        let now = Instant::now();
171        let stale: Vec<u32> = self.pending.iter()
172            .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
173            .map(|(&tid, _)| tid)
174            .collect();
175
176        for tid in &stale {
177            if let Some(req) = self.pending.remove(tid) {
178                log::warn!("Command request {} ('{}') timed out after {:?}",
179                    tid, req.topic, timeout);
180            }
181        }
182
183        stale
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use mechutil::ipc::MessageType;
191    use serde_json::json;
192
193    #[test]
194    fn test_send_pushes_to_channel() {
195        let (write_tx, mut write_rx) = mpsc::unbounded_channel();
196        let (_response_tx, response_rx) = mpsc::unbounded_channel();
197        let mut client = CommandClient::new(write_tx, response_rx);
198
199        let tid = client.send("test.command", json!({"key": "value"}));
200
201        // Should have pushed a message to the write channel
202        let msg_json = write_rx.try_recv().expect("should have a message");
203        let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
204        assert_eq!(msg.transaction_id, tid);
205        assert_eq!(msg.topic, "test.command");
206        assert_eq!(msg.message_type, MessageType::Request);
207        assert_eq!(msg.data, json!({"key": "value"}));
208
209        // Should be tracked as pending
210        assert!(client.is_pending(tid));
211        assert_eq!(client.pending_count(), 1);
212    }
213
214    #[test]
215    fn test_poll_and_take_response() {
216        let (write_tx, _write_rx) = mpsc::unbounded_channel();
217        let (response_tx, response_rx) = mpsc::unbounded_channel();
218        let mut client = CommandClient::new(write_tx, response_rx);
219
220        let tid = client.send("test.command", json!(null));
221        assert!(client.is_pending(tid));
222
223        // Simulate response arriving
224        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
225
226        // Before poll, take_response finds nothing
227        assert!(client.take_response(tid).is_none());
228
229        // After poll, take_response returns the response
230        client.poll();
231        assert!(!client.is_pending(tid));
232        assert_eq!(client.response_count(), 1);
233
234        let recv = client.take_response(tid).unwrap();
235        assert_eq!(recv.transaction_id, tid);
236        assert_eq!(client.response_count(), 0);
237    }
238
239    #[test]
240    fn test_multi_consumer_isolation() {
241        let (write_tx, _write_rx) = mpsc::unbounded_channel();
242        let (response_tx, response_rx) = mpsc::unbounded_channel();
243        let mut client = CommandClient::new(write_tx, response_rx);
244
245        // Two subsystems send requests
246        let tid_a = client.send("labelit.inspect", json!(null));
247        let tid_b = client.send("other.status", json!(null));
248
249        // Both responses arrive
250        response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
251        response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
252
253        // Single poll drains both
254        client.poll();
255        assert_eq!(client.response_count(), 2);
256
257        // Each subsystem claims only its own response
258        let resp_a = client.take_response(tid_a).unwrap();
259        assert_eq!(resp_a.data, json!("a_result"));
260
261        let resp_b = client.take_response(tid_b).unwrap();
262        assert_eq!(resp_b.data, json!("b_result"));
263
264        // Neither response is available again
265        assert!(client.take_response(tid_a).is_none());
266        assert!(client.take_response(tid_b).is_none());
267        assert_eq!(client.response_count(), 0);
268    }
269
270    #[test]
271    fn test_drain_stale() {
272        let (write_tx, _write_rx) = mpsc::unbounded_channel();
273        let (_response_tx, response_rx) = mpsc::unbounded_channel();
274        let mut client = CommandClient::new(write_tx, response_rx);
275
276        let tid = client.send("test.command", json!(null));
277        assert_eq!(client.pending_count(), 1);
278
279        // With a zero timeout, the request should be immediately stale
280        let stale = client.drain_stale(Duration::from_secs(0));
281        assert_eq!(stale, vec![tid]);
282        assert_eq!(client.pending_count(), 0);
283    }
284
285    #[test]
286    fn test_drain_stale_keeps_fresh() {
287        let (write_tx, _write_rx) = mpsc::unbounded_channel();
288        let (_response_tx, response_rx) = mpsc::unbounded_channel();
289        let mut client = CommandClient::new(write_tx, response_rx);
290
291        let tid = client.send("test.command", json!(null));
292
293        // With a long timeout, nothing should be stale
294        let stale = client.drain_stale(Duration::from_secs(3600));
295        assert!(stale.is_empty());
296        assert!(client.is_pending(tid));
297    }
298
299    #[test]
300    fn test_drain_stale_ignores_received() {
301        let (write_tx, _write_rx) = mpsc::unbounded_channel();
302        let (response_tx, response_rx) = mpsc::unbounded_channel();
303        let mut client = CommandClient::new(write_tx, response_rx);
304
305        let tid = client.send("test.command", json!(null));
306
307        // Response arrives before we drain
308        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
309        client.poll();
310
311        // drain_stale should not report it since it's no longer pending
312        let stale = client.drain_stale(Duration::from_secs(0));
313        assert!(stale.is_empty());
314
315        // But it's still in the response buffer
316        assert!(client.take_response(tid).is_some());
317    }
318
319    #[test]
320    fn test_multiple_pending() {
321        let (write_tx, _write_rx) = mpsc::unbounded_channel();
322        let (response_tx, response_rx) = mpsc::unbounded_channel();
323        let mut client = CommandClient::new(write_tx, response_rx);
324
325        let tid1 = client.send("cmd.first", json!(1));
326        let tid2 = client.send("cmd.second", json!(2));
327        let tid3 = client.send("cmd.third", json!(3));
328        assert_eq!(client.pending_count(), 3);
329
330        // Respond to the second one
331        response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
332        client.poll();
333
334        assert_eq!(client.pending_count(), 2);
335        assert!(client.is_pending(tid1));
336        assert!(!client.is_pending(tid2));
337        assert!(client.is_pending(tid3));
338
339        let recv = client.take_response(tid2).unwrap();
340        assert_eq!(recv.transaction_id, tid2);
341    }
342
343    #[test]
344    fn test_unsolicited_responses_discarded() {
345        let (write_tx, _write_rx) = mpsc::unbounded_channel();
346        let (response_tx, response_rx) = mpsc::unbounded_channel();
347        let mut client = CommandClient::new(write_tx, response_rx);
348
349        // Simulate responses arriving for transaction IDs that were never
350        // registered via send() (e.g. gm.write sent directly via ws_write_tx).
351        response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
352        response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
353
354        client.poll();
355
356        // Unsolicited responses must NOT accumulate in the responses HashMap
357        assert_eq!(client.response_count(), 0);
358        assert!(client.take_response(99999).is_none());
359        assert!(client.take_response(99998).is_none());
360    }
361
362    #[test]
363    fn test_mix_of_solicited_and_unsolicited() {
364        let (write_tx, _write_rx) = mpsc::unbounded_channel();
365        let (response_tx, response_rx) = mpsc::unbounded_channel();
366        let mut client = CommandClient::new(write_tx, response_rx);
367
368        // One real request
369        let tid = client.send("real.command", json!(null));
370
371        // One unsolicited + one solicited response
372        response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
373        response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
374
375        client.poll();
376
377        // Only the solicited response should be buffered
378        assert_eq!(client.response_count(), 1);
379        let resp = client.take_response(tid).unwrap();
380        assert_eq!(resp.data, json!("real_result"));
381        assert!(client.take_response(77777).is_none());
382    }
383}