1use std::future::Future;
10use std::pin::Pin;
11use std::time::Duration;
12
13use axum::extract::FromRequestParts;
14use serde::{Serialize, de::DeserializeOwned};
15
16use crate::state::AppState;
17use crate::{AutumnError, AutumnResult};
18
19pub type TaskHandler =
21 fn(AppState) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send>>;
22
23pub type OneOffTaskHandler =
25 fn(AppState, Vec<String>) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>>;
26
27pub struct OneOffTaskInfo {
29 pub name: String,
31 pub description: String,
33 pub handler: OneOffTaskHandler,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
39pub struct OneOffTaskListing {
40 pub name: String,
42 pub description: String,
44}
45
46pub struct TaskArgs<T>(pub T);
52
53pub trait TaskExtractor: Sized {
55 fn from_task_parts<'a>(
57 parts: &'a mut http::request::Parts,
58 state: &'a AppState,
59 ) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>>;
60}
61
62impl<T> TaskExtractor for T
63where
64 T: FromRequestParts<AppState> + Send,
65 T::Rejection: Into<AutumnError> + Send,
66{
67 fn from_task_parts<'a>(
68 parts: &'a mut http::request::Parts,
69 state: &'a AppState,
70 ) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>> {
71 Box::pin(async move {
72 T::from_request_parts(parts, state)
73 .await
74 .map_err(Into::into)
75 })
76 }
77}
78
79impl<T, S> FromRequestParts<S> for TaskArgs<T>
80where
81 T: DeserializeOwned + Send,
82 S: Send + Sync,
83{
84 type Rejection = AutumnError;
85
86 async fn from_request_parts(
87 parts: &mut http::request::Parts,
88 _state: &S,
89 ) -> Result<Self, Self::Rejection> {
90 let query = parts.uri.query().unwrap_or_default();
91 serde_urlencoded::from_str(query)
92 .map(Self)
93 .map_err(|error| AutumnError::bad_request_msg(format!("invalid task args: {error}")))
94 }
95}
96
97pub struct TaskInfo {
99 pub name: String,
101 pub schedule: Schedule,
103 pub coordination: TaskCoordination,
105 pub handler: TaskHandler,
107}
108
109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, serde::Deserialize)]
111#[serde(rename_all = "snake_case")]
112pub enum TaskCoordination {
113 #[default]
115 Fleet,
116 PerReplica,
118}
119
120impl std::fmt::Display for TaskCoordination {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Self::Fleet => f.write_str("fleet"),
124 Self::PerReplica => f.write_str("per_replica"),
125 }
126 }
127}
128
129#[non_exhaustive]
131pub enum Schedule {
132 FixedDelay(Duration),
134 Cron {
136 expression: String,
138 timezone: Option<String>,
140 },
141}
142
143impl std::fmt::Display for Schedule {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Self::FixedDelay(d) => write!(f, "every {}s", d.as_secs()),
147 Self::Cron { expression, .. } => write!(f, "cron {expression}"),
148 }
149 }
150}
151
152#[must_use]
160pub fn parse_duration(s: &str) -> Option<Duration> {
161 let mut total_secs = 0u64;
162 let mut current_num = String::new();
163
164 for ch in s.chars() {
165 if ch.is_ascii_digit() {
166 current_num.push(ch);
167 } else if ch.is_ascii_alphabetic() {
168 let num: u64 = current_num.parse().ok()?;
169 current_num.clear();
170 match ch {
171 's' => total_secs = total_secs.checked_add(num)?,
172 'm' => total_secs = total_secs.checked_add(num.checked_mul(60)?)?,
173 'h' => total_secs = total_secs.checked_add(num.checked_mul(3600)?)?,
174 'd' => total_secs = total_secs.checked_add(num.checked_mul(86400)?)?,
175 _ => return None,
176 }
177 } else if ch == ' ' {
178 } else {
180 return None;
181 }
182 }
183
184 if !current_num.is_empty() {
185 return None; }
187
188 if total_secs == 0 {
189 return None;
190 }
191
192 Some(Duration::from_secs(total_secs))
193}
194
195pub fn task_args_to_query(args: &[String]) -> AutumnResult<String> {
204 let mut serializer = url::form_urlencoded::Serializer::new(String::new());
205 let mut i = 0;
206 while i < args.len() {
207 let token = &args[i];
208 let Some(flag) = token.strip_prefix("--") else {
209 return Err(AutumnError::bad_request_msg(format!(
210 "unexpected positional argument {token:?}; task args must use --flag value syntax"
211 )));
212 };
213 if flag.is_empty() {
214 return Err(AutumnError::bad_request_msg(
215 "empty task argument flag is not allowed",
216 ));
217 }
218
219 let (key, value) = if let Some((key, value)) = flag.split_once('=') {
220 (key, value.to_owned())
221 } else if args.get(i + 1).is_some_and(|next| !next.starts_with("--")) {
222 i += 1;
223 (flag, args[i].clone())
224 } else {
225 (flag, "true".to_owned())
226 };
227
228 serializer.append_pair(&key.replace('-', "_"), &value);
229 i += 1;
230 }
231
232 Ok(serializer.finish())
233}
234
235pub fn request_parts_for_task_args(args: &[String]) -> AutumnResult<http::request::Parts> {
242 let query = task_args_to_query(args)?;
243 let uri = if query.is_empty() {
244 "/".to_owned()
245 } else {
246 format!("/?{query}")
247 };
248 let request = http::Request::builder()
249 .uri(uri)
250 .body(())
251 .map_err(AutumnError::internal_server_error)?;
252 Ok(request.into_parts().0)
253}
254
255#[must_use]
257pub fn list_one_off_tasks(tasks: &[OneOffTaskInfo]) -> Vec<OneOffTaskListing> {
258 let mut listing: Vec<_> = tasks
259 .iter()
260 .map(|task| OneOffTaskListing {
261 name: task.name.clone(),
262 description: task.description.clone(),
263 })
264 .collect();
265 listing.sort_by(|a, b| a.name.cmp(&b.name));
266 listing
267}
268
269pub fn validate_unique_one_off_task_names(tasks: &[OneOffTaskInfo]) -> Result<(), String> {
275 let mut names = std::collections::HashSet::new();
276 for task in tasks {
277 if !names.insert(task.name.as_str()) {
278 return Err(format!("duplicate task name '{}'", task.name));
279 }
280 }
281 Ok(())
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use serde::Deserialize;
288
289 #[test]
290 fn parse_seconds() {
291 assert_eq!(parse_duration("5s"), Some(Duration::from_secs(5)));
292 }
293
294 #[test]
295 fn parse_minutes() {
296 assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
297 }
298
299 #[test]
300 fn parse_hours() {
301 assert_eq!(parse_duration("2h"), Some(Duration::from_secs(7200)));
302 }
303
304 #[test]
305 fn parse_compound() {
306 assert_eq!(parse_duration("1h 30m"), Some(Duration::from_secs(5400)));
307 }
308
309 #[test]
310 fn parse_day() {
311 assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86400)));
312 }
313
314 #[test]
315 fn invalid_unit() {
316 assert!(parse_duration("5x").is_none());
317 }
318
319 #[test]
320 fn trailing_number() {
321 assert!(parse_duration("5").is_none());
322 }
323
324 #[test]
325 fn empty() {
326 assert!(parse_duration("").is_none());
327 }
328
329 #[test]
330 fn zero_duration() {
331 assert!(parse_duration("0s").is_none());
332 assert!(parse_duration("0m").is_none());
333 }
334
335 #[test]
336 fn invalid_characters() {
337 assert!(parse_duration("1h_30m").is_none());
338 assert!(parse_duration("1h-30m").is_none());
339 }
340
341 #[test]
342 fn multiple_spaces() {
343 assert_eq!(parse_duration("1h 30m"), Some(Duration::from_secs(5400)));
344 }
345
346 #[test]
347 fn compound_trailing_number() {
348 assert!(parse_duration("1h 30").is_none());
349 }
350
351 #[test]
352 fn task_args_to_query_converts_long_flags_to_snake_case_fields() {
353 let args = vec![
354 "--user-id".to_string(),
355 "42".to_string(),
356 "--dry-run".to_string(),
357 ];
358
359 let query = task_args_to_query(&args).expect("task args should parse");
360
361 assert_eq!(query, "user_id=42&dry_run=true");
362 }
363
364 #[tokio::test]
365 async fn task_args_extractor_parses_struct_from_cli_style_args() {
366 #[derive(Debug, Deserialize, PartialEq, Eq)]
367 struct CleanupArgs {
368 user_id: i64,
369 dry_run: bool,
370 }
371
372 let raw = vec![
373 "--user-id".to_string(),
374 "42".to_string(),
375 "--dry-run".to_string(),
376 ];
377 let mut parts =
378 request_parts_for_task_args(&raw).expect("parts should be built from task args");
379 let state = AppState::for_test().with_profile("dev");
380
381 let TaskArgs(args) = <TaskArgs<CleanupArgs> as axum::extract::FromRequestParts<
382 AppState,
383 >>::from_request_parts(&mut parts, &state)
384 .await
385 .expect("task args should deserialize");
386
387 assert_eq!(
388 args,
389 CleanupArgs {
390 user_id: 42,
391 dry_run: true,
392 }
393 );
394 }
395
396 #[test]
397 fn task_args_to_query_rejects_values_without_a_flag_name() {
398 let error = task_args_to_query(&["42".to_string()])
399 .expect_err("bare positional values should be rejected");
400
401 assert!(
402 error.to_string().contains("unexpected positional argument"),
403 "unexpected error: {error}"
404 );
405 }
406
407 #[test]
408 fn list_one_off_tasks_sorts_by_name_and_keeps_descriptions() {
409 fn handler(
410 _state: AppState,
411 _args: Vec<String>,
412 ) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
413 Box::pin(async { Ok(()) })
414 }
415
416 let tasks = vec![
417 OneOffTaskInfo {
418 name: "zeta".to_string(),
419 description: "Last task".to_string(),
420 handler,
421 },
422 OneOffTaskInfo {
423 name: "alpha".to_string(),
424 description: "First task".to_string(),
425 handler,
426 },
427 ];
428
429 let listing = list_one_off_tasks(&tasks);
430
431 assert_eq!(listing[0].name, "alpha");
432 assert_eq!(listing[0].description, "First task");
433 assert_eq!(listing[1].name, "zeta");
434 assert_eq!(listing[1].description, "Last task");
435 }
436
437 #[test]
438 fn validate_unique_one_off_task_names_rejects_duplicates() {
439 fn handler(
440 _state: AppState,
441 _args: Vec<String>,
442 ) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
443 Box::pin(async { Ok(()) })
444 }
445
446 let tasks = vec![
447 OneOffTaskInfo {
448 name: "cleanup".to_string(),
449 description: String::new(),
450 handler,
451 },
452 OneOffTaskInfo {
453 name: "cleanup".to_string(),
454 description: String::new(),
455 handler,
456 },
457 ];
458
459 let error = validate_unique_one_off_task_names(&tasks)
460 .expect_err("duplicate task names should be rejected");
461
462 assert!(error.contains("duplicate task name 'cleanup'"));
463 }
464}
465
466#[cfg(test)]
467mod havoc_proptests {
468 use super::*;
469 use proptest::prelude::*;
470
471 proptest! {
472 #[test]
473 fn parse_duration_fuzz_panic(s in "[0-9]{15,30}[smhd]") {
474 let _ = parse_duration(&s);
475 }
476 }
477}