use sayiir_core::codec::{self, Decoder, Encoder};
use sayiir_core::snapshot::{PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot};
use sayiir_persistence::validation::validate_signal_allowed;
use sayiir_persistence::{BackendError, SignalStore};
use sqlx::Row;
use crate::backend::PostgresBackend;
use crate::error::PgError;
impl<C> SignalStore for PostgresBackend<C>
where
C: Encoder
+ Decoder
+ codec::sealed::EncodeValue<WorkflowSnapshot>
+ codec::sealed::DecodeValue<WorkflowSnapshot>,
{
#[tracing::instrument(
name = "db.store_signal",
skip(self, request),
fields(db.system = "postgresql", kind = %kind.as_ref()),
err(level = tracing::Level::ERROR),
)]
async fn store_signal(
&self,
instance_id: &str,
kind: SignalKind,
request: SignalRequest,
) -> Result<(), BackendError> {
tracing::debug!("storing signal");
let row =
sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
.bind(instance_id)
.fetch_optional(&self.pool)
.await
.map_err(PgError)?
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
let status: String = row.get("status");
validate_signal_allowed(&status, kind)?;
sqlx::query(
"INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (instance_id, kind) DO UPDATE SET
reason = $3, requested_by = $4, created_at = now()",
)
.bind(instance_id)
.bind(kind.as_ref())
.bind(&request.reason)
.bind(&request.requested_by)
.execute(&self.pool)
.await
.map_err(PgError)?;
Ok(())
}
#[tracing::instrument(
name = "db.get_signal",
skip(self),
fields(db.system = "postgresql", kind = %kind.as_ref()),
err(level = tracing::Level::ERROR),
)]
async fn get_signal(
&self,
instance_id: &str,
kind: SignalKind,
) -> Result<Option<SignalRequest>, BackendError> {
tracing::debug!("getting signal");
let row = sqlx::query(
"SELECT reason, requested_by, created_at
FROM sayiir_workflow_signals
WHERE instance_id = $1 AND kind = $2",
)
.bind(instance_id)
.bind(kind.as_ref())
.fetch_optional(&self.pool)
.await
.map_err(PgError)?;
Ok(row.map(|r| SignalRequest {
reason: r.get("reason"),
requested_by: r.get("requested_by"),
requested_at: r.get("created_at"),
}))
}
#[tracing::instrument(
name = "db.clear_signal",
skip(self),
fields(db.system = "postgresql", kind = %kind.as_ref()),
err(level = tracing::Level::ERROR),
)]
async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
tracing::debug!("clearing signal");
sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
.bind(instance_id)
.bind(kind.as_ref())
.execute(&self.pool)
.await
.map_err(PgError)?;
Ok(())
}
#[tracing::instrument(
name = "db.send_event",
skip(self, payload),
fields(db.system = "postgresql"),
err(level = tracing::Level::ERROR),
)]
async fn send_event(
&self,
instance_id: &str,
signal_name: &str,
payload: bytes::Bytes,
) -> Result<(), BackendError> {
tracing::debug!("buffering external event");
sqlx::query(
"INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
VALUES ($1, $2, $3)",
)
.bind(instance_id)
.bind(signal_name)
.bind(payload.as_ref())
.execute(&self.pool)
.await
.map_err(PgError)?;
Ok(())
}
#[tracing::instrument(
name = "db.consume_event",
skip(self),
fields(db.system = "postgresql"),
err(level = tracing::Level::ERROR),
)]
async fn consume_event(
&self,
instance_id: &str,
signal_name: &str,
) -> Result<Option<bytes::Bytes>, BackendError> {
tracing::debug!("consuming oldest buffered event");
let row = sqlx::query(
"DELETE FROM sayiir_workflow_events
WHERE id = (
SELECT id FROM sayiir_workflow_events
WHERE instance_id = $1 AND signal_name = $2
ORDER BY id ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING payload",
)
.bind(instance_id)
.bind(signal_name)
.fetch_optional(&self.pool)
.await
.map_err(PgError)?;
Ok(row.map(|r| {
let raw: Vec<u8> = r.get("payload");
bytes::Bytes::from(raw)
}))
}
#[tracing::instrument(
name = "db.check_and_cancel",
skip(self),
fields(db.system = "postgresql"),
err(level = tracing::Level::ERROR),
)]
async fn check_and_cancel(
&self,
instance_id: &str,
interrupted_at_task: Option<&str>,
) -> Result<bool, BackendError> {
tracing::debug!("checking for cancel signal");
let mut tx = self.pool.begin().await.map_err(PgError)?;
let signal_row = sqlx::query(
"SELECT reason, requested_by
FROM sayiir_workflow_signals
WHERE instance_id = $1 AND kind = $2
FOR UPDATE",
)
.bind(instance_id)
.bind(SignalKind::Cancel.as_ref())
.fetch_optional(&mut *tx)
.await
.map_err(PgError)?;
let Some(signal_row) = signal_row else {
tx.rollback().await.map_err(PgError)?;
return Ok(false);
};
let snap_row = sqlx::query(
"SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
)
.bind(instance_id)
.fetch_one(&mut *tx)
.await
.map_err(PgError)?;
let raw: &[u8] = snap_row.get("data");
let mut snapshot = self.decode(raw)?;
if !snapshot.state.is_in_progress() {
tx.rollback().await.map_err(PgError)?;
return Ok(false);
}
let reason: Option<String> = signal_row.get("reason");
let requested_by: Option<String> = signal_row.get("requested_by");
snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
let data = self.encode(&snapshot)?;
let status = snapshot.state.as_ref();
let error = snapshot.error_message().map(ToString::to_string);
let pos_kind = snapshot.position_kind();
let wake_at = snapshot.delay_wake_at();
sqlx::query(
"UPDATE sayiir_workflow_snapshots
SET data = $1, status = $2, error = $3,
position_kind = $4, delay_wake_at = $5,
completed_at = now(), updated_at = now()
WHERE instance_id = $6",
)
.bind(&data)
.bind(status)
.bind(&error)
.bind(pos_kind)
.bind(wake_at)
.bind(instance_id)
.execute(&mut *tx)
.await
.map_err(PgError)?;
sqlx::query(
"UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
WHERE instance_id = $1 AND status = 'active'",
)
.bind(instance_id)
.execute(&mut *tx)
.await
.map_err(PgError)?;
sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
.bind(instance_id)
.bind(SignalKind::Cancel.as_ref())
.execute(&mut *tx)
.await
.map_err(PgError)?;
tx.commit().await.map_err(PgError)?;
tracing::info!(instance_id, "workflow cancelled");
Ok(true)
}
#[tracing::instrument(
name = "db.check_and_pause",
skip(self),
fields(db.system = "postgresql"),
err(level = tracing::Level::ERROR),
)]
async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
tracing::debug!("checking for pause signal");
let mut tx = self.pool.begin().await.map_err(PgError)?;
let signal_row = sqlx::query(
"SELECT reason, requested_by
FROM sayiir_workflow_signals
WHERE instance_id = $1 AND kind = $2
FOR UPDATE",
)
.bind(instance_id)
.bind(SignalKind::Pause.as_ref())
.fetch_optional(&mut *tx)
.await
.map_err(PgError)?;
let Some(signal_row) = signal_row else {
tx.rollback().await.map_err(PgError)?;
return Ok(false);
};
let snap_row = sqlx::query(
"SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
)
.bind(instance_id)
.fetch_one(&mut *tx)
.await
.map_err(PgError)?;
let raw: &[u8] = snap_row.get("data");
let mut snapshot = self.decode(raw)?;
if !snapshot.state.is_in_progress() {
tx.rollback().await.map_err(PgError)?;
return Ok(false);
}
let reason: Option<String> = signal_row.get("reason");
let requested_by: Option<String> = signal_row.get("requested_by");
let pause_request = PauseRequest::new(reason, requested_by);
snapshot.mark_paused(&pause_request);
let data = self.encode(&snapshot)?;
let status = snapshot.state.as_ref();
let task_id = snapshot.current_task_id().map(ToString::to_string);
let task_count = snapshot.completed_task_count();
let pos_kind = snapshot.position_kind();
let wake_at = snapshot.delay_wake_at();
sqlx::query(
"UPDATE sayiir_workflow_snapshots
SET data = $1, status = $2, current_task_id = $3,
completed_task_count = $4, position_kind = $5,
delay_wake_at = $6, updated_at = now()
WHERE instance_id = $7",
)
.bind(&data)
.bind(status)
.bind(&task_id)
.bind(task_count)
.bind(pos_kind)
.bind(wake_at)
.bind(instance_id)
.execute(&mut *tx)
.await
.map_err(PgError)?;
sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
.bind(instance_id)
.bind(SignalKind::Pause.as_ref())
.execute(&mut *tx)
.await
.map_err(PgError)?;
tx.commit().await.map_err(PgError)?;
tracing::info!(instance_id, "workflow paused");
Ok(true)
}
#[tracing::instrument(
name = "db.unpause",
skip(self),
fields(db.system = "postgresql"),
err(level = tracing::Level::ERROR),
)]
async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
tracing::debug!("unpausing workflow");
let mut tx = self.pool.begin().await.map_err(PgError)?;
let row = sqlx::query(
"SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
)
.bind(instance_id)
.fetch_optional(&mut *tx)
.await
.map_err(PgError)?
.ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
let raw: &[u8] = row.get("data");
let mut snapshot = self.decode(raw)?;
if !snapshot.state.is_paused() {
let state_name = snapshot.state.as_ref();
return Err(BackendError::CannotPause(format!(
"Workflow is not paused (current state: {state_name:?})"
)));
}
snapshot.mark_unpaused();
let data = self.encode(&snapshot)?;
let status = snapshot.state.as_ref();
let task_id = snapshot.current_task_id().map(ToString::to_string);
let task_count = snapshot.completed_task_count();
let pos_kind = snapshot.position_kind();
let wake_at = snapshot.delay_wake_at();
sqlx::query(
"UPDATE sayiir_workflow_snapshots
SET data = $1, status = $2, current_task_id = $3,
completed_task_count = $4, position_kind = $5,
delay_wake_at = $6, updated_at = now()
WHERE instance_id = $7",
)
.bind(&data)
.bind(status)
.bind(&task_id)
.bind(task_count)
.bind(pos_kind)
.bind(wake_at)
.bind(instance_id)
.execute(&mut *tx)
.await
.map_err(PgError)?;
tx.commit().await.map_err(PgError)?;
tracing::info!(instance_id, "workflow unpaused");
Ok(snapshot)
}
}