use crate::config::{Config, Job, JobExecution, JobStatus};
use crate::error::{Error, Result};
use crate::executor::JobExecutor;
use crate::scheduler::state::{ScheduledJob, SchedulerState};
use crate::scheduler::{JobTrigger, SchedulerEvent, SchedulerMessage};
use chrono::Utc;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::{self, Duration, Instant};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
struct TimedTrigger {
trigger: JobTrigger,
}
impl PartialEq for TimedTrigger {
fn eq(&self, other: &Self) -> bool {
self.trigger.scheduled_at == other.trigger.scheduled_at
}
}
impl Eq for TimedTrigger {}
impl PartialOrd for TimedTrigger {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for TimedTrigger {
fn cmp(&self, other: &Self) -> Ordering {
other.trigger.scheduled_at.cmp(&self.trigger.scheduled_at)
}
}
pub struct Scheduler {
config: Arc<RwLock<Config>>,
config_path: PathBuf,
pub(crate) state: Arc<RwLock<SchedulerState>>,
executor: Arc<JobExecutor>,
message_rx: mpsc::Receiver<SchedulerMessage>,
#[allow(dead_code)]
message_tx: mpsc::Sender<SchedulerMessage>,
event_tx: broadcast::Sender<SchedulerEvent>,
trigger_queue: BinaryHeap<TimedTrigger>,
job_semaphore: Arc<tokio::sync::Semaphore>,
shutdown: bool,
}
impl Scheduler {
pub fn new(config: Config, config_path: PathBuf) -> (Self, SchedulerHandle) {
let (message_tx, message_rx) = mpsc::channel(100);
let (event_tx, _) = broadcast::channel(100);
let max_concurrent = config.settings.max_concurrent_jobs;
let semaphore = Arc::new(tokio::sync::Semaphore::new(if max_concurrent == 0 {
usize::MAX
} else {
max_concurrent
}));
let executor = Arc::new(JobExecutor::new(
config.settings.shell.clone(),
config.settings.shell_args.clone(),
));
let scheduler = Self {
config: Arc::new(RwLock::new(config)),
config_path,
state: Arc::new(RwLock::new(SchedulerState::new())),
executor,
message_rx,
message_tx: message_tx.clone(),
event_tx: event_tx.clone(),
trigger_queue: BinaryHeap::new(),
job_semaphore: semaphore,
shutdown: false,
};
let handle = SchedulerHandle {
message_tx,
event_tx,
};
(scheduler, handle)
}
async fn initialize(&mut self) -> Result<()> {
let config = self.config.read().await;
let mut state = self.state.write().await;
for (name, job) in config.enabled_jobs() {
let next_run = job.next_run();
let scheduled_job = ScheduledJob::new(name.clone(), job.schedule.clone(), job.enabled)
.with_next_run(next_run);
state.add_job(scheduled_job);
if let Some(next) = next_run {
self.trigger_queue.push(TimedTrigger {
trigger: JobTrigger::new(name.clone(), next, job.clone()),
});
}
}
info!(
"Scheduler initialized with {} jobs ({} enabled)",
state.total_jobs, state.enabled_jobs
);
Ok(())
}
async fn run_startup_jobs(&mut self) -> Result<()> {
let config = self.config.read().await;
let startup_jobs: Vec<_> = config
.enabled_jobs()
.filter(|(_, job)| job.run_on_startup)
.map(|(name, _)| name.clone())
.collect();
drop(config);
for job_name in startup_jobs {
info!(job = %job_name, "Running startup job");
self.trigger_job(&job_name).await?;
}
Ok(())
}
pub async fn run(mut self) -> Result<()> {
self.initialize().await?;
let _ = self.event_tx.send(SchedulerEvent::Started);
self.run_startup_jobs().await?;
info!("Scheduler started");
loop {
let sleep_duration = self.calculate_sleep_duration();
tokio::select! {
_ = time::sleep(sleep_duration) => {
self.process_due_triggers().await?;
}
Some(msg) = self.message_rx.recv() => {
if self.handle_message(msg).await? {
break;
}
}
}
if self.shutdown {
break;
}
}
info!("Scheduler stopped");
let _ = self.event_tx.send(SchedulerEvent::Stopped);
Ok(())
}
fn calculate_sleep_duration(&self) -> Duration {
if let Some(next) = self.trigger_queue.peek() {
let ms = next.trigger.ms_until_due();
if ms <= 0 {
Duration::from_millis(0)
} else {
Duration::from_millis(ms as u64)
}
} else {
Duration::from_secs(60)
}
}
async fn process_due_triggers(&mut self) -> Result<()> {
let now = Utc::now();
while let Some(timed) = self.trigger_queue.peek() {
if timed.trigger.scheduled_at > now {
break;
}
let timed = self.trigger_queue.pop().unwrap();
self.execute_trigger(timed.trigger).await?;
}
Ok(())
}
async fn execute_trigger(&mut self, trigger: JobTrigger) -> Result<()> {
let job_name = trigger.job_name.clone();
let job = trigger.job.clone();
if let Some(next) = job.next_run() {
if next > trigger.scheduled_at {
self.trigger_queue.push(TimedTrigger {
trigger: JobTrigger::new(job_name.clone(), next, job.clone()),
});
let mut state = self.state.write().await;
if let Some(sj) = state.get_job_mut(&job_name) {
sj.next_run = Some(next);
}
}
}
{
let state = self.state.read().await;
if let Some(sj) = state.get_job(&job_name) {
if sj.is_running {
debug!(job = %job_name, "Job already running, skipping");
return Ok(());
}
}
}
self.spawn_job_execution(job_name, job).await
}
async fn spawn_job_execution(&self, job_name: String, job: Job) -> Result<()> {
let executor = Arc::clone(&self.executor);
let state = Arc::clone(&self.state);
let event_tx = self.event_tx.clone();
let semaphore = Arc::clone(&self.job_semaphore);
let history_size = self.config.read().await.settings.history_size;
tokio::spawn(async move {
let _permit = match semaphore.acquire().await {
Ok(p) => p,
Err(_) => {
error!(job = %job_name, "Failed to acquire job semaphore");
return;
}
};
let mut execution = JobExecution::new(&job_name);
let execution_id = execution.id;
{
let mut s = state.write().await;
s.record_job_start(&job_name, execution_id);
}
let _ = event_tx.send(SchedulerEvent::JobStarting {
job_name: job_name.clone(),
execution_id,
});
info!(job = %job_name, execution_id = %execution_id, "Starting job");
let start = Instant::now();
let result = executor.execute(&job_name, &job).await;
let duration = start.elapsed();
let (status, next_run) = match result {
Ok((exit_code, stdout, stderr)) => {
if exit_code == 0 {
execution.complete_success(exit_code, stdout, stderr);
info!(
job = %job_name,
duration_ms = %duration.as_millis(),
"Job completed successfully"
);
(JobStatus::Success, job.next_run())
} else {
let error = format!("Exit code: {}", exit_code);
execution.complete_failed(error.clone(), Some(exit_code), stdout, stderr);
warn!(
job = %job_name,
exit_code = %exit_code,
"Job failed"
);
(JobStatus::Failed { error }, job.next_run())
}
}
Err(Error::JobTimeout { .. }) => {
execution.complete_timeout();
warn!(job = %job_name, "Job timed out");
(JobStatus::Timeout, job.next_run())
}
Err(e) => {
let error = e.to_string();
execution.complete_failed(error.clone(), None, String::new(), String::new());
error!(job = %job_name, error = %e, "Job execution error");
(JobStatus::Failed { error }, job.next_run())
}
};
let success = matches!(status, JobStatus::Success);
{
let mut s = state.write().await;
s.record_job_completion(&job_name, status, execution, next_run, history_size);
}
let _ = event_tx.send(SchedulerEvent::JobCompleted {
job_name,
execution_id,
success,
duration_ms: duration.as_millis() as u64,
});
});
Ok(())
}
async fn trigger_job(&mut self, job_name: &str) -> Result<()> {
let config = self.config.read().await;
let job = config
.get_job(job_name)
.ok_or_else(|| Error::job_not_found(job_name))?
.clone();
drop(config);
info!(job = %job_name, "Manually triggering job");
self.spawn_job_execution(job_name.to_string(), job).await
}
async fn reload_config(&mut self) -> Result<()> {
info!("Reloading configuration from {:?}", self.config_path);
let new_config = Config::from_file(&self.config_path)?;
self.executor.update_shell(
new_config.settings.shell.clone(),
new_config.settings.shell_args.clone(),
);
self.trigger_queue.clear();
let mut state = self.state.write().await;
let running_jobs: Vec<_> = state
.jobs
.iter()
.filter(|(_, j)| j.is_running)
.map(|(n, j)| (n.clone(), j.current_execution_id))
.collect();
state.jobs.clear();
state.total_jobs = 0;
state.enabled_jobs = 0;
for (name, job) in new_config.enabled_jobs() {
let next_run = job.next_run();
let mut scheduled_job =
ScheduledJob::new(name.clone(), job.schedule.clone(), job.enabled)
.with_next_run(next_run);
if let Some((_, exec_id)) = running_jobs.iter().find(|(n, _)| n == name) {
scheduled_job.is_running = true;
scheduled_job.current_execution_id = *exec_id;
}
state.add_job(scheduled_job);
if let Some(next) = next_run {
self.trigger_queue.push(TimedTrigger {
trigger: JobTrigger::new(name.clone(), next, job.clone()),
});
}
}
let job_count = state.enabled_jobs;
drop(state);
*self.config.write().await = new_config;
info!("Configuration reloaded: {} enabled jobs", job_count);
let _ = self
.event_tx
.send(SchedulerEvent::ConfigReloaded { job_count });
Ok(())
}
async fn handle_message(&mut self, msg: SchedulerMessage) -> Result<bool> {
match msg {
SchedulerMessage::TriggerJob { job_name } => {
if let Err(e) = self.trigger_job(&job_name).await {
warn!(job = %job_name, error = %e, "Failed to trigger job");
}
}
SchedulerMessage::ReloadConfig => {
if let Err(e) = self.reload_config().await {
error!(error = %e, "Failed to reload configuration");
}
}
SchedulerMessage::GetStatus { response_tx } => {
let mut state = self.state.read().await.clone();
state.update_time();
let _ = response_tx.send(state);
}
SchedulerMessage::StopJob { job_name } => {
warn!(job = %job_name, "Job cancellation not yet implemented");
}
SchedulerMessage::Shutdown => {
info!("Shutdown requested");
self.shutdown = true;
return Ok(true);
}
}
Ok(false)
}
}
#[derive(Clone)]
pub struct SchedulerHandle {
message_tx: mpsc::Sender<SchedulerMessage>,
event_tx: broadcast::Sender<SchedulerEvent>,
}
impl SchedulerHandle {
pub async fn send(&self, msg: SchedulerMessage) -> Result<()> {
self.message_tx
.send(msg)
.await
.map_err(|_| Error::ChannelSend)
}
pub async fn trigger_job(&self, job_name: impl Into<String>) -> Result<()> {
self.send(SchedulerMessage::TriggerJob {
job_name: job_name.into(),
})
.await
}
pub async fn reload_config(&self) -> Result<()> {
self.send(SchedulerMessage::ReloadConfig).await
}
pub async fn get_status(&self) -> Result<SchedulerState> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.send(SchedulerMessage::GetStatus { response_tx: tx })
.await?;
rx.await.map_err(|_| Error::ChannelSend)
}
pub async fn stop_job(&self, job_name: impl Into<String>) -> Result<()> {
self.send(SchedulerMessage::StopJob {
job_name: job_name.into(),
})
.await
}
pub async fn shutdown(&self) -> Result<()> {
self.send(SchedulerMessage::Shutdown).await
}
pub fn subscribe(&self) -> broadcast::Receiver<SchedulerEvent> {
self.event_tx.subscribe()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> Config {
let toml = r#"
[settings]
max_concurrent_jobs = 2
[jobs.test]
schedule = "* * * * *"
command = "echo test"
"#;
Config::from_str(toml, "test.toml").unwrap()
}
#[tokio::test]
async fn test_scheduler_creation() {
let config = test_config();
let (_scheduler, handle) = Scheduler::new(config, PathBuf::from("test.toml"));
assert!(handle.trigger_job("test").await.is_ok());
}
#[tokio::test]
async fn test_scheduler_status() {
let config = test_config();
let (mut scheduler, _handle) = Scheduler::new(config, PathBuf::from("test.toml"));
scheduler.initialize().await.unwrap();
let state = scheduler.state.read().await;
assert_eq!(state.total_jobs, 1);
assert_eq!(state.enabled_jobs, 1);
}
}