use crate::error::AwaError;
use crate::job::{InsertOpts, InsertParams, JobRow, JobState};
use crate::unique::compute_unique_key;
use crate::JobArgs;
use sqlx::postgres::PgConnection;
use sqlx::{PgExecutor, PgPool};
pub async fn insert<'e, E>(executor: E, args: &impl JobArgs) -> Result<JobRow, AwaError>
where
E: PgExecutor<'e>,
{
insert_with(executor, args, InsertOpts::default()).await
}
#[tracing::instrument(skip(executor, args), fields(job.kind = args.kind_str(), job.queue = %opts.queue))]
pub async fn insert_with<'e, E>(
executor: E,
args: &impl JobArgs,
opts: InsertOpts,
) -> Result<JobRow, AwaError>
where
E: PgExecutor<'e>,
{
let kind = args.kind_str();
let args_json = args.to_args()?;
let state = if opts.run_at.is_some() {
JobState::Scheduled
} else {
JobState::Available
};
let unique_key = opts.unique.as_ref().map(|u| {
compute_unique_key(
kind,
if u.by_queue { Some(&opts.queue) } else { None },
if u.by_args { Some(&args_json) } else { None },
u.by_period,
)
});
let unique_states_bits: Option<String> = opts.unique.as_ref().map(|u| {
let mut bit_string = String::with_capacity(8);
for bit_position in 0..8 {
if u.states & (1 << bit_position) != 0 {
bit_string.push('1');
} else {
bit_string.push('0');
}
}
bit_string
});
let row = sqlx::query_as::<_, JobRow>(
r#"
INSERT INTO awa.jobs (kind, queue, args, state, priority, max_attempts, run_at, metadata, tags, unique_key, unique_states)
VALUES ($1, $2, $3, $4, $5, $6, COALESCE($7, now()), $8, $9, $10, $11::bit(8))
RETURNING *
"#,
)
.bind(kind)
.bind(&opts.queue)
.bind(&args_json)
.bind(state)
.bind(opts.priority)
.bind(opts.max_attempts)
.bind(opts.run_at)
.bind(&opts.metadata)
.bind(&opts.tags)
.bind(&unique_key)
.bind(&unique_states_bits)
.fetch_one(executor)
.await
.map_err(|err| {
if let sqlx::Error::Database(ref db_err) = err {
if db_err.code().as_deref() == Some("23505") {
return AwaError::UniqueConflict {
constraint: db_err
.constraint()
.map(|c| c.to_string()),
};
}
}
AwaError::Database(err)
})?;
Ok(row)
}
struct RowValues {
kind: String,
queue: String,
args: serde_json::Value,
state: JobState,
priority: i16,
max_attempts: i16,
run_at: Option<chrono::DateTime<chrono::Utc>>,
metadata: serde_json::Value,
tags: Vec<String>,
unique_key: Option<Vec<u8>>,
unique_states: Option<String>,
}
fn precompute_row_values(jobs: &[InsertParams]) -> Vec<RowValues> {
jobs.iter()
.map(|job| {
let unique_key = job.opts.unique.as_ref().map(|u| {
compute_unique_key(
&job.kind,
if u.by_queue {
Some(job.opts.queue.as_str())
} else {
None
},
if u.by_args { Some(&job.args) } else { None },
u.by_period,
)
});
let unique_states = job.opts.unique.as_ref().map(|u| {
let mut bit_string = String::with_capacity(8);
for bit_position in 0..8 {
if u.states & (1 << bit_position) != 0 {
bit_string.push('1');
} else {
bit_string.push('0');
}
}
bit_string
});
RowValues {
kind: job.kind.clone(),
queue: job.opts.queue.clone(),
args: job.args.clone(),
state: if job.opts.run_at.is_some() {
JobState::Scheduled
} else {
JobState::Available
},
priority: job.opts.priority,
max_attempts: job.opts.max_attempts,
run_at: job.opts.run_at,
metadata: job.opts.metadata.clone(),
tags: job.opts.tags.clone(),
unique_key,
unique_states,
}
})
.collect()
}
#[tracing::instrument(skip(executor, jobs), fields(job.count = jobs.len()))]
pub async fn insert_many<'e, E>(executor: E, jobs: &[InsertParams]) -> Result<Vec<JobRow>, AwaError>
where
E: PgExecutor<'e>,
{
if jobs.is_empty() {
return Ok(Vec::new());
}
let count = jobs.len();
let rows = precompute_row_values(jobs);
let mut query = String::from(
"INSERT INTO awa.jobs (kind, queue, args, state, priority, max_attempts, run_at, metadata, tags, unique_key, unique_states) VALUES ",
);
let params_per_row = 11u32;
let mut param_index = 1u32;
for i in 0..count {
if i > 0 {
query.push_str(", ");
}
query.push_str(&format!(
"(${}, ${}, ${}, ${}, ${}, ${}, COALESCE(${}, now()), ${}, ${}, ${}, ${}::bit(8))",
param_index,
param_index + 1,
param_index + 2,
param_index + 3,
param_index + 4,
param_index + 5,
param_index + 6,
param_index + 7,
param_index + 8,
param_index + 9,
param_index + 10,
));
param_index += params_per_row;
}
query.push_str(" RETURNING *");
let mut sql_query = sqlx::query_as::<_, JobRow>(&query);
for row in &rows {
sql_query = sql_query
.bind(&row.kind)
.bind(&row.queue)
.bind(&row.args)
.bind(row.state)
.bind(row.priority)
.bind(row.max_attempts)
.bind(row.run_at)
.bind(&row.metadata)
.bind(&row.tags)
.bind(&row.unique_key)
.bind(&row.unique_states);
}
let results = sql_query.fetch_all(executor).await?;
Ok(results)
}
#[tracing::instrument(skip(conn, jobs), fields(job.count = jobs.len()))]
pub async fn insert_many_copy(
conn: &mut PgConnection,
jobs: &[InsertParams],
) -> Result<Vec<JobRow>, AwaError> {
if jobs.is_empty() {
return Ok(Vec::new());
}
let rows = precompute_row_values(jobs);
sqlx::query(
r#"
CREATE TEMP TABLE awa_copy_staging (
kind TEXT NOT NULL,
queue TEXT NOT NULL,
args JSONB NOT NULL,
state TEXT NOT NULL,
priority SMALLINT NOT NULL,
max_attempts SMALLINT NOT NULL,
run_at TEXT,
metadata JSONB NOT NULL,
tags TEXT NOT NULL,
unique_key TEXT,
unique_states TEXT
) ON COMMIT DROP
"#,
)
.execute(&mut *conn)
.await?;
let mut csv_buf = Vec::with_capacity(rows.len() * 256);
for row in &rows {
write_csv_row(&mut csv_buf, row);
}
let mut copy_in = conn
.copy_in_raw(
"COPY awa_copy_staging (kind, queue, args, state, priority, max_attempts, run_at, metadata, tags, unique_key, unique_states) FROM STDIN WITH (FORMAT csv, NULL '\\N')",
)
.await?;
copy_in.send(csv_buf).await?;
copy_in.finish().await?;
let has_unique = rows.iter().any(|r| r.unique_key.is_some());
let results = if has_unique {
let staged_rows = sqlx::query_as::<
_,
(
String,
String,
serde_json::Value,
String,
i16,
i16,
Option<chrono::DateTime<chrono::Utc>>,
serde_json::Value,
Vec<String>,
Option<Vec<u8>>,
Option<String>,
),
>(
r#"
SELECT
kind,
queue,
args,
state,
priority,
max_attempts,
CASE WHEN run_at = '\N' OR run_at IS NULL THEN NULL ELSE run_at::timestamptz END,
metadata,
tags::text[],
CASE WHEN unique_key = '\N' OR unique_key IS NULL THEN NULL ELSE decode(unique_key, 'hex') END,
unique_states
FROM awa_copy_staging
"#,
)
.fetch_all(&mut *conn)
.await?;
let mut inserted = Vec::with_capacity(staged_rows.len());
for (
kind,
queue,
args,
state,
priority,
max_attempts,
run_at,
metadata,
tags,
unique_key,
unique_states,
) in staged_rows
{
sqlx::query("SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
let result = sqlx::query_as::<_, JobRow>(
r#"
INSERT INTO awa.jobs (kind, queue, args, state, priority, max_attempts, run_at, metadata, tags, unique_key, unique_states)
VALUES ($1, $2, $3, $4::awa.job_state, $5, $6, COALESCE($7, now()), $8, $9, $10, $11::bit(8))
RETURNING *
"#,
)
.bind(&kind)
.bind(&queue)
.bind(&args)
.bind(&state)
.bind(priority)
.bind(max_attempts)
.bind(run_at)
.bind(&metadata)
.bind(&tags)
.bind(&unique_key)
.bind(&unique_states)
.fetch_one(&mut *conn)
.await;
match result {
Ok(row) => {
inserted.push(row);
sqlx::query("RELEASE SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
}
Err(sqlx::Error::Database(db_err)) if db_err.code().as_deref() == Some("23505") => {
sqlx::query("ROLLBACK TO SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
sqlx::query("RELEASE SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
continue;
}
Err(err) => {
sqlx::query("ROLLBACK TO SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
sqlx::query("RELEASE SAVEPOINT awa_copy_unique_row")
.execute(&mut *conn)
.await?;
return Err(AwaError::Database(err));
}
}
}
inserted
} else {
let insert_sql = r#"
INSERT INTO awa.jobs (kind, queue, args, state, priority, max_attempts, run_at, metadata, tags, unique_key, unique_states)
SELECT
s.kind,
s.queue,
s.args,
s.state::awa.job_state,
s.priority,
s.max_attempts,
CASE WHEN s.run_at = '\N' OR s.run_at IS NULL THEN now() ELSE s.run_at::timestamptz END,
s.metadata,
s.tags::text[],
CASE WHEN s.unique_key = '\N' OR s.unique_key IS NULL THEN NULL ELSE decode(s.unique_key, 'hex') END,
CASE WHEN s.unique_states = '\N' OR s.unique_states IS NULL THEN NULL ELSE s.unique_states::bit(8) END
FROM awa_copy_staging s
RETURNING *
"#;
sqlx::query_as::<_, JobRow>(insert_sql)
.fetch_all(&mut *conn)
.await?
};
Ok(results)
}
#[tracing::instrument(skip(pool, jobs), fields(job.count = jobs.len()))]
pub async fn insert_many_copy_from_pool(
pool: &PgPool,
jobs: &[InsertParams],
) -> Result<Vec<JobRow>, AwaError> {
if jobs.is_empty() {
return Ok(Vec::new());
}
let mut tx = pool.begin().await?;
let results = insert_many_copy(&mut tx, jobs).await?;
tx.commit().await?;
Ok(results)
}
fn write_csv_row(buf: &mut Vec<u8>, row: &RowValues) {
write_csv_field(buf, &row.kind);
buf.push(b',');
write_csv_field(buf, &row.queue);
buf.push(b',');
let args_str = serde_json::to_string(&row.args).expect("JSON serialization should not fail");
write_csv_field(buf, &args_str);
buf.push(b',');
write_csv_field(buf, &row.state.to_string());
buf.push(b',');
buf.extend_from_slice(row.priority.to_string().as_bytes());
buf.push(b',');
buf.extend_from_slice(row.max_attempts.to_string().as_bytes());
buf.push(b',');
match &row.run_at {
Some(dt) => write_csv_field(buf, &dt.to_rfc3339()),
None => buf.extend_from_slice(b"\\N"),
}
buf.push(b',');
let metadata_str =
serde_json::to_string(&row.metadata).expect("JSON serialization should not fail");
write_csv_field(buf, &metadata_str);
buf.push(b',');
write_pg_text_array(buf, &row.tags);
buf.push(b',');
match &row.unique_key {
Some(key) => {
let hex = hex::encode(key);
write_csv_field(buf, &hex);
}
None => buf.extend_from_slice(b"\\N"),
}
buf.push(b',');
match &row.unique_states {
Some(bits) => write_csv_field(buf, bits),
None => buf.extend_from_slice(b"\\N"),
}
buf.push(b'\n');
}
fn write_csv_field(buf: &mut Vec<u8>, value: &str) {
if value.contains(',')
|| value.contains('"')
|| value.contains('\n')
|| value.contains('\r')
|| value.contains('\\')
{
buf.push(b'"');
for byte in value.bytes() {
if byte == b'"' {
buf.push(b'"');
}
buf.push(byte);
}
buf.push(b'"');
} else {
buf.extend_from_slice(value.as_bytes());
}
}
fn write_pg_text_array(buf: &mut Vec<u8>, values: &[String]) {
buf.push(b'"');
buf.push(b'{');
for (i, val) in values.iter().enumerate() {
if i > 0 {
buf.push(b',');
}
if val.is_empty()
|| val.contains(',')
|| val.contains('"')
|| val.contains('\\')
|| val.contains('{')
|| val.contains('}')
|| val.contains(' ')
|| val.eq_ignore_ascii_case("NULL")
{
buf.push(b'"');
buf.push(b'"');
for ch in val.chars() {
match ch {
'"' => {
buf.extend_from_slice(b"\\\"\"");
}
'\\' => {
buf.extend_from_slice(b"\\\\");
}
_ => {
let mut utf8_buf = [0u8; 4];
buf.extend_from_slice(ch.encode_utf8(&mut utf8_buf).as_bytes());
}
}
}
buf.push(b'"');
buf.push(b'"');
} else {
buf.extend_from_slice(val.as_bytes());
}
}
buf.push(b'}');
buf.push(b'"');
}
pub fn params(args: &impl JobArgs) -> Result<InsertParams, AwaError> {
params_with(args, InsertOpts::default())
}
pub fn params_with(args: &impl JobArgs, opts: InsertOpts) -> Result<InsertParams, AwaError> {
Ok(InsertParams {
kind: args.kind_str().to_string(),
args: args.to_args()?,
opts,
})
}