use crate::scenarios::ChaosScenario;
use chrono::{DateTime, Duration, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::interval;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ScheduleType {
Once { at: DateTime<Utc> },
Delayed { delay_seconds: u64 },
Periodic {
interval_seconds: u64,
max_executions: usize,
},
Cron {
hour: Option<u8>,
minute: Option<u8>,
day_of_week: Option<u8>,
max_executions: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduledScenario {
pub id: String,
pub scenario: ChaosScenario,
pub schedule: ScheduleType,
pub enabled: bool,
pub execution_count: usize,
pub last_executed: Option<DateTime<Utc>>,
pub next_execution: Option<DateTime<Utc>>,
}
impl ScheduledScenario {
pub fn new(id: impl Into<String>, scenario: ChaosScenario, schedule: ScheduleType) -> Self {
let mut scheduled = Self {
id: id.into(),
scenario,
schedule,
enabled: true,
execution_count: 0,
last_executed: None,
next_execution: None,
};
scheduled.calculate_next_execution();
scheduled
}
pub fn calculate_next_execution(&mut self) {
let now = Utc::now();
self.next_execution = match &self.schedule {
ScheduleType::Once { at } => {
if *at > now && self.execution_count == 0 {
Some(*at)
} else {
None
}
}
ScheduleType::Delayed { delay_seconds } => {
if self.execution_count == 0 {
Some(now + Duration::seconds(*delay_seconds as i64))
} else {
None
}
}
ScheduleType::Periodic {
interval_seconds,
max_executions,
} => {
if *max_executions == 0 || self.execution_count < *max_executions {
Some(now + Duration::seconds(*interval_seconds as i64))
} else {
None
}
}
ScheduleType::Cron {
hour: _,
minute: _,
day_of_week: _,
max_executions,
} => {
if *max_executions > 0 && self.execution_count >= *max_executions {
None
} else {
let next = now + Duration::hours(1);
Some(next)
}
}
};
}
pub fn should_execute(&self) -> bool {
if !self.enabled {
return false;
}
if let Some(next) = self.next_execution {
Utc::now() >= next
} else {
false
}
}
pub fn mark_executed(&mut self) {
self.execution_count += 1;
self.last_executed = Some(Utc::now());
self.calculate_next_execution();
}
}
pub struct ScenarioScheduler {
schedules: Arc<RwLock<HashMap<String, ScheduledScenario>>>,
execution_tx: Arc<RwLock<Option<mpsc::Sender<ScheduledScenario>>>>,
task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
}
impl ScenarioScheduler {
pub fn new() -> Self {
Self {
schedules: Arc::new(RwLock::new(HashMap::new())),
execution_tx: Arc::new(RwLock::new(None)),
task_handle: Arc::new(RwLock::new(None)),
}
}
pub fn add_schedule(&self, scheduled: ScheduledScenario) {
let id = scheduled.id.clone();
let mut schedules = self.schedules.write();
schedules.insert(id.clone(), scheduled);
info!("Added scheduled scenario: {}", id);
}
pub fn remove_schedule(&self, id: &str) -> Option<ScheduledScenario> {
let mut schedules = self.schedules.write();
let removed = schedules.remove(id);
if removed.is_some() {
info!("Removed scheduled scenario: {}", id);
}
removed
}
pub fn get_schedule(&self, id: &str) -> Option<ScheduledScenario> {
let schedules = self.schedules.read();
schedules.get(id).cloned()
}
pub fn get_all_schedules(&self) -> Vec<ScheduledScenario> {
let schedules = self.schedules.read();
schedules.values().cloned().collect()
}
pub fn enable_schedule(&self, id: &str) -> Result<(), String> {
let mut schedules = self.schedules.write();
if let Some(scheduled) = schedules.get_mut(id) {
scheduled.enabled = true;
scheduled.calculate_next_execution();
info!("Enabled scheduled scenario: {}", id);
Ok(())
} else {
Err(format!("Schedule '{}' not found", id))
}
}
pub fn disable_schedule(&self, id: &str) -> Result<(), String> {
let mut schedules = self.schedules.write();
if let Some(scheduled) = schedules.get_mut(id) {
scheduled.enabled = false;
info!("Disabled scheduled scenario: {}", id);
Ok(())
} else {
Err(format!("Schedule '{}' not found", id))
}
}
pub async fn start<F>(&self, callback: F)
where
F: Fn(ScheduledScenario) + Send + 'static,
{
{
let task_handle = self.task_handle.read();
if task_handle.is_some() {
warn!("Scheduler already running");
return;
}
}
info!("Starting scenario scheduler");
let (tx, rx) = mpsc::channel::<ScheduledScenario>(100);
{
let mut execution_tx = self.execution_tx.write();
*execution_tx = Some(tx);
}
let schedules = Arc::clone(&self.schedules);
let handle = tokio::spawn(async move {
Self::scheduler_task(schedules, rx, callback).await;
});
{
let mut task_handle = self.task_handle.write();
*task_handle = Some(handle);
}
}
async fn scheduler_task<F>(
schedules: Arc<RwLock<HashMap<String, ScheduledScenario>>>,
mut rx: mpsc::Receiver<ScheduledScenario>,
callback: F,
) where
F: Fn(ScheduledScenario),
{
let mut interval = interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = interval.tick() => {
let mut to_execute = Vec::new();
{
let mut schedules_guard = schedules.write();
for (id, scheduled) in schedules_guard.iter_mut() {
if scheduled.should_execute() {
debug!("Triggering scheduled scenario: {}", id);
to_execute.push(scheduled.clone());
scheduled.mark_executed();
}
}
}
for scheduled in to_execute {
info!("Executing scheduled scenario: {}", scheduled.id);
callback(scheduled);
}
}
Some(scheduled) = rx.recv() => {
info!("Manual execution of scheduled scenario: {}", scheduled.id);
callback(scheduled);
}
else => break,
}
}
info!("Scheduler task stopped");
}
pub async fn stop(&self) {
info!("Stopping scenario scheduler");
{
let mut execution_tx = self.execution_tx.write();
*execution_tx = None;
}
let mut task_handle = self.task_handle.write();
if let Some(handle) = task_handle.take() {
handle.abort();
}
}
pub async fn trigger_now(&self, id: &str) -> Result<(), String> {
let scheduled = {
let schedules = self.schedules.read();
schedules.get(id).cloned()
};
if let Some(scheduled) = scheduled {
let tx = self.execution_tx.read().as_ref().cloned();
if let Some(tx) = tx {
tx.send(scheduled).await.map_err(|e| format!("Failed to trigger: {}", e))?;
Ok(())
} else {
Err("Scheduler not started".to_string())
}
} else {
Err(format!("Schedule '{}' not found", id))
}
}
pub fn get_next_execution(&self) -> Option<(String, DateTime<Utc>)> {
let schedules = self.schedules.read();
schedules
.iter()
.filter_map(|(id, s)| s.next_execution.map(|t| (id.clone(), t)))
.min_by_key(|(_, t)| *t)
}
}
impl Default for ScenarioScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ChaosConfig;
#[test]
fn test_scheduled_scenario_once() {
let scenario = ChaosScenario::new("test", ChaosConfig::default());
let future_time = Utc::now() + Duration::hours(1);
let schedule = ScheduleType::Once { at: future_time };
let scheduled = ScheduledScenario::new("sched1", scenario, schedule);
assert_eq!(scheduled.id, "sched1");
assert!(scheduled.enabled);
assert_eq!(scheduled.execution_count, 0);
assert!(scheduled.next_execution.is_some());
}
#[test]
fn test_scheduled_scenario_periodic() {
let scenario = ChaosScenario::new("test", ChaosConfig::default());
let schedule = ScheduleType::Periodic {
interval_seconds: 60,
max_executions: 10,
};
let scheduled = ScheduledScenario::new("sched1", scenario, schedule);
assert!(scheduled.next_execution.is_some());
}
#[test]
fn test_scheduler_add_remove() {
let scheduler = ScenarioScheduler::new();
let scenario = ChaosScenario::new("test", ChaosConfig::default());
let schedule = ScheduleType::Delayed { delay_seconds: 10 };
let scheduled = ScheduledScenario::new("sched1", scenario, schedule);
scheduler.add_schedule(scheduled.clone());
assert!(scheduler.get_schedule("sched1").is_some());
let removed = scheduler.remove_schedule("sched1");
assert!(removed.is_some());
assert!(scheduler.get_schedule("sched1").is_none());
}
#[test]
fn test_enable_disable() {
let scheduler = ScenarioScheduler::new();
let scenario = ChaosScenario::new("test", ChaosConfig::default());
let schedule = ScheduleType::Delayed { delay_seconds: 10 };
let scheduled = ScheduledScenario::new("sched1", scenario, schedule);
scheduler.add_schedule(scheduled);
scheduler.disable_schedule("sched1").unwrap();
let s = scheduler.get_schedule("sched1").unwrap();
assert!(!s.enabled);
scheduler.enable_schedule("sched1").unwrap();
let s = scheduler.get_schedule("sched1").unwrap();
assert!(s.enabled);
}
}