use std::sync::Arc;
use sea_orm::{ConnectionTrait, DbBackend, FromQueryResult, Statement, TransactionTrait};
use tokio_util::sync::CancellationToken;
use tracing::warn;
use super::super::Outbox;
use super::super::dialect::{AllocSql, Dialect};
use super::super::prioritizer::SharedPrioritizer;
use super::super::taskward::{Directive, WorkerAction};
use super::super::types::{OutboxError, SequencerConfig};
use crate::Db;
use crate::contention::is_retryable_contention;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SequencerReport {
pub partition_id: i64,
pub rows_claimed: u32,
}
pub struct Sequencer {
config: SequencerConfig,
outbox: Arc<Outbox>,
db: Db,
shared_prioritizer: Arc<SharedPrioritizer>,
}
#[derive(Debug, FromQueryResult)]
struct ClaimedIncoming {
id: i64,
body_id: i64,
}
impl Sequencer {
#[must_use]
pub fn new(
config: SequencerConfig,
outbox: Arc<Outbox>,
db: Db,
shared_prioritizer: Arc<SharedPrioritizer>,
) -> Self {
Self {
config,
outbox,
db,
shared_prioritizer,
}
}
async fn process_partition(
&self,
partition_id: i64,
) -> Result<PartitionProcessResult, PartitionError> {
let conn = self.db.sea_internal();
let backend = conn.get_database_backend();
let dialect = Dialect::from(backend);
let mut drained = true;
let mut total_claimed: u32 = 0;
for _iteration in 0..self.config.max_inner_iterations {
let txn = conn.begin().await?;
if let Some(lock_sql) = dialect.lock_partition() {
let locked = self
.try_lock_partition(&txn, backend, partition_id, lock_sql)
.await?;
if !locked {
drop(txn);
return Err(PartitionError::Skipped);
}
}
let claimed = self
.claim_incoming_for_partition(&txn, backend, &dialect, partition_id)
.await?;
if claimed.is_empty() {
drained = true;
drop(txn);
break;
}
#[allow(clippy::cast_possible_wrap)]
let item_count = claimed.len() as i64;
#[allow(clippy::cast_possible_truncation)]
let drained_this_iteration = (claimed.len() as u32) < self.config.batch_size;
let start_seq = self
.allocate_sequences(&txn, backend, &dialect, partition_id, item_count)
.await?;
let outgoing_sql = dialect.build_insert_outgoing_batch(claimed.len());
let mut values: Vec<sea_orm::Value> = Vec::with_capacity(claimed.len() * 3);
for (i, item) in claimed.iter().enumerate() {
#[allow(clippy::cast_possible_wrap)]
let seq = start_seq + 1 + i as i64;
values.push(partition_id.into());
values.push(item.body_id.into());
values.push(seq.into());
}
txn.execute(Statement::from_sql_and_values(
backend,
&outgoing_sql,
values,
))
.await?;
txn.commit().await?;
#[allow(clippy::cast_possible_truncation)]
{
total_claimed += claimed.len() as u32;
}
self.outbox.notify_partition(partition_id);
if drained_this_iteration {
drained = true;
break;
}
drained = false;
}
Ok(PartitionProcessResult {
drained,
rows_claimed: total_claimed,
})
}
}
struct PartitionProcessResult {
drained: bool,
rows_claimed: u32,
}
enum PartitionError {
Skipped,
Db(OutboxError),
}
impl From<sea_orm::DbErr> for PartitionError {
fn from(e: sea_orm::DbErr) -> Self {
Self::Db(OutboxError::Database(e))
}
}
impl From<OutboxError> for PartitionError {
fn from(e: OutboxError) -> Self {
Self::Db(e)
}
}
impl WorkerAction for Sequencer {
type Payload = SequencerReport;
type Error = OutboxError;
async fn execute(
&mut self,
_cancel: &CancellationToken,
) -> Result<Directive<SequencerReport>, OutboxError> {
let Some(guard) = self.shared_prioritizer.take() else {
return Ok(Directive::Idle(SequencerReport {
partition_id: -1,
rows_claimed: 0,
}));
};
let pid = guard.partition_id();
match self.process_partition(pid).await {
Ok(result) => {
let report = SequencerReport {
partition_id: pid,
rows_claimed: result.rows_claimed,
};
guard.processed();
if !result.drained {
self.shared_prioritizer.push_dirty(pid);
}
if result.rows_claimed == 0 {
Ok(Directive::Idle(report))
} else {
Ok(Directive::Proceed(report))
}
}
Err(PartitionError::Skipped) => {
guard.skipped();
Ok(Directive::Idle(SequencerReport {
partition_id: pid,
rows_claimed: 0,
}))
}
Err(PartitionError::Db(e)) => {
let backend = self.db.sea_internal().get_database_backend();
if matches!(&e, OutboxError::Database(db_err) if is_retryable_contention(backend, db_err))
{
tracing::debug!(
partition_id = pid,
"sequencer: lock contention, retrying partition"
);
guard.skipped();
self.shared_prioritizer.push_dirty(pid);
return Ok(Directive::Proceed(SequencerReport {
partition_id: pid,
rows_claimed: 0,
}));
}
warn!(partition_id = pid, error = %e, "sequencer partition error");
guard.error();
Ok(Directive::Proceed(SequencerReport {
partition_id: pid,
rows_claimed: 0,
}))
}
}
}
}
impl Sequencer {
async fn try_lock_partition(
&self,
txn: &impl ConnectionTrait,
backend: DbBackend,
partition_id: i64,
sql: &str,
) -> Result<bool, OutboxError> {
let row = txn
.query_one(Statement::from_sql_and_values(
backend,
sql,
[partition_id.into()],
))
.await?;
Ok(row.is_some())
}
async fn claim_incoming_for_partition(
&self,
txn: &impl ConnectionTrait,
backend: DbBackend,
dialect: &Dialect,
partition_id: i64,
) -> Result<Vec<ClaimedIncoming>, OutboxError> {
let claim = dialect.claim_incoming(self.config.batch_size);
let rows = ClaimedIncoming::find_by_statement(Statement::from_sql_and_values(
backend,
&claim.select,
[partition_id.into()],
))
.all(txn)
.await?;
if rows.is_empty() {
return Ok(rows);
}
let delete_sql = dialect.delete_incoming_batch(rows.len());
let values: Vec<sea_orm::Value> = rows.iter().map(|r| r.id.into()).collect();
txn.execute(Statement::from_sql_and_values(backend, &delete_sql, values))
.await?;
Ok(rows)
}
async fn allocate_sequences(
&self,
txn: &impl ConnectionTrait,
backend: DbBackend,
dialect: &Dialect,
partition_id: i64,
count: i64,
) -> Result<i64, OutboxError> {
match dialect.allocate_sequences() {
AllocSql::UpdateReturning(sql) => {
let row = txn
.query_one(Statement::from_sql_and_values(
backend,
sql,
[partition_id.into(), count.into()],
))
.await?
.ok_or_else(|| {
OutboxError::Database(sea_orm::DbErr::Custom(
"UPDATE RETURNING returned no row for sequence allocation".to_owned(),
))
})?;
let start_seq: i64 = row.try_get_by_index(0).map_err(|e| {
OutboxError::Database(sea_orm::DbErr::Custom(format!("start_seq column: {e}")))
})?;
Ok(start_seq)
}
AllocSql::UpdateThenSelect { update, select } => {
txn.execute(Statement::from_sql_and_values(
backend,
update,
[count.into(), partition_id.into()],
))
.await?;
let row = txn
.query_one(Statement::from_sql_and_values(
backend,
select,
[count.into(), partition_id.into()],
))
.await?
.ok_or_else(|| {
OutboxError::Database(sea_orm::DbErr::Custom(
"SELECT returned no row for sequence allocation".to_owned(),
))
})?;
let start_seq: i64 = row.try_get_by_index(0).map_err(|e| {
OutboxError::Database(sea_orm::DbErr::Custom(format!("start_seq column: {e}")))
})?;
Ok(start_seq)
}
}
}
}