Skip to main content

autocore_std/
command_client.rs

1//! Client for sending IPC commands to external modules via WebSocket.
2//!
3//! `CommandClient` allows control programs to send [`CommandMessage`] requests
4//! (e.g., `labelit.translate_check`) to any external module through the existing
5//! WebSocket connection and poll for responses non-blockingly from `process_tick`.
6//!
7//! # Multi-consumer pattern
8//!
9//! The framework passes a `&mut CommandClient` to the control program each cycle
10//! via [`TickContext`](crate::TickContext). The framework calls [`poll()`](CommandClient::poll)
11//! before `process_tick`, so incoming responses are already buffered. Each
12//! subsystem sends requests via [`send()`](CommandClient::send) and retrieves
13//! its own responses by transaction ID via
14//! [`take_response()`](CommandClient::take_response).
15//!
16//! ```ignore
17//! fn process_tick(&mut self, ctx: &mut TickContext<Self::Memory>) {
18//!     // poll() is already called by the framework before process_tick
19//!
20//!     // Each subsystem checks for its own responses by transaction_id
21//!     self.labelit.tick(ctx.client);
22//!     self.other_camera.tick(ctx.client);
23//!
24//!     // Clean up stale requests
25//!     ctx.client.drain_stale(Duration::from_secs(10));
26//! }
27//! ```
28
29use std::collections::HashMap;
30use std::time::{Duration, Instant};
31
32use mechutil::ipc::CommandMessage;
33use serde_json::Value;
34use tokio::sync::mpsc;
35
36struct PendingRequest {
37    topic: String,
38    sent_at: Instant,
39}
40
41/// A non-blocking client for sending IPC commands to external modules.
42///
43/// `CommandClient` is constructed by [`ControlRunner`](crate::ControlRunner) during
44/// startup and passed to the control program each cycle via
45/// [`TickContext::client`](crate::TickContext::client).
46///
47/// All methods are non-blocking and safe to call from `process_tick`.
48///
49/// Multiple subsystems (state machines, modules) can share a single `CommandClient`.
50/// Each subsystem calls [`send()`](Self::send) to issue requests and
51/// [`take_response()`](Self::take_response) to claim its own responses by
52/// `transaction_id`. The framework calls [`poll()`](Self::poll) once per tick
53/// before `process_tick`, so incoming messages are already buffered.
54pub struct CommandClient {
55    /// Channel to send serialized messages for the WS write task.
56    write_tx: mpsc::UnboundedSender<String>,
57    /// Channel to receive response CommandMessages from the WS read task.
58    response_rx: mpsc::UnboundedReceiver<CommandMessage>,
59    /// Track pending requests by transaction_id for timeout/diagnostics.
60    pending: HashMap<u32, PendingRequest>,
61    /// Buffered responses keyed by transaction_id, ready for consumers to claim.
62    responses: HashMap<u32, CommandMessage>,
63}
64
65impl CommandClient {
66    /// Create a new `CommandClient` from channels created by `ControlRunner`.
67    pub fn new(
68        write_tx: mpsc::UnboundedSender<String>,
69        response_rx: mpsc::UnboundedReceiver<CommandMessage>,
70    ) -> Self {
71        Self {
72            write_tx,
73            response_rx,
74            pending: HashMap::new(),
75            responses: HashMap::new(),
76        }
77    }
78
79    /// Send a command request to an external module.
80    ///
81    /// Creates a [`CommandMessage::request`] with the given topic and data,
82    /// serializes it, and pushes it into the WebSocket write channel.
83    ///
84    /// Returns the `transaction_id` which can be used to match the response.
85    ///
86    /// # Arguments
87    ///
88    /// * `topic` - Fully-qualified topic name (e.g., `"labelit.translate_check"`)
89    /// * `data` - JSON payload for the request
90    pub fn send(&mut self, topic: &str, data: Value) -> u32 {
91        let msg = CommandMessage::request(topic, data);
92        self.send_message(msg)
93    }
94
95    /// Send an arbitrary `CommandMessage`.
96    /// 
97    /// Useful for sending specific `MessageType` requests like `Read` or `Write`.
98    /// Returns the `transaction_id` which can be used to match the response.
99    pub fn send_message(&mut self, msg: CommandMessage) -> u32 {
100        let transaction_id = msg.transaction_id;
101
102        if let Ok(json) = serde_json::to_string(&msg) {
103            let _ = self.write_tx.send(json);
104        }
105
106        self.pending.insert(transaction_id, PendingRequest {
107            topic: msg.topic.clone(),
108            sent_at: Instant::now(),
109        });
110
111        transaction_id
112    }
113
114    /// Drain all available responses from the WebSocket channel into the
115    /// internal buffer.
116    ///
117    /// Call this **once per tick** at the top of `process_tick`, before any
118    /// subsystem calls [`take_response()`](Self::take_response). This ensures
119    /// every subsystem sees responses that arrived since the last cycle.
120    pub fn poll(&mut self) {
121        while let Ok(msg) = self.response_rx.try_recv() {
122            let tid = msg.transaction_id;
123            if self.pending.remove(&tid).is_some() {
124                //log::info!("poll: matched response tid={} topic='{}'", tid, msg.topic);
125                self.responses.insert(tid, msg);
126            }
127            // Unsolicited responses (gm.write, ethercat.write) are silently discarded
128        }
129    }
130
131    /// Take a response for a specific `transaction_id` from the buffer.
132    ///
133    /// Returns `Some(response)` if a response with that ID has been received,
134    /// or `None` if it hasn't arrived yet. The response is removed from the
135    /// buffer on retrieval.
136    ///
137    /// This is the recommended way for subsystems to retrieve their responses,
138    /// since each subsystem only claims its own `transaction_id` and cannot
139    /// accidentally consume another subsystem's response.
140    ///
141    /// # Example
142    ///
143    /// ```ignore
144    /// // In a subsystem's tick method:
145    /// if let Some(response) = client.take_response(self.my_tid) {
146    ///     if response.success {
147    ///         // handle success
148    ///     } else {
149    ///         // handle error
150    ///     }
151    /// }
152    /// ```
153    pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
154        self.responses.remove(&transaction_id)
155    }
156
157    /// Check if a request is still awaiting a response.
158    ///
159    /// Returns `true` if the request has been sent but no response has arrived
160    /// yet. Returns `false` if the response is already buffered (even if not
161    /// yet claimed via [`take_response()`](Self::take_response)) or if the
162    /// transaction ID is unknown.
163    pub fn is_pending(&self, transaction_id: u32) -> bool {
164        self.pending.contains_key(&transaction_id)
165    }
166
167    /// Number of outstanding requests (sent but no response received yet).
168    pub fn pending_count(&self) -> usize {
169        self.pending.len()
170    }
171
172    /// Number of responses buffered and ready to be claimed.
173    pub fn response_count(&self) -> usize {
174        self.responses.len()
175    }
176
177    /// Remove and return transaction IDs that have been pending longer than `timeout`.
178    pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
179        let now = Instant::now();
180        let stale: Vec<u32> = self.pending.iter()
181            .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
182            .map(|(&tid, _)| tid)
183            .collect();
184
185        for tid in &stale {
186            if let Some(req) = self.pending.remove(tid) {
187                log::warn!("Command request {} ('{}') timed out after {:?}",
188                    tid, req.topic, timeout);
189            }
190        }
191
192        stale
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use mechutil::ipc::MessageType;
200    use serde_json::json;
201
202    #[test]
203    fn test_send_pushes_to_channel() {
204        let (write_tx, mut write_rx) = mpsc::unbounded_channel();
205        let (_response_tx, response_rx) = mpsc::unbounded_channel();
206        let mut client = CommandClient::new(write_tx, response_rx);
207
208        let tid = client.send("test.command", json!({"key": "value"}));
209
210        // Should have pushed a message to the write channel
211        let msg_json = write_rx.try_recv().expect("should have a message");
212        let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
213        assert_eq!(msg.transaction_id, tid);
214        assert_eq!(msg.topic, "test.command");
215        assert_eq!(msg.message_type, MessageType::Request);
216        assert_eq!(msg.data, json!({"key": "value"}));
217
218        // Should be tracked as pending
219        assert!(client.is_pending(tid));
220        assert_eq!(client.pending_count(), 1);
221    }
222
223    #[test]
224    fn test_poll_and_take_response() {
225        let (write_tx, _write_rx) = mpsc::unbounded_channel();
226        let (response_tx, response_rx) = mpsc::unbounded_channel();
227        let mut client = CommandClient::new(write_tx, response_rx);
228
229        let tid = client.send("test.command", json!(null));
230        assert!(client.is_pending(tid));
231
232        // Simulate response arriving
233        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
234
235        // Before poll, take_response finds nothing
236        assert!(client.take_response(tid).is_none());
237
238        // After poll, take_response returns the response
239        client.poll();
240        assert!(!client.is_pending(tid));
241        assert_eq!(client.response_count(), 1);
242
243        let recv = client.take_response(tid).unwrap();
244        assert_eq!(recv.transaction_id, tid);
245        assert_eq!(client.response_count(), 0);
246    }
247
248    #[test]
249    fn test_multi_consumer_isolation() {
250        let (write_tx, _write_rx) = mpsc::unbounded_channel();
251        let (response_tx, response_rx) = mpsc::unbounded_channel();
252        let mut client = CommandClient::new(write_tx, response_rx);
253
254        // Two subsystems send requests
255        let tid_a = client.send("labelit.inspect", json!(null));
256        let tid_b = client.send("other.status", json!(null));
257
258        // Both responses arrive
259        response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
260        response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();
261
262        // Single poll drains both
263        client.poll();
264        assert_eq!(client.response_count(), 2);
265
266        // Each subsystem claims only its own response
267        let resp_a = client.take_response(tid_a).unwrap();
268        assert_eq!(resp_a.data, json!("a_result"));
269
270        let resp_b = client.take_response(tid_b).unwrap();
271        assert_eq!(resp_b.data, json!("b_result"));
272
273        // Neither response is available again
274        assert!(client.take_response(tid_a).is_none());
275        assert!(client.take_response(tid_b).is_none());
276        assert_eq!(client.response_count(), 0);
277    }
278
279    #[test]
280    fn test_drain_stale() {
281        let (write_tx, _write_rx) = mpsc::unbounded_channel();
282        let (_response_tx, response_rx) = mpsc::unbounded_channel();
283        let mut client = CommandClient::new(write_tx, response_rx);
284
285        let tid = client.send("test.command", json!(null));
286        assert_eq!(client.pending_count(), 1);
287
288        // With a zero timeout, the request should be immediately stale
289        let stale = client.drain_stale(Duration::from_secs(0));
290        assert_eq!(stale, vec![tid]);
291        assert_eq!(client.pending_count(), 0);
292    }
293
294    #[test]
295    fn test_drain_stale_keeps_fresh() {
296        let (write_tx, _write_rx) = mpsc::unbounded_channel();
297        let (_response_tx, response_rx) = mpsc::unbounded_channel();
298        let mut client = CommandClient::new(write_tx, response_rx);
299
300        let tid = client.send("test.command", json!(null));
301
302        // With a long timeout, nothing should be stale
303        let stale = client.drain_stale(Duration::from_secs(3600));
304        assert!(stale.is_empty());
305        assert!(client.is_pending(tid));
306    }
307
308    #[test]
309    fn test_drain_stale_ignores_received() {
310        let (write_tx, _write_rx) = mpsc::unbounded_channel();
311        let (response_tx, response_rx) = mpsc::unbounded_channel();
312        let mut client = CommandClient::new(write_tx, response_rx);
313
314        let tid = client.send("test.command", json!(null));
315
316        // Response arrives before we drain
317        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
318        client.poll();
319
320        // drain_stale should not report it since it's no longer pending
321        let stale = client.drain_stale(Duration::from_secs(0));
322        assert!(stale.is_empty());
323
324        // But it's still in the response buffer
325        assert!(client.take_response(tid).is_some());
326    }
327
328    #[test]
329    fn test_multiple_pending() {
330        let (write_tx, _write_rx) = mpsc::unbounded_channel();
331        let (response_tx, response_rx) = mpsc::unbounded_channel();
332        let mut client = CommandClient::new(write_tx, response_rx);
333
334        let tid1 = client.send("cmd.first", json!(1));
335        let tid2 = client.send("cmd.second", json!(2));
336        let tid3 = client.send("cmd.third", json!(3));
337        assert_eq!(client.pending_count(), 3);
338
339        // Respond to the second one
340        response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
341        client.poll();
342
343        assert_eq!(client.pending_count(), 2);
344        assert!(client.is_pending(tid1));
345        assert!(!client.is_pending(tid2));
346        assert!(client.is_pending(tid3));
347
348        let recv = client.take_response(tid2).unwrap();
349        assert_eq!(recv.transaction_id, tid2);
350    }
351
352    #[test]
353    fn test_unsolicited_responses_discarded() {
354        let (write_tx, _write_rx) = mpsc::unbounded_channel();
355        let (response_tx, response_rx) = mpsc::unbounded_channel();
356        let mut client = CommandClient::new(write_tx, response_rx);
357
358        // Simulate responses arriving for transaction IDs that were never
359        // registered via send() (e.g. gm.write sent directly via ws_write_tx).
360        response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
361        response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();
362
363        client.poll();
364
365        // Unsolicited responses must NOT accumulate in the responses HashMap
366        assert_eq!(client.response_count(), 0);
367        assert!(client.take_response(99999).is_none());
368        assert!(client.take_response(99998).is_none());
369    }
370
371    #[test]
372    fn test_mix_of_solicited_and_unsolicited() {
373        let (write_tx, _write_rx) = mpsc::unbounded_channel();
374        let (response_tx, response_rx) = mpsc::unbounded_channel();
375        let mut client = CommandClient::new(write_tx, response_rx);
376
377        // One real request
378        let tid = client.send("real.command", json!(null));
379
380        // One unsolicited + one solicited response
381        response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
382        response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();
383
384        client.poll();
385
386        // Only the solicited response should be buffered
387        assert_eq!(client.response_count(), 1);
388        let resp = client.take_response(tid).unwrap();
389        assert_eq!(resp.data, json!("real_result"));
390        assert!(client.take_response(77777).is_none());
391    }
392}