autocore-std 3.3.20

Standard library for AutoCore control programs - shared memory, IPC, and logging utilities
Documentation
//! Client for sending IPC commands to external modules via WebSocket.
//!
//! `CommandClient` allows control programs to send [`CommandMessage`] requests
//! (e.g., `labelit.translate_check`) to any external module through the existing
//! WebSocket connection and poll for responses non-blockingly from `process_tick`.
//!
//! # Multi-consumer pattern
//!
//! The framework passes a `&mut CommandClient` to the control program each cycle
//! via [`TickContext`](crate::TickContext). The framework calls [`poll()`](CommandClient::poll)
//! before `process_tick`, so incoming responses are already buffered. Each
//! subsystem sends requests via [`send()`](CommandClient::send) and retrieves
//! its own responses by transaction ID via
//! [`take_response()`](CommandClient::take_response).
//!
//! ```ignore
//! fn process_tick(&mut self, ctx: &mut TickContext<Self::Memory>) {
//!     // poll() is already called by the framework before process_tick
//!
//!     // Each subsystem checks for its own responses by transaction_id
//!     self.labelit.tick(ctx.client);
//!     self.other_camera.tick(ctx.client);
//!
//!     // Clean up stale requests
//!     ctx.client.drain_stale(Duration::from_secs(10));
//! }
//! ```

use std::collections::HashMap;
use std::time::{Duration, Instant};

use mechutil::ipc::CommandMessage;
use serde_json::Value;
use tokio::sync::mpsc;

struct PendingRequest {
    topic: String,
    sent_at: Instant,
}

/// A non-blocking client for sending IPC commands to external modules.
///
/// `CommandClient` is constructed by [`ControlRunner`](crate::ControlRunner) during
/// startup and passed to the control program each cycle via
/// [`TickContext::client`](crate::TickContext::client).
///
/// All methods are non-blocking and safe to call from `process_tick`.
///
/// Multiple subsystems (state machines, modules) can share a single `CommandClient`.
/// Each subsystem calls [`send()`](Self::send) to issue requests and
/// [`take_response()`](Self::take_response) to claim its own responses by
/// `transaction_id`. The framework calls [`poll()`](Self::poll) once per tick
/// before `process_tick`, so incoming messages are already buffered.
pub struct CommandClient {
    /// Channel to send serialized messages for the WS write task.
    write_tx: mpsc::UnboundedSender<String>,
    /// Channel to receive response CommandMessages from the WS read task.
    response_rx: mpsc::UnboundedReceiver<CommandMessage>,
    /// Track pending requests by transaction_id for timeout/diagnostics.
    pending: HashMap<u32, PendingRequest>,
    /// Buffered responses keyed by transaction_id, ready for consumers to claim.
    responses: HashMap<u32, CommandMessage>,
}

impl CommandClient {
    /// Create a new `CommandClient` from channels created by `ControlRunner`.
    pub fn new(
        write_tx: mpsc::UnboundedSender<String>,
        response_rx: mpsc::UnboundedReceiver<CommandMessage>,
    ) -> Self {
        Self {
            write_tx,
            response_rx,
            pending: HashMap::new(),
            responses: HashMap::new(),
        }
    }

    /// Send a command request to an external module.
    ///
    /// Creates a [`CommandMessage::request`] with the given topic and data,
    /// serializes it, and pushes it into the WebSocket write channel.
    ///
    /// Returns the `transaction_id` which can be used to match the response.
    ///
    /// # Arguments
    ///
    /// * `topic` - Fully-qualified topic name (e.g., `"labelit.translate_check"`)
    /// * `data` - JSON payload for the request
    pub fn send(&mut self, topic: &str, data: Value) -> u32 {
        let msg = CommandMessage::request(topic, data);
        let transaction_id = msg.transaction_id;

        if let Ok(json) = serde_json::to_string(&msg) {
            let _ = self.write_tx.send(json);
        }

        self.pending.insert(transaction_id, PendingRequest {
            topic: topic.to_string(),
            sent_at: Instant::now(),
        });

        transaction_id
    }

    /// Drain all available responses from the WebSocket channel into the
    /// internal buffer.
    ///
    /// Call this **once per tick** at the top of `process_tick`, before any
    /// subsystem calls [`take_response()`](Self::take_response). This ensures
    /// every subsystem sees responses that arrived since the last cycle.
    pub fn poll(&mut self) {
        while let Ok(msg) = self.response_rx.try_recv() {
            let tid = msg.transaction_id;
            if self.pending.remove(&tid).is_some() {
                self.responses.insert(tid, msg);
            }
            // else: unsolicited response (e.g. gm.write from control loop), discard
        }
    }

    /// Take a response for a specific `transaction_id` from the buffer.
    ///
    /// Returns `Some(response)` if a response with that ID has been received,
    /// or `None` if it hasn't arrived yet. The response is removed from the
    /// buffer on retrieval.
    ///
    /// This is the recommended way for subsystems to retrieve their responses,
    /// since each subsystem only claims its own `transaction_id` and cannot
    /// accidentally consume another subsystem's response.
    ///
    /// # Example
    ///
    /// ```ignore
    /// // In a subsystem's tick method:
    /// if let Some(response) = client.take_response(self.my_tid) {
    ///     if response.success {
    ///         // handle success
    ///     } else {
    ///         // handle error
    ///     }
    /// }
    /// ```
    pub fn take_response(&mut self, transaction_id: u32) -> Option<CommandMessage> {
        self.responses.remove(&transaction_id)
    }

    /// Check if a request is still awaiting a response.
    ///
    /// Returns `true` if the request has been sent but no response has arrived
    /// yet. Returns `false` if the response is already buffered (even if not
    /// yet claimed via [`take_response()`](Self::take_response)) or if the
    /// transaction ID is unknown.
    pub fn is_pending(&self, transaction_id: u32) -> bool {
        self.pending.contains_key(&transaction_id)
    }

    /// Number of outstanding requests (sent but no response received yet).
    pub fn pending_count(&self) -> usize {
        self.pending.len()
    }

    /// Number of responses buffered and ready to be claimed.
    pub fn response_count(&self) -> usize {
        self.responses.len()
    }

    /// Remove and return transaction IDs that have been pending longer than `timeout`.
    pub fn drain_stale(&mut self, timeout: Duration) -> Vec<u32> {
        let now = Instant::now();
        let stale: Vec<u32> = self.pending.iter()
            .filter(|(_, req)| now.duration_since(req.sent_at) > timeout)
            .map(|(&tid, _)| tid)
            .collect();

        for tid in &stale {
            if let Some(req) = self.pending.remove(tid) {
                log::warn!("Command request {} ('{}') timed out after {:?}",
                    tid, req.topic, timeout);
            }
        }

        stale
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use mechutil::ipc::MessageType;
    use serde_json::json;

    #[test]
    fn test_send_pushes_to_channel() {
        let (write_tx, mut write_rx) = mpsc::unbounded_channel();
        let (_response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid = client.send("test.command", json!({"key": "value"}));

        // Should have pushed a message to the write channel
        let msg_json = write_rx.try_recv().expect("should have a message");
        let msg: CommandMessage = serde_json::from_str(&msg_json).unwrap();
        assert_eq!(msg.transaction_id, tid);
        assert_eq!(msg.topic, "test.command");
        assert_eq!(msg.message_type, MessageType::Request);
        assert_eq!(msg.data, json!({"key": "value"}));

        // Should be tracked as pending
        assert!(client.is_pending(tid));
        assert_eq!(client.pending_count(), 1);
    }

    #[test]
    fn test_poll_and_take_response() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid = client.send("test.command", json!(null));
        assert!(client.is_pending(tid));

        // Simulate response arriving
        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();

        // Before poll, take_response finds nothing
        assert!(client.take_response(tid).is_none());

        // After poll, take_response returns the response
        client.poll();
        assert!(!client.is_pending(tid));
        assert_eq!(client.response_count(), 1);

        let recv = client.take_response(tid).unwrap();
        assert_eq!(recv.transaction_id, tid);
        assert_eq!(client.response_count(), 0);
    }

    #[test]
    fn test_multi_consumer_isolation() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        // Two subsystems send requests
        let tid_a = client.send("labelit.inspect", json!(null));
        let tid_b = client.send("other.status", json!(null));

        // Both responses arrive
        response_tx.send(CommandMessage::response(tid_b, json!("b_result"))).unwrap();
        response_tx.send(CommandMessage::response(tid_a, json!("a_result"))).unwrap();

        // Single poll drains both
        client.poll();
        assert_eq!(client.response_count(), 2);

        // Each subsystem claims only its own response
        let resp_a = client.take_response(tid_a).unwrap();
        assert_eq!(resp_a.data, json!("a_result"));

        let resp_b = client.take_response(tid_b).unwrap();
        assert_eq!(resp_b.data, json!("b_result"));

        // Neither response is available again
        assert!(client.take_response(tid_a).is_none());
        assert!(client.take_response(tid_b).is_none());
        assert_eq!(client.response_count(), 0);
    }

    #[test]
    fn test_drain_stale() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (_response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid = client.send("test.command", json!(null));
        assert_eq!(client.pending_count(), 1);

        // With a zero timeout, the request should be immediately stale
        let stale = client.drain_stale(Duration::from_secs(0));
        assert_eq!(stale, vec![tid]);
        assert_eq!(client.pending_count(), 0);
    }

    #[test]
    fn test_drain_stale_keeps_fresh() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (_response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid = client.send("test.command", json!(null));

        // With a long timeout, nothing should be stale
        let stale = client.drain_stale(Duration::from_secs(3600));
        assert!(stale.is_empty());
        assert!(client.is_pending(tid));
    }

    #[test]
    fn test_drain_stale_ignores_received() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid = client.send("test.command", json!(null));

        // Response arrives before we drain
        response_tx.send(CommandMessage::response(tid, json!("ok"))).unwrap();
        client.poll();

        // drain_stale should not report it since it's no longer pending
        let stale = client.drain_stale(Duration::from_secs(0));
        assert!(stale.is_empty());

        // But it's still in the response buffer
        assert!(client.take_response(tid).is_some());
    }

    #[test]
    fn test_multiple_pending() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        let tid1 = client.send("cmd.first", json!(1));
        let tid2 = client.send("cmd.second", json!(2));
        let tid3 = client.send("cmd.third", json!(3));
        assert_eq!(client.pending_count(), 3);

        // Respond to the second one
        response_tx.send(CommandMessage::response(tid2, json!("ok"))).unwrap();
        client.poll();

        assert_eq!(client.pending_count(), 2);
        assert!(client.is_pending(tid1));
        assert!(!client.is_pending(tid2));
        assert!(client.is_pending(tid3));

        let recv = client.take_response(tid2).unwrap();
        assert_eq!(recv.transaction_id, tid2);
    }

    #[test]
    fn test_unsolicited_responses_discarded() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        // Simulate responses arriving for transaction IDs that were never
        // registered via send() (e.g. gm.write sent directly via ws_write_tx).
        response_tx.send(CommandMessage::response(99999, json!("stale1"))).unwrap();
        response_tx.send(CommandMessage::response(99998, json!("stale2"))).unwrap();

        client.poll();

        // Unsolicited responses must NOT accumulate in the responses HashMap
        assert_eq!(client.response_count(), 0);
        assert!(client.take_response(99999).is_none());
        assert!(client.take_response(99998).is_none());
    }

    #[test]
    fn test_mix_of_solicited_and_unsolicited() {
        let (write_tx, _write_rx) = mpsc::unbounded_channel();
        let (response_tx, response_rx) = mpsc::unbounded_channel();
        let mut client = CommandClient::new(write_tx, response_rx);

        // One real request
        let tid = client.send("real.command", json!(null));

        // One unsolicited + one solicited response
        response_tx.send(CommandMessage::response(77777, json!("unsolicited"))).unwrap();
        response_tx.send(CommandMessage::response(tid, json!("real_result"))).unwrap();

        client.poll();

        // Only the solicited response should be buffered
        assert_eq!(client.response_count(), 1);
        let resp = client.take_response(tid).unwrap();
        assert_eq!(resp.data, json!("real_result"));
        assert!(client.take_response(77777).is_none());
    }
}