use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::SystemTime;
use tokio::sync::broadcast;
use crate::param::ParamValue;
#[derive(Debug, Clone, Default)]
pub struct InterruptFilter {
pub reason: Option<usize>,
pub addr: Option<i32>,
pub uint32_mask: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct InterruptValue {
pub reason: usize,
pub addr: i32,
pub value: ParamValue,
pub timestamp: SystemTime,
pub uint32_changed_mask: u32,
}
struct SubscriptionMailbox {
filter: InterruptFilter,
latest: parking_lot::Mutex<Option<InterruptValue>>,
wakeup: tokio::sync::Notify,
active: AtomicBool,
}
impl SubscriptionMailbox {
fn matches(&self, iv: &InterruptValue) -> bool {
if let Some(r) = self.filter.reason {
if iv.reason != r {
return false;
}
}
if let Some(a) = self.filter.addr {
if iv.addr != a {
return false;
}
}
if let Some(m) = self.filter.uint32_mask {
if iv.uint32_changed_mask & m == 0 {
return false;
}
}
true
}
}
pub struct InterruptReceiver {
mailbox: Arc<SubscriptionMailbox>,
}
impl InterruptReceiver {
pub async fn recv(&mut self) -> Option<InterruptValue> {
loop {
let notified = self.mailbox.wakeup.notified();
if let Some(value) = self.mailbox.latest.lock().take() {
return Some(value);
}
if !self.mailbox.active.load(Ordering::Acquire) {
return None;
}
notified.await;
}
}
}
pub struct InterruptSubscription {
mailbox: Arc<SubscriptionMailbox>,
state: Arc<InterruptSharedState>,
}
impl Drop for InterruptSubscription {
fn drop(&mut self) {
self.mailbox.active.store(false, Ordering::Release);
self.mailbox.wakeup.notify_one();
self.state
.mailboxes
.lock()
.retain(|s| s.active.load(Ordering::Relaxed));
}
}
pub struct InterruptSharedState {
async_tx: broadcast::Sender<InterruptValue>,
mailboxes: parking_lot::Mutex<Vec<Arc<SubscriptionMailbox>>>,
notify_count: AtomicU64,
coalesce_count: AtomicU64,
}
pub struct InterruptManager {
state: Arc<InterruptSharedState>,
}
impl InterruptManager {
pub fn new(async_capacity: usize) -> Self {
let (async_tx, _) = broadcast::channel(async_capacity);
Self {
state: Arc::new(InterruptSharedState {
async_tx,
mailboxes: parking_lot::Mutex::new(Vec::new()),
notify_count: AtomicU64::new(0),
coalesce_count: AtomicU64::new(0),
}),
}
}
pub fn from_shared_state(state: Arc<InterruptSharedState>) -> Self {
Self { state }
}
pub fn shared_state(&self) -> Arc<InterruptSharedState> {
self.state.clone()
}
pub fn from_broadcast_sender(sender: broadcast::Sender<InterruptValue>) -> Self {
Self {
state: Arc::new(InterruptSharedState {
async_tx: sender,
mailboxes: parking_lot::Mutex::new(Vec::new()),
notify_count: AtomicU64::new(0),
coalesce_count: AtomicU64::new(0),
}),
}
}
pub fn subscribe_async(&self) -> broadcast::Receiver<InterruptValue> {
self.state.async_tx.subscribe()
}
pub fn broadcast_sender(&self) -> broadcast::Sender<InterruptValue> {
self.state.async_tx.clone()
}
pub fn notify(&self, value: InterruptValue) {
self.state.notify_count.fetch_add(1, Ordering::Relaxed);
let subs = self.state.mailboxes.lock();
for sub in subs.iter() {
if !sub.active.load(Ordering::Relaxed) {
continue;
}
if !sub.matches(&value) {
continue;
}
let mut slot = sub.latest.lock();
if slot.is_some() {
self.state.coalesce_count.fetch_add(1, Ordering::Relaxed);
}
*slot = Some(value.clone());
drop(slot);
sub.wakeup.notify_one();
}
drop(subs);
let _ = self.state.async_tx.send(value);
}
pub fn register_interrupt_user(
&self,
filter: InterruptFilter,
) -> (InterruptSubscription, InterruptReceiver) {
let mailbox = Arc::new(SubscriptionMailbox {
filter,
latest: parking_lot::Mutex::new(None),
wakeup: tokio::sync::Notify::new(),
active: AtomicBool::new(true),
});
self.state.mailboxes.lock().push(mailbox.clone());
(
InterruptSubscription {
mailbox: mailbox.clone(),
state: self.state.clone(),
},
InterruptReceiver { mailbox },
)
}
pub fn notify_count(&self) -> u64 {
self.state.notify_count.load(Ordering::Relaxed)
}
pub fn coalesce_count(&self) -> u64 {
self.state.coalesce_count.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_async_subscribe_receive() {
let im = InterruptManager::new(16);
let mut rx = im.subscribe_async();
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Float64(3.14),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = rx.recv().await.unwrap();
assert_eq!(v.reason, 1);
}
#[tokio::test]
async fn test_async_multiple_subscribers() {
let im = InterruptManager::new(16);
let mut rx1 = im.subscribe_async();
let mut rx2 = im.subscribe_async();
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(99),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v1 = rx1.recv().await.unwrap();
let v2 = rx2.recv().await.unwrap();
assert_eq!(v1.reason, 0);
assert_eq!(v2.reason, 0);
}
#[tokio::test]
async fn test_register_interrupt_user_filter_by_reason() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter {
reason: Some(1),
addr: None,
..Default::default()
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(10),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Int32(20),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.reason, 1);
if let ParamValue::Int32(n) = v.value {
assert_eq!(n, 20);
} else {
panic!("expected Int32");
}
}
#[tokio::test]
async fn test_register_interrupt_user_filter_by_addr() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter {
reason: None,
addr: Some(3),
..Default::default()
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(1),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
im.notify(InterruptValue {
reason: 0,
addr: 3,
value: ParamValue::Int32(2),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.addr, 3);
}
#[tokio::test]
async fn test_register_interrupt_user_no_filter() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter::default());
im.notify(InterruptValue {
reason: 5,
addr: 2,
value: ParamValue::Float64(1.5),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.reason, 5);
assert_eq!(v.addr, 2);
}
#[tokio::test]
async fn test_register_interrupt_user_drop_unsubscribes() {
let im = InterruptManager::new(16);
let (sub, mut rx) = im.register_interrupt_user(InterruptFilter::default());
drop(sub);
let result = tokio::time::timeout(std::time::Duration::from_millis(50), rx.recv()).await;
match result {
Ok(None) => {} Err(_) => {} Ok(Some(_)) => panic!("should not receive after unsubscribe"),
}
}
#[tokio::test]
async fn test_register_interrupt_user_multiple_subscribers() {
let im = InterruptManager::new(16);
let (_sub1, mut rx1) = im.register_interrupt_user(InterruptFilter {
reason: Some(0),
addr: None,
..Default::default()
});
let (_sub2, mut rx2) = im.register_interrupt_user(InterruptFilter {
reason: Some(1),
addr: None,
..Default::default()
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(10),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
im.notify(InterruptValue {
reason: 1,
addr: 0,
value: ParamValue::Int32(20),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v1.reason, 0);
let v2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v2.reason, 1);
}
#[test]
fn test_notify_no_subscribers_no_panic() {
let im = InterruptManager::new(16);
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(1),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
}
#[tokio::test]
async fn test_coalescing() {
let im = InterruptManager::new(16);
let (_sub, mut rx) = im.register_interrupt_user(InterruptFilter {
reason: Some(0),
..Default::default()
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(1),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(2),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
im.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(3),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
if let ParamValue::Int32(n) = v.value {
assert_eq!(n, 3);
} else {
panic!("expected Int32");
}
assert_eq!(im.coalesce_count(), 2);
}
#[tokio::test]
async fn test_shared_state_between_managers() {
let im1 = InterruptManager::new(16);
let shared = im1.shared_state();
let im2 = InterruptManager::from_shared_state(shared);
let (_sub, mut rx) = im2.register_interrupt_user(InterruptFilter {
reason: Some(0),
..Default::default()
});
im1.notify(InterruptValue {
reason: 0,
addr: 0,
value: ParamValue::Int32(42),
timestamp: SystemTime::now(),
uint32_changed_mask: 0,
});
let v = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(v.reason, 0);
if let ParamValue::Int32(n) = v.value {
assert_eq!(n, 42);
} else {
panic!("expected Int32");
}
}
}