use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use axum::extract::FromRequestParts;
use serde::{Serialize, de::DeserializeOwned};
use crate::state::AppState;
use crate::{AutumnError, AutumnResult};
pub type TaskHandler =
fn(AppState) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send>>;
pub type OneOffTaskHandler =
fn(AppState, Vec<String>) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>>;
pub struct OneOffTaskInfo {
pub name: String,
pub description: String,
pub handler: OneOffTaskHandler,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct OneOffTaskListing {
pub name: String,
pub description: String,
}
pub struct TaskArgs<T>(pub T);
pub trait TaskExtractor: Sized {
fn from_task_parts<'a>(
parts: &'a mut http::request::Parts,
state: &'a AppState,
) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>>;
}
impl<T> TaskExtractor for T
where
T: FromRequestParts<AppState> + Send,
T::Rejection: Into<AutumnError> + Send,
{
fn from_task_parts<'a>(
parts: &'a mut http::request::Parts,
state: &'a AppState,
) -> Pin<Box<dyn Future<Output = AutumnResult<Self>> + Send + 'a>> {
Box::pin(async move {
T::from_request_parts(parts, state)
.await
.map_err(Into::into)
})
}
}
impl<T, S> FromRequestParts<S> for TaskArgs<T>
where
T: DeserializeOwned + Send,
S: Send + Sync,
{
type Rejection = AutumnError;
async fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
serde_urlencoded::from_str(query)
.map(Self)
.map_err(|error| AutumnError::bad_request_msg(format!("invalid task args: {error}")))
}
}
pub struct TaskInfo {
pub name: String,
pub schedule: Schedule,
pub coordination: TaskCoordination,
pub handler: TaskHandler,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskCoordination {
#[default]
Fleet,
PerReplica,
}
impl std::fmt::Display for TaskCoordination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fleet => f.write_str("fleet"),
Self::PerReplica => f.write_str("per_replica"),
}
}
}
#[non_exhaustive]
pub enum Schedule {
FixedDelay(Duration),
Cron {
expression: String,
timezone: Option<String>,
},
}
impl std::fmt::Display for Schedule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FixedDelay(d) => write!(f, "every {}s", d.as_secs()),
Self::Cron { expression, .. } => write!(f, "cron {expression}"),
}
}
}
#[must_use]
pub fn parse_duration(s: &str) -> Option<Duration> {
let mut total_secs = 0u64;
let mut current_num = String::new();
for ch in s.chars() {
if ch.is_ascii_digit() {
current_num.push(ch);
} else if ch.is_ascii_alphabetic() {
let num: u64 = current_num.parse().ok()?;
current_num.clear();
match ch {
's' => total_secs = total_secs.checked_add(num)?,
'm' => total_secs = total_secs.checked_add(num.checked_mul(60)?)?,
'h' => total_secs = total_secs.checked_add(num.checked_mul(3600)?)?,
'd' => total_secs = total_secs.checked_add(num.checked_mul(86400)?)?,
_ => return None,
}
} else if ch == ' ' {
} else {
return None;
}
}
if !current_num.is_empty() {
return None; }
if total_secs == 0 {
return None;
}
Some(Duration::from_secs(total_secs))
}
pub fn task_args_to_query(args: &[String]) -> AutumnResult<String> {
let mut serializer = url::form_urlencoded::Serializer::new(String::new());
let mut i = 0;
while i < args.len() {
let token = &args[i];
let Some(flag) = token.strip_prefix("--") else {
return Err(AutumnError::bad_request_msg(format!(
"unexpected positional argument {token:?}; task args must use --flag value syntax"
)));
};
if flag.is_empty() {
return Err(AutumnError::bad_request_msg(
"empty task argument flag is not allowed",
));
}
let (key, value) = if let Some((key, value)) = flag.split_once('=') {
(key, value.to_owned())
} else if args.get(i + 1).is_some_and(|next| !next.starts_with("--")) {
i += 1;
(flag, args[i].clone())
} else {
(flag, "true".to_owned())
};
serializer.append_pair(&key.replace('-', "_"), &value);
i += 1;
}
Ok(serializer.finish())
}
pub fn request_parts_for_task_args(args: &[String]) -> AutumnResult<http::request::Parts> {
let query = task_args_to_query(args)?;
let uri = if query.is_empty() {
"/".to_owned()
} else {
format!("/?{query}")
};
let request = http::Request::builder()
.uri(uri)
.body(())
.map_err(AutumnError::internal_server_error)?;
Ok(request.into_parts().0)
}
#[must_use]
pub fn list_one_off_tasks(tasks: &[OneOffTaskInfo]) -> Vec<OneOffTaskListing> {
let mut listing: Vec<_> = tasks
.iter()
.map(|task| OneOffTaskListing {
name: task.name.clone(),
description: task.description.clone(),
})
.collect();
listing.sort_by(|a, b| a.name.cmp(&b.name));
listing
}
pub fn validate_unique_one_off_task_names(tasks: &[OneOffTaskInfo]) -> Result<(), String> {
let mut names = std::collections::HashSet::new();
for task in tasks {
if !names.insert(task.name.as_str()) {
return Err(format!("duplicate task name '{}'", task.name));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[test]
fn parse_seconds() {
assert_eq!(parse_duration("5s"), Some(Duration::from_secs(5)));
}
#[test]
fn parse_minutes() {
assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
}
#[test]
fn parse_hours() {
assert_eq!(parse_duration("2h"), Some(Duration::from_secs(7200)));
}
#[test]
fn parse_compound() {
assert_eq!(parse_duration("1h 30m"), Some(Duration::from_secs(5400)));
}
#[test]
fn parse_day() {
assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86400)));
}
#[test]
fn invalid_unit() {
assert!(parse_duration("5x").is_none());
}
#[test]
fn trailing_number() {
assert!(parse_duration("5").is_none());
}
#[test]
fn empty() {
assert!(parse_duration("").is_none());
}
#[test]
fn zero_duration() {
assert!(parse_duration("0s").is_none());
assert!(parse_duration("0m").is_none());
}
#[test]
fn invalid_characters() {
assert!(parse_duration("1h_30m").is_none());
assert!(parse_duration("1h-30m").is_none());
}
#[test]
fn multiple_spaces() {
assert_eq!(parse_duration("1h 30m"), Some(Duration::from_secs(5400)));
}
#[test]
fn compound_trailing_number() {
assert!(parse_duration("1h 30").is_none());
}
#[test]
fn task_args_to_query_converts_long_flags_to_snake_case_fields() {
let args = vec![
"--user-id".to_string(),
"42".to_string(),
"--dry-run".to_string(),
];
let query = task_args_to_query(&args).expect("task args should parse");
assert_eq!(query, "user_id=42&dry_run=true");
}
#[tokio::test]
async fn task_args_extractor_parses_struct_from_cli_style_args() {
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct CleanupArgs {
user_id: i64,
dry_run: bool,
}
let raw = vec![
"--user-id".to_string(),
"42".to_string(),
"--dry-run".to_string(),
];
let mut parts =
request_parts_for_task_args(&raw).expect("parts should be built from task args");
let state = AppState::for_test().with_profile("dev");
let TaskArgs(args) = <TaskArgs<CleanupArgs> as axum::extract::FromRequestParts<
AppState,
>>::from_request_parts(&mut parts, &state)
.await
.expect("task args should deserialize");
assert_eq!(
args,
CleanupArgs {
user_id: 42,
dry_run: true,
}
);
}
#[test]
fn task_args_to_query_rejects_values_without_a_flag_name() {
let error = task_args_to_query(&["42".to_string()])
.expect_err("bare positional values should be rejected");
assert!(
error.to_string().contains("unexpected positional argument"),
"unexpected error: {error}"
);
}
#[test]
fn list_one_off_tasks_sorts_by_name_and_keeps_descriptions() {
fn handler(
_state: AppState,
_args: Vec<String>,
) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
Box::pin(async { Ok(()) })
}
let tasks = vec![
OneOffTaskInfo {
name: "zeta".to_string(),
description: "Last task".to_string(),
handler,
},
OneOffTaskInfo {
name: "alpha".to_string(),
description: "First task".to_string(),
handler,
},
];
let listing = list_one_off_tasks(&tasks);
assert_eq!(listing[0].name, "alpha");
assert_eq!(listing[0].description, "First task");
assert_eq!(listing[1].name, "zeta");
assert_eq!(listing[1].description, "Last task");
}
#[test]
fn validate_unique_one_off_task_names_rejects_duplicates() {
fn handler(
_state: AppState,
_args: Vec<String>,
) -> Pin<Box<dyn Future<Output = AutumnResult<()>> + Send>> {
Box::pin(async { Ok(()) })
}
let tasks = vec![
OneOffTaskInfo {
name: "cleanup".to_string(),
description: String::new(),
handler,
},
OneOffTaskInfo {
name: "cleanup".to_string(),
description: String::new(),
handler,
},
];
let error = validate_unique_one_off_task_names(&tasks)
.expect_err("duplicate task names should be rejected");
assert!(error.contains("duplicate task name 'cleanup'"));
}
}
#[cfg(test)]
mod havoc_proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn parse_duration_fuzz_panic(s in "[0-9]{15,30}[smhd]") {
let _ = parse_duration(&s);
}
}
}