use crate::types::AgentId;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleConfig {
pub agent_id: AgentId,
pub cron_expr: String,
pub observation: String,
pub enabled: bool,
pub created_at: DateTime<Utc>,
pub last_run: Option<DateTime<Utc>>,
pub run_count: u64,
}
#[derive(Debug, Clone)]
pub struct ParsedCron {
pub expression: String,
pub next_fire: Option<DateTime<Utc>>,
}
pub fn parse_cron(expr: &str) -> Result<ParsedCron, SchedulerError> {
let fields: Vec<&str> = expr.split_whitespace().collect();
if fields.len() < 5 || fields.len() > 7 {
return Err(SchedulerError::InvalidCron {
expression: expr.to_string(),
message: format!("Expected 5-7 fields, got {}", fields.len()),
});
}
for (i, field) in fields.iter().enumerate() {
validate_cron_field(field, i).map_err(|msg| SchedulerError::InvalidCron {
expression: expr.to_string(),
message: msg,
})?;
}
Ok(ParsedCron {
expression: expr.to_string(),
next_fire: None, })
}
fn validate_cron_field(field: &str, index: usize) -> Result<(), String> {
let field_name = match index {
0 => "minute",
1 => "hour",
2 => "day-of-month",
3 => "month",
4 => "day-of-week",
5 => "year",
6 => "seconds",
_ => "unknown",
};
if field == "*" || field == "?" {
return Ok(());
}
for part in field.split(',') {
let part = part.trim();
if part.contains('/') {
let parts: Vec<&str> = part.split('/').collect();
if parts.len() != 2 {
return Err(format!("Invalid step in {} field: {}", field_name, part));
}
if parts[0] != "*" {
parts[0].parse::<u32>().map_err(|_| {
format!("Invalid base value in {} field: {}", field_name, parts[0])
})?;
}
parts[1]
.parse::<u32>()
.map_err(|_| format!("Invalid step value in {} field: {}", field_name, parts[1]))?;
} else if part.contains('-') {
let parts: Vec<&str> = part.split('-').collect();
if parts.len() != 2 {
return Err(format!("Invalid range in {} field: {}", field_name, part));
}
parts[0].parse::<u32>().map_err(|_| {
format!("Invalid range start in {} field: {}", field_name, parts[0])
})?;
parts[1]
.parse::<u32>()
.map_err(|_| format!("Invalid range end in {} field: {}", field_name, parts[1]))?;
} else {
part.parse::<u32>()
.map_err(|_| format!("Invalid value in {} field: {}", field_name, part))?;
}
}
Ok(())
}
pub struct AgentScheduler {
schedules: Arc<RwLock<HashMap<String, ScheduleConfig>>>,
}
impl Default for AgentScheduler {
fn default() -> Self {
Self::new()
}
}
impl AgentScheduler {
pub fn new() -> Self {
Self {
schedules: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn schedule(
&self,
name: impl Into<String>,
agent_id: AgentId,
cron_expr: impl Into<String>,
observation: impl Into<String>,
) -> Result<(), SchedulerError> {
let cron_expr = cron_expr.into();
parse_cron(&cron_expr)?;
let config = ScheduleConfig {
agent_id,
cron_expr,
observation: observation.into(),
enabled: true,
created_at: Utc::now(),
last_run: None,
run_count: 0,
};
self.schedules.write().await.insert(name.into(), config);
Ok(())
}
pub async fn get_schedule(&self, name: &str) -> Option<ScheduleConfig> {
self.schedules.read().await.get(name).cloned()
}
pub async fn list_schedules(&self) -> Vec<(String, ScheduleConfig)> {
self.schedules
.read()
.await
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub async fn set_enabled(&self, name: &str, enabled: bool) -> bool {
if let Some(config) = self.schedules.write().await.get_mut(name) {
config.enabled = enabled;
true
} else {
false
}
}
pub async fn remove_schedule(&self, name: &str) -> bool {
self.schedules.write().await.remove(name).is_some()
}
pub async fn record_execution(&self, name: &str) -> bool {
if let Some(config) = self.schedules.write().await.get_mut(name) {
config.last_run = Some(Utc::now());
config.run_count += 1;
true
} else {
false
}
}
pub async fn due_schedules(&self) -> Vec<(String, ScheduleConfig)> {
self.schedules
.read()
.await
.iter()
.filter(|(_, config)| config.enabled)
.map(|(name, config)| (name.clone(), config.clone()))
.collect()
}
}
#[derive(Debug, thiserror::Error)]
pub enum SchedulerError {
#[error("Invalid cron expression '{expression}': {message}")]
InvalidCron { expression: String, message: String },
#[error("Schedule '{name}' not found")]
NotFound { name: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cron_valid_5_field() {
let result = parse_cron("0 * * * *");
assert!(result.is_ok());
assert_eq!(result.unwrap().expression, "0 * * * *");
}
#[test]
fn test_parse_cron_valid_with_ranges() {
assert!(parse_cron("0 9-17 * * 1-5").is_ok());
}
#[test]
fn test_parse_cron_valid_with_steps() {
assert!(parse_cron("*/15 * * * *").is_ok());
}
#[test]
fn test_parse_cron_valid_with_lists() {
assert!(parse_cron("0 0 1,15 * *").is_ok());
}
#[test]
fn test_parse_cron_invalid_too_few_fields() {
let result = parse_cron("0 *");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("5-7 fields"));
}
#[test]
fn test_parse_cron_invalid_field_value() {
let result = parse_cron("abc * * * *");
assert!(result.is_err());
}
#[test]
fn test_parse_cron_valid_6_field() {
assert!(parse_cron("0 30 9 * * 1-5").is_ok());
}
#[tokio::test]
async fn test_scheduler_add_and_get() {
let scheduler = AgentScheduler::new();
let agent_id = AgentId::new();
scheduler
.schedule("daily_check", agent_id, "0 9 * * *", "Run daily analysis")
.await
.unwrap();
let config = scheduler.get_schedule("daily_check").await.unwrap();
assert_eq!(config.agent_id, agent_id);
assert_eq!(config.cron_expr, "0 9 * * *");
assert!(config.enabled);
assert_eq!(config.run_count, 0);
}
#[tokio::test]
async fn test_scheduler_list() {
let scheduler = AgentScheduler::new();
scheduler
.schedule("a", AgentId::new(), "0 * * * *", "task a")
.await
.unwrap();
scheduler
.schedule("b", AgentId::new(), "*/5 * * * *", "task b")
.await
.unwrap();
let schedules = scheduler.list_schedules().await;
assert_eq!(schedules.len(), 2);
}
#[tokio::test]
async fn test_scheduler_enable_disable() {
let scheduler = AgentScheduler::new();
scheduler
.schedule("job", AgentId::new(), "0 * * * *", "task")
.await
.unwrap();
assert!(scheduler.set_enabled("job", false).await);
assert!(!scheduler.get_schedule("job").await.unwrap().enabled);
assert!(scheduler.set_enabled("job", true).await);
assert!(scheduler.get_schedule("job").await.unwrap().enabled);
assert!(!scheduler.set_enabled("nonexistent", false).await);
}
#[tokio::test]
async fn test_scheduler_remove() {
let scheduler = AgentScheduler::new();
scheduler
.schedule("temp", AgentId::new(), "0 * * * *", "task")
.await
.unwrap();
assert!(scheduler.remove_schedule("temp").await);
assert!(scheduler.get_schedule("temp").await.is_none());
assert!(!scheduler.remove_schedule("temp").await);
}
#[tokio::test]
async fn test_scheduler_record_execution() {
let scheduler = AgentScheduler::new();
scheduler
.schedule("job", AgentId::new(), "0 * * * *", "task")
.await
.unwrap();
assert!(scheduler.record_execution("job").await);
let config = scheduler.get_schedule("job").await.unwrap();
assert_eq!(config.run_count, 1);
assert!(config.last_run.is_some());
assert!(scheduler.record_execution("job").await);
let config = scheduler.get_schedule("job").await.unwrap();
assert_eq!(config.run_count, 2);
assert!(!scheduler.record_execution("nonexistent").await);
}
#[tokio::test]
async fn test_scheduler_due_schedules() {
let scheduler = AgentScheduler::new();
scheduler
.schedule("enabled", AgentId::new(), "0 * * * *", "task")
.await
.unwrap();
scheduler
.schedule("disabled", AgentId::new(), "0 * * * *", "task")
.await
.unwrap();
scheduler.set_enabled("disabled", false).await;
let due = scheduler.due_schedules().await;
assert_eq!(due.len(), 1);
assert_eq!(due[0].0, "enabled");
}
#[tokio::test]
async fn test_scheduler_invalid_cron_rejected() {
let scheduler = AgentScheduler::new();
let result = scheduler
.schedule("bad", AgentId::new(), "invalid cron", "task")
.await;
assert!(result.is_err());
}
#[test]
fn test_schedule_config_serialization() {
let config = ScheduleConfig {
agent_id: AgentId::new(),
cron_expr: "0 9 * * *".into(),
observation: "Run analysis".into(),
enabled: true,
created_at: Utc::now(),
last_run: None,
run_count: 0,
};
let json = serde_json::to_string(&config).unwrap();
let restored: ScheduleConfig = serde_json::from_str(&json).unwrap();
assert_eq!(restored.cron_expr, "0 9 * * *");
assert!(restored.enabled);
}
}