use std::sync::Mutex;
use std::time::SystemTime;
use tokio::sync::broadcast;
use crate::error::{AsynError, AsynResult};
use crate::param::ParamValue;
#[derive(Debug, Clone, Default)]
pub struct InterruptFilter {
pub reason: Option<usize>,
pub addr: Option<i32>,
}
pub struct InterruptSubscription {
cancel_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl InterruptSubscription {
fn new(cancel_tx: tokio::sync::oneshot::Sender<()>) -> Self {
Self {
cancel_tx: Some(cancel_tx),
}
}
}
impl Drop for InterruptSubscription {
fn drop(&mut self) {
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(());
}
}
}
#[derive(Debug, Clone)]
pub struct InterruptValue {
pub reason: usize,
pub addr: i32,
pub value: ParamValue,
pub timestamp: SystemTime,
}
pub struct InterruptManager {
async_tx: broadcast::Sender<InterruptValue>,
sync_tx: std::sync::mpsc::Sender<InterruptValue>,
sync_rx: Mutex<Option<std::sync::mpsc::Receiver<InterruptValue>>>,
}
impl InterruptManager {
pub fn new(async_capacity: usize) -> Self {
let (async_tx, _) = broadcast::channel(async_capacity);
let (sync_tx, sync_rx) = std::sync::mpsc::channel();
Self {
async_tx,
sync_tx,
sync_rx: Mutex::new(Some(sync_rx)),
}
}
pub fn from_broadcast_sender(sender: broadcast::Sender<InterruptValue>) -> Self {
let (sync_tx, sync_rx) = std::sync::mpsc::channel();
Self {
async_tx: sender,
sync_tx,
sync_rx: Mutex::new(Some(sync_rx)),
}
}
pub fn subscribe_async(&self) -> broadcast::Receiver<InterruptValue> {
self.async_tx.subscribe()
}
pub fn subscribe_sync(&self) -> AsynResult<std::sync::mpsc::Receiver<InterruptValue>> {
self.sync_rx
.lock()
.unwrap()
.take()
.ok_or(AsynError::AlreadySubscribed)
}
pub fn broadcast_sender(&self) -> broadcast::Sender<InterruptValue> {
self.async_tx.clone()
}
pub fn notify(&self, value: InterruptValue) {
let _ = self.async_tx.send(value.clone());
let _ = self.sync_tx.send(value);
}
pub fn register_interrupt_user(
&self,
filter: InterruptFilter,
) -> (InterruptSubscription, tokio::sync::mpsc::Receiver<InterruptValue>) {
let mut intr_rx = self.async_tx.subscribe();
let (tx, rx) = tokio::sync::mpsc::channel(64);
let (cancel_tx, mut cancel_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut cancel_rx => break,
recv = intr_rx.recv() => {
match recv {
Ok(iv) => {
if let Some(r) = filter.reason {
if iv.reason != r { continue; }
}
if let Some(a) = filter.addr {
if iv.addr != a { continue; }
}
if tx.send(iv).await.is_err() { break; }
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
}
});
(InterruptSubscription::new(cancel_tx), rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sync_subscribe_once() {
let im = InterruptManager::new(16);
let _rx = im.subscribe_sync().unwrap();
assert!(im.subscribe_sync().is_err());
}
#[test]
fn test_sync_notify_receive() {
let im = InterruptManager::new(16);
let rx = im.subscribe_sync().unwrap();
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(42),
timestamp: SystemTime::now(),
});
let v = rx.try_recv().unwrap();
assert_eq!(v.reason, 0);
if let ParamValue::Int32(n) = v.value {
assert_eq!(n, 42);
} else {
panic!("expected Int32");
}
}
#[test]
fn test_notify_after_sync_drop() {
let im = InterruptManager::new(16);
let rx = im.subscribe_sync().unwrap();
drop(rx);
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(1),
timestamp: SystemTime::now(),
});
}
#[tokio::test]
async fn test_async_subscribe_receive() {
let im = InterruptManager::new(16);
let mut rx = im.subscribe_async();
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Float64(3.14),
timestamp: SystemTime::now(),
});
let v = rx.recv().await.unwrap();
assert_eq!(v.reason, 1);
}
#[tokio::test]
async fn test_async_multiple_subscribers() {
let im = InterruptManager::new(16);
let mut rx1 = im.subscribe_async();
let mut rx2 = im.subscribe_async();
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(99),
timestamp: SystemTime::now(),
});
let v1 = rx1.recv().await.unwrap();
let v2 = rx2.recv().await.unwrap();
assert_eq!(v1.reason, 0);
assert_eq!(v2.reason, 0);
}
#[tokio::test]
async fn test_register_interrupt_user_filter_by_reason() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter {
reason: Some(1),
addr: None,
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(10),
timestamp: SystemTime::now(),
});
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Int32(20),
timestamp: SystemTime::now(),
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.reason, 1);
if let ParamValue::Int32(n) = v.value {
assert_eq!(n, 20);
} else {
panic!("expected Int32");
}
}
#[tokio::test]
async fn test_register_interrupt_user_filter_by_addr() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter {
reason: None,
addr: Some(3),
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(1),
timestamp: SystemTime::now(),
});
im.notify(InterruptValue {
reason: 0,
addr: 3,
value: ParamValue::Int32(2),
timestamp: SystemTime::now(),
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.addr, 3);
}
#[tokio::test]
async fn test_register_interrupt_user_no_filter() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter::default());
im.notify(InterruptValue {
reason: 5,
addr: 2,
value: ParamValue::Float64(1.5),
timestamp: SystemTime::now(),
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.reason, 5);
assert_eq!(v.addr, 2);
}
#[tokio::test]
async fn test_register_interrupt_user_drop_unsubscribes() {
let im = InterruptManager::new(16);
let (sub, mut rx) = im.register_interrupt_user(InterruptFilter::default());
drop(sub);
tokio::task::yield_now().await;
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(999),
timestamp: SystemTime::now(),
});
let result = tokio::time::timeout(std::time::Duration::from_millis(50), rx.recv()).await;
match result {
Ok(None) => {} Err(_) => {} Ok(Some(_)) => panic!("should not receive after unsubscribe"),
}
}
#[tokio::test]
async fn test_register_interrupt_user_multiple_subscribers() {
let im = InterruptManager::new(16);
let (_sub1, mut rx1) = im.register_interrupt_user(InterruptFilter {
reason: Some(0),
addr: None,
});
let (_sub2, mut rx2) = im.register_interrupt_user(InterruptFilter {
reason: Some(1),
addr: None,
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(10),
timestamp: SystemTime::now(),
});
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Int32(20),
timestamp: SystemTime::now(),
});
let v1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v1.reason, 0);
let v2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v2.reason, 1);
}
}