use super::{PeriodicMessage, PeriodicStats};
use crate::{CanBackendAsync, CanError};
use std::collections::{BinaryHeap, HashMap};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
#[derive(Debug)]
pub enum SchedulerCommand {
Add {
message: PeriodicMessage,
reply: oneshot::Sender<Result<u32, CanError>>,
},
Remove {
id: u32,
reply: oneshot::Sender<Result<(), CanError>>,
},
UpdateData {
id: u32,
data: Vec<u8>,
reply: oneshot::Sender<Result<(), CanError>>,
},
UpdateInterval {
id: u32,
interval: Duration,
reply: oneshot::Sender<Result<(), CanError>>,
},
SetEnabled {
id: u32,
enabled: bool,
reply: oneshot::Sender<Result<(), CanError>>,
},
GetStats {
id: u32,
reply: oneshot::Sender<Option<PeriodicStats>>,
},
ListIds {
reply: oneshot::Sender<Vec<u32>>,
},
Shutdown,
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct ScheduledEntry {
next_send: Instant,
message_id: u32,
}
impl Ord for ScheduledEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.next_send.cmp(&self.next_send)
}
}
impl PartialOrd for ScheduledEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
struct SchedulerState {
messages: HashMap<u32, PeriodicMessage>,
stats: HashMap<u32, PeriodicStats>,
schedule: BinaryHeap<ScheduledEntry>,
next_id: u32,
capacity: usize,
}
impl SchedulerState {
fn new(capacity: usize) -> Self {
Self {
messages: HashMap::new(),
stats: HashMap::new(),
schedule: BinaryHeap::new(),
next_id: 1,
capacity,
}
}
fn add(&mut self, mut message: PeriodicMessage) -> Result<u32, CanError> {
if self.messages.len() >= self.capacity {
return Err(CanError::InsufficientResources {
resource: format!(
"periodic message capacity exceeded (max: {})",
self.capacity
),
});
}
let id = self.next_id;
self.next_id += 1;
message.set_id(id);
let interval = message.interval();
self.messages.insert(id, message);
self.stats.insert(id, PeriodicStats::new());
self.schedule.push(ScheduledEntry {
next_send: Instant::now() + interval,
message_id: id,
});
Ok(id)
}
fn remove(&mut self, id: u32) -> Result<(), CanError> {
if self.messages.remove(&id).is_none() {
return Err(CanError::InvalidParameter {
parameter: "id".to_string(),
reason: format!("periodic message with id {id} not found"),
});
}
self.stats.remove(&id);
Ok(())
}
}
#[derive(Clone)]
pub struct PeriodicScheduler {
command_tx: mpsc::Sender<SchedulerCommand>,
}
impl PeriodicScheduler {
#[must_use]
pub fn new(channel_size: usize) -> (Self, mpsc::Receiver<SchedulerCommand>) {
let (command_tx, command_rx) = mpsc::channel(channel_size);
(Self { command_tx }, command_rx)
}
pub async fn add(&self, message: PeriodicMessage) -> Result<u32, CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::Add {
message,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})?
}
pub async fn remove(&self, id: u32) -> Result<(), CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::Remove {
id,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})?
}
pub async fn update_data(&self, id: u32, data: Vec<u8>) -> Result<(), CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::UpdateData {
id,
data,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})?
}
pub async fn update_interval(&self, id: u32, interval: Duration) -> Result<(), CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::UpdateInterval {
id,
interval,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})?
}
pub async fn set_enabled(&self, id: u32, enabled: bool) -> Result<(), CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::SetEnabled {
id,
enabled,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})?
}
pub async fn get_stats(&self, id: u32) -> Result<Option<PeriodicStats>, CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::GetStats {
id,
reply: reply_tx,
})
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})
}
pub async fn list_ids(&self) -> Result<Vec<u32>, CanError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.command_tx
.send(SchedulerCommand::ListIds { reply: reply_tx })
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})?;
reply_rx.await.map_err(|_| CanError::Other {
message: "scheduler reply channel closed".to_string(),
})
}
pub async fn shutdown(&self) -> Result<(), CanError> {
self.command_tx
.send(SchedulerCommand::Shutdown)
.await
.map_err(|_| CanError::Other {
message: "scheduler channel closed".to_string(),
})
}
}
#[allow(clippy::too_many_lines)]
pub async fn run_scheduler<B>(
mut backend: B,
mut command_rx: mpsc::Receiver<SchedulerCommand>,
capacity: usize,
) where
B: CanBackendAsync,
{
let mut state = SchedulerState::new(capacity);
loop {
let sleep_duration = state
.schedule
.peek()
.map_or(Duration::from_secs(1), |entry| {
let now = Instant::now();
if entry.next_send > now {
entry.next_send - now
} else {
Duration::ZERO
}
});
tokio::select! {
Some(cmd) = command_rx.recv() => {
match cmd {
SchedulerCommand::Add { message, reply } => {
let _ = reply.send(state.add(message));
}
SchedulerCommand::Remove { id, reply } => {
let _ = reply.send(state.remove(id));
}
SchedulerCommand::UpdateData { id, data, reply } => {
let result = if let Some(msg) = state.messages.get_mut(&id) {
msg.update_data(data)
} else {
Err(CanError::InvalidParameter {
parameter: "id".to_string(),
reason: format!("periodic message with id {id} not found"),
})
};
let _ = reply.send(result);
}
SchedulerCommand::UpdateInterval { id, interval, reply } => {
let result = if let Some(msg) = state.messages.get_mut(&id) {
msg.set_interval(interval)
} else {
Err(CanError::InvalidParameter {
parameter: "id".to_string(),
reason: format!("periodic message with id {id} not found"),
})
};
let _ = reply.send(result);
}
SchedulerCommand::SetEnabled { id, enabled, reply } => {
let result = if let Some(msg) = state.messages.get_mut(&id) {
msg.set_enabled(enabled);
Ok(())
} else {
Err(CanError::InvalidParameter {
parameter: "id".to_string(),
reason: format!("periodic message with id {id} not found"),
})
};
let _ = reply.send(result);
}
SchedulerCommand::GetStats { id, reply } => {
let message_stats = state.stats.get(&id).cloned();
let _ = reply.send(message_stats);
}
SchedulerCommand::ListIds { reply } => {
let ids: Vec<u32> = state.messages.keys().copied().collect();
let _ = reply.send(ids);
}
SchedulerCommand::Shutdown => {
break;
}
}
}
() = tokio::time::sleep(sleep_duration) => {
let now = Instant::now();
while let Some(entry) = state.schedule.peek() {
if entry.next_send > now {
break;
}
let Some(entry) = state.schedule.pop() else {
break;
};
let id = entry.message_id;
if let Some(msg) = state.messages.get(&id) {
if msg.is_enabled() {
let send_result = backend.send_message_async(msg.message()).await;
if let Some(stats) = state.stats.get_mut(&id) {
stats.record_send(now.into());
}
if let Err(e) = send_result {
#[cfg(feature = "tracing")]
tracing::warn!(
"Periodic send failed for message {}: {}",
id,
e
);
let _ = e; }
}
if let Some(msg) = state.messages.get(&id) {
state.schedule.push(ScheduledEntry {
next_send: now + msg.interval(),
message_id: id,
});
}
}
}
}
}
}
}