Skip to main content

forge_runtime/cron/
scheduler.rs

1use std::str::FromStr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use chrono::{DateTime, Utc};
6use forge_core::observability::{Metric, Span, SpanKind};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use super::registry::CronRegistry;
11use crate::observability::ObservabilityState;
12use forge_core::cron::CronContext;
13
14/// Cron run status.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CronStatus {
17    /// Pending execution.
18    Pending,
19    /// Currently running.
20    Running,
21    /// Completed successfully.
22    Completed,
23    /// Failed with error.
24    Failed,
25}
26
27impl CronStatus {
28    /// Convert to string for database storage.
29    pub fn as_str(&self) -> &'static str {
30        match self {
31            Self::Pending => "pending",
32            Self::Running => "running",
33            Self::Completed => "completed",
34            Self::Failed => "failed",
35        }
36    }
37}
38
39impl FromStr for CronStatus {
40    type Err = std::convert::Infallible;
41
42    fn from_str(s: &str) -> Result<Self, Self::Err> {
43        Ok(match s {
44            "pending" => Self::Pending,
45            "running" => Self::Running,
46            "completed" => Self::Completed,
47            "failed" => Self::Failed,
48            _ => Self::Pending,
49        })
50    }
51}
52
53/// A cron run record from the database.
54#[derive(Debug, Clone)]
55pub struct CronRecord {
56    /// Run ID.
57    pub id: Uuid,
58    /// Cron name.
59    pub cron_name: String,
60    /// Scheduled time.
61    pub scheduled_time: DateTime<Utc>,
62    /// Timezone.
63    pub timezone: String,
64    /// Current status.
65    pub status: CronStatus,
66    /// Node that executed the cron.
67    pub node_id: Option<Uuid>,
68    /// When execution started.
69    pub started_at: Option<DateTime<Utc>>,
70    /// When execution completed.
71    pub completed_at: Option<DateTime<Utc>>,
72    /// Error message if failed.
73    pub error: Option<String>,
74}
75
76impl CronRecord {
77    /// Create a new pending cron record.
78    pub fn new(
79        cron_name: impl Into<String>,
80        scheduled_time: DateTime<Utc>,
81        timezone: impl Into<String>,
82    ) -> Self {
83        Self {
84            id: Uuid::new_v4(),
85            cron_name: cron_name.into(),
86            scheduled_time,
87            timezone: timezone.into(),
88            status: CronStatus::Pending,
89            node_id: None,
90            started_at: None,
91            completed_at: None,
92            error: None,
93        }
94    }
95}
96
97/// Configuration for the cron runner.
98#[derive(Debug, Clone)]
99pub struct CronRunnerConfig {
100    /// How often to check for due crons.
101    pub poll_interval: Duration,
102    /// Node ID for this runner.
103    pub node_id: Uuid,
104    /// Whether this node is the leader (only leaders run crons).
105    pub is_leader: bool,
106}
107
108impl Default for CronRunnerConfig {
109    fn default() -> Self {
110        Self {
111            poll_interval: Duration::from_secs(1),
112            node_id: Uuid::new_v4(),
113            is_leader: true,
114        }
115    }
116}
117
118/// Cron scheduler and executor.
119pub struct CronRunner {
120    registry: Arc<CronRegistry>,
121    pool: sqlx::PgPool,
122    http_client: reqwest::Client,
123    config: CronRunnerConfig,
124    is_running: Arc<RwLock<bool>>,
125    observability: Option<ObservabilityState>,
126}
127
128impl CronRunner {
129    /// Create a new cron runner.
130    pub fn new(
131        registry: Arc<CronRegistry>,
132        pool: sqlx::PgPool,
133        http_client: reqwest::Client,
134        config: CronRunnerConfig,
135    ) -> Self {
136        Self {
137            registry,
138            pool,
139            http_client,
140            config,
141            is_running: Arc::new(RwLock::new(false)),
142            observability: None,
143        }
144    }
145
146    /// Create a new cron runner with observability.
147    pub fn with_observability(
148        registry: Arc<CronRegistry>,
149        pool: sqlx::PgPool,
150        http_client: reqwest::Client,
151        config: CronRunnerConfig,
152        observability: ObservabilityState,
153    ) -> Self {
154        Self {
155            registry,
156            pool,
157            http_client,
158            config,
159            is_running: Arc::new(RwLock::new(false)),
160            observability: Some(observability),
161        }
162    }
163
164    /// Start the cron runner loop.
165    pub async fn run(&self) -> forge_core::Result<()> {
166        {
167            let mut running = self.is_running.write().await;
168            if *running {
169                return Ok(());
170            }
171            *running = true;
172        }
173
174        tracing::info!("Cron runner starting");
175
176        loop {
177            if !*self.is_running.read().await {
178                break;
179            }
180
181            if self.config.is_leader {
182                if let Err(e) = self.tick().await {
183                    tracing::error!(error = %e, "Cron tick failed");
184                }
185            }
186
187            tokio::time::sleep(self.config.poll_interval).await;
188        }
189
190        tracing::info!("Cron runner stopped");
191        Ok(())
192    }
193
194    /// Stop the cron runner.
195    pub async fn stop(&self) {
196        let mut running = self.is_running.write().await;
197        *running = false;
198    }
199
200    /// Execute one tick of the scheduler.
201    async fn tick(&self) -> forge_core::Result<()> {
202        let now = Utc::now();
203        // Look back 2x poll interval to catch any scheduled times we might have missed
204        let window_start = now
205            - chrono::Duration::from_std(self.config.poll_interval * 2)
206                .unwrap_or(chrono::Duration::seconds(2));
207
208        let cron_list = self.registry.list();
209
210        if cron_list.is_empty() {
211            tracing::trace!("Cron tick: no crons registered");
212        } else {
213            tracing::trace!(
214                cron_count = cron_list.len(),
215                "Cron tick checking {} registered crons",
216                cron_list.len()
217            );
218        }
219
220        for entry in cron_list {
221            let info = &entry.info;
222
223            let scheduled_times = info
224                .schedule
225                .between_in_tz(window_start, now, info.timezone);
226            if !scheduled_times.is_empty() {
227                tracing::trace!(
228                    cron = info.name,
229                    schedule = info.schedule.expression(),
230                    scheduled_count = scheduled_times.len(),
231                    "Found scheduled cron runs"
232                );
233            }
234
235            for scheduled in scheduled_times {
236                // Try to claim this cron run (database ensures exactly-once execution)
237                if let Ok(claimed) = self.try_claim(info.name, scheduled, info.timezone).await {
238                    if claimed {
239                        // Execute the cron
240                        self.execute_cron(entry, scheduled, false).await;
241                    }
242                }
243            }
244
245            // Handle catch-up if enabled
246            if info.catch_up {
247                if let Err(e) = self.handle_catch_up(entry).await {
248                    tracing::warn!(
249                        cron = info.name,
250                        error = %e,
251                        "Failed to process catch-up runs"
252                    );
253                }
254            }
255        }
256
257        Ok(())
258    }
259
260    /// Try to claim a cron run (returns true if claimed successfully).
261    async fn try_claim(
262        &self,
263        cron_name: &str,
264        scheduled_time: DateTime<Utc>,
265        _timezone: &str,
266    ) -> forge_core::Result<bool> {
267        // Insert with ON CONFLICT DO NOTHING to ensure exactly-once execution
268        let result = sqlx::query(
269            r#"
270            INSERT INTO forge_cron_runs (id, cron_name, scheduled_time, status, node_id, started_at)
271            VALUES ($1, $2, $3, 'running', $4, NOW())
272            ON CONFLICT (cron_name, scheduled_time) DO NOTHING
273            "#,
274        )
275        .bind(Uuid::new_v4())
276        .bind(cron_name)
277        .bind(scheduled_time)
278        .bind(self.config.node_id)
279        .execute(&self.pool)
280        .await
281        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
282
283        Ok(result.rows_affected() > 0)
284    }
285
286    /// Execute a cron job.
287    async fn execute_cron(
288        &self,
289        entry: &super::registry::CronEntry,
290        scheduled_time: DateTime<Utc>,
291        is_catch_up: bool,
292    ) {
293        let info = &entry.info;
294        let run_id = Uuid::new_v4();
295        let start = Instant::now();
296
297        tracing::debug!(
298            cron = info.name,
299            scheduled_time = %scheduled_time,
300            is_catch_up = is_catch_up,
301            "Executing cron"
302        );
303
304        // Record cron run metric
305        if let Some(ref obs) = self.observability {
306            let mut metric = Metric::counter("cron_runs_total", 1.0);
307            metric
308                .labels
309                .insert("cron_name".to_string(), info.name.to_string());
310            metric
311                .labels
312                .insert("is_catch_up".to_string(), is_catch_up.to_string());
313            obs.record_metric(metric).await;
314        }
315
316        let ctx = CronContext::new(
317            run_id,
318            info.name.to_string(),
319            scheduled_time,
320            info.timezone.to_string(),
321            is_catch_up,
322            self.pool.clone(),
323            self.http_client.clone(),
324        );
325
326        // Execute with timeout
327        let handler = entry.handler.clone();
328        let result = tokio::time::timeout(info.timeout, handler(&ctx)).await;
329        let duration = start.elapsed();
330
331        // Record duration metric
332        if let Some(ref obs) = self.observability {
333            let mut duration_metric =
334                Metric::gauge("cron_duration_seconds", duration.as_secs_f64());
335            duration_metric
336                .labels
337                .insert("cron_name".to_string(), info.name.to_string());
338            obs.record_metric(duration_metric).await;
339        }
340
341        // Record cron execution span
342        if let Some(ref obs) = self.observability {
343            let mut span = Span::new(format!("cron.{}", info.name));
344            span.kind = SpanKind::Internal;
345            span.attributes.insert(
346                "cron.name".to_string(),
347                serde_json::Value::String(info.name.to_string()),
348            );
349            span.attributes.insert(
350                "cron.run_id".to_string(),
351                serde_json::Value::String(run_id.to_string()),
352            );
353            span.attributes.insert(
354                "cron.scheduled_time".to_string(),
355                serde_json::Value::String(scheduled_time.to_rfc3339()),
356            );
357            span.attributes.insert(
358                "cron.is_catch_up".to_string(),
359                serde_json::Value::Bool(is_catch_up),
360            );
361            span.attributes.insert(
362                "cron.duration_ms".to_string(),
363                serde_json::Value::Number(serde_json::Number::from(duration.as_millis() as u64)),
364            );
365
366            match &result {
367                Ok(Ok(())) => {
368                    span.end_ok();
369                }
370                Ok(Err(e)) => {
371                    span.end_error(e.to_string());
372                }
373                Err(_) => {
374                    span.end_error("Cron timed out");
375                }
376            }
377
378            obs.record_span(span).await;
379        }
380
381        match result {
382            Ok(Ok(())) => {
383                tracing::info!(
384                    cron = info.name,
385                    scheduled_time = %scheduled_time,
386                    duration_ms = start.elapsed().as_millis(),
387                    "Cron executed"
388                );
389                self.mark_completed(info.name, scheduled_time).await;
390
391                // Record success metric
392                if let Some(ref obs) = self.observability {
393                    let mut metric = Metric::counter("cron_success_total", 1.0);
394                    metric
395                        .labels
396                        .insert("cron_name".to_string(), info.name.to_string());
397                    obs.record_metric(metric).await;
398                }
399            }
400            Ok(Err(e)) => {
401                tracing::error!(cron = info.name, error = %e, "Cron failed");
402                self.mark_failed(info.name, scheduled_time, &e.to_string())
403                    .await;
404
405                // Record failure metric
406                if let Some(ref obs) = self.observability {
407                    let mut metric = Metric::counter("cron_failures_total", 1.0);
408                    metric
409                        .labels
410                        .insert("cron_name".to_string(), info.name.to_string());
411                    metric
412                        .labels
413                        .insert("reason".to_string(), "error".to_string());
414                    obs.record_metric(metric).await;
415                }
416            }
417            Err(_) => {
418                tracing::error!(cron = info.name, "Cron timed out");
419                self.mark_failed(info.name, scheduled_time, "Execution timed out")
420                    .await;
421
422                // Record timeout metric
423                if let Some(ref obs) = self.observability {
424                    let mut metric = Metric::counter("cron_failures_total", 1.0);
425                    metric
426                        .labels
427                        .insert("cron_name".to_string(), info.name.to_string());
428                    metric
429                        .labels
430                        .insert("reason".to_string(), "timeout".to_string());
431                    obs.record_metric(metric).await;
432                }
433            }
434        }
435    }
436
437    /// Mark a cron run as completed.
438    async fn mark_completed(&self, cron_name: &str, scheduled_time: DateTime<Utc>) {
439        let _ = sqlx::query(
440            r#"
441            UPDATE forge_cron_runs
442            SET status = 'completed', completed_at = NOW()
443            WHERE cron_name = $1 AND scheduled_time = $2
444            "#,
445        )
446        .bind(cron_name)
447        .bind(scheduled_time)
448        .execute(&self.pool)
449        .await;
450    }
451
452    /// Mark a cron run as failed.
453    async fn mark_failed(&self, cron_name: &str, scheduled_time: DateTime<Utc>, error: &str) {
454        let _ = sqlx::query(
455            r#"
456            UPDATE forge_cron_runs
457            SET status = 'failed', completed_at = NOW(), error = $3
458            WHERE cron_name = $1 AND scheduled_time = $2
459            "#,
460        )
461        .bind(cron_name)
462        .bind(scheduled_time)
463        .bind(error)
464        .execute(&self.pool)
465        .await;
466    }
467
468    /// Handle catch-up for missed runs.
469    async fn handle_catch_up(&self, entry: &super::registry::CronEntry) -> forge_core::Result<()> {
470        let info = &entry.info;
471        let now = Utc::now();
472
473        // Find the last completed run
474        let last_run: Option<(DateTime<Utc>,)> = sqlx::query_as(
475            r#"
476            SELECT scheduled_time
477            FROM forge_cron_runs
478            WHERE cron_name = $1 AND status = 'completed'
479            ORDER BY scheduled_time DESC
480            LIMIT 1
481            "#,
482        )
483        .bind(info.name)
484        .fetch_optional(&self.pool)
485        .await
486        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
487
488        let start_time = last_run
489            .map(|(t,)| t)
490            .unwrap_or(now - chrono::Duration::days(1));
491
492        // Get all scheduled times between last run and now
493        let missed_times = info.schedule.between_in_tz(start_time, now, info.timezone);
494
495        // Limit catch-up runs
496        let to_catch_up: Vec<_> = missed_times
497            .into_iter()
498            .take(info.catch_up_limit as usize)
499            .collect();
500
501        for scheduled in to_catch_up {
502            // Try to claim and execute
503            if self.try_claim(info.name, scheduled, info.timezone).await? {
504                self.execute_cron(entry, scheduled, true).await;
505            }
506        }
507
508        Ok(())
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_cron_status_conversion() {
518        assert_eq!(CronStatus::Pending.as_str(), "pending");
519        assert_eq!(CronStatus::Running.as_str(), "running");
520        assert_eq!(CronStatus::Completed.as_str(), "completed");
521        assert_eq!(CronStatus::Failed.as_str(), "failed");
522
523        assert_eq!("pending".parse::<CronStatus>(), Ok(CronStatus::Pending));
524        assert_eq!("running".parse::<CronStatus>(), Ok(CronStatus::Running));
525        assert_eq!("completed".parse::<CronStatus>(), Ok(CronStatus::Completed));
526        assert_eq!("failed".parse::<CronStatus>(), Ok(CronStatus::Failed));
527    }
528
529    #[test]
530    fn test_cron_record_creation() {
531        let record = CronRecord::new("daily_cleanup", Utc::now(), "UTC");
532        assert_eq!(record.cron_name, "daily_cleanup");
533        assert_eq!(record.timezone, "UTC");
534        assert_eq!(record.status, CronStatus::Pending);
535        assert!(record.node_id.is_none());
536    }
537
538    #[test]
539    fn test_cron_runner_config_default() {
540        let config = CronRunnerConfig::default();
541        assert_eq!(config.poll_interval, Duration::from_secs(1));
542        assert!(config.is_leader);
543    }
544}