use std::collections::{HashMap, HashSet};
use apalis_sql::{DateTime, DateTimeExt};
use diesel::{
Connection, PgConnection, RunQueryDsl, sql_query,
sql_types::{Array, Binary, Integer, Nullable, Text, Timestamptz},
};
use ulid::Ulid;
use crate::{CompactType, Config, Error, PgPool, PgTask, queries::with_conn};
pub(crate) const MAX_METADATA_PAYLOAD_LEN: usize = 8 * 1024;
pub(crate) const MAX_QUEUE_NAME_LEN: usize = 255;
pub(crate) const MAX_IDEMPOTENCY_KEY_LEN: usize = 1024;
#[derive(diesel::QueryableByName)]
struct ReturnedIdempotencyKey {
#[diesel(sql_type = Nullable<Text>)]
idempotency_key: Option<String>,
}
pub(crate) fn push_tasks(
pool: PgPool,
config: Config,
tasks: Vec<PgTask<CompactType>>,
) -> impl Future<Output = Result<(), Error>> + Send {
with_conn(pool, move |conn| push_tasks_on_conn(conn, &config, tasks))
}
pub(crate) fn push_tasks_on_conn(
conn: &mut PgConnection,
config: &Config,
tasks: Vec<PgTask<CompactType>>,
) -> Result<(), Error> {
if tasks.is_empty() {
return Ok(());
}
let job_type = config.queue().to_string();
if job_type.len() > MAX_QUEUE_NAME_LEN {
return Err(Error::InvalidArgument(format!(
"queue name is {} bytes, exceeds the {MAX_QUEUE_NAME_LEN}-byte cap",
job_type.len(),
)));
}
let mut ids = Vec::with_capacity(tasks.len());
let mut jobs = Vec::with_capacity(tasks.len());
let mut max_attempts = Vec::with_capacity(tasks.len());
let mut run_ats = Vec::with_capacity(tasks.len());
let mut priorities = Vec::with_capacity(tasks.len());
let mut metadata = Vec::with_capacity(tasks.len());
let mut idempotency_keys = Vec::with_capacity(tasks.len());
for task in tasks {
ids.push(
task.parts
.task_id
.map(|task_id| task_id.to_string())
.unwrap_or_else(|| Ulid::new().to_string()),
);
jobs.push(task.args);
max_attempts.push(task.parts.ctx.max_attempts());
let run_at_secs = i64::try_from(task.parts.run_at).map_err(|_| {
Error::InvalidArgument(format!(
"run_at {} exceeds i64::MAX seconds and cannot be stored",
task.parts.run_at
))
})?;
run_ats.push(<DateTime as DateTimeExt>::from_unix_timestamp(run_at_secs));
priorities.push(task.parts.ctx.priority());
let meta_json = serde_json::to_string(task.parts.ctx.meta())
.map_err(|err| Error::InvalidArgument(format!("serializing task metadata: {err}")))?;
if meta_json.len() > MAX_METADATA_PAYLOAD_LEN {
return Err(Error::InvalidArgument(format!(
"task metadata is {} bytes, exceeds the {MAX_METADATA_PAYLOAD_LEN}-byte cap",
meta_json.len(),
)));
}
metadata.push(meta_json);
let idempotency_key = task.parts.idempotency_key;
if let Some(key) = idempotency_key.as_deref()
&& key.len() > MAX_IDEMPOTENCY_KEY_LEN
{
return Err(Error::InvalidArgument(format!(
"idempotency_key is {} bytes, exceeds the {MAX_IDEMPOTENCY_KEY_LEN}-byte cap",
key.len(),
)));
}
idempotency_keys.push(idempotency_key);
}
let task_count = ids.len();
let any_idempotency_key = idempotency_keys.iter().any(Option::is_some);
let conflict_job_type = job_type.clone();
let submitted_keys: Vec<String> = idempotency_keys.iter().flatten().cloned().collect();
conn.transaction(|conn| {
let inserted_rows = sql_query(
"INSERT INTO apalis.jobs (
id,
job_type,
job,
status,
attempts,
max_attempts,
run_at,
priority,
metadata,
idempotency_key
)
SELECT
unnest($1::text[]) AS id,
$2::text AS job_type,
unnest($3::bytea[]) AS job,
'Pending' AS status,
0 AS attempts,
unnest($4::integer[]) AS max_attempts,
unnest($5::timestamptz[]) AS run_at,
unnest($6::integer[]) AS priority,
unnest($7::text[])::jsonb AS metadata,
unnest($8::text[]) AS idempotency_key
ON CONFLICT (job_type, idempotency_key)
WHERE idempotency_key IS NOT NULL
DO NOTHING
RETURNING idempotency_key",
)
.bind::<Array<Text>, _>(ids)
.bind::<Text, _>(job_type)
.bind::<Array<Binary>, _>(jobs)
.bind::<Array<Integer>, _>(max_attempts)
.bind::<Array<Timestamptz>, _>(run_ats)
.bind::<Array<Integer>, _>(priorities)
.bind::<Array<Text>, _>(metadata)
.bind::<Array<Nullable<Text>>, _>(idempotency_keys)
.load::<ReturnedIdempotencyKey>(conn)
.map_err(Error::database("inserting jobs"))?;
let inserted = inserted_rows.len();
if inserted < task_count && any_idempotency_key {
let mut inserted_remaining: HashMap<&str, usize> = HashMap::new();
for row in &inserted_rows {
if let Some(key) = row.idempotency_key.as_deref() {
*inserted_remaining.entry(key).or_insert(0) += 1;
}
}
let mut seen: HashSet<&str> = HashSet::new();
let mut conflicting_keys: Vec<String> = Vec::new();
for key in &submitted_keys {
let inserted_here = inserted_remaining
.get_mut(key.as_str())
.is_some_and(|count| {
if *count > 0 {
*count -= 1;
true
} else {
false
}
});
if !inserted_here && seen.insert(key.as_str()) {
conflicting_keys.push(key.clone());
}
}
return Err(Error::idempotency_conflict(
conflict_job_type,
conflicting_keys,
task_count,
));
}
Ok(())
})
}