#![doc(html_root_url = "https://docs.rs/crosstalk/1.0")]
#![doc = include_str!("../README.md")]
use std::sync::Arc;
use tokio::sync::Mutex;
use std::collections::HashMap;
use tokio::sync::broadcast::{
Sender as TokioSender,
Receiver as TokioReceiver,
};
pub use crosstalk_macros::init;
pub use crosstalk_macros::AsTopic;
pub mod __macro_exports {
pub use tokio::runtime;
pub use tokio::sync::broadcast;
#[inline(always)]
pub fn downcast<T>(buf: Box<dyn std::any::Any + 'static>, on_error: crate::Error) -> Result<T, crate::Error>
where
T: 'static,
{
match buf.downcast::<T>() {
Ok(t) => Ok(*t),
Err(_) => Err(on_error),
}
}
}
pub trait CrosstalkTopic: Eq + Copy + Clone + PartialEq + std::hash::Hash {}
pub trait CrosstalkData: Clone + Send + 'static {}
impl<T: Clone + Send + 'static> CrosstalkData for T {}
#[derive(Copy, Clone, Debug)]
pub enum Error {
PublisherMismatch(&'static str, &'static str),
SubscriberMismatch(&'static str, &'static str),
}
impl std::error::Error for Error {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::PublisherMismatch(input, output) => write!(f, "Publisher type mismatch: {} (cast) != {} (expected)", input, output),
Error::SubscriberMismatch(input, output) => write!(f, "Subscriber type mismatch: {} (cast) != {} (expected)", input, output),
}
}
}
pub trait CrosstalkPubSub<T> {
fn publisher<D: CrosstalkData>(&mut self, topic: T) -> Result<Publisher<D, T>, crate::Error>;
fn subscriber<D: CrosstalkData>(&mut self, topic: T) -> Result<Subscriber<D, T>, crate::Error>;
#[allow(clippy::type_complexity)]
fn pubsub<D: CrosstalkData>(&mut self, topic: T) -> Result<(Publisher<D, T>, Subscriber<D, T>), crate::Error>;
}
#[derive(Clone)]
pub struct BoundedNode<T> {
pub node: Arc<Mutex<ImplementedBoundedNode<T>>>,
pub size: usize,
}
impl<T> BoundedNode<T>
where
T: CrosstalkTopic,
ImplementedBoundedNode<T>: CrosstalkPubSub<T>,
{
#[inline(always)]
pub fn new(size: usize) -> Self {
if size == 0 {
panic!("Size must be greater than 0. Attempting to make `tokio::sync::broadcast::channels` later will result in a panic.");
}
Self {
node: Arc::new(Mutex::new(ImplementedBoundedNode::<T>::new(size))),
size,
}
}
#[inline(always)]
pub async fn publisher<D: CrosstalkData>(&mut self, topic: T) -> Result<Publisher<D, T>, crate::Error> {
self.node.lock().await.publisher(topic)
}
#[inline(always)]
pub fn publisher_blocking<D: CrosstalkData>(&mut self, topic: T) -> Result<Publisher<D, T>, crate::Error> {
self.node.blocking_lock().publisher(topic)
}
#[inline(always)]
pub async fn subscriber<D: CrosstalkData>(&mut self, topic: T) -> Result<Subscriber<D, T>, crate::Error> {
self.node.lock().await.subscriber(topic)
}
#[inline(always)]
pub fn subscriber_blocking<D: CrosstalkData>(&mut self, topic: T) -> Result<Subscriber<D, T>, crate::Error> {
self.node.blocking_lock().subscriber(topic)
}
#[inline(always)]
pub async fn pubsub<D: CrosstalkData>(&mut self, topic: T) -> Result<(Publisher<D, T>, Subscriber<D, T>), crate::Error> {
self.node.lock().await.pubsub(topic)
}
#[inline(always)]
#[allow(clippy::type_complexity)]
pub fn pubsub_blocking<D: CrosstalkData>(&mut self, topic: T) -> Result<(Publisher<D, T>, Subscriber<D, T>), crate::Error> {
self.node.blocking_lock().pubsub(topic)
}
}
pub struct ImplementedBoundedNode<T> {
pub senders: HashMap<T, Box<dyn std::any::Any + 'static>>,
pub size: usize,
}
unsafe impl<T> Send for ImplementedBoundedNode<T> {}
unsafe impl<T> Sync for ImplementedBoundedNode<T> {}
impl<T> ImplementedBoundedNode<T>
where
T: CrosstalkTopic,
{
pub fn new(size: usize) -> Self {
Self {
senders: HashMap::new(),
size,
}
}
}
#[derive(Clone)]
pub struct Publisher<D, T> {
pub topic: T,
buf: TokioSender<D>,
}
impl<D, T> Publisher<D, T> {
#[inline(always)]
pub fn new(topic: T, buf: TokioSender<D>) -> Self {
Self { topic, buf }
}
#[inline(always)]
pub fn write(&self, sample: D) {
let _ = self.buf.send(sample);
}
}
pub struct Subscriber<D, T> {
pub topic: T,
rcvr: Receiver<D>,
sndr: Arc<TokioSender<D>>,
}
impl<D: Clone, T: Clone> Subscriber<D, T> {
#[inline(always)]
pub fn new(
topic: T,
rcvr: Option<TokioReceiver<D>>,
sndr: Arc<TokioSender<D>>,
) -> Self {
Self {
topic,
rcvr: Receiver::new(rcvr.unwrap_or(sndr.subscribe())),
sndr: sndr.clone(),
}
}
#[inline(always)]
pub async fn read(&mut self) -> Option<D> {
self.rcvr.read().await
}
#[inline(always)]
pub fn try_read(&mut self) -> Option<D> {
self.rcvr.try_read()
}
#[inline(always)]
pub fn try_read_raw(&mut self) -> Option<D> {
self.rcvr.try_read_raw()
}
#[inline(always)]
pub fn read_blocking(&mut self) -> Option<D> {
self.rcvr.read_blocking()
}
#[inline(always)]
pub async fn read_timeout(&mut self, timeout: std::time::Duration) -> Option<D> {
self.rcvr.read_timeout(timeout).await
}
}
impl<D: Clone, T: Clone> Clone for Subscriber<D, T> {
#[inline(always)]
fn clone(&self) -> Self {
Self {
topic: self.topic.clone(),
rcvr: Receiver::new(self.sndr.subscribe()),
sndr: self.sndr.clone(),
}
}
}
struct Receiver<D> {
buf: TokioReceiver<D>,
}
impl<D: Clone> Receiver<D>{
#[inline(always)]
pub fn new(
buf: TokioReceiver<D>,
) -> Self {
Self { buf }
}
async fn read(&mut self) -> Option<D> {
loop {
match self.buf.recv().await {
Ok(res) => return Some(res),
Err(e) => match e {
tokio::sync::broadcast::error::RecvError::Lagged(_) => { continue; }
#[cfg(not(any(feature = "log", feature = "tracing")))]
_ => return None,
#[cfg(any(feature = "log", feature = "tracing"))]
_ => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
return None
}
}
}
}
}
fn try_read(&mut self) -> Option<D> {
loop {
match self.buf.try_recv() {
Ok(d) => return Some(d),
Err(e) => {
match e {
tokio::sync::broadcast::error::TryRecvError::Lagged(_) => { continue; },
#[cfg(not(any(feature = "log", feature = "tracing")))]
_ => return None,
#[cfg(any(feature = "log", feature = "tracing"))]
_ => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
return None
},
}
},
}
}
}
fn try_read_raw(&mut self) -> Option<D> {
match self.buf.try_recv() {
Ok(d) => Some(d),
#[cfg(not(any(feature = "log", feature = "tracing")))]
Err(_) => None,
#[cfg(any(feature = "log", feature = "tracing"))]
Err(e) => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
None
},
}
}
fn read_blocking(&mut self) -> Option<D> {
loop {
match self.buf.blocking_recv() {
Ok(res) => return Some(res),
Err(e) => match e {
tokio::sync::broadcast::error::RecvError::Lagged(_) => { continue; }
#[cfg(not(any(feature = "log", feature = "tracing")))]
_ => return None,
#[cfg(any(feature = "log", feature = "tracing"))]
_ => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
return None
},
}
}
}
}
async fn read_timeout(&mut self, timeout: std::time::Duration) -> Option<D> {
match tokio::runtime::Handle::try_current() {
Ok(_) => {
match tokio::time::timeout(timeout, self.buf.recv()).await {
Ok(res) => {
match res {
Ok(res) => Some(res),
#[cfg(not(any(feature = "log", feature = "tracing")))]
Err(_) => None,
#[cfg(any(feature = "log", feature = "tracing"))]
Err(e) => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
None
},
}
},
#[cfg(not(any(feature = "log", feature = "tracing")))]
Err(_) => None,
#[cfg(any(feature = "log", feature = "tracing"))]
Err(e) => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
None
},
}
},
#[cfg(not(any(feature = "log", feature = "tracing")))]
Err(_) => None,
#[cfg(any(feature = "log", feature = "tracing"))]
Err(e) => {
#[cfg(feature = "log")]
log::error!("{}", e);
#[cfg(feature = "tracing")]
tracing::error!("{}", e);
None
},
}
}
}
#[allow(unused_imports)]
use crosstalk_macros::init_test;
#[allow(unused_imports)]
use crosstalk_macros::AsTopicTest;
#[cfg(test)]
mod tests {
use super::*;
#[derive(AsTopicTest)]
enum TestTopic {
A,
B,
C,
}
super::init_test! {
TestTopic::A => String,
TestTopic::B => bool,
TestTopic::C => i32,
}
#[derive(AsTopicTest)]
enum AnotherTestTopic {
Foo,
Bar,
}
super::init_test! {
AnotherTestTopic::Foo => Vec<String>,
AnotherTestTopic::Bar => Vec<bool>,
}
#[test]
fn test_single_pubsub_blocking() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
publisher.write("test".to_string());
assert_eq!(subscriber.try_read().unwrap(), "test");
}
#[test]
fn test_multiple_subscribers_blocking() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut sub1) = node.pubsub_blocking(TestTopic::A).unwrap();
let mut sub2 = node.subscriber_blocking::<String>(TestTopic::A).unwrap();
publisher.write("hello".to_string());
assert_eq!(sub1.try_read().unwrap(), "hello");
assert_eq!(sub2.try_read().unwrap(), "hello");
}
#[test]
fn test_cross_topic_isolation() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (pub_a, mut sub_a) = node.pubsub_blocking(TestTopic::A).unwrap();
let (pub_b, mut sub_b) = node.pubsub_blocking(TestTopic::B).unwrap();
pub_a.write("string".to_string());
pub_b.write(true);
assert_eq!(sub_a.try_read().unwrap(), "string");
assert!(sub_b.try_read().unwrap());
assert!(sub_a.try_read().is_none());
assert!(sub_b.try_read().is_none());
}
#[test]
fn test_multiple_threads_blocking() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
let handle = std::thread::spawn(move || {
publisher.write("threaded".to_string());
});
handle.join().unwrap();
assert_eq!(subscriber.try_read().unwrap(), "threaded");
}
#[tokio::test]
async fn test_async_pubsub_single_runtime() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub(TestTopic::A).await.unwrap();
publisher.write("async".to_string());
assert_eq!(subscriber.read().await.unwrap(), "async");
}
#[test]
fn test_high_volume_blocking() {
let mut node = BoundedNode::<TestTopic>::new(100);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::C).unwrap();
for i in 0..100 {
publisher.write(i);
}
for i in 0..100 {
assert_eq!(subscriber.try_read().unwrap(), i);
}
assert!(subscriber.try_read().is_none());
}
#[test]
fn test_cloned_subscribers() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut sub1) = node.pubsub_blocking(TestTopic::A).unwrap();
let mut sub2 = sub1.clone();
publisher.write("clone".to_string());
assert_eq!(sub1.try_read().unwrap(), "clone");
assert_eq!(sub2.try_read().unwrap(), "clone");
}
#[test]
fn test_buffer_overflow_handling() {
let mut node = BoundedNode::<TestTopic>::new(2);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
publisher.write("msg1".to_string());
publisher.write("msg2".to_string());
publisher.write("msg3".to_string());
assert_eq!(subscriber.try_read().unwrap(), "msg2");
assert_eq!(subscriber.try_read().unwrap(), "msg3");
assert!(subscriber.try_read().is_none());
}
#[test]
fn test_type_mismatch_errors() {
let mut node = BoundedNode::<TestTopic>::new(10);
let publisher_res = node.publisher_blocking::<i32>(TestTopic::A);
assert!(matches!(publisher_res, Err(Error::PublisherMismatch(_, _))));
let subscriber_res = node.subscriber_blocking::<i32>(TestTopic::A);
assert!(matches!(subscriber_res, Err(Error::SubscriberMismatch(_, _))));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_multiple_async_runtimes() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub(TestTopic::A).await.unwrap();
let handle = tokio::spawn(async move {
publisher.write("async".to_string());
});
handle.await.unwrap();
assert_eq!(subscriber.read().await.unwrap(), "async");
}
#[test]
fn test_mixed_async_blocking() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
publisher.write("mixed".to_string());
});
assert_eq!(subscriber.try_read().unwrap(), "mixed");
}
#[test]
fn test_complex_data_types() {
let mut node = BoundedNode::<AnotherTestTopic>::new(10);
let (pub_foo, mut sub_foo) = node.pubsub_blocking(AnotherTestTopic::Foo).unwrap();
let (pub_bar, mut sub_bar) = node.pubsub_blocking(AnotherTestTopic::Bar).unwrap();
pub_foo.write(vec!["a".to_string(), "b".to_string()]);
pub_bar.write(vec![true, false]);
assert_eq!(sub_foo.try_read().unwrap(), vec!["a", "b"]);
assert_eq!(sub_bar.try_read().unwrap(), vec![true, false]);
}
#[test]
fn test_concurrent_publishers() {
let mut node = BoundedNode::<TestTopic>::new(100);
let (pub1, mut sub) = node.pubsub_blocking(TestTopic::C).unwrap();
let pub2 = node.publisher_blocking::<i32>(TestTopic::C).unwrap();
let handle1 = std::thread::spawn(move || {
for i in 0..50 {
pub1.write(i);
}
});
let handle2 = std::thread::spawn(move || {
for i in 50..100 {
pub2.write(i);
}
});
handle1.join().unwrap();
handle2.join().unwrap();
let mut received = Vec::new();
while let Some(msg) = sub.try_read() {
received.push(msg);
}
assert_eq!(received.len(), 100);
}
#[tokio::test]
async fn test_async_subscriber_cloning() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut sub1) = node.pubsub(TestTopic::A).await.unwrap();
let mut sub2 = sub1.clone();
publisher.write("async_clone".to_string());
assert_eq!(sub1.read().await.unwrap(), "async_clone");
assert_eq!(sub2.read().await.unwrap(), "async_clone");
}
#[test]
fn test_dropped_publisher_behavior() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub_blocking::<String>(TestTopic::A).unwrap();
drop(publisher);
assert!(subscriber.try_read().is_none());
}
#[tokio::test]
async fn test_multiple_async_publishers() {
const LOOP_COUNT: usize = 10;
const BUFFER_SIZE: usize = 10;
let mut node = BoundedNode::<TestTopic>::new(BUFFER_SIZE);
let (publisher, mut subscriber) = node.pubsub(TestTopic::A).await.unwrap();
let publisher_1 = publisher.clone();
let publisher_2 = publisher.clone();
let task1 = tokio::spawn({
async move {
for _ in 0..LOOP_COUNT {
publisher_1.write("task1".to_string());
}
}
});
let task2 = tokio::spawn({
async move {
for _ in 0..LOOP_COUNT {
publisher_2.write("task2".to_string());
}
}
});
let _ = tokio::join!(task1, task2);
let mut task1_count = 0;
let mut task2_count = 0;
for _ in 0..BUFFER_SIZE {
let msg = subscriber.read().await.unwrap();
if msg == "task1" { task1_count += 1; }
if msg == "task2" { task2_count += 1; }
}
assert_eq!(task1_count + task2_count, BUFFER_SIZE);
}
#[test]
fn test_blocking_read_with_delay() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(500));
publisher.write("delayed".to_string());
});
assert_eq!(subscriber.read_blocking().unwrap(), "delayed");
}
#[tokio::test]
async fn test_read_timeout_behavior() {
let mut node = BoundedNode::<TestTopic>::new(10);
let (publisher, mut subscriber) = node.pubsub(TestTopic::A).await.unwrap();
let timeout = std::time::Duration::from_millis(100);
assert!(subscriber.read_timeout(timeout).await.is_none());
publisher.write("timeout_test".to_string());
assert_eq!(subscriber.read_timeout(timeout).await.unwrap(), "timeout_test");
}
#[test]
fn test_multiple_topics_concurrently() {
let mut node = BoundedNode::<AnotherTestTopic>::new(10);
let (pub_foo, mut sub_foo) = node.pubsub_blocking(AnotherTestTopic::Foo).unwrap();
let (pub_bar, mut sub_bar) = node.pubsub_blocking(AnotherTestTopic::Bar).unwrap();
let handle1 = std::thread::spawn(move || {
pub_foo.write(vec!["thread".to_string()]);
});
let handle2 = std::thread::spawn(move || {
pub_bar.write(vec![true]);
});
handle1.join().unwrap();
handle2.join().unwrap();
assert_eq!(sub_foo.try_read().unwrap(), vec!["thread"]);
assert_eq!(sub_bar.try_read().unwrap(), vec![true]);
}
#[test]
#[should_panic]
fn test_zero_capacity_node() {
let _ = BoundedNode::<TestTopic>::new(0);
}
#[tokio::test]
async fn test_async_unbounded_messaging() {
let mut node = BoundedNode::<TestTopic>::new(1000);
let (publisher, mut subscriber) = node.pubsub(TestTopic::A).await.unwrap();
let messages = vec!["msg1", "msg2", "msg3", "msg4", "msg5"];
for msg in &messages {
publisher.write(msg.to_string());
}
for expected in messages {
assert_eq!(subscriber.read().await.unwrap(), expected);
}
}
#[test]
fn test_error_handling_lagged_messages() {
let mut node = BoundedNode::<TestTopic>::new(2);
let (publisher, mut subscriber) = node.pubsub_blocking(TestTopic::A).unwrap();
for i in 0..5 {
publisher.write(format!("msg{}", i));
}
let mut received = Vec::new();
while let Some(msg) = subscriber.try_read() {
received.push(msg);
}
assert_eq!(received, vec!["msg3", "msg4"]);
}
}