forge_runtime/cron/
scheduler.rs

1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use chrono::{DateTime, Utc};
6use forge_core::observability::{Metric, Span, SpanKind};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use super::registry::CronRegistry;
11use crate::observability::ObservabilityState;
12use forge_core::cron::CronContext;
13
14/// Cron run status.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CronStatus {
17    /// Pending execution.
18    Pending,
19    /// Currently running.
20    Running,
21    /// Completed successfully.
22    Completed,
23    /// Failed with error.
24    Failed,
25}
26
27impl CronStatus {
28    /// Convert to string for database storage.
29    pub fn as_str(&self) -> &'static str {
30        match self {
31            Self::Pending => "pending",
32            Self::Running => "running",
33            Self::Completed => "completed",
34            Self::Failed => "failed",
35        }
36    }
37}
38
39impl FromStr for CronStatus {
40    type Err = std::convert::Infallible;
41
42    fn from_str(s: &str) -> Result<Self, Self::Err> {
43        Ok(match s {
44            "pending" => Self::Pending,
45            "running" => Self::Running,
46            "completed" => Self::Completed,
47            "failed" => Self::Failed,
48            _ => Self::Pending,
49        })
50    }
51}
52
53/// A cron run record from the database.
54#[derive(Debug, Clone)]
55pub struct CronRecord {
56    /// Run ID.
57    pub id: Uuid,
58    /// Cron name.
59    pub cron_name: String,
60    /// Scheduled time.
61    pub scheduled_time: DateTime<Utc>,
62    /// Timezone.
63    pub timezone: String,
64    /// Current status.
65    pub status: CronStatus,
66    /// Node that executed the cron.
67    pub node_id: Option<Uuid>,
68    /// When execution started.
69    pub started_at: Option<DateTime<Utc>>,
70    /// When execution completed.
71    pub completed_at: Option<DateTime<Utc>>,
72    /// Error message if failed.
73    pub error: Option<String>,
74}
75
76impl CronRecord {
77    /// Create a new pending cron record.
78    pub fn new(
79        cron_name: impl Into<String>,
80        scheduled_time: DateTime<Utc>,
81        timezone: impl Into<String>,
82    ) -> Self {
83        Self {
84            id: Uuid::new_v4(),
85            cron_name: cron_name.into(),
86            scheduled_time,
87            timezone: timezone.into(),
88            status: CronStatus::Pending,
89            node_id: None,
90            started_at: None,
91            completed_at: None,
92            error: None,
93        }
94    }
95}
96
97/// Configuration for the cron runner.
98#[derive(Debug, Clone)]
99pub struct CronRunnerConfig {
100    /// How often to check for due crons.
101    pub poll_interval: Duration,
102    /// Node ID for this runner.
103    pub node_id: Uuid,
104    /// Whether this node is the leader (only leaders run crons).
105    pub is_leader: bool,
106}
107
108impl Default for CronRunnerConfig {
109    fn default() -> Self {
110        Self {
111            poll_interval: Duration::from_secs(1),
112            node_id: Uuid::new_v4(),
113            is_leader: true,
114        }
115    }
116}
117
118/// Cron scheduler and executor.
119pub struct CronRunner {
120    registry: Arc<CronRegistry>,
121    pool: sqlx::PgPool,
122    http_client: reqwest::Client,
123    config: CronRunnerConfig,
124    is_running: Arc<RwLock<bool>>,
125    observability: Option<ObservabilityState>,
126}
127
128impl CronRunner {
129    /// Create a new cron runner.
130    pub fn new(
131        registry: Arc<CronRegistry>,
132        pool: sqlx::PgPool,
133        http_client: reqwest::Client,
134        config: CronRunnerConfig,
135    ) -> Self {
136        Self {
137            registry,
138            pool,
139            http_client,
140            config,
141            is_running: Arc::new(RwLock::new(false)),
142            observability: None,
143        }
144    }
145
146    /// Create a new cron runner with observability.
147    pub fn with_observability(
148        registry: Arc<CronRegistry>,
149        pool: sqlx::PgPool,
150        http_client: reqwest::Client,
151        config: CronRunnerConfig,
152        observability: ObservabilityState,
153    ) -> Self {
154        Self {
155            registry,
156            pool,
157            http_client,
158            config,
159            is_running: Arc::new(RwLock::new(false)),
160            observability: Some(observability),
161        }
162    }
163
164    /// Start the cron runner loop.
165    pub async fn run(&self) -> forge_core::Result<()> {
166        {
167            let mut running = self.is_running.write().await;
168            if *running {
169                return Ok(());
170            }
171            *running = true;
172        }
173
174        tracing::info!("Cron runner starting");
175
176        loop {
177            if !*self.is_running.read().await {
178                break;
179            }
180
181            if self.config.is_leader {
182                if let Err(e) = self.tick().await {
183                    tracing::error!(error = %e, "Cron tick failed");
184                }
185            }
186
187            tokio::time::sleep(self.config.poll_interval).await;
188        }
189
190        tracing::info!("Cron runner stopped");
191        Ok(())
192    }
193
194    /// Stop the cron runner.
195    pub async fn stop(&self) {
196        let mut running = self.is_running.write().await;
197        *running = false;
198    }
199
200    /// Execute one tick of the scheduler.
201    async fn tick(&self) -> forge_core::Result<()> {
202        let now = Utc::now();
203        // Look back 2x poll interval to catch any scheduled times we might have missed
204        let window_start = now
205            - chrono::Duration::from_std(self.config.poll_interval * 2)
206                .unwrap_or(chrono::Duration::seconds(2));
207
208        let cron_list = self.registry.list();
209
210        if cron_list.is_empty() {
211            tracing::debug!("Cron tick: no crons registered");
212        } else {
213            tracing::debug!(
214                cron_count = cron_list.len(),
215                "Cron tick checking {} registered crons",
216                cron_list.len()
217            );
218        }
219
220        for entry in cron_list {
221            let info = &entry.info;
222
223            let scheduled_times = info
224                .schedule
225                .between_in_tz(window_start, now, info.timezone);
226            if scheduled_times.is_empty() {
227                tracing::debug!(
228                    cron = info.name,
229                    schedule = info.schedule.expression(),
230                    "No scheduled runs in window"
231                );
232            } else {
233                tracing::info!(
234                    cron = info.name,
235                    schedule = info.schedule.expression(),
236                    scheduled_count = scheduled_times.len(),
237                    "Found scheduled cron runs"
238                );
239            }
240
241            for scheduled in scheduled_times {
242                // Try to claim this cron run (database ensures exactly-once execution)
243                if let Ok(claimed) = self.try_claim(info.name, scheduled, info.timezone).await {
244                    if claimed {
245                        // Execute the cron
246                        self.execute_cron(entry, scheduled, false).await;
247                    }
248                }
249            }
250
251            // Handle catch-up if enabled
252            if info.catch_up {
253                if let Err(e) = self.handle_catch_up(entry).await {
254                    tracing::warn!(
255                        cron = info.name,
256                        error = %e,
257                        "Failed to process catch-up runs"
258                    );
259                }
260            }
261        }
262
263        Ok(())
264    }
265
266    /// Try to claim a cron run (returns true if claimed successfully).
267    async fn try_claim(
268        &self,
269        cron_name: &str,
270        scheduled_time: DateTime<Utc>,
271        _timezone: &str,
272    ) -> forge_core::Result<bool> {
273        // Insert with ON CONFLICT DO NOTHING to ensure exactly-once execution
274        let result = sqlx::query(
275            r#"
276            INSERT INTO forge_cron_runs (id, cron_name, scheduled_time, status, node_id, started_at)
277            VALUES ($1, $2, $3, 'running', $4, NOW())
278            ON CONFLICT (cron_name, scheduled_time) DO NOTHING
279            "#,
280        )
281        .bind(Uuid::new_v4())
282        .bind(cron_name)
283        .bind(scheduled_time)
284        .bind(self.config.node_id)
285        .execute(&self.pool)
286        .await
287        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
288
289        Ok(result.rows_affected() > 0)
290    }
291
292    /// Execute a cron job.
293    async fn execute_cron(
294        &self,
295        entry: &super::registry::CronEntry,
296        scheduled_time: DateTime<Utc>,
297        is_catch_up: bool,
298    ) {
299        let info = &entry.info;
300        let run_id = Uuid::new_v4();
301        let start = Instant::now();
302
303        tracing::info!(
304            cron = info.name,
305            scheduled_time = %scheduled_time,
306            is_catch_up = is_catch_up,
307            "Executing cron"
308        );
309
310        // Record cron run metric
311        if let Some(ref obs) = self.observability {
312            let mut metric = Metric::counter("cron_runs_total", 1.0);
313            metric
314                .labels
315                .insert("cron_name".to_string(), info.name.to_string());
316            metric
317                .labels
318                .insert("is_catch_up".to_string(), is_catch_up.to_string());
319            obs.record_metric(metric).await;
320        }
321
322        let ctx = CronContext::new(
323            run_id,
324            info.name.to_string(),
325            scheduled_time,
326            info.timezone.to_string(),
327            is_catch_up,
328            self.pool.clone(),
329            self.http_client.clone(),
330        );
331
332        // Execute with timeout
333        let handler = entry.handler.clone();
334        let result = tokio::time::timeout(info.timeout, handler(&ctx)).await;
335        let duration = start.elapsed();
336
337        // Record duration metric
338        if let Some(ref obs) = self.observability {
339            let mut duration_metric =
340                Metric::gauge("cron_duration_seconds", duration.as_secs_f64());
341            duration_metric
342                .labels
343                .insert("cron_name".to_string(), info.name.to_string());
344            obs.record_metric(duration_metric).await;
345        }
346
347        // Record cron execution span
348        if let Some(ref obs) = self.observability {
349            let mut span = Span::new(format!("cron.{}", info.name));
350            span.kind = SpanKind::Internal;
351            span.attributes.insert(
352                "cron.name".to_string(),
353                serde_json::Value::String(info.name.to_string()),
354            );
355            span.attributes.insert(
356                "cron.run_id".to_string(),
357                serde_json::Value::String(run_id.to_string()),
358            );
359            span.attributes.insert(
360                "cron.scheduled_time".to_string(),
361                serde_json::Value::String(scheduled_time.to_rfc3339()),
362            );
363            span.attributes.insert(
364                "cron.is_catch_up".to_string(),
365                serde_json::Value::Bool(is_catch_up),
366            );
367            span.attributes.insert(
368                "cron.duration_ms".to_string(),
369                serde_json::Value::Number(serde_json::Number::from(duration.as_millis() as u64)),
370            );
371
372            match &result {
373                Ok(Ok(())) => {
374                    span.end_ok();
375                }
376                Ok(Err(e)) => {
377                    span.end_error(e.to_string());
378                }
379                Err(_) => {
380                    span.end_error("Cron timed out");
381                }
382            }
383
384            obs.record_span(span).await;
385        }
386
387        match result {
388            Ok(Ok(())) => {
389                tracing::info!(cron = info.name, "Cron completed successfully");
390                self.mark_completed(info.name, scheduled_time).await;
391
392                // Record success metric
393                if let Some(ref obs) = self.observability {
394                    let mut metric = Metric::counter("cron_success_total", 1.0);
395                    metric
396                        .labels
397                        .insert("cron_name".to_string(), info.name.to_string());
398                    obs.record_metric(metric).await;
399                }
400            }
401            Ok(Err(e)) => {
402                tracing::error!(cron = info.name, error = %e, "Cron failed");
403                self.mark_failed(info.name, scheduled_time, &e.to_string())
404                    .await;
405
406                // Record failure metric
407                if let Some(ref obs) = self.observability {
408                    let mut metric = Metric::counter("cron_failures_total", 1.0);
409                    metric
410                        .labels
411                        .insert("cron_name".to_string(), info.name.to_string());
412                    metric
413                        .labels
414                        .insert("reason".to_string(), "error".to_string());
415                    obs.record_metric(metric).await;
416                }
417            }
418            Err(_) => {
419                tracing::error!(cron = info.name, "Cron timed out");
420                self.mark_failed(info.name, scheduled_time, "Execution timed out")
421                    .await;
422
423                // Record timeout metric
424                if let Some(ref obs) = self.observability {
425                    let mut metric = Metric::counter("cron_failures_total", 1.0);
426                    metric
427                        .labels
428                        .insert("cron_name".to_string(), info.name.to_string());
429                    metric
430                        .labels
431                        .insert("reason".to_string(), "timeout".to_string());
432                    obs.record_metric(metric).await;
433                }
434            }
435        }
436    }
437
438    /// Mark a cron run as completed.
439    async fn mark_completed(&self, cron_name: &str, scheduled_time: DateTime<Utc>) {
440        let _ = sqlx::query(
441            r#"
442            UPDATE forge_cron_runs
443            SET status = 'completed', completed_at = NOW()
444            WHERE cron_name = $1 AND scheduled_time = $2
445            "#,
446        )
447        .bind(cron_name)
448        .bind(scheduled_time)
449        .execute(&self.pool)
450        .await;
451    }
452
453    /// Mark a cron run as failed.
454    async fn mark_failed(&self, cron_name: &str, scheduled_time: DateTime<Utc>, error: &str) {
455        let _ = sqlx::query(
456            r#"
457            UPDATE forge_cron_runs
458            SET status = 'failed', completed_at = NOW(), error = $3
459            WHERE cron_name = $1 AND scheduled_time = $2
460            "#,
461        )
462        .bind(cron_name)
463        .bind(scheduled_time)
464        .bind(error)
465        .execute(&self.pool)
466        .await;
467    }
468
469    /// Handle catch-up for missed runs.
470    async fn handle_catch_up(&self, entry: &super::registry::CronEntry) -> forge_core::Result<()> {
471        let info = &entry.info;
472        let now = Utc::now();
473
474        // Find the last completed run
475        let last_run: Option<(DateTime<Utc>,)> = sqlx::query_as(
476            r#"
477            SELECT scheduled_time
478            FROM forge_cron_runs
479            WHERE cron_name = $1 AND status = 'completed'
480            ORDER BY scheduled_time DESC
481            LIMIT 1
482            "#,
483        )
484        .bind(info.name)
485        .fetch_optional(&self.pool)
486        .await
487        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
488
489        let start_time = last_run
490            .map(|(t,)| t)
491            .unwrap_or(now - chrono::Duration::days(1));
492
493        // Get all scheduled times between last run and now
494        let missed_times = info.schedule.between_in_tz(start_time, now, info.timezone);
495
496        // Limit catch-up runs
497        let to_catch_up: Vec<_> = missed_times
498            .into_iter()
499            .take(info.catch_up_limit as usize)
500            .collect();
501
502        for scheduled in to_catch_up {
503            // Try to claim and execute
504            if self.try_claim(info.name, scheduled, info.timezone).await? {
505                self.execute_cron(entry, scheduled, true).await;
506            }
507        }
508
509        Ok(())
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_cron_status_conversion() {
519        assert_eq!(CronStatus::Pending.as_str(), "pending");
520        assert_eq!(CronStatus::Running.as_str(), "running");
521        assert_eq!(CronStatus::Completed.as_str(), "completed");
522        assert_eq!(CronStatus::Failed.as_str(), "failed");
523
524        assert_eq!("pending".parse::<CronStatus>(), Ok(CronStatus::Pending));
525        assert_eq!("running".parse::<CronStatus>(), Ok(CronStatus::Running));
526        assert_eq!("completed".parse::<CronStatus>(), Ok(CronStatus::Completed));
527        assert_eq!("failed".parse::<CronStatus>(), Ok(CronStatus::Failed));
528    }
529
530    #[test]
531    fn test_cron_record_creation() {
532        let record = CronRecord::new("daily_cleanup", Utc::now(), "UTC");
533        assert_eq!(record.cron_name, "daily_cleanup");
534        assert_eq!(record.timezone, "UTC");
535        assert_eq!(record.status, CronStatus::Pending);
536        assert!(record.node_id.is_none());
537    }
538
539    #[test]
540    fn test_cron_runner_config_default() {
541        let config = CronRunnerConfig::default();
542        assert_eq!(config.poll_interval, Duration::from_secs(1));
543        assert!(config.is_leader);
544    }
545}