1#[cfg(feature = "workflow")]
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct CronPolicyError {
7 message: String,
8}
9
10#[cfg(feature = "workflow")]
11impl CronPolicyError {
12 fn new(message: impl Into<String>) -> Self {
13 Self { message: message.into() }
14 }
15
16 pub fn message(&self) -> &str {
18 &self.message
19 }
20}
21
22#[cfg(feature = "workflow")]
23impl std::fmt::Display for CronPolicyError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "invalid cron policy: {}", self.message)
26 }
27}
28
29#[cfg(feature = "workflow")]
30impl std::error::Error for CronPolicyError {}
31
32#[cfg(feature = "workflow")]
43#[cfg_attr(feature = "serde", derive(serde::Serialize))]
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct CronPolicy {
46 expression: String,
47 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Option::is_none"))]
48 max_occurrences: Option<u32>,
49}
50
51#[cfg(feature = "workflow")]
52impl CronPolicy {
53 pub fn new(expression: impl Into<String>) -> Result<Self, CronPolicyError> {
62 use std::str::FromStr as _;
63 let expression = expression.into();
64 cron::Schedule::from_str(&expression).map_err(|e| CronPolicyError::new(e.to_string()))?;
65 Ok(Self { expression, max_occurrences: None })
66 }
67
68 pub fn with_max_occurrences(mut self, max: u32) -> Result<Self, CronPolicyError> {
77 if max == 0 {
78 return Err(CronPolicyError::new("max_occurrences must be at least 1"));
79 }
80 self.max_occurrences = Some(max);
81 Ok(self)
82 }
83
84 pub fn expression(&self) -> &str {
86 &self.expression
87 }
88
89 pub fn max_occurrences(&self) -> Option<u32> {
91 self.max_occurrences
92 }
93
94 pub fn next_occurrences_after(&self, after_secs: u64, count: usize) -> Vec<u64> {
103 use std::str::FromStr as _;
104
105 use chrono::{TimeZone as _, Utc};
106
107 if count == 0 {
108 return Vec::new();
109 }
110
111 let schedule = cron::Schedule::from_str(&self.expression)
114 .expect("cron expression pre-validated at CronPolicy::new");
115
116 let ts = i64::try_from(after_secs).unwrap_or(i64::MAX);
117 let after_dt = Utc
118 .timestamp_opt(ts, 0)
119 .single()
120 .unwrap_or_else(|| Utc.timestamp_opt(0, 0).single().expect("epoch is valid"));
121
122 schedule
123 .after(&after_dt)
124 .take(count)
125 .filter_map(|dt| u64::try_from(dt.timestamp()).ok())
126 .collect()
127 }
128}
129
130#[cfg(all(feature = "serde", feature = "workflow"))]
131impl<'de> serde::Deserialize<'de> for CronPolicy {
132 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133 where
134 D: serde::Deserializer<'de>,
135 {
136 #[derive(serde::Deserialize)]
137 struct CronPolicyWire {
138 expression: String,
139 #[serde(default)]
140 max_occurrences: Option<u32>,
141 }
142
143 let wire = CronPolicyWire::deserialize(deserializer)?;
144 let mut policy = CronPolicy::new(wire.expression).map_err(serde::de::Error::custom)?;
145 if let Some(max) = wire.max_occurrences {
146 policy = policy.with_max_occurrences(max).map_err(serde::de::Error::custom)?;
147 }
148 Ok(policy)
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum RunPolicyError {
155 InvalidRepeatCount {
157 count: u32,
159 },
160 InvalidRepeatIntervalSecs {
162 interval_secs: u64,
164 },
165}
166
167impl std::fmt::Display for RunPolicyError {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 match self {
170 RunPolicyError::InvalidRepeatCount { count } => {
171 write!(f, "invalid repeat count: {count} (must be >= 1)")
172 }
173 RunPolicyError::InvalidRepeatIntervalSecs { interval_secs } => {
174 write!(f, "invalid repeat interval_secs: {interval_secs} (must be >= 1)")
175 }
176 }
177 }
178}
179
180impl std::error::Error for RunPolicyError {}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
187pub struct RepeatPolicy {
188 count: u32,
189 interval_secs: u64,
190}
191
192impl RepeatPolicy {
193 pub fn new(count: u32, interval_secs: u64) -> Result<Self, RunPolicyError> {
200 if count == 0 {
201 return Err(RunPolicyError::InvalidRepeatCount { count });
202 }
203 if interval_secs == 0 {
204 return Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs });
205 }
206 Ok(Self { count, interval_secs })
207 }
208
209 pub fn count(&self) -> u32 {
211 self.count
212 }
213
214 pub fn interval_secs(&self) -> u64 {
216 self.interval_secs
217 }
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
222#[cfg_attr(feature = "serde", derive(serde::Serialize))]
223pub enum RunPolicy {
224 Once,
226 Repeat(RepeatPolicy),
228 #[cfg(feature = "workflow")]
230 Cron(CronPolicy),
231}
232
233impl RunPolicy {
234 pub fn repeat(count: u32, interval_secs: u64) -> Result<Self, RunPolicyError> {
240 Ok(Self::Repeat(RepeatPolicy::new(count, interval_secs)?))
241 }
242
243 #[cfg(feature = "workflow")]
249 pub fn cron(expression: impl Into<String>) -> Result<Self, CronPolicyError> {
250 Ok(Self::Cron(CronPolicy::new(expression)?))
251 }
252
253 pub fn validate(&self) -> Result<(), RunPolicyError> {
258 Ok(())
262 }
263}
264
265#[cfg(feature = "serde")]
266impl<'de> serde::Deserialize<'de> for RepeatPolicy {
267 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
268 where
269 D: serde::Deserializer<'de>,
270 {
271 #[derive(serde::Deserialize)]
272 struct RepeatPolicyWire {
273 count: u32,
274 interval_secs: u64,
275 }
276
277 let wire = RepeatPolicyWire::deserialize(deserializer)?;
278 RepeatPolicy::new(wire.count, wire.interval_secs).map_err(serde::de::Error::custom)
279 }
280}
281
282#[cfg(feature = "serde")]
283impl serde::Serialize for RepeatPolicy {
284 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
285 where
286 S: serde::Serializer,
287 {
288 use serde::ser::SerializeStruct;
289 let mut state = serializer.serialize_struct("RepeatPolicy", 2)?;
290 state.serialize_field("count", &self.count)?;
291 state.serialize_field("interval_secs", &self.interval_secs)?;
292 state.end()
293 }
294}
295
296#[cfg(feature = "serde")]
297impl<'de> serde::Deserialize<'de> for RunPolicy {
298 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
299 where
300 D: serde::Deserializer<'de>,
301 {
302 #[derive(serde::Deserialize)]
303 enum RunPolicyWire {
304 Once,
305 Repeat {
306 count: u32,
307 interval_secs: u64,
308 },
309 #[cfg(feature = "workflow")]
310 Cron {
311 expression: String,
312 #[serde(default)]
313 max_occurrences: Option<u32>,
314 },
315 }
316
317 let wire = <RunPolicyWire as serde::Deserialize>::deserialize(deserializer)?;
318 match wire {
319 RunPolicyWire::Once => Ok(RunPolicy::Once),
320 RunPolicyWire::Repeat { count, interval_secs } => {
321 let policy =
322 RepeatPolicy::new(count, interval_secs).map_err(serde::de::Error::custom)?;
323 Ok(RunPolicy::Repeat(policy))
324 }
325 #[cfg(feature = "workflow")]
326 RunPolicyWire::Cron { expression, max_occurrences } => {
327 let mut policy = CronPolicy::new(expression).map_err(serde::de::Error::custom)?;
329 if let Some(max) = max_occurrences {
330 policy = policy.with_max_occurrences(max).map_err(serde::de::Error::custom)?;
331 }
332 Ok(RunPolicy::Cron(policy))
333 }
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::{RepeatPolicy, RunPolicy, RunPolicyError};
341
342 #[test]
343 fn repeat_policy_rejects_zero_count() {
344 let result = RepeatPolicy::new(0, 60);
345 assert_eq!(result, Err(RunPolicyError::InvalidRepeatCount { count: 0 }));
346 }
347
348 #[test]
349 fn repeat_policy_rejects_zero_interval() {
350 let result = RepeatPolicy::new(3, 0);
351 assert_eq!(result, Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs: 0 }));
352 }
353
354 #[test]
355 fn repeat_policy_accepts_valid_values() {
356 let policy = RepeatPolicy::new(6, 1800).expect("repeat policy should be valid");
357 assert_eq!(policy.count(), 6);
358 assert_eq!(policy.interval_secs(), 1800);
359 }
360
361 #[test]
362 fn repeat_constructor_rejects_zero_count() {
363 let result = RunPolicy::repeat(0, 60);
364 assert_eq!(result, Err(RunPolicyError::InvalidRepeatCount { count: 0 }));
365 }
366
367 #[test]
368 fn repeat_constructor_rejects_zero_interval() {
369 let result = RunPolicy::repeat(3, 0);
370 assert_eq!(result, Err(RunPolicyError::InvalidRepeatIntervalSecs { interval_secs: 0 }));
371 }
372
373 #[test]
374 fn repeat_constructor_accepts_valid_values() {
375 let policy = RunPolicy::repeat(6, 1800).expect("repeat policy should be valid");
376 assert_eq!(policy, RunPolicy::Repeat(RepeatPolicy::new(6, 1800).unwrap()));
377 }
378
379 #[test]
380 fn validate_always_succeeds_for_valid_policies() {
381 assert!(RunPolicy::Once.validate().is_ok());
382 assert!(RunPolicy::repeat(3, 60).unwrap().validate().is_ok());
383 }
384
385 #[test]
386 fn repeat_policy_accessors() {
387 let rp = RepeatPolicy::new(5, 120).unwrap();
388 assert_eq!(rp.count(), 5);
389 assert_eq!(rp.interval_secs(), 120);
390 }
391
392 #[cfg(feature = "workflow")]
393 #[test]
394 fn cron_next_occurrences_after_with_u64_max_returns_empty() {
395 let policy = super::CronPolicy::new("* * * * * * *").expect("valid");
396 let result = policy.next_occurrences_after(u64::MAX, 5);
399 assert!(result.is_empty() || result.iter().all(|&ts| ts <= i64::MAX as u64));
400 }
401
402 #[cfg(feature = "workflow")]
403 #[test]
404 fn cron_with_max_occurrences_zero_rejected() {
405 let policy = super::CronPolicy::new("* * * * * * *").expect("valid");
406 let err = policy.with_max_occurrences(0).expect_err("zero should be rejected");
407 assert!(err.message().contains("at least 1"));
408 }
409}