nodo 0.18.5

A realtime framework for robotics
Documentation
// Copyright 2025 David Weikersdorfer

use crate::{
    channels::{BackStage, FlushResult, OverflowPolicy, SharedBackStage, Tx, TxConnectable},
    prelude::RetentionPolicy,
};
use std::fmt;

/// The maximum number of receivers which can be connected to a single transmitter. This is a
/// technical limitation as some error codes use 64-bit bitmasks.
pub const MAX_RECEIVER_COUNT: usize = 64;

/// The transmitting side of a double-buffered SP-MC channel
///
/// Messages in the outbox are sent to all connected receivers. Each receiver gets its own copy.
/// If there is more than one receiver `clone` is used to duplicate the message. Messages with
/// large data blocks should use memory sharing like `Rc` to avoid costly memory copies.
pub struct DoubleBufferTx<T> {
    outbox: BackStage<T>,
    connections: Vec<SharedBackStage<T>>,
}

impl<T> DoubleBufferTx<T> {
    /// Creates a new TX channel with fixed capacity
    /// TODO rename to `new_fixed`
    pub fn new(capacity: usize) -> Self {
        Self {
            outbox: BackStage::new(OverflowPolicy::Reject(capacity), RetentionPolicy::Drop),
            connections: Vec::new(),
        }
    }

    /// Creates a TX channel which automatically resizes itself to always succeed in sending
    /// all messages.
    /// WARNING: This might lead to data congestion and infinitely growing queues. Usually it is
    /// better to use a fixed capacity or to forget old messages.
    pub fn new_auto_size() -> Self {
        Self {
            outbox: BackStage::new(OverflowPolicy::Resize, RetentionPolicy::Drop),
            connections: Vec::new(),
        }
    }

    /// Puts a message in the outbox
    pub fn push(&mut self, value: T) -> Result<(), TxSendError> {
        self.outbox.push(value).map_err(|_| TxSendError::QueueFull)
    }

    /// Puts multiple messages in the outbox
    pub fn push_many<I: IntoIterator<Item = T>>(&mut self, values: I) -> Result<(), TxSendError> {
        for x in values.into_iter() {
            self.push(x)?;
        }
        Ok(())
    }
}

impl<V: Send + Sync + Clone> TxConnectable for DoubleBufferTx<V> {
    type Message = V;

    fn has_max_connection_count(&self) -> bool {
        self.connections.len() >= MAX_RECEIVER_COUNT
    }

    fn overflow_policy(&self) -> OverflowPolicy {
        *self.outbox.overflow_policy()
    }

    fn on_connect(&mut self, stage: SharedBackStage<Self::Message>) {
        self.connections.push(stage);
    }
}

#[derive(Debug, thiserror::Error)]
pub enum TxConnectError {
    #[error("RX cannot be connected to more than one transmitter")]
    ReceiverAlreadyConnected,

    #[error("TX exceeded maximum connection count")]
    MaxConnectionCountExceeded,

    #[error(
        "Cannot connect a TX with policy `Resize` to an RX with policy `Reject`.
             Either change the TX policy to `Reject` or the RX policy to `Resize` or `Forget`."
    )]
    PolicyMismatch,
}

impl<T: Send + Sync + Clone> Tx for DoubleBufferTx<T> {
    fn flush(&mut self) -> FlushResult {
        let mut result = FlushResult::default();
        result.available = self.outbox.len();

        // clone messages for connections 2..N
        for (i, rx) in self.connections.iter().enumerate().skip(1) {
            let mut q = rx.write().unwrap();
            for v in self.outbox.iter() {
                if matches!(q.push((*v).clone()), Err(_)) {
                    result.error_indicator.mark(i);
                    break;
                }
                result.cloned += 1;
                result.published += 1;
            }
        }

        // move messages for connection 1
        if let Some(first_rx) = self.connections.get(0) {
            let mut q = first_rx.write().unwrap();
            for v in self.outbox.drain_all() {
                if matches!(q.push(v), Err(_)) {
                    result.error_indicator.mark(0);
                    break;
                }
                result.published += 1;
            }
        } else {
            // still clear outbox if there is no connection
            self.outbox.clear();
        }

        result
    }

    fn is_connected(&self) -> bool {
        !self.connections.is_empty()
    }

    fn len(&self) -> usize {
        self.outbox.len()
    }
}

#[derive(Debug)]
pub enum TxSendError {
    QueueFull,
}

impl fmt::Display for TxSendError {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        match self {
            TxSendError::QueueFull => write!(fmt, "QueueFull"),
        }
    }
}

impl std::error::Error for TxSendError {}

#[cfg(test)]
mod tests {
    use crate::{
        channels::{FlushResult, SyncResult},
        prelude::*,
    };
    use std::sync::mpsc;

    fn fixed_channel<T: Clone + Send + Sync>(
        size: usize,
    ) -> (DoubleBufferTx<T>, DoubleBufferRx<T>) {
        let mut tx = DoubleBufferTx::new(size);
        let mut rx =
            DoubleBufferRx::new(OverflowPolicy::Reject(size), RetentionPolicy::EnforceEmpty);
        connect(&mut tx, &mut rx).unwrap();
        (tx, rx)
    }

    #[test]
    fn test() {
        const NUM_MESSAGES: usize = 100;
        const NUM_ROUNDS: usize = 100;

        let (mut tx, mut rx) = fixed_channel(NUM_MESSAGES);

        // channel used for synchronizing tx and rx threads
        let (sync_tx, sync_rx) = mpsc::sync_channel(1);
        let (rep_tx, rep_rx) = mpsc::sync_channel(1);

        // receiver
        let t1 = std::thread::spawn(move || {
            for k in 0..NUM_ROUNDS {
                // wait for signal to sync
                sync_rx.recv().unwrap();

                assert_eq!(
                    rx.sync(),
                    SyncResult {
                        received: NUM_MESSAGES,
                        ..Default::default()
                    }
                );

                rep_tx.send(()).unwrap();

                // receive messages
                for i in 0..NUM_MESSAGES {
                    assert_eq!(rx.pop().unwrap(), format!("hello {k} {i}"));
                }
            }
        });

        // sender
        let t2 = std::thread::spawn(move || {
            for k in 0..NUM_ROUNDS {
                // send messages
                for i in 0..NUM_MESSAGES {
                    tx.push(format!("hello {k} {i}")).unwrap();
                }
                assert_eq!(
                    tx.flush(),
                    FlushResult {
                        available: NUM_MESSAGES,
                        published: NUM_MESSAGES,
                        ..Default::default()
                    }
                );

                // send sync signal
                sync_tx.send(()).unwrap();
                rep_rx.recv().unwrap();
            }
        });

        t1.join().unwrap();
        t2.join().unwrap();
    }
}