Skip to main content

actionqueue_core/task/
run_policy.rs

1//! Run policy definitions for task execution scheduling.
2
3/// Error returned when constructing or configuring a [`CronPolicy`].
4#[cfg(feature = "workflow")]
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct CronPolicyError {
7    message: String,
8}
9
10#[cfg(feature = "workflow")]
11impl CronPolicyError {
12    fn new(message: impl Into<String>) -> Self {
13        Self { message: message.into() }
14    }
15
16    /// Returns the human-readable error description.
17    pub fn message(&self) -> &str {
18        &self.message
19    }
20}
21
22#[cfg(feature = "workflow")]
23impl std::fmt::Display for CronPolicyError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "invalid cron policy: {}", self.message)
26    }
27}
28
29#[cfg(feature = "workflow")]
30impl std::error::Error for CronPolicyError {}
31
32/// A validated cron scheduling policy.
33///
34/// The cron expression is validated at construction time via [`CronPolicy::new`].
35/// Expressions use the 7-field format supported by the `cron` crate:
36/// `sec min hour dom month dow year`.
37///
38/// Examples:
39/// - `"0 * * * * * *"` — every minute at second 0
40/// - `"0 0 * * * * *"` — every hour at minute 0
41/// - `"0 0 9 * * MON *"` — every Monday at 09:00 UTC
42#[cfg(feature = "workflow")]
43#[cfg_attr(feature = "serde", derive(serde::Serialize))]
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct CronPolicy {
46    expression: String,
47    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Option::is_none"))]
48    max_occurrences: Option<u32>,
49}
50
51#[cfg(feature = "workflow")]
52impl CronPolicy {
53    /// Creates a validated cron policy from a cron expression string.
54    ///
55    /// The expression is parsed and validated immediately; invalid expressions
56    /// are rejected at construction time.
57    ///
58    /// # Errors
59    ///
60    /// Returns [`CronPolicyError`] if the expression cannot be parsed.
61    pub fn new(expression: impl Into<String>) -> Result<Self, CronPolicyError> {
62        use std::str::FromStr as _;
63        let expression = expression.into();
64        cron::Schedule::from_str(&expression).map_err(|e| CronPolicyError::new(e.to_string()))?;
65        Ok(Self { expression, max_occurrences: None })
66    }
67
68    /// Sets a maximum number of occurrences to ever derive for this task.
69    ///
70    /// `None` (the default) means unlimited occurrences.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`CronPolicyError`] if `max` is zero (a cron task that never
75    /// derives runs is invalid, matching the `RepeatPolicy::new(0, _)` rejection).
76    pub fn with_max_occurrences(mut self, max: u32) -> Result<Self, CronPolicyError> {
77        if max == 0 {
78            return Err(CronPolicyError::new("max_occurrences must be at least 1"));
79        }
80        self.max_occurrences = Some(max);
81        Ok(self)
82    }
83
84    /// Returns the raw cron expression string.
85    pub fn expression(&self) -> &str {
86        &self.expression
87    }
88
89    /// Returns the maximum number of occurrences, if configured.
90    pub fn max_occurrences(&self) -> Option<u32> {
91        self.max_occurrences
92    }
93
94    /// Returns the next `count` occurrence timestamps (Unix seconds, UTC)
95    /// that are strictly after `after_secs`.
96    ///
97    /// Returns fewer than `count` items if the schedule has fewer remaining
98    /// occurrences (e.g., when `max_occurrences` is set and the cap is near).
99    ///
100    /// The expression is pre-validated at construction, so schedule parsing
101    /// here should never fail.
102    pub fn next_occurrences_after(&self, after_secs: u64, count: usize) -> Vec<u64> {
103        use std::str::FromStr as _;
104
105        use chrono::{TimeZone as _, Utc};
106
107        if count == 0 {
108            return Vec::new();
109        }
110
111        // perf: re-parses expression on each call; acceptable for alpha.
112        // cron::Schedule is not Clone/Serialize, making OnceCell caching impractical.
113        let schedule = cron::Schedule::from_str(&self.expression)
114            .expect("cron expression pre-validated at CronPolicy::new");
115
116        let ts = i64::try_from(after_secs).unwrap_or(i64::MAX);
117        let after_dt = Utc
118            .timestamp_opt(ts, 0)
119            .single()
120            .unwrap_or_else(|| Utc.timestamp_opt(0, 0).single().expect("epoch is valid"));
121
122        schedule
123            .after(&after_dt)
124            .take(count)
125            .filter_map(|dt| u64::try_from(dt.timestamp()).ok())
126            .collect()
127    }
128}
129
130#[cfg(all(feature = "serde", feature = "workflow"))]
131impl<'de> serde::Deserialize<'de> for CronPolicy {
132    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133    where
134        D: serde::Deserializer<'de>,
135    {
136        #[derive(serde::Deserialize)]
137        struct CronPolicyWire {
138            expression: String,
139            #[serde(default)]
140            max_occurrences: Option<u32>,
141        }
142
143        let wire = CronPolicyWire::deserialize(deserializer)?;
144        let mut policy = CronPolicy::new(wire.expression).map_err(serde::de::Error::custom)?;
145        if let Some(max) = wire.max_occurrences {
146            policy = policy.with_max_occurrences(max).map_err(serde::de::Error::custom)?;
147        }
148        Ok(policy)
149    }
150}
151
152/// Typed validation errors for [`RunPolicy`].
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum RunPolicyError {
155    /// Repeat policy `count` must be greater than zero.
156    InvalidRepeatCount {
157        /// The rejected `count` value.
158        count: u32,
159    },
160    /// Repeat policy `interval_secs` must be greater than zero.
161    InvalidRepeatIntervalSecs {
162        /// The rejected `interval_secs` value.
163        interval_secs: u64,
164    },
165}
166
167impl std::fmt::Display for RunPolicyError {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        match self {
170            RunPolicyError::InvalidRepeatCount { count } => {
171                write!(f, "invalid repeat count: {count} (must be >= 1)")
172            }
173            RunPolicyError::InvalidRepeatIntervalSecs { interval_secs } => {
174                write!(f, "invalid repeat interval_secs: {interval_secs} (must be >= 1)")
175            }
176        }
177    }
178}
179
180impl std::error::Error for RunPolicyError {}
181
182/// A validated repeat policy with count and interval.
183///
184/// Both `count` and `interval_secs` are guaranteed to be strictly positive
185/// after construction through [`RepeatPolicy::new`].
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
187pub struct RepeatPolicy {
188    count: u32,
189    interval_secs: u64,
190}
191
192impl RepeatPolicy {
193    /// Creates a validated repeat policy.
194    ///
195    /// # Errors
196    ///
197    /// Returns [`RunPolicyError::InvalidRepeatCount`] if `count` is zero.
198    /// Returns [`RunPolicyError::InvalidRepeatIntervalSecs`] if `interval_secs` is zero.
199    pub fn new(count: u32, interval_secs: u64) -> Result<Self, RunPolicyError> {
200        if count == 0 {
201            return Err(RunPolicyError::InvalidRepeatCount { count });
202        }
203        if interval_secs == 0 {
204            return Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs });
205        }
206        Ok(Self { count, interval_secs })
207    }
208
209    /// Returns the total number of runs to derive.
210    pub fn count(&self) -> u32 {
211        self.count
212    }
213
214    /// Returns the interval in seconds between derived runs.
215    pub fn interval_secs(&self) -> u64 {
216        self.interval_secs
217    }
218}
219
220/// A policy that defines how many times a task should be run.
221#[derive(Debug, Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "serde", derive(serde::Serialize))]
223pub enum RunPolicy {
224    /// Run exactly once.
225    Once,
226    /// Run a specific number of times at a fixed interval.
227    Repeat(RepeatPolicy),
228    /// Run on a cron schedule (rolling window derivation, UTC only).
229    #[cfg(feature = "workflow")]
230    Cron(CronPolicy),
231}
232
233impl RunPolicy {
234    /// Constructs a validated [`RunPolicy::Repeat`] policy.
235    ///
236    /// This is the canonical constructor for repeat policies and enforces
237    /// the contract requirement that both `count` and `interval_secs`
238    /// are strictly positive.
239    pub fn repeat(count: u32, interval_secs: u64) -> Result<Self, RunPolicyError> {
240        Ok(Self::Repeat(RepeatPolicy::new(count, interval_secs)?))
241    }
242
243    /// Constructs a validated [`RunPolicy::Cron`] policy.
244    ///
245    /// # Errors
246    ///
247    /// Returns [`CronPolicyError`] if the cron expression cannot be parsed.
248    #[cfg(feature = "workflow")]
249    pub fn cron(expression: impl Into<String>) -> Result<Self, CronPolicyError> {
250        Ok(Self::Cron(CronPolicy::new(expression)?))
251    }
252
253    /// Validates this run policy against contract invariants.
254    ///
255    /// For [`RunPolicy::Repeat`], validation is guaranteed at construction time
256    /// by [`RepeatPolicy::new`], so this method always returns `Ok(())`.
257    pub fn validate(&self) -> Result<(), RunPolicyError> {
258        // RepeatPolicy invariants are enforced at construction time.
259        // CronPolicy expression is validated at construction time.
260        // No additional validation needed for any variant.
261        Ok(())
262    }
263}
264
265#[cfg(feature = "serde")]
266impl<'de> serde::Deserialize<'de> for RepeatPolicy {
267    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
268    where
269        D: serde::Deserializer<'de>,
270    {
271        #[derive(serde::Deserialize)]
272        struct RepeatPolicyWire {
273            count: u32,
274            interval_secs: u64,
275        }
276
277        let wire = RepeatPolicyWire::deserialize(deserializer)?;
278        RepeatPolicy::new(wire.count, wire.interval_secs).map_err(serde::de::Error::custom)
279    }
280}
281
282#[cfg(feature = "serde")]
283impl serde::Serialize for RepeatPolicy {
284    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
285    where
286        S: serde::Serializer,
287    {
288        use serde::ser::SerializeStruct;
289        let mut state = serializer.serialize_struct("RepeatPolicy", 2)?;
290        state.serialize_field("count", &self.count)?;
291        state.serialize_field("interval_secs", &self.interval_secs)?;
292        state.end()
293    }
294}
295
296#[cfg(feature = "serde")]
297impl<'de> serde::Deserialize<'de> for RunPolicy {
298    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
299    where
300        D: serde::Deserializer<'de>,
301    {
302        #[derive(serde::Deserialize)]
303        enum RunPolicyWire {
304            Once,
305            Repeat {
306                count: u32,
307                interval_secs: u64,
308            },
309            #[cfg(feature = "workflow")]
310            Cron {
311                expression: String,
312                #[serde(default)]
313                max_occurrences: Option<u32>,
314            },
315        }
316
317        let wire = <RunPolicyWire as serde::Deserialize>::deserialize(deserializer)?;
318        match wire {
319            RunPolicyWire::Once => Ok(RunPolicy::Once),
320            RunPolicyWire::Repeat { count, interval_secs } => {
321                let policy =
322                    RepeatPolicy::new(count, interval_secs).map_err(serde::de::Error::custom)?;
323                Ok(RunPolicy::Repeat(policy))
324            }
325            #[cfg(feature = "workflow")]
326            RunPolicyWire::Cron { expression, max_occurrences } => {
327                // NOTE: keep in sync with CronPolicy Deserialize impl.
328                let mut policy = CronPolicy::new(expression).map_err(serde::de::Error::custom)?;
329                if let Some(max) = max_occurrences {
330                    policy = policy.with_max_occurrences(max).map_err(serde::de::Error::custom)?;
331                }
332                Ok(RunPolicy::Cron(policy))
333            }
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::{RepeatPolicy, RunPolicy, RunPolicyError};
341
342    #[test]
343    fn repeat_policy_rejects_zero_count() {
344        let result = RepeatPolicy::new(0, 60);
345        assert_eq!(result, Err(RunPolicyError::InvalidRepeatCount { count: 0 }));
346    }
347
348    #[test]
349    fn repeat_policy_rejects_zero_interval() {
350        let result = RepeatPolicy::new(3, 0);
351        assert_eq!(result, Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs: 0 }));
352    }
353
354    #[test]
355    fn repeat_policy_accepts_valid_values() {
356        let policy = RepeatPolicy::new(6, 1800).expect("repeat policy should be valid");
357        assert_eq!(policy.count(), 6);
358        assert_eq!(policy.interval_secs(), 1800);
359    }
360
361    #[test]
362    fn repeat_constructor_rejects_zero_count() {
363        let result = RunPolicy::repeat(0, 60);
364        assert_eq!(result, Err(RunPolicyError::InvalidRepeatCount { count: 0 }));
365    }
366
367    #[test]
368    fn repeat_constructor_rejects_zero_interval() {
369        let result = RunPolicy::repeat(3, 0);
370        assert_eq!(result, Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs: 0 }));
371    }
372
373    #[test]
374    fn repeat_constructor_accepts_valid_values() {
375        let policy = RunPolicy::repeat(6, 1800).expect("repeat policy should be valid");
376        assert_eq!(policy, RunPolicy::Repeat(RepeatPolicy::new(6, 1800).unwrap()));
377    }
378
379    #[test]
380    fn validate_always_succeeds_for_valid_policies() {
381        assert!(RunPolicy::Once.validate().is_ok());
382        assert!(RunPolicy::repeat(3, 60).unwrap().validate().is_ok());
383    }
384
385    #[test]
386    fn repeat_policy_accessors() {
387        let rp = RepeatPolicy::new(5, 120).unwrap();
388        assert_eq!(rp.count(), 5);
389        assert_eq!(rp.interval_secs(), 120);
390    }
391
392    #[cfg(feature = "workflow")]
393    #[test]
394    fn cron_next_occurrences_after_with_u64_max_returns_empty() {
395        let policy = super::CronPolicy::new("* * * * * * *").expect("valid");
396        // u64::MAX saturates to i64::MAX; no cron occurrences should be
397        // after year 292 billion, so the result is empty.
398        let result = policy.next_occurrences_after(u64::MAX, 5);
399        assert!(result.is_empty() || result.iter().all(|&ts| ts <= i64::MAX as u64));
400    }
401
402    #[cfg(feature = "workflow")]
403    #[test]
404    fn cron_with_max_occurrences_zero_rejected() {
405        let policy = super::CronPolicy::new("* * * * * * *").expect("valid");
406        let err = policy.with_max_occurrences(0).expect_err("zero should be rejected");
407        assert!(err.message().contains("at least 1"));
408    }
409}