Skip to main content

autocore_std/
command_client.rs

1//! Client for sending IPC commands to external modules via WebSocket.
2//!
3//! `CommandClient` allows control programs to send [`CommandMessage`] requests
4//! (e.g., `labelit.translate_check`) to any external module through the existing
5//! WebSocket connection and poll for responses non-blockingly from `process_tick`.
6//!
7//! # Multi-consumer pattern
8//!
9//! The framework passes a `&mut CommandClient` to the control program each cycle
10//! via [`TickContext`](crate::TickContext). The framework calls [`poll()`](CommandClient::poll)
11//! before `process_tick`, so incoming responses are already buffered. Each
12//! subsystem sends requests via [`send()`](CommandClient::send) and retrieves
13//! its own responses by transaction ID via
14//! [`take_response()`](CommandClient::take_response).
15//!
16//! ```ignore
17//! fn process_tick(&mut self, ctx: &mut TickContext<Self::Memory>) {
18//!     // poll() is already called by the framework before process_tick
19//!
20//!     // Each subsystem checks for its own responses by transaction_id
21//!     self.labelit.tick(ctx.client);
22//!     self.other_camera.tick(ctx.client);
23//!
24//!     // Clean up stale requests
25//!     ctx.client.drain_stale(Duration::from_secs(10));
26//! }
27//! ```
28
29use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37    topic: String,
38    sent_at: Instant,
39}
40
41/// A non-blocking client for sending IPC commands to external modules.
42///
43/// `CommandClient` is constructed by [`ControlRunner`](crate::ControlRunner) during
44/// startup and passed to the control program each cycle via
45/// [`TickContext::client`](crate::TickContext::client).
46///
47/// All methods are non-blocking and safe to call from `process_tick`.
48///
49/// Multiple subsystems (state machines, modules) can share a single `CommandClient`.
50/// Each subsystem calls [`send()`](Self::send) to issue requests and
51/// [`take_response()`](Self::take_response) to claim its own responses by
52/// `transaction_id`. The framework calls [`poll()`](Self::poll) once per tick
53/// before `process_tick`, so incoming messages are already buffered.
54pub struct CommandClient {
55    /// Channel to send serialized messages for the WS write task.
56    write_tx: mpsc::UnboundedSender<String>,
57    /// Channel to receive response CommandMessages from the WS read task.
58    response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59    /// Track pending requests by transaction_id for timeout/diagnostics.
60    pending: HashMap<u32, PendingRequest>,
61    /// Buffered responses keyed by transaction_id, ready for consumers to claim.
62    responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66    /// Create a new `CommandClient` from channels created by `ControlRunner`.
67    pub fn new(
68        write_tx: mpsc::UnboundedSender<String>,
69        response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70    ) -> Self {
71        Self {
72            write_tx,
73            response_rx,
74            pending: HashMap::new(),
75            responses: HashMap::new(),
76        }
77    }
78
79    /// Send a command request to an external module.
80    ///
81    /// Creates a [`CommandMessage::request`] with the given topic and data,
82    /// serializes it, and pushes it into the WebSocket write channel.
83    ///
84    /// Returns the `transaction_id` which can be used to match the response.
85    ///
86    /// # Arguments
87    ///
88    /// * `topic` - Fully-qualified topic name (e.g., `"labelit.translate_check"`)
89    /// * `data` - JSON payload for the request
90    pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91        let msg = CommandMessage::request(topic, data);
92        let transaction_id = msg.transaction_id;
93
94        if let Ok(json) = serde_json::to_string(&msg) {
95            let _ = self.write_tx.send(json);
96        }
97
98        self.pending.insert(transaction_id, PendingRequest {
99            topic: topic.to_string(),
100            sent_at: Instant::now(),
101        });
102
103        transaction_id
104    }
105
106    /// Drain all available responses from the WebSocket channel into the
107    /// internal buffer.
108    ///
109    /// Call this **once per tick** at the top of `process_tick`, before any
110    /// subsystem calls [`take_response()`](Self::take_response). This ensures
111    /// every subsystem sees responses that arrived since the last cycle.
112    pub fn poll(&mut self) {
113        while let Ok(msg) = self.response_rx.try_recv() {
114            self.pending.remove(&msg.transaction_id);
115            self.responses.insert(msg.transaction_id, msg);
116        }
117    }
118
119    /// Take a response for a specific `transaction_id` from the buffer.
120    ///
121    /// Returns `Some(response)` if a response with that ID has been received,
122    /// or `None` if it hasn't arrived yet. The response is removed from the
123    /// buffer on retrieval.
124    ///
125    /// This is the recommended way for subsystems to retrieve their responses,
126    /// since each subsystem only claims its own `transaction_id` and cannot
127    /// accidentally consume another subsystem's response.
128    ///
129    /// # Example
130    ///
131    /// ```ignore
132    /// // In a subsystem's tick method:
133    /// if let Some(response) = client.take_response(self.my_tid) {
134    ///     if response.success {
135    ///         // handle success
136    ///     } else {
137    ///         // handle error
138    ///     }
139    /// }
140    /// ```
141    pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
142        self.responses.remove(&transaction_id)
143    }
144
145    /// Check if a request is still awaiting a response.
146    ///
147    /// Returns `true` if the request has been sent but no response has arrived
148    /// yet. Returns `false` if the response is already buffered (even if not
149    /// yet claimed via [`take_response()`](Self::take_response)) or if the
150    /// transaction ID is unknown.
151    pub fn is_pending(&self, transaction_id: u32) -> bool {
152        self.pending.contains_key(&transaction_id)
153    }
154
155    /// Number of outstanding requests (sent but no response received yet).
156    pub fn pending_count(&self) -> usize {
157        self.pending.len()
158    }
159
160    /// Number of responses buffered and ready to be claimed.
161    pub fn response_count(&self) -> usize {
162        self.responses.len()
163    }
164
165    /// Remove and return transaction IDs that have been pending longer than `timeout`.
166    pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
167        let now = Instant::now();
168        let stale: Vec<u32> = self.pending.iter()
169            .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
170            .map(|(&tid, _)| tid)
171            .collect();
172
173        for tid in &stale {
174            if let Some(req) = self.pending.remove(tid) {
175                log::warn!("Command request {} ('{}') timed out after {:?}",
176                    tid, req.topic, timeout);
177            }
178        }
179
180        stale
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use mechutil::ipc::MessageType;
188    use serde_json::json;
189
190    #[test]
191    fn test_send_pushes_to_channel() {
192        let (write_tx, mut write_rx) = mpsc::unbounded_channel();
193        let (_response_tx, response_rx) = mpsc::unbounded_channel();
194        let mut client = CommandClient::new(write_tx, response_rx);
195
196        let tid = client.send("test.command", json!({"key": "value"}));
197
198        // Should have pushed a message to the write channel
199        let msg_json = write_rx.try_recv().expect("should have a message");
200        let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
201        assert_eq!(msg.transaction_id, tid);
202        assert_eq!(msg.topic, "test.command");
203        assert_eq!(msg.message_type, MessageType::Request);
204        assert_eq!(msg.data, json!({"key": "value"}));
205
206        // Should be tracked as pending
207        assert!(client.is_pending(tid));
208        assert_eq!(client.pending_count(), 1);
209    }
210
211    #[test]
212    fn test_poll_and_take_response() {
213        let (write_tx, _write_rx) = mpsc::unbounded_channel();
214        let (response_tx, response_rx) = mpsc::unbounded_channel();
215        let mut client = CommandClient::new(write_tx, response_rx);
216
217        let tid = client.send("test.command", json!(null));
218        assert!(client.is_pending(tid));
219
220        // Simulate response arriving
221        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
222
223        // Before poll, take_response finds nothing
224        assert!(client.take_response(tid).is_none());
225
226        // After poll, take_response returns the response
227        client.poll();
228        assert!(!client.is_pending(tid));
229        assert_eq!(client.response_count(), 1);
230
231        let recv = client.take_response(tid).unwrap();
232        assert_eq!(recv.transaction_id, tid);
233        assert_eq!(client.response_count(), 0);
234    }
235
236    #[test]
237    fn test_multi_consumer_isolation() {
238        let (write_tx, _write_rx) = mpsc::unbounded_channel();
239        let (response_tx, response_rx) = mpsc::unbounded_channel();
240        let mut client = CommandClient::new(write_tx, response_rx);
241
242        // Two subsystems send requests
243        let tid_a = client.send("labelit.inspect", json!(null));
244        let tid_b = client.send("other.status", json!(null));
245
246        // Both responses arrive
247        response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
248        response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
249
250        // Single poll drains both
251        client.poll();
252        assert_eq!(client.response_count(), 2);
253
254        // Each subsystem claims only its own response
255        let resp_a = client.take_response(tid_a).unwrap();
256        assert_eq!(resp_a.data, json!("a_result"));
257
258        let resp_b = client.take_response(tid_b).unwrap();
259        assert_eq!(resp_b.data, json!("b_result"));
260
261        // Neither response is available again
262        assert!(client.take_response(tid_a).is_none());
263        assert!(client.take_response(tid_b).is_none());
264        assert_eq!(client.response_count(), 0);
265    }
266
267    #[test]
268    fn test_drain_stale() {
269        let (write_tx, _write_rx) = mpsc::unbounded_channel();
270        let (_response_tx, response_rx) = mpsc::unbounded_channel();
271        let mut client = CommandClient::new(write_tx, response_rx);
272
273        let tid = client.send("test.command", json!(null));
274        assert_eq!(client.pending_count(), 1);
275
276        // With a zero timeout, the request should be immediately stale
277        let stale = client.drain_stale(Duration::from_secs(0));
278        assert_eq!(stale, vec![tid]);
279        assert_eq!(client.pending_count(), 0);
280    }
281
282    #[test]
283    fn test_drain_stale_keeps_fresh() {
284        let (write_tx, _write_rx) = mpsc::unbounded_channel();
285        let (_response_tx, response_rx) = mpsc::unbounded_channel();
286        let mut client = CommandClient::new(write_tx, response_rx);
287
288        let tid = client.send("test.command", json!(null));
289
290        // With a long timeout, nothing should be stale
291        let stale = client.drain_stale(Duration::from_secs(3600));
292        assert!(stale.is_empty());
293        assert!(client.is_pending(tid));
294    }
295
296    #[test]
297    fn test_drain_stale_ignores_received() {
298        let (write_tx, _write_rx) = mpsc::unbounded_channel();
299        let (response_tx, response_rx) = mpsc::unbounded_channel();
300        let mut client = CommandClient::new(write_tx, response_rx);
301
302        let tid = client.send("test.command", json!(null));
303
304        // Response arrives before we drain
305        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
306        client.poll();
307
308        // drain_stale should not report it since it's no longer pending
309        let stale = client.drain_stale(Duration::from_secs(0));
310        assert!(stale.is_empty());
311
312        // But it's still in the response buffer
313        assert!(client.take_response(tid).is_some());
314    }
315
316    #[test]
317    fn test_multiple_pending() {
318        let (write_tx, _write_rx) = mpsc::unbounded_channel();
319        let (response_tx, response_rx) = mpsc::unbounded_channel();
320        let mut client = CommandClient::new(write_tx, response_rx);
321
322        let tid1 = client.send("cmd.first", json!(1));
323        let tid2 = client.send("cmd.second", json!(2));
324        let tid3 = client.send("cmd.third", json!(3));
325        assert_eq!(client.pending_count(), 3);
326
327        // Respond to the second one
328        response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
329        client.poll();
330
331        assert_eq!(client.pending_count(), 2);
332        assert!(client.is_pending(tid1));
333        assert!(!client.is_pending(tid2));
334        assert!(client.is_pending(tid3));
335
336        let recv = client.take_response(tid2).unwrap();
337        assert_eq!(recv.transaction_id, tid2);
338    }
339}