Skip to main content

autumn_web/
task.rs

1//! Scheduled task infrastructure.
2//!
3//! Provides [`TaskInfo`] and [`Schedule`] types used by the `#[scheduled]`
4//! macro and `tasks![]` collection macro.
5//!
6//! Tasks are registered via [`AppBuilder::tasks`](crate::app::AppBuilder::tasks)
7//! and run alongside the HTTP server using Tokio timers.
8
9use std::future::Future;
10use std::pin::Pin;
11use std::time::Duration;
12
13use axum::extract::FromRequestParts;
14use serde::{Serialize, de::DeserializeOwned};
15
16use crate::state::AppState;
17use crate::{AutumnError, AutumnResult};
18
19/// Handler function type for scheduled tasks.
20pub type TaskHandler =
21    fn(AppState) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send>>;
22
23/// Handler function type for named one-off operational tasks.
24pub type OneOffTaskHandler =
25    fn(AppState, Vec<String>) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>>;
26
27/// Metadata for a named one-off operational task generated by `#[task]`.
28pub struct OneOffTaskInfo {
29    /// Name operators pass to `autumn task <name>`.
30    pub name: String,
31    /// First doc-comment line, used by `autumn task --list`.
32    pub description: String,
33    /// Handler invoked with fully booted app state and raw CLI args.
34    pub handler: OneOffTaskHandler,
35}
36
37/// Serializable task metadata printed by `autumn task --list`.
38#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
39pub struct OneOffTaskListing {
40    /// Task name accepted by `autumn task <name>`.
41    pub name: String,
42    /// First doc-comment line captured by `#[task]`.
43    pub description: String,
44}
45
46/// Structured CLI arguments for one-off tasks.
47///
48/// `autumn task cleanup --user-id 42 --dry-run` becomes the query string
49/// `user_id=42&dry_run=true`, so this extractor can deserialize the same way
50/// `Query<T>` does while keeping task signatures explicit.
51pub struct TaskArgs<T>(pub T);
52
53/// Extractor bridge used by generated `#[task]` handlers.
54pub trait TaskExtractor: Sized {
55    /// Resolve an argument from task request parts and app state.
56    fn from_task_parts<'a>(
57        parts: &'a mut http::request::Parts,
58        state: &'a AppState,
59    ) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>>;
60}
61
62impl<T> TaskExtractor for T
63where
64    T: FromRequestParts<AppState> + Send,
65    T::Rejection: Into<AutumnError> + Send,
66{
67    fn from_task_parts<'a>(
68        parts: &'a mut http::request::Parts,
69        state: &'a AppState,
70    ) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>> {
71        Box::pin(async move {
72            T::from_request_parts(parts, state)
73                .await
74                .map_err(Into::into)
75        })
76    }
77}
78
79impl<T, S> FromRequestParts<S> for TaskArgs<T>
80where
81    T: DeserializeOwned + Send,
82    S: Send + Sync,
83{
84    type Rejection = AutumnError;
85
86    async fn from_request_parts(
87        parts: &mut http::request::Parts,
88        _state: &S,
89    ) -> Result<Self, Self::Rejection> {
90        let query = parts.uri.query().unwrap_or_default();
91        serde_urlencoded::from_str(query)
92            .map(Self)
93            .map_err(|error| AutumnError::bad_request_msg(format!("invalid task args: {error}")))
94    }
95}
96
97/// Metadata for a scheduled task, generated by the `#[scheduled]` macro.
98pub struct TaskInfo {
99    /// Human-readable task name (for logging and health checks).
100    pub name: String,
101    /// When/how often to run.
102    pub schedule: Schedule,
103    /// Whether this task is coordinated fleet-wide or intentionally per-replica.
104    pub coordination: TaskCoordination,
105    /// The task handler, invoked with a clone of `AppState` each run.
106    pub handler: TaskHandler,
107}
108
109/// Cross-replica coordination mode for a scheduled task.
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, serde::Deserialize)]
111#[serde(rename_all = "snake_case")]
112pub enum TaskCoordination {
113    /// Run at most once per scheduled tick across the configured scheduler backend.
114    #[default]
115    Fleet,
116    /// Run on every replica, bypassing fleet coordination.
117    PerReplica,
118}
119
120impl std::fmt::Display for TaskCoordination {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Self::Fleet => f.write_str("fleet"),
124            Self::PerReplica => f.write_str("per_replica"),
125        }
126    }
127}
128
129/// How a scheduled task is triggered.
130#[non_exhaustive]
131pub enum Schedule {
132    /// Run after a fixed delay from the end of the previous run.
133    FixedDelay(Duration),
134    /// Run on a cron schedule (6-field: sec min hour day month weekday).
135    Cron {
136        /// The 6-field cron expression (e.g., `"0 * * * * *"` for every minute).
137        expression: String,
138        /// The timezone for the cron expression (e.g., `"America/New_York"`).
139        timezone: Option<String>,
140    },
141}
142
143impl std::fmt::Display for Schedule {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            Self::FixedDelay(d) => write!(f, "every {}s", d.as_secs()),
147            Self::Cron { expression, .. } => write!(f, "cron {expression}"),
148        }
149    }
150}
151
152/// Parse a human-readable duration string like `"5m"`, `"1h 30m"`.
153///
154/// Supported units: `s` (seconds), `m` (minutes), `h` (hours), `d` (days).
155///
156/// # Errors
157///
158/// Returns `None` if the string contains invalid syntax.
159#[must_use]
160pub fn parse_duration(s: &str) -> Option<Duration> {
161    let mut total_secs = 0u64;
162    let mut current_num = String::new();
163
164    for ch in s.chars() {
165        if ch.is_ascii_digit() {
166            current_num.push(ch);
167        } else if ch.is_ascii_alphabetic() {
168            let num: u64 = current_num.parse().ok()?;
169            current_num.clear();
170            match ch {
171                's' => total_secs = total_secs.checked_add(num)?,
172                'm' => total_secs = total_secs.checked_add(num.checked_mul(60)?)?,
173                'h' => total_secs = total_secs.checked_add(num.checked_mul(3600)?)?,
174                'd' => total_secs = total_secs.checked_add(num.checked_mul(86400)?)?,
175                _ => return None,
176            }
177        } else if ch == ' ' {
178            // Skip spaces between components
179        } else {
180            return None;
181        }
182    }
183
184    if !current_num.is_empty() {
185        return None; // Trailing number without unit
186    }
187
188    if total_secs == 0 {
189        return None;
190    }
191
192    Some(Duration::from_secs(total_secs))
193}
194
195/// Convert raw `autumn task` arguments to a query string for task extractors.
196///
197/// Long flags are converted to `snake_case` field names. A flag with no explicit
198/// value is treated as a boolean `true`.
199///
200/// # Errors
201///
202/// Returns [`AutumnError`] when an argument is not a `--long-flag`.
203pub fn task_args_to_query(args: &[String]) -> AutumnResult<String> {
204    let mut serializer = url::form_urlencoded::Serializer::new(String::new());
205    let mut i = 0;
206    while i < args.len() {
207        let token = &args[i];
208        let Some(flag) = token.strip_prefix("--") else {
209            return Err(AutumnError::bad_request_msg(format!(
210                "unexpected positional argument {token:?}; task args must use --flag value syntax"
211            )));
212        };
213        if flag.is_empty() {
214            return Err(AutumnError::bad_request_msg(
215                "empty task argument flag is not allowed",
216            ));
217        }
218
219        let (key, value) = if let Some((key, value)) = flag.split_once('=') {
220            (key, value.to_owned())
221        } else if args.get(i + 1).is_some_and(|next| !next.starts_with("--")) {
222            i += 1;
223            (flag, args[i].clone())
224        } else {
225            (flag, "true".to_owned())
226        };
227
228        serializer.append_pair(&key.replace('-', "_"), &value);
229        i += 1;
230    }
231
232    Ok(serializer.finish())
233}
234
235/// Build request parts used to resolve task extractors.
236///
237/// # Errors
238///
239/// Returns [`AutumnError`] when task CLI arguments cannot be represented as a
240/// query string or the synthetic URI cannot be built.
241pub fn request_parts_for_task_args(args: &[String]) -> AutumnResult<http::request::Parts> {
242    let query = task_args_to_query(args)?;
243    let uri = if query.is_empty() {
244        "/".to_owned()
245    } else {
246        format!("/?{query}")
247    };
248    let request = http::Request::builder()
249        .uri(uri)
250        .body(())
251        .map_err(AutumnError::internal_server_error)?;
252    Ok(request.into_parts().0)
253}
254
255/// Return task metadata sorted by task name.
256#[must_use]
257pub fn list_one_off_tasks(tasks: &[OneOffTaskInfo]) -> Vec<OneOffTaskListing> {
258    let mut listing: Vec<_> = tasks
259        .iter()
260        .map(|task| OneOffTaskListing {
261            name: task.name.clone(),
262            description: task.description.clone(),
263        })
264        .collect();
265    listing.sort_by(|a, b| a.name.cmp(&b.name));
266    listing
267}
268
269/// Validate that every registered one-off task has a unique name.
270///
271/// # Errors
272///
273/// Returns a message naming the first duplicate task.
274pub fn validate_unique_one_off_task_names(tasks: &[OneOffTaskInfo]) -> Result<(), String> {
275    let mut names = std::collections::HashSet::new();
276    for task in tasks {
277        if !names.insert(task.name.as_str()) {
278            return Err(format!("duplicate task name '{}'", task.name));
279        }
280    }
281    Ok(())
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use serde::Deserialize;
288
289    #[test]
290    fn parse_seconds() {
291        assert_eq!(parse_duration("5s"), Some(Duration::from_secs(5)));
292    }
293
294    #[test]
295    fn parse_minutes() {
296        assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
297    }
298
299    #[test]
300    fn parse_hours() {
301        assert_eq!(parse_duration("2h"), Some(Duration::from_secs(7200)));
302    }
303
304    #[test]
305    fn parse_compound() {
306        assert_eq!(parse_duration("1h 30m"), Some(Duration::from_secs(5400)));
307    }
308
309    #[test]
310    fn parse_day() {
311        assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86400)));
312    }
313
314    #[test]
315    fn invalid_unit() {
316        assert!(parse_duration("5x").is_none());
317    }
318
319    #[test]
320    fn trailing_number() {
321        assert!(parse_duration("5").is_none());
322    }
323
324    #[test]
325    fn empty() {
326        assert!(parse_duration("").is_none());
327    }
328
329    #[test]
330    fn zero_duration() {
331        assert!(parse_duration("0s").is_none());
332        assert!(parse_duration("0m").is_none());
333    }
334
335    #[test]
336    fn invalid_characters() {
337        assert!(parse_duration("1h_30m").is_none());
338        assert!(parse_duration("1h-30m").is_none());
339    }
340
341    #[test]
342    fn multiple_spaces() {
343        assert_eq!(parse_duration("1h   30m"), Some(Duration::from_secs(5400)));
344    }
345
346    #[test]
347    fn compound_trailing_number() {
348        assert!(parse_duration("1h 30").is_none());
349    }
350
351    #[test]
352    fn task_args_to_query_converts_long_flags_to_snake_case_fields() {
353        let args = vec![
354            "--user-id".to_string(),
355            "42".to_string(),
356            "--dry-run".to_string(),
357        ];
358
359        let query = task_args_to_query(&args).expect("task args should parse");
360
361        assert_eq!(query, "user_id=42&dry_run=true");
362    }
363
364    #[tokio::test]
365    async fn task_args_extractor_parses_struct_from_cli_style_args() {
366        #[derive(Debug, Deserialize, PartialEq, Eq)]
367        struct CleanupArgs {
368            user_id: i64,
369            dry_run: bool,
370        }
371
372        let raw = vec![
373            "--user-id".to_string(),
374            "42".to_string(),
375            "--dry-run".to_string(),
376        ];
377        let mut parts =
378            request_parts_for_task_args(&raw).expect("parts should be built from task args");
379        let state = AppState::for_test().with_profile("dev");
380
381        let TaskArgs(args) = <TaskArgs<CleanupArgs> as axum::extract::FromRequestParts<
382            AppState,
383        >>::from_request_parts(&mut parts, &state)
384        .await
385        .expect("task args should deserialize");
386
387        assert_eq!(
388            args,
389            CleanupArgs {
390                user_id: 42,
391                dry_run: true,
392            }
393        );
394    }
395
396    #[test]
397    fn task_args_to_query_rejects_values_without_a_flag_name() {
398        let error = task_args_to_query(&["42".to_string()])
399            .expect_err("bare positional values should be rejected");
400
401        assert!(
402            error.to_string().contains("unexpected positional argument"),
403            "unexpected error: {error}"
404        );
405    }
406
407    #[test]
408    fn list_one_off_tasks_sorts_by_name_and_keeps_descriptions() {
409        fn handler(
410            _state: AppState,
411            _args: Vec<String>,
412        ) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
413            Box::pin(async { Ok(()) })
414        }
415
416        let tasks = vec![
417            OneOffTaskInfo {
418                name: "zeta".to_string(),
419                description: "Last task".to_string(),
420                handler,
421            },
422            OneOffTaskInfo {
423                name: "alpha".to_string(),
424                description: "First task".to_string(),
425                handler,
426            },
427        ];
428
429        let listing = list_one_off_tasks(&tasks);
430
431        assert_eq!(listing[0].name, "alpha");
432        assert_eq!(listing[0].description, "First task");
433        assert_eq!(listing[1].name, "zeta");
434        assert_eq!(listing[1].description, "Last task");
435    }
436
437    #[test]
438    fn validate_unique_one_off_task_names_rejects_duplicates() {
439        fn handler(
440            _state: AppState,
441            _args: Vec<String>,
442        ) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
443            Box::pin(async { Ok(()) })
444        }
445
446        let tasks = vec![
447            OneOffTaskInfo {
448                name: "cleanup".to_string(),
449                description: String::new(),
450                handler,
451            },
452            OneOffTaskInfo {
453                name: "cleanup".to_string(),
454                description: String::new(),
455                handler,
456            },
457        ];
458
459        let error = validate_unique_one_off_task_names(&tasks)
460            .expect_err("duplicate task names should be rejected");
461
462        assert!(error.contains("duplicate task name 'cleanup'"));
463    }
464}
465
466#[cfg(test)]
467mod havoc_proptests {
468    use super::*;
469    use proptest::prelude::*;
470
471    proptest! {
472        #[test]
473        fn parse_duration_fuzz_panic(s in "[0-9]{15,30}[smhd]") {
474            let _ = parse_duration(&s);
475        }
476    }
477}