use anyhow::{Context, Result};
use chrono::{TimeZone, Utc};
use duroxide::providers::{
DeleteInstanceResult, DispatcherCapabilityFilter, ExecutionInfo, ExecutionMetadata,
InstanceFilter, InstanceInfo, OrchestrationItem, Provider, ProviderAdmin, ProviderError,
PruneOptions, PruneResult, QueueDepths, ScheduledActivityIdentifier, SessionFetchConfig,
SystemMetrics, TagFilter, WorkItem,
};
use duroxide::{Event, EventKind, SystemStats};
use sqlx::postgres::{PgConnectOptions, PgSslMode};
use sqlx::{postgres::PgPoolOptions, Error as SqlxError, PgPool};
use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::task::AbortHandle;
use tokio::time::sleep;
use tracing::{debug, error, instrument, warn};
use crate::entra::{EntraAuthOptions, TokenSource};
use crate::migrations::MigrationRunner;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SqlStateClass {
Retryable,
Permanent,
}
pub(crate) fn classify_pg_sqlstate(code: Option<&str>, is_entra: bool) -> SqlStateClass {
match code {
Some("40P01") => SqlStateClass::Retryable, Some("28000") | Some("28P01") if is_entra => SqlStateClass::Retryable, Some("40001") => SqlStateClass::Permanent, Some("23505") => SqlStateClass::Permanent, Some("23503") => SqlStateClass::Permanent, Some("0A000") => SqlStateClass::Retryable, _ => SqlStateClass::Permanent,
}
}
pub struct PostgresProvider {
pool: Arc<PgPool>,
schema_name: String,
is_entra: bool,
_refresh_task: Option<AbortOnDropHandle>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum MigrationPolicy {
#[default]
ApplyAll,
VerifyOnly,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ProviderConfig {
pub connection: ConnectionConfig,
pub schema_name: Option<String>,
pub migration_policy: MigrationPolicy,
}
impl ProviderConfig {
pub fn url(database_url: impl Into<String>) -> Self {
Self {
connection: ConnectionConfig::Url(database_url.into()),
schema_name: None,
migration_policy: MigrationPolicy::default(),
}
}
pub fn entra(
host: impl Into<String>,
port: u16,
database: impl Into<String>,
user: impl Into<String>,
options: EntraAuthOptions,
) -> Self {
Self {
connection: ConnectionConfig::Entra {
host: host.into(),
port,
database: database.into(),
user: user.into(),
options,
},
schema_name: None,
migration_policy: MigrationPolicy::default(),
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ConnectionConfig {
Url(String),
Entra {
host: String,
port: u16,
database: String,
user: String,
options: EntraAuthOptions,
},
}
fn validate_schema_name(schema_name: &str) -> Result<()> {
let mut chars = schema_name.chars();
let Some(first) = chars.next() else {
anyhow::bail!("Invalid schema_name '': must match [A-Za-z_][A-Za-z0-9_]*");
};
if !(first == '_' || first.is_ascii_alphabetic()) {
anyhow::bail!("Invalid schema_name '{schema_name}': must match [A-Za-z_][A-Za-z0-9_]*");
}
for ch in chars {
if !(ch == '_' || ch.is_ascii_alphanumeric()) {
anyhow::bail!("Invalid schema_name '{schema_name}': must match [A-Za-z_][A-Za-z0-9_]*");
}
}
Ok(())
}
struct AbortOnDropHandle(AbortHandle);
impl Drop for AbortOnDropHandle {
fn drop(&mut self) {
self.0.abort();
}
}
impl PostgresProvider {
pub async fn new(database_url: &str) -> Result<Self> {
Self::new_with_config(ProviderConfig::url(database_url)).await
}
pub async fn new_with_schema(database_url: &str, schema_name: Option<&str>) -> Result<Self> {
let mut config = ProviderConfig::url(database_url);
config.schema_name = schema_name.map(str::to_string);
Self::new_with_config(config).await
}
pub async fn new_with_config(config: ProviderConfig) -> Result<Self> {
let ProviderConfig {
connection,
schema_name,
migration_policy,
} = config;
if let Some(ref s) = schema_name {
validate_schema_name(s)?;
}
match connection {
ConnectionConfig::Url(database_url) => {
Self::new_from_url(&database_url, schema_name.as_deref(), migration_policy).await
}
ConnectionConfig::Entra {
host,
port,
database,
user,
options,
} => {
let token_source = options.default_token_source().context(
"Entra credential resolution failed: could not build the default credential chain",
)?;
Self::new_with_entra_with_token_source(
&host,
port,
&database,
&user,
schema_name.as_deref(),
options,
token_source,
PgSslMode::VerifyFull,
migration_policy,
)
.await
}
}
}
async fn new_from_url(
database_url: &str,
schema_name: Option<&str>,
migration_policy: MigrationPolicy,
) -> Result<Self> {
let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(10);
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.min_connections(1)
.acquire_timeout(std::time::Duration::from_secs(30))
.connect(database_url)
.await?;
let schema_name = schema_name.unwrap_or("public").to_string();
let provider = Self {
pool: Arc::new(pool),
schema_name: schema_name.clone(),
is_entra: false,
_refresh_task: None,
};
let migration_runner = MigrationRunner::new(provider.pool.clone(), schema_name);
match migration_policy {
MigrationPolicy::ApplyAll => migration_runner.migrate().await?,
MigrationPolicy::VerifyOnly => migration_runner.verify().await?,
}
Ok(provider)
}
#[deprecated(
since = "0.1.34",
note = "use `PostgresProvider::new_with_config(ProviderConfig::entra(...))` instead"
)]
pub async fn new_with_entra(
host: &str,
port: u16,
database: &str,
user: &str,
options: EntraAuthOptions,
) -> Result<Self> {
Self::new_with_config(ProviderConfig::entra(host, port, database, user, options)).await
}
#[deprecated(
since = "0.1.34",
note = "use `PostgresProvider::new_with_config(ProviderConfig::entra(...))` with `schema_name` set instead"
)]
#[instrument(
skip(options),
fields(host = %host, port = %port, database = %database, user = %user, schema = ?schema_name),
target = "duroxide::providers::postgres",
)]
pub async fn new_with_schema_and_entra(
host: &str,
port: u16,
database: &str,
user: &str,
schema_name: Option<&str>,
options: EntraAuthOptions,
) -> Result<Self> {
let mut config = ProviderConfig::entra(host, port, database, user, options);
config.schema_name = schema_name.map(str::to_string);
Self::new_with_config(config).await
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn new_with_entra_with_token_source(
host: &str,
port: u16,
database: &str,
user: &str,
schema_name: Option<&str>,
options: EntraAuthOptions,
token_source: Arc<dyn TokenSource>,
ssl_mode: PgSslMode,
migration_policy: MigrationPolicy,
) -> Result<Self> {
let audience = options.audience_str().to_string();
let token = token_source
.fetch_token(&[audience.as_str()])
.await
.context(
"Entra credential resolution failed: could not acquire an initial access token",
)?;
let base_options = build_entra_connect_options(host, port, database, user, ssl_mode);
let pool = PgPoolOptions::new()
.max_connections(options.max_connections_value())
.min_connections(1)
.acquire_timeout(options.acquire_timeout_value())
.connect_with(base_options.clone().password(&token.secret))
.await?;
let pool = Arc::new(pool);
let schema_name = schema_name.unwrap_or("public").to_string();
let migration_runner = MigrationRunner::new(pool.clone(), schema_name.clone());
match migration_policy {
MigrationPolicy::ApplyAll => migration_runner.migrate().await?,
MigrationPolicy::VerifyOnly => migration_runner.verify().await?,
}
let refresh_handle = spawn_token_refresh_task(
pool.clone(),
token_source,
base_options,
audience,
options.refresh_interval_value(),
token.expires_at,
);
Ok(Self {
pool,
schema_name,
is_entra: true,
_refresh_task: Some(AbortOnDropHandle(refresh_handle)),
})
}
#[deprecated(
since = "0.1.34",
note = "schema initialization is now run automatically by every constructor; this shim will be removed in a future release"
)]
#[instrument(skip(self), target = "duroxide::providers::postgres")]
pub async fn initialize_schema(&self) -> Result<()> {
let migration_runner = MigrationRunner::new(self.pool.clone(), self.schema_name.clone());
migration_runner.migrate().await?;
Ok(())
}
fn now_millis() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}
fn table_name(&self, table: &str) -> String {
format!("{}.{}", self.schema_name, table)
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn schema_name(&self) -> &str {
&self.schema_name
}
fn sqlx_to_provider_error(&self, operation: &str, e: SqlxError) -> ProviderError {
match e {
SqlxError::Database(ref db_err) => {
let code_opt = db_err.code();
let code = code_opt.as_deref();
match classify_pg_sqlstate(code, self.is_entra) {
SqlStateClass::Retryable => ProviderError::retryable(
operation,
match code {
Some("40P01") => format!("Deadlock detected: {e}"),
Some("28000") | Some("28P01") => {
format!("Authentication error (likely token rotation): {e}")
}
Some("0A000") => format!("Cached plan invalidated: {e}"),
_ => format!("Retryable database error: {e}"),
},
),
SqlStateClass::Permanent => ProviderError::permanent(
operation,
match code {
Some("40001") => format!("Serialization failure: {e}"),
Some("23505") => format!("Duplicate detected: {e}"),
Some("23503") => format!("Foreign key violation: {e}"),
_ => format!("Database error: {e}"),
},
),
}
}
SqlxError::PoolClosed | SqlxError::PoolTimedOut => {
ProviderError::retryable(operation, format!("Connection pool error: {e}"))
}
SqlxError::Io(_) => ProviderError::retryable(operation, format!("I/O error: {e}")),
_ => ProviderError::permanent(operation, format!("Unexpected error: {e}")),
}
}
fn tag_filter_to_sql(filter: &TagFilter) -> (&'static str, Vec<String>) {
match filter {
TagFilter::DefaultOnly => ("default_only", vec![]),
TagFilter::Tags(set) => {
let mut tags: Vec<String> = set.iter().cloned().collect();
tags.sort();
("tags", tags)
}
TagFilter::DefaultAnd(set) => {
let mut tags: Vec<String> = set.iter().cloned().collect();
tags.sort();
("default_and", tags)
}
TagFilter::Any => ("any", vec![]),
TagFilter::None => ("none", vec![]),
}
}
pub async fn cleanup_schema(&self) -> Result<()> {
const MAX_RETRIES: u32 = 5;
const BASE_RETRY_DELAY_MS: u64 = 50;
for attempt in 0..=MAX_RETRIES {
let cleanup_result = async {
sqlx::query(&format!("SELECT {}.cleanup_schema()", self.schema_name))
.execute(&*self.pool)
.await?;
if self.schema_name != "public" {
sqlx::query(&format!(
"DROP SCHEMA IF EXISTS {} CASCADE",
self.schema_name
))
.execute(&*self.pool)
.await?;
} else {
}
Ok::<(), SqlxError>(())
}
.await;
match cleanup_result {
Ok(()) => return Ok(()),
Err(SqlxError::Database(db_err)) if db_err.code().as_deref() == Some("40P01") => {
if attempt < MAX_RETRIES {
warn!(
target = "duroxide::providers::postgres",
schema = %self.schema_name,
attempt = attempt + 1,
"Deadlock during cleanup_schema, retrying"
);
sleep(Duration::from_millis(
BASE_RETRY_DELAY_MS * (attempt as u64 + 1),
))
.await;
continue;
}
return Err(anyhow::anyhow!(db_err.to_string()));
}
Err(e) => return Err(anyhow::anyhow!(e.to_string())),
}
}
Ok(())
}
}
pub(crate) fn build_entra_connect_options(
host: &str,
port: u16,
database: &str,
user: &str,
ssl_mode: PgSslMode,
) -> PgConnectOptions {
PgConnectOptions::new()
.host(host)
.port(port)
.database(database)
.username(user)
.ssl_mode(ssl_mode)
}
const ENTRA_REFRESH_MIN_INTERVAL: Duration = Duration::from_secs(30);
pub(crate) const ENTRA_REFRESH_SAFETY_MARGIN: Duration = Duration::from_secs(5 * 60);
const ENTRA_PANIC_MSG_TRUNCATION_LIMIT: usize = 256;
async fn run_with_panic_guard<Fut, T>(fut: Fut) -> Result<T, String>
where
Fut: std::future::Future<Output = T>,
{
use futures_util::FutureExt;
use std::panic::AssertUnwindSafe;
AssertUnwindSafe(fut).catch_unwind().await.map_err(|panic| {
let raw = if let Some(s) = panic.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
};
truncate_panic_message(raw, ENTRA_PANIC_MSG_TRUNCATION_LIMIT)
})
}
fn truncate_panic_message(s: String, limit: usize) -> String {
if s.len() <= limit {
return s;
}
let mut cut = limit;
while cut > 0 && !s.is_char_boundary(cut) {
cut -= 1;
}
let mut out = String::with_capacity(cut + 16);
out.push_str(&s[..cut]);
out.push_str("…[truncated]");
out
}
fn spawn_token_refresh_task(
pool: Arc<PgPool>,
token_source: Arc<dyn TokenSource>,
base_options: PgConnectOptions,
audience: String,
refresh_interval_ceiling: Duration,
initial_expires_at: SystemTime,
) -> AbortHandle {
let handle = tokio::spawn(async move {
let mut next_expires_at = initial_expires_at;
let mut sleep_duration = compute_next_refresh_sleep(
refresh_interval_ceiling,
next_expires_at,
SystemTime::now(),
);
loop {
debug!(
target: "duroxide::providers::postgres",
sleep_secs = sleep_duration.as_secs(),
"Entra refresh task sleeping",
);
sleep(sleep_duration).await;
let result = run_with_panic_guard(refresh_loop_iteration(
&pool,
token_source.as_ref(),
&base_options,
&audience,
&mut next_expires_at,
))
.await;
if let Err(panic_msg) = &result {
error!(
target: "duroxide::providers::postgres",
panic = %panic_msg,
"Entra refresh task body panicked; continuing with bounded backoff",
);
}
sleep_duration = next_sleep_after_iteration(
&result,
refresh_interval_ceiling,
next_expires_at,
SystemTime::now(),
);
}
});
handle.abort_handle()
}
fn next_sleep_after_iteration(
result: &Result<Result<(), ()>, String>,
refresh_interval_ceiling: Duration,
next_expires_at: SystemTime,
now: SystemTime,
) -> Duration {
match result {
Ok(Ok(())) => compute_next_refresh_sleep(refresh_interval_ceiling, next_expires_at, now),
Ok(Err(())) | Err(_) => ENTRA_REFRESH_MIN_INTERVAL,
}
}
async fn refresh_loop_iteration(
pool: &Arc<PgPool>,
token_source: &dyn TokenSource,
base_options: &PgConnectOptions,
audience: &str,
next_expires_at: &mut SystemTime,
) -> Result<(), ()> {
match token_source.fetch_token(&[audience]).await {
Ok(token) => {
let new_options = base_options.clone().password(&token.secret);
pool.set_connect_options(new_options);
*next_expires_at = token.expires_at;
debug!(
target: "duroxide::providers::postgres",
"Entra token refreshed and applied to pool",
);
Ok(())
}
Err(e) => {
warn!(
target: "duroxide::providers::postgres",
error = %e,
"Entra token refresh failed; will retry after bounded backoff",
);
Err(())
}
}
}
fn compute_next_refresh_sleep(
ceiling: Duration,
expires_at: SystemTime,
now: SystemTime,
) -> Duration {
let until_expiry = expires_at.duration_since(now).unwrap_or(Duration::ZERO);
let expiry_driven = until_expiry
.checked_sub(ENTRA_REFRESH_SAFETY_MARGIN)
.unwrap_or(Duration::ZERO);
let expiry_driven = expiry_driven.max(ENTRA_REFRESH_MIN_INTERVAL);
ceiling.min(expiry_driven).max(ENTRA_REFRESH_MIN_INTERVAL)
}
#[async_trait::async_trait]
impl Provider for PostgresProvider {
fn name(&self) -> &str {
"duroxide-pg"
}
fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn fetch_orchestration_item(
&self,
lock_timeout: Duration,
_poll_timeout: Duration,
filter: Option<&DispatcherCapabilityFilter>,
) -> Result<Option<(OrchestrationItem, String, u32)>, ProviderError> {
let start = std::time::Instant::now();
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 50;
let lock_timeout_ms = lock_timeout.as_millis() as i64;
let mut _last_error: Option<ProviderError> = None;
let (min_packed, max_packed) = if let Some(f) = filter {
if let Some(range) = f.supported_duroxide_versions.first() {
let min = range.min.major as i64 * 1_000_000
+ range.min.minor as i64 * 1_000
+ range.min.patch as i64;
let max = range.max.major as i64 * 1_000_000
+ range.max.minor as i64 * 1_000
+ range.max.patch as i64;
(Some(min), Some(max))
} else {
return Ok(None);
}
} else {
(None, None)
};
for attempt in 0..=MAX_RETRIES {
let now_ms = Self::now_millis();
let result: Result<
Option<(
String,
String,
String,
i64,
serde_json::Value,
serde_json::Value,
String,
i32,
serde_json::Value,
)>,
SqlxError,
> = sqlx::query_as(&format!(
"SELECT * FROM {}.fetch_orchestration_item($1, $2, $3, $4)",
self.schema_name
))
.bind(now_ms)
.bind(lock_timeout_ms)
.bind(min_packed)
.bind(max_packed)
.fetch_optional(&*self.pool)
.await;
let row = match result {
Ok(r) => r,
Err(e) => {
let provider_err = self.sqlx_to_provider_error("fetch_orchestration_item", e);
if provider_err.is_retryable() && attempt < MAX_RETRIES {
warn!(
target = "duroxide::providers::postgres",
operation = "fetch_orchestration_item",
attempt = attempt + 1,
error = %provider_err,
"Retryable error, will retry"
);
_last_error = Some(provider_err);
sleep(std::time::Duration::from_millis(
RETRY_DELAY_MS * (attempt as u64 + 1),
))
.await;
continue;
}
return Err(provider_err);
}
};
if let Some((
instance_id,
orchestration_name,
orchestration_version,
execution_id,
history_json,
messages_json,
lock_token,
attempt_count,
kv_snapshot_json,
)) = row
{
let (history, history_error) =
match serde_json::from_value::<Vec<Event>>(history_json) {
Ok(h) => (h, None),
Err(e) => {
let error_msg = format!("Failed to deserialize history: {e}");
warn!(
target = "duroxide::providers::postgres",
instance = %instance_id,
error = %error_msg,
"History deserialization failed, returning item with history_error"
);
(vec![], Some(error_msg))
}
};
let messages: Vec<WorkItem> =
serde_json::from_value(messages_json).map_err(|e| {
ProviderError::permanent(
"fetch_orchestration_item",
format!("Failed to deserialize messages: {e}"),
)
})?;
let kv_snapshot: std::collections::HashMap<String, duroxide::providers::KvEntry> = {
let raw: std::collections::HashMap<String, serde_json::Value> =
serde_json::from_value(kv_snapshot_json).unwrap_or_default();
raw.into_iter()
.filter_map(|(k, v)| {
let value = v.get("value")?.as_str()?.to_string();
let last_updated_at_ms =
v.get("last_updated_at_ms")?.as_u64().unwrap_or(0);
Some((
k,
duroxide::providers::KvEntry {
value,
last_updated_at_ms,
},
))
})
.collect()
};
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "fetch_orchestration_item",
instance_id = %instance_id,
execution_id = execution_id,
message_count = messages.len(),
history_count = history.len(),
attempt_count = attempt_count,
duration_ms = duration_ms,
attempts = attempt + 1,
"Fetched orchestration item via stored procedure"
);
if orchestration_name == "Unknown"
&& history.is_empty()
&& messages
.iter()
.all(|m| matches!(m, WorkItem::QueueMessage { .. }))
{
let message_count = messages.len();
tracing::warn!(
target = "duroxide::providers::postgres",
instance = %instance_id,
message_count,
"Dropping orphan queue messages — events enqueued before orchestration started are not supported"
);
self.ack_orchestration_item(
&lock_token,
execution_id as u64,
vec![],
vec![],
vec![],
ExecutionMetadata::default(),
vec![],
)
.await?;
return Ok(None);
}
return Ok(Some((
OrchestrationItem {
instance: instance_id,
orchestration_name,
execution_id: execution_id as u64,
version: orchestration_version,
history,
messages,
history_error,
kv_snapshot,
},
lock_token,
attempt_count as u32,
)));
}
return Ok(None);
}
Ok(None)
}
#[instrument(skip(self), fields(lock_token = %lock_token, execution_id = execution_id), target = "duroxide::providers::postgres")]
async fn ack_orchestration_item(
&self,
lock_token: &str,
execution_id: u64,
history_delta: Vec<Event>,
worker_items: Vec<WorkItem>,
orchestrator_items: Vec<WorkItem>,
metadata: ExecutionMetadata,
cancelled_activities: Vec<ScheduledActivityIdentifier>,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
const MAX_RETRIES: u32 = 3;
const RETRY_DELAY_MS: u64 = 50;
let mut history_delta_payload = Vec::with_capacity(history_delta.len());
for event in &history_delta {
if event.event_id() == 0 {
return Err(ProviderError::permanent(
"ack_orchestration_item",
"event_id must be set by runtime",
));
}
let event_json = serde_json::to_string(event).map_err(|e| {
ProviderError::permanent(
"ack_orchestration_item",
format!("Failed to serialize event: {e}"),
)
})?;
let event_type = format!("{event:?}")
.split('{')
.next()
.unwrap_or("Unknown")
.trim()
.to_string();
history_delta_payload.push(serde_json::json!({
"event_id": event.event_id(),
"event_type": event_type,
"event_data": event_json,
}));
}
let history_delta_json = serde_json::Value::Array(history_delta_payload);
let worker_items_json = serde_json::to_value(&worker_items).map_err(|e| {
ProviderError::permanent(
"ack_orchestration_item",
format!("Failed to serialize worker items: {e}"),
)
})?;
let orchestrator_items_json = serde_json::to_value(&orchestrator_items).map_err(|e| {
ProviderError::permanent(
"ack_orchestration_item",
format!("Failed to serialize orchestrator items: {e}"),
)
})?;
let (custom_status_action, custom_status_value): (Option<&str>, Option<&str>) = {
let mut last_status: Option<&Option<String>> = None;
for event in &history_delta {
if let EventKind::CustomStatusUpdated { ref status } = event.kind {
last_status = Some(status);
}
}
match last_status {
Some(Some(s)) => (Some("set"), Some(s.as_str())),
Some(None) => (Some("clear"), None),
None => (None, None),
}
};
let kv_mutations: Vec<serde_json::Value> = history_delta
.iter()
.filter_map(|event| match &event.kind {
EventKind::KeyValueSet {
key,
value,
last_updated_at_ms,
} => Some(serde_json::json!({
"action": "set",
"key": key,
"value": value,
"last_updated_at_ms": last_updated_at_ms,
})),
EventKind::KeyValueCleared { key } => Some(serde_json::json!({
"action": "clear_key",
"key": key,
})),
EventKind::KeyValuesCleared => Some(serde_json::json!({
"action": "clear_all",
})),
_ => None,
})
.collect();
let metadata_json = serde_json::json!({
"orchestration_name": metadata.orchestration_name,
"orchestration_version": metadata.orchestration_version,
"status": metadata.status,
"output": metadata.output,
"parent_instance_id": metadata.parent_instance_id,
"pinned_duroxide_version": metadata.pinned_duroxide_version.as_ref().map(|v| {
serde_json::json!({
"major": v.major,
"minor": v.minor,
"patch": v.patch,
})
}),
"custom_status_action": custom_status_action,
"custom_status_value": custom_status_value,
"kv_mutations": kv_mutations,
});
let cancelled_activities_json: Vec<serde_json::Value> = cancelled_activities
.iter()
.map(|a| {
serde_json::json!({
"instance": a.instance,
"execution_id": a.execution_id,
"activity_id": a.activity_id,
})
})
.collect();
let cancelled_activities_json = serde_json::Value::Array(cancelled_activities_json);
for attempt in 0..=MAX_RETRIES {
let now_ms = Self::now_millis();
let result = sqlx::query(&format!(
"SELECT {}.ack_orchestration_item($1, $2, $3, $4, $5, $6, $7, $8)",
self.schema_name
))
.bind(lock_token)
.bind(now_ms)
.bind(execution_id as i64)
.bind(&history_delta_json)
.bind(&worker_items_json)
.bind(&orchestrator_items_json)
.bind(&metadata_json)
.bind(&cancelled_activities_json)
.execute(&*self.pool)
.await;
match result {
Ok(_) => {
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "ack_orchestration_item",
execution_id = execution_id,
history_count = history_delta.len(),
worker_items_count = worker_items.len(),
orchestrator_items_count = orchestrator_items.len(),
cancelled_activities_count = cancelled_activities.len(),
duration_ms = duration_ms,
attempts = attempt + 1,
"Acknowledged orchestration item via stored procedure"
);
return Ok(());
}
Err(e) => {
if let SqlxError::Database(db_err) = &e {
if db_err.message().contains("Invalid lock token") {
return Err(ProviderError::permanent(
"ack_orchestration_item",
"Invalid lock token",
));
}
} else if e.to_string().contains("Invalid lock token") {
return Err(ProviderError::permanent(
"ack_orchestration_item",
"Invalid lock token",
));
}
let provider_err = self.sqlx_to_provider_error("ack_orchestration_item", e);
if provider_err.is_retryable() && attempt < MAX_RETRIES {
warn!(
target = "duroxide::providers::postgres",
operation = "ack_orchestration_item",
attempt = attempt + 1,
error = %provider_err,
"Retryable error, will retry"
);
sleep(std::time::Duration::from_millis(
RETRY_DELAY_MS * (attempt as u64 + 1),
))
.await;
continue;
}
return Err(provider_err);
}
}
}
Ok(())
}
#[instrument(skip(self), fields(lock_token = %lock_token), target = "duroxide::providers::postgres")]
async fn abandon_orchestration_item(
&self,
lock_token: &str,
delay: Option<Duration>,
ignore_attempt: bool,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
let now_ms = Self::now_millis();
let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
let instance_id = match sqlx::query_scalar::<_, String>(&format!(
"SELECT {}.abandon_orchestration_item($1, $2, $3, $4)",
self.schema_name
))
.bind(lock_token)
.bind(now_ms)
.bind(delay_param)
.bind(ignore_attempt)
.fetch_one(&*self.pool)
.await
{
Ok(instance_id) => instance_id,
Err(e) => {
if let SqlxError::Database(db_err) = &e {
if db_err.message().contains("Invalid lock token") {
return Err(ProviderError::permanent(
"abandon_orchestration_item",
"Invalid lock token",
));
}
} else if e.to_string().contains("Invalid lock token") {
return Err(ProviderError::permanent(
"abandon_orchestration_item",
"Invalid lock token",
));
}
return Err(self.sqlx_to_provider_error("abandon_orchestration_item", e));
}
};
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "abandon_orchestration_item",
instance_id = %instance_id,
delay_ms = delay.map(|d| d.as_millis() as u64),
ignore_attempt = ignore_attempt,
duration_ms = duration_ms,
"Abandoned orchestration item via stored procedure"
);
Ok(())
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn read(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
"SELECT out_event_data FROM {}.fetch_history($1)",
self.schema_name
))
.bind(instance)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("read", e))?;
event_data_rows
.into_iter()
.map(|event_data| {
serde_json::from_str::<Event>(&event_data).map_err(|e| {
ProviderError::permanent("read", format!("Failed to deserialize event: {e}"))
})
})
.collect()
}
#[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
async fn append_with_execution(
&self,
instance: &str,
execution_id: u64,
new_events: Vec<Event>,
) -> Result<(), ProviderError> {
if new_events.is_empty() {
return Ok(());
}
let mut events_payload = Vec::with_capacity(new_events.len());
for event in &new_events {
if event.event_id() == 0 {
error!(
target = "duroxide::providers::postgres",
operation = "append_with_execution",
error_type = "validation_error",
instance_id = %instance,
execution_id = execution_id,
"event_id must be set by runtime"
);
return Err(ProviderError::permanent(
"append_with_execution",
"event_id must be set by runtime",
));
}
let event_json = serde_json::to_string(event).map_err(|e| {
ProviderError::permanent(
"append_with_execution",
format!("Failed to serialize event: {e}"),
)
})?;
let event_type = format!("{event:?}")
.split('{')
.next()
.unwrap_or("Unknown")
.trim()
.to_string();
events_payload.push(serde_json::json!({
"event_id": event.event_id(),
"event_type": event_type,
"event_data": event_json,
}));
}
let events_json = serde_json::Value::Array(events_payload);
sqlx::query(&format!(
"SELECT {}.append_history($1, $2, $3)",
self.schema_name
))
.bind(instance)
.bind(execution_id as i64)
.bind(events_json)
.execute(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("append_with_execution", e))?;
debug!(
target = "duroxide::providers::postgres",
operation = "append_with_execution",
instance_id = %instance,
execution_id = execution_id,
event_count = new_events.len(),
"Appended history events via stored procedure"
);
Ok(())
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn enqueue_for_worker(&self, item: WorkItem) -> Result<(), ProviderError> {
let work_item = serde_json::to_string(&item).map_err(|e| {
ProviderError::permanent(
"enqueue_worker_work",
format!("Failed to serialize work item: {e}"),
)
})?;
let now_ms = Self::now_millis();
let (instance_id, execution_id, activity_id, session_id, tag) = match &item {
WorkItem::ActivityExecute {
instance,
execution_id,
id,
session_id,
tag,
..
} => (
Some(instance.clone()),
Some(*execution_id as i64),
Some(*id as i64),
session_id.clone(),
tag.clone(),
),
_ => (None, None, None, None, None),
};
sqlx::query(&format!(
"SELECT {}.enqueue_worker_work($1, $2, $3, $4, $5, $6, $7)",
self.schema_name
))
.bind(work_item)
.bind(now_ms)
.bind(&instance_id)
.bind(execution_id)
.bind(activity_id)
.bind(&session_id)
.bind(&tag)
.execute(&*self.pool)
.await
.map_err(|e| {
error!(
target = "duroxide::providers::postgres",
operation = "enqueue_worker_work",
error_type = "database_error",
error = %e,
"Failed to enqueue worker work"
);
self.sqlx_to_provider_error("enqueue_worker_work", e)
})?;
Ok(())
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn fetch_work_item(
&self,
lock_timeout: Duration,
_poll_timeout: Duration,
session: Option<&SessionFetchConfig>,
tag_filter: &TagFilter,
) -> Result<Option<(WorkItem, String, u32)>, ProviderError> {
if matches!(tag_filter, TagFilter::None) {
return Ok(None);
}
let start = std::time::Instant::now();
let lock_timeout_ms = lock_timeout.as_millis() as i64;
let (owner_id, session_lock_timeout_ms): (Option<&str>, Option<i64>) = match session {
Some(config) => (
Some(&config.owner_id),
Some(config.lock_timeout.as_millis() as i64),
),
None => (None, None),
};
let (tag_mode, tag_names) = Self::tag_filter_to_sql(tag_filter);
let row = match sqlx::query_as::<_, (String, String, i32)>(&format!(
"SELECT * FROM {}.fetch_work_item($1, $2, $3, $4, $5, $6)",
self.schema_name
))
.bind(Self::now_millis())
.bind(lock_timeout_ms)
.bind(owner_id)
.bind(session_lock_timeout_ms)
.bind(&tag_names)
.bind(tag_mode)
.fetch_optional(&*self.pool)
.await
{
Ok(row) => row,
Err(e) => {
return Err(self.sqlx_to_provider_error("fetch_work_item", e));
}
};
let (work_item_json, lock_token, attempt_count) = match row {
Some(row) => row,
None => return Ok(None),
};
let work_item: WorkItem = serde_json::from_str(&work_item_json).map_err(|e| {
ProviderError::permanent(
"fetch_work_item",
format!("Failed to deserialize worker item: {e}"),
)
})?;
let duration_ms = start.elapsed().as_millis() as u64;
let instance_id = match &work_item {
WorkItem::ActivityExecute { instance, .. } => instance.as_str(),
WorkItem::ActivityCompleted { instance, .. } => instance.as_str(),
WorkItem::ActivityFailed { instance, .. } => instance.as_str(),
WorkItem::StartOrchestration { instance, .. } => instance.as_str(),
WorkItem::TimerFired { instance, .. } => instance.as_str(),
WorkItem::ExternalRaised { instance, .. } => instance.as_str(),
WorkItem::CancelInstance { instance, .. } => instance.as_str(),
WorkItem::ContinueAsNew { instance, .. } => instance.as_str(),
WorkItem::SubOrchCompleted {
parent_instance, ..
} => parent_instance.as_str(),
WorkItem::SubOrchFailed {
parent_instance, ..
} => parent_instance.as_str(),
WorkItem::QueueMessage { instance, .. } => instance.as_str(),
};
debug!(
target = "duroxide::providers::postgres",
operation = "fetch_work_item",
instance_id = %instance_id,
attempt_count = attempt_count,
duration_ms = duration_ms,
"Fetched activity work item via stored procedure"
);
Ok(Some((work_item, lock_token, attempt_count as u32)))
}
#[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
async fn ack_work_item(
&self,
token: &str,
completion: Option<WorkItem>,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
let Some(completion) = completion else {
let now_ms = Self::now_millis();
sqlx::query(&format!(
"SELECT {}.ack_worker($1, NULL, NULL, $2)",
self.schema_name
))
.bind(token)
.bind(now_ms)
.execute(&*self.pool)
.await
.map_err(|e| {
if e.to_string().contains("Worker queue item not found") {
ProviderError::permanent(
"ack_worker",
"Worker queue item not found or already processed",
)
} else {
self.sqlx_to_provider_error("ack_worker", e)
}
})?;
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "ack_worker",
token = %token,
duration_ms = duration_ms,
"Acknowledged worker without completion (cancelled)"
);
return Ok(());
};
let instance_id = match &completion {
WorkItem::ActivityCompleted { instance, .. }
| WorkItem::ActivityFailed { instance, .. } => instance,
_ => {
error!(
target = "duroxide::providers::postgres",
operation = "ack_worker",
error_type = "invalid_completion_type",
"Invalid completion work item type"
);
return Err(ProviderError::permanent(
"ack_worker",
"Invalid completion work item type",
));
}
};
let completion_json = serde_json::to_string(&completion).map_err(|e| {
ProviderError::permanent("ack_worker", format!("Failed to serialize completion: {e}"))
})?;
let now_ms = Self::now_millis();
sqlx::query(&format!(
"SELECT {}.ack_worker($1, $2, $3, $4)",
self.schema_name
))
.bind(token)
.bind(instance_id)
.bind(completion_json)
.bind(now_ms)
.execute(&*self.pool)
.await
.map_err(|e| {
if e.to_string().contains("Worker queue item not found") {
error!(
target = "duroxide::providers::postgres",
operation = "ack_worker",
error_type = "worker_item_not_found",
token = %token,
"Worker queue item not found or already processed"
);
ProviderError::permanent(
"ack_worker",
"Worker queue item not found or already processed",
)
} else {
self.sqlx_to_provider_error("ack_worker", e)
}
})?;
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "ack_worker",
instance_id = %instance_id,
duration_ms = duration_ms,
"Acknowledged worker and enqueued completion"
);
Ok(())
}
#[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
async fn renew_work_item_lock(
&self,
token: &str,
extend_for: Duration,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
let now_ms = Self::now_millis();
let extend_secs = extend_for.as_secs() as i64;
match sqlx::query(&format!(
"SELECT {}.renew_work_item_lock($1, $2, $3)",
self.schema_name
))
.bind(token)
.bind(now_ms)
.bind(extend_secs)
.execute(&*self.pool)
.await
{
Ok(_) => {
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "renew_work_item_lock",
token = %token,
extend_for_secs = extend_secs,
duration_ms = duration_ms,
"Work item lock renewed successfully"
);
Ok(())
}
Err(e) => {
if let SqlxError::Database(db_err) = &e {
if db_err.message().contains("Lock token invalid") {
return Err(ProviderError::permanent(
"renew_work_item_lock",
"Lock token invalid, expired, or already acked",
));
}
} else if e.to_string().contains("Lock token invalid") {
return Err(ProviderError::permanent(
"renew_work_item_lock",
"Lock token invalid, expired, or already acked",
));
}
Err(self.sqlx_to_provider_error("renew_work_item_lock", e))
}
}
}
#[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
async fn abandon_work_item(
&self,
token: &str,
delay: Option<Duration>,
ignore_attempt: bool,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
let now_ms = Self::now_millis();
let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
match sqlx::query(&format!(
"SELECT {}.abandon_work_item($1, $2, $3, $4)",
self.schema_name
))
.bind(token)
.bind(now_ms)
.bind(delay_param)
.bind(ignore_attempt)
.execute(&*self.pool)
.await
{
Ok(_) => {
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "abandon_work_item",
token = %token,
delay_ms = delay.map(|d| d.as_millis() as u64),
ignore_attempt = ignore_attempt,
duration_ms = duration_ms,
"Abandoned work item via stored procedure"
);
Ok(())
}
Err(e) => {
if let SqlxError::Database(db_err) = &e {
if db_err.message().contains("Invalid lock token")
|| db_err.message().contains("already acked")
{
return Err(ProviderError::permanent(
"abandon_work_item",
"Invalid lock token or already acked",
));
}
} else if e.to_string().contains("Invalid lock token")
|| e.to_string().contains("already acked")
{
return Err(ProviderError::permanent(
"abandon_work_item",
"Invalid lock token or already acked",
));
}
Err(self.sqlx_to_provider_error("abandon_work_item", e))
}
}
}
#[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
async fn renew_orchestration_item_lock(
&self,
token: &str,
extend_for: Duration,
) -> Result<(), ProviderError> {
let start = std::time::Instant::now();
let now_ms = Self::now_millis();
let extend_secs = extend_for.as_secs() as i64;
match sqlx::query(&format!(
"SELECT {}.renew_orchestration_item_lock($1, $2, $3)",
self.schema_name
))
.bind(token)
.bind(now_ms)
.bind(extend_secs)
.execute(&*self.pool)
.await
{
Ok(_) => {
let duration_ms = start.elapsed().as_millis() as u64;
debug!(
target = "duroxide::providers::postgres",
operation = "renew_orchestration_item_lock",
token = %token,
extend_for_secs = extend_secs,
duration_ms = duration_ms,
"Orchestration item lock renewed successfully"
);
Ok(())
}
Err(e) => {
if let SqlxError::Database(db_err) = &e {
if db_err.message().contains("Lock token invalid")
|| db_err.message().contains("expired")
|| db_err.message().contains("already released")
{
return Err(ProviderError::permanent(
"renew_orchestration_item_lock",
"Lock token invalid, expired, or already released",
));
}
} else if e.to_string().contains("Lock token invalid")
|| e.to_string().contains("expired")
|| e.to_string().contains("already released")
{
return Err(ProviderError::permanent(
"renew_orchestration_item_lock",
"Lock token invalid, expired, or already released",
));
}
Err(self.sqlx_to_provider_error("renew_orchestration_item_lock", e))
}
}
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn enqueue_for_orchestrator(
&self,
item: WorkItem,
delay: Option<Duration>,
) -> Result<(), ProviderError> {
let work_item = serde_json::to_string(&item).map_err(|e| {
ProviderError::permanent(
"enqueue_orchestrator_work",
format!("Failed to serialize work item: {e}"),
)
})?;
let instance_id = match &item {
WorkItem::StartOrchestration { instance, .. }
| WorkItem::ActivityCompleted { instance, .. }
| WorkItem::ActivityFailed { instance, .. }
| WorkItem::TimerFired { instance, .. }
| WorkItem::ExternalRaised { instance, .. }
| WorkItem::CancelInstance { instance, .. }
| WorkItem::ContinueAsNew { instance, .. }
| WorkItem::QueueMessage { instance, .. } => instance,
WorkItem::SubOrchCompleted {
parent_instance, ..
}
| WorkItem::SubOrchFailed {
parent_instance, ..
} => parent_instance,
WorkItem::ActivityExecute { .. } => {
return Err(ProviderError::permanent(
"enqueue_orchestrator_work",
"ActivityExecute should go to worker queue, not orchestrator queue",
));
}
};
let now_ms = Self::now_millis();
let visible_at_ms = if let WorkItem::TimerFired { fire_at_ms, .. } = &item {
if *fire_at_ms > 0 {
if let Some(delay) = delay {
std::cmp::max(*fire_at_ms, now_ms as u64 + delay.as_millis() as u64)
} else {
*fire_at_ms
}
} else {
delay
.map(|d| now_ms as u64 + d.as_millis() as u64)
.unwrap_or(now_ms as u64)
}
} else {
delay
.map(|d| now_ms as u64 + d.as_millis() as u64)
.unwrap_or(now_ms as u64)
};
let visible_at = Utc
.timestamp_millis_opt(visible_at_ms as i64)
.single()
.ok_or_else(|| {
ProviderError::permanent(
"enqueue_orchestrator_work",
"Invalid visible_at timestamp",
)
})?;
sqlx::query(&format!(
"SELECT {}.enqueue_orchestrator_work($1, $2, $3, $4, $5, $6)",
self.schema_name
))
.bind(instance_id)
.bind(&work_item)
.bind(visible_at)
.bind::<Option<String>>(None) .bind::<Option<String>>(None) .bind::<Option<i64>>(None) .execute(&*self.pool)
.await
.map_err(|e| {
error!(
target = "duroxide::providers::postgres",
operation = "enqueue_orchestrator_work",
error_type = "database_error",
error = %e,
instance_id = %instance_id,
"Failed to enqueue orchestrator work"
);
self.sqlx_to_provider_error("enqueue_orchestrator_work", e)
})?;
debug!(
target = "duroxide::providers::postgres",
operation = "enqueue_orchestrator_work",
instance_id = %instance_id,
delay_ms = delay.map(|d| d.as_millis() as u64),
"Enqueued orchestrator work"
);
Ok(())
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn read_with_execution(
&self,
instance: &str,
execution_id: u64,
) -> Result<Vec<Event>, ProviderError> {
let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
"SELECT event_data FROM {} WHERE instance_id = $1 AND execution_id = $2 ORDER BY event_id",
self.table_name("history")
))
.bind(instance)
.bind(execution_id as i64)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("read_with_execution", e))?;
event_data_rows
.into_iter()
.map(|event_data| {
serde_json::from_str::<Event>(&event_data).map_err(|e| {
ProviderError::permanent(
"read_with_execution",
format!("Failed to deserialize event: {e}"),
)
})
})
.collect()
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn renew_session_lock(
&self,
owner_ids: &[&str],
extend_for: Duration,
idle_timeout: Duration,
) -> Result<usize, ProviderError> {
if owner_ids.is_empty() {
return Ok(0);
}
let now_ms = Self::now_millis();
let extend_ms = extend_for.as_millis() as i64;
let idle_timeout_ms = idle_timeout.as_millis() as i64;
let owner_ids_vec: Vec<&str> = owner_ids.to_vec();
let result = sqlx::query_scalar::<_, i64>(&format!(
"SELECT {}.renew_session_lock($1, $2, $3, $4)",
self.schema_name
))
.bind(&owner_ids_vec)
.bind(now_ms)
.bind(extend_ms)
.bind(idle_timeout_ms)
.fetch_one(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("renew_session_lock", e))?;
debug!(
target = "duroxide::providers::postgres",
operation = "renew_session_lock",
owner_count = owner_ids.len(),
sessions_renewed = result,
"Session locks renewed"
);
Ok(result as usize)
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn cleanup_orphaned_sessions(
&self,
_idle_timeout: Duration,
) -> Result<usize, ProviderError> {
let now_ms = Self::now_millis();
let result = sqlx::query_scalar::<_, i64>(&format!(
"SELECT {}.cleanup_orphaned_sessions($1)",
self.schema_name
))
.bind(now_ms)
.fetch_one(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("cleanup_orphaned_sessions", e))?;
debug!(
target = "duroxide::providers::postgres",
operation = "cleanup_orphaned_sessions",
sessions_cleaned = result,
"Orphaned sessions cleaned up"
);
Ok(result as usize)
}
fn as_management_capability(&self) -> Option<&dyn ProviderAdmin> {
Some(self)
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn get_custom_status(
&self,
instance: &str,
last_seen_version: u64,
) -> Result<Option<(Option<String>, u64)>, ProviderError> {
let row = sqlx::query_as::<_, (Option<String>, i64)>(&format!(
"SELECT * FROM {}.get_custom_status($1, $2)",
self.schema_name
))
.bind(instance)
.bind(last_seen_version as i64)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_custom_status", e))?;
match row {
Some((custom_status, version)) => Ok(Some((custom_status, version as u64))),
None => Ok(None),
}
}
async fn get_kv_value(
&self,
instance_id: &str,
key: &str,
) -> Result<Option<String>, ProviderError> {
let row: Option<(Option<String>, bool)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_kv_value($1, $2)",
self.schema_name
))
.bind(instance_id)
.bind(key)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_kv_value", e))?;
Ok(row.and_then(|(value, found)| if found { value } else { None }))
}
async fn get_kv_all_values(
&self,
instance_id: &str,
) -> Result<std::collections::HashMap<String, String>, ProviderError> {
let rows: Vec<(String, String)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_kv_all_values($1)",
self.schema_name
))
.bind(instance_id)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_kv_all_values", e))?;
Ok(rows.into_iter().collect())
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn get_instance_stats(
&self,
instance: &str,
) -> Result<Option<SystemStats>, ProviderError> {
let row: Option<(bool, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_instance_stats($1)",
self.schema_name
))
.bind(instance)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_instance_stats", e))?;
match row {
Some((
true,
history_event_count,
history_size_bytes,
queue_pending_count,
kv_user_key_count,
kv_total_value_bytes,
)) => Ok(Some(SystemStats {
history_event_count: history_event_count as u64,
history_size_bytes: history_size_bytes as u64,
queue_pending_count: queue_pending_count as u64,
kv_user_key_count: kv_user_key_count as u64,
kv_total_value_bytes: kv_total_value_bytes as u64,
})),
_ => Ok(None),
}
}
}
#[async_trait::async_trait]
impl ProviderAdmin for PostgresProvider {
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn list_instances(&self) -> Result<Vec<String>, ProviderError> {
sqlx::query_scalar(&format!(
"SELECT instance_id FROM {}.list_instances()",
self.schema_name
))
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("list_instances", e))
}
#[instrument(skip(self), fields(status = %status), target = "duroxide::providers::postgres")]
async fn list_instances_by_status(&self, status: &str) -> Result<Vec<String>, ProviderError> {
sqlx::query_scalar(&format!(
"SELECT instance_id FROM {}.list_instances_by_status($1)",
self.schema_name
))
.bind(status)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("list_instances_by_status", e))
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn list_executions(&self, instance: &str) -> Result<Vec<u64>, ProviderError> {
let execution_ids: Vec<i64> = sqlx::query_scalar(&format!(
"SELECT execution_id FROM {}.list_executions($1)",
self.schema_name
))
.bind(instance)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("list_executions", e))?;
Ok(execution_ids.into_iter().map(|id| id as u64).collect())
}
#[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
async fn read_history_with_execution_id(
&self,
instance: &str,
execution_id: u64,
) -> Result<Vec<Event>, ProviderError> {
let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
"SELECT out_event_data FROM {}.fetch_history_with_execution($1, $2)",
self.schema_name
))
.bind(instance)
.bind(execution_id as i64)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("read_execution", e))?;
event_data_rows
.into_iter()
.map(|event_data| {
serde_json::from_str::<Event>(&event_data).map_err(|e| {
ProviderError::permanent(
"read_history_with_execution_id",
format!("Failed to deserialize event: {e}"),
)
})
})
.collect()
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn read_history(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
let execution_id = self.latest_execution_id(instance).await?;
self.read_history_with_execution_id(instance, execution_id)
.await
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn latest_execution_id(&self, instance: &str) -> Result<u64, ProviderError> {
sqlx::query_scalar(&format!(
"SELECT {}.latest_execution_id($1)",
self.schema_name
))
.bind(instance)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("latest_execution_id", e))?
.map(|id: i64| id as u64)
.ok_or_else(|| ProviderError::permanent("latest_execution_id", "Instance not found"))
}
#[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
async fn get_instance_info(&self, instance: &str) -> Result<InstanceInfo, ProviderError> {
let row: Option<(
String,
String,
String,
i64,
chrono::DateTime<Utc>,
Option<chrono::DateTime<Utc>>,
Option<String>,
Option<String>,
Option<String>,
)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_instance_info($1)",
self.schema_name
))
.bind(instance)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_instance_info", e))?;
let (
instance_id,
orchestration_name,
orchestration_version,
current_execution_id,
created_at,
updated_at,
status,
output,
parent_instance_id,
) =
row.ok_or_else(|| ProviderError::permanent("get_instance_info", "Instance not found"))?;
Ok(InstanceInfo {
instance_id,
orchestration_name,
orchestration_version,
current_execution_id: current_execution_id as u64,
status: status.unwrap_or_else(|| "Running".to_string()),
output,
created_at: created_at.timestamp_millis() as u64,
updated_at: updated_at
.map(|dt| dt.timestamp_millis() as u64)
.unwrap_or(created_at.timestamp_millis() as u64),
parent_instance_id,
})
}
#[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
async fn get_execution_info(
&self,
instance: &str,
execution_id: u64,
) -> Result<ExecutionInfo, ProviderError> {
let row: Option<(
i64,
String,
Option<String>,
chrono::DateTime<Utc>,
Option<chrono::DateTime<Utc>>,
i64,
)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_execution_info($1, $2)",
self.schema_name
))
.bind(instance)
.bind(execution_id as i64)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_execution_info", e))?;
let (exec_id, status, output, started_at, completed_at, event_count) = row
.ok_or_else(|| ProviderError::permanent("get_execution_info", "Execution not found"))?;
Ok(ExecutionInfo {
execution_id: exec_id as u64,
status,
output,
started_at: started_at.timestamp_millis() as u64,
completed_at: completed_at.map(|dt| dt.timestamp_millis() as u64),
event_count: event_count as usize,
})
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn get_system_metrics(&self) -> Result<SystemMetrics, ProviderError> {
let row: Option<(i64, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_system_metrics()",
self.schema_name
))
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_system_metrics", e))?;
let (
total_instances,
total_executions,
running_instances,
completed_instances,
failed_instances,
total_events,
) = row.ok_or_else(|| {
ProviderError::permanent("get_system_metrics", "Failed to get system metrics")
})?;
Ok(SystemMetrics {
total_instances: total_instances as u64,
total_executions: total_executions as u64,
running_instances: running_instances as u64,
completed_instances: completed_instances as u64,
failed_instances: failed_instances as u64,
total_events: total_events as u64,
})
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn get_queue_depths(&self) -> Result<QueueDepths, ProviderError> {
let now_ms = Self::now_millis();
let row: Option<(i64, i64)> = sqlx::query_as(&format!(
"SELECT * FROM {}.get_queue_depths($1)",
self.schema_name
))
.bind(now_ms)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("get_queue_depths", e))?;
let (orchestrator_queue, worker_queue) = row.ok_or_else(|| {
ProviderError::permanent("get_queue_depths", "Failed to get queue depths")
})?;
Ok(QueueDepths {
orchestrator_queue: orchestrator_queue as usize,
worker_queue: worker_queue as usize,
timer_queue: 0, })
}
#[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
async fn list_children(&self, instance_id: &str) -> Result<Vec<String>, ProviderError> {
sqlx::query_scalar(&format!(
"SELECT child_instance_id FROM {}.list_children($1)",
self.schema_name
))
.bind(instance_id)
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("list_children", e))
}
#[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
async fn get_parent_id(&self, instance_id: &str) -> Result<Option<String>, ProviderError> {
let result: Result<Option<String>, _> =
sqlx::query_scalar(&format!("SELECT {}.get_parent_id($1)", self.schema_name))
.bind(instance_id)
.fetch_one(&*self.pool)
.await;
match result {
Ok(parent_id) => Ok(parent_id),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("Instance not found") {
Err(ProviderError::permanent(
"get_parent_id",
format!("Instance not found: {}", instance_id),
))
} else {
Err(self.sqlx_to_provider_error("get_parent_id", e))
}
}
}
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn delete_instances_atomic(
&self,
ids: &[String],
force: bool,
) -> Result<DeleteInstanceResult, ProviderError> {
if ids.is_empty() {
return Ok(DeleteInstanceResult::default());
}
let row: Option<(i64, i64, i64, i64)> = sqlx::query_as(&format!(
"SELECT * FROM {}.delete_instances_atomic($1, $2)",
self.schema_name
))
.bind(ids)
.bind(force)
.fetch_optional(&*self.pool)
.await
.map_err(|e| {
let err_str = e.to_string();
if err_str.contains("is Running") {
ProviderError::permanent("delete_instances_atomic", err_str)
} else if err_str.contains("Orphan detected") {
ProviderError::permanent("delete_instances_atomic", err_str)
} else {
self.sqlx_to_provider_error("delete_instances_atomic", e)
}
})?;
let (instances_deleted, executions_deleted, events_deleted, queue_messages_deleted) =
row.unwrap_or((0, 0, 0, 0));
debug!(
target = "duroxide::providers::postgres",
operation = "delete_instances_atomic",
instances_deleted = instances_deleted,
executions_deleted = executions_deleted,
events_deleted = events_deleted,
queue_messages_deleted = queue_messages_deleted,
"Deleted instances atomically"
);
Ok(DeleteInstanceResult {
instances_deleted: instances_deleted as u64,
executions_deleted: executions_deleted as u64,
events_deleted: events_deleted as u64,
queue_messages_deleted: queue_messages_deleted as u64,
})
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn delete_instance_bulk(
&self,
filter: InstanceFilter,
) -> Result<DeleteInstanceResult, ProviderError> {
let mut sql = format!(
r#"
SELECT i.instance_id
FROM {}.instances i
LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
AND i.current_execution_id = e.execution_id
WHERE i.parent_instance_id IS NULL
AND e.status IN ('Completed', 'Failed', 'ContinuedAsNew')
"#,
self.schema_name, self.schema_name
);
if let Some(ref ids) = filter.instance_ids {
if ids.is_empty() {
return Ok(DeleteInstanceResult::default());
}
let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
sql.push_str(&format!(
" AND i.instance_id IN ({})",
placeholders.join(", ")
));
}
if filter.completed_before.is_some() {
let param_num = filter
.instance_ids
.as_ref()
.map(|ids| ids.len())
.unwrap_or(0)
+ 1;
sql.push_str(&format!(
" AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
param_num
));
}
let limit = filter.limit.unwrap_or(1000);
let limit_param_num = filter
.instance_ids
.as_ref()
.map(|ids| ids.len())
.unwrap_or(0)
+ if filter.completed_before.is_some() {
1
} else {
0
}
+ 1;
sql.push_str(&format!(" LIMIT ${}", limit_param_num));
let mut query = sqlx::query_scalar::<_, String>(&sql);
if let Some(ref ids) = filter.instance_ids {
for id in ids {
query = query.bind(id);
}
}
if let Some(completed_before) = filter.completed_before {
query = query.bind(completed_before as i64);
}
query = query.bind(limit as i64);
let instance_ids: Vec<String> = query
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("delete_instance_bulk", e))?;
if instance_ids.is_empty() {
return Ok(DeleteInstanceResult::default());
}
let mut result = DeleteInstanceResult::default();
for instance_id in &instance_ids {
let tree = self.get_instance_tree(instance_id).await?;
let delete_result = self.delete_instances_atomic(&tree.all_ids, true).await?;
result.instances_deleted += delete_result.instances_deleted;
result.executions_deleted += delete_result.executions_deleted;
result.events_deleted += delete_result.events_deleted;
result.queue_messages_deleted += delete_result.queue_messages_deleted;
}
debug!(
target = "duroxide::providers::postgres",
operation = "delete_instance_bulk",
instances_deleted = result.instances_deleted,
executions_deleted = result.executions_deleted,
events_deleted = result.events_deleted,
queue_messages_deleted = result.queue_messages_deleted,
"Bulk deleted instances"
);
Ok(result)
}
#[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
async fn prune_executions(
&self,
instance_id: &str,
options: PruneOptions,
) -> Result<PruneResult, ProviderError> {
let keep_last: Option<i32> = options.keep_last.map(|v| v as i32);
let completed_before_ms: Option<i64> = options.completed_before.map(|v| v as i64);
let row: Option<(i64, i64, i64)> = sqlx::query_as(&format!(
"SELECT * FROM {}.prune_executions($1, $2, $3)",
self.schema_name
))
.bind(instance_id)
.bind(keep_last)
.bind(completed_before_ms)
.fetch_optional(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("prune_executions", e))?;
let (instances_processed, executions_deleted, events_deleted) = row.unwrap_or((0, 0, 0));
debug!(
target = "duroxide::providers::postgres",
operation = "prune_executions",
instance_id = %instance_id,
instances_processed = instances_processed,
executions_deleted = executions_deleted,
events_deleted = events_deleted,
"Pruned executions"
);
Ok(PruneResult {
instances_processed: instances_processed as u64,
executions_deleted: executions_deleted as u64,
events_deleted: events_deleted as u64,
})
}
#[instrument(skip(self), target = "duroxide::providers::postgres")]
async fn prune_executions_bulk(
&self,
filter: InstanceFilter,
options: PruneOptions,
) -> Result<PruneResult, ProviderError> {
let mut sql = format!(
r#"
SELECT i.instance_id
FROM {}.instances i
LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
AND i.current_execution_id = e.execution_id
WHERE 1=1
"#,
self.schema_name, self.schema_name
);
if let Some(ref ids) = filter.instance_ids {
if ids.is_empty() {
return Ok(PruneResult::default());
}
let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
sql.push_str(&format!(
" AND i.instance_id IN ({})",
placeholders.join(", ")
));
}
if filter.completed_before.is_some() {
let param_num = filter
.instance_ids
.as_ref()
.map(|ids| ids.len())
.unwrap_or(0)
+ 1;
sql.push_str(&format!(
" AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
param_num
));
}
let limit = filter.limit.unwrap_or(1000);
let limit_param_num = filter
.instance_ids
.as_ref()
.map(|ids| ids.len())
.unwrap_or(0)
+ if filter.completed_before.is_some() {
1
} else {
0
}
+ 1;
sql.push_str(&format!(" LIMIT ${}", limit_param_num));
let mut query = sqlx::query_scalar::<_, String>(&sql);
if let Some(ref ids) = filter.instance_ids {
for id in ids {
query = query.bind(id);
}
}
if let Some(completed_before) = filter.completed_before {
query = query.bind(completed_before as i64);
}
query = query.bind(limit as i64);
let instance_ids: Vec<String> = query
.fetch_all(&*self.pool)
.await
.map_err(|e| self.sqlx_to_provider_error("prune_executions_bulk", e))?;
let mut result = PruneResult::default();
for instance_id in &instance_ids {
let single_result = self.prune_executions(instance_id, options.clone()).await?;
result.instances_processed += single_result.instances_processed;
result.executions_deleted += single_result.executions_deleted;
result.events_deleted += single_result.events_deleted;
}
debug!(
target = "duroxide::providers::postgres",
operation = "prune_executions_bulk",
instances_processed = result.instances_processed,
executions_deleted = result.executions_deleted,
events_deleted = result.events_deleted,
"Bulk pruned executions"
);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entra::test_support::{token, RecordingFakeTokenSource};
#[test]
fn build_entra_connect_options_uses_verify_full() {
let opts =
build_entra_connect_options("h.example.com", 5432, "db", "u", PgSslMode::VerifyFull);
assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull));
assert_eq!(opts.get_host(), "h.example.com");
assert_eq!(opts.get_port(), 5432);
assert_eq!(opts.get_database(), Some("db"));
assert_eq!(opts.get_username(), "u");
}
#[test]
fn compute_next_refresh_sleep_is_capped_by_ceiling() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(24 * 3600);
let sleep = compute_next_refresh_sleep(Duration::from_secs(5 * 60), expires, now);
assert_eq!(sleep, Duration::from_secs(5 * 60));
}
#[test]
fn compute_next_refresh_sleep_drives_from_expiry() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(6 * 60);
let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
assert!(sleep <= Duration::from_secs(60), "got {sleep:?}");
assert!(sleep >= ENTRA_REFRESH_MIN_INTERVAL, "got {sleep:?}");
}
#[test]
fn compute_next_refresh_sleep_floors_at_min_interval() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(60); let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
}
#[tokio::test]
async fn recording_token_source_returns_distinct_tokens_in_script_order() {
let fake = RecordingFakeTokenSource::with_tokens(vec![
token("token-A", 3600),
token("token-B", 3600),
token("token-C", 3600),
token("token-D", 3600),
token("token-E", 3600),
token("token-F", 3600),
]);
let token_source: Arc<dyn TokenSource> = fake.clone();
let base_options =
build_entra_connect_options("127.0.0.1", 5432, "db", "u", PgSslMode::VerifyFull);
let pool: Arc<PgPool> = Arc::new(
PgPoolOptions::new()
.max_connections(1)
.connect_lazy_with(base_options.clone().password("placeholder")),
);
let initial_expires_at = SystemTime::now() + Duration::from_secs(3600);
let _ = pool;
let _ = initial_expires_at;
let t1 = token_source.fetch_token(&["aud"]).await.unwrap();
let t2 = token_source.fetch_token(&["aud"]).await.unwrap();
let t3 = token_source.fetch_token(&["aud"]).await.unwrap();
assert_ne!(t1.secret, t2.secret);
assert_ne!(t2.secret, t3.secret);
assert_eq!(fake.call_count(), 3);
}
#[tokio::test]
async fn audience_override_is_passed_to_token_source() {
let fake = RecordingFakeTokenSource::with_tokens(vec![token("t", 3600)]);
let source: Arc<dyn TokenSource> = fake.clone();
let opts =
crate::entra::EntraAuthOptions::new().audience("https://custom.example/.default");
let _t = source.fetch_token(&[opts.audience_str()]).await.unwrap();
let scopes = fake.recorded_scopes();
assert_eq!(scopes.len(), 1);
assert_eq!(
scopes[0],
vec!["https://custom.example/.default".to_string()]
);
}
#[tokio::test]
async fn missing_credential_surfaces_descriptive_error() {
let fake = RecordingFakeTokenSource::always_failing("no credential available");
let source: Arc<dyn TokenSource> = fake;
let result: anyhow::Result<crate::entra::EntraToken> = source.fetch_token(&["aud"]).await;
let err = result.expect_err("should fail");
let msg = format!("{err:#}");
assert!(msg.contains("no credential available"), "got: {msg}");
}
#[test]
fn next_sleep_after_iteration_uses_expiry_schedule_on_success() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(3600);
let result: Result<Result<(), ()>, String> = Ok(Ok(()));
let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
let expected = compute_next_refresh_sleep(Duration::from_secs(20 * 60), expires, now);
assert_eq!(sleep, expected);
assert_eq!(sleep, Duration::from_secs(20 * 60));
}
#[test]
fn next_sleep_after_iteration_returns_min_interval_on_fetch_failure() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(3600);
let result: Result<Result<(), ()>, String> = Ok(Err(()));
let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
}
#[test]
fn next_sleep_after_iteration_returns_min_interval_on_panic() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(3600);
let result: Result<Result<(), ()>, String> = Err("simulated panic".to_string());
let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
}
#[test]
fn compute_next_refresh_sleep_floors_when_ceiling_is_tiny() {
let now = SystemTime::now();
let expires = now + Duration::from_secs(3600);
let sleep = compute_next_refresh_sleep(Duration::from_secs(1), expires, now);
assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
}
#[test]
fn entra_token_debug_redacts_secret() {
use crate::entra::test_support::token;
let t = token("super-secret-bearer-string", 3600);
let debug = format!("{t:?}");
assert!(
!debug.contains("super-secret-bearer-string"),
"leaked: {debug}"
);
assert!(
debug.contains("<redacted>"),
"expected redaction marker: {debug}"
);
}
#[test]
fn classify_pg_sqlstate_gates_28xxx_on_is_entra() {
use crate::provider::{classify_pg_sqlstate, SqlStateClass};
assert_eq!(
classify_pg_sqlstate(Some("28000"), true),
SqlStateClass::Retryable
);
assert_eq!(
classify_pg_sqlstate(Some("28P01"), true),
SqlStateClass::Retryable
);
assert_eq!(
classify_pg_sqlstate(Some("28000"), false),
SqlStateClass::Permanent
);
assert_eq!(
classify_pg_sqlstate(Some("28P01"), false),
SqlStateClass::Permanent
);
assert_eq!(
classify_pg_sqlstate(Some("40P01"), true),
SqlStateClass::Retryable
);
assert_eq!(
classify_pg_sqlstate(Some("40P01"), false),
SqlStateClass::Retryable
);
assert_eq!(
classify_pg_sqlstate(Some("23505"), true),
SqlStateClass::Permanent
);
assert_eq!(
classify_pg_sqlstate(Some("23505"), false),
SqlStateClass::Permanent
);
assert_eq!(
classify_pg_sqlstate(Some("0A000"), true),
SqlStateClass::Retryable
);
assert_eq!(classify_pg_sqlstate(None, true), SqlStateClass::Permanent);
}
#[tokio::test]
async fn run_with_panic_guard_catches_string_panic_and_continues() {
let result: Result<(), String> = run_with_panic_guard(async { panic!("boom") }).await;
let msg = result.expect_err("must catch the panic");
assert!(msg.contains("boom"), "got: {msg}");
}
#[tokio::test]
async fn run_with_panic_guard_returns_ok_when_future_completes() {
let result: Result<i32, String> = run_with_panic_guard(async { 42 }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn run_with_panic_guard_handles_non_string_panic_payload() {
let result: Result<(), String> =
run_with_panic_guard(async { std::panic::panic_any(42_i32) }).await;
let msg = result.expect_err("must catch");
assert!(msg.contains("non-string panic payload"), "got: {msg}");
}
#[test]
fn truncate_panic_message_passes_through_short_input() {
let s = "short message".to_string();
assert_eq!(truncate_panic_message(s.clone(), 256), s);
}
#[test]
fn truncate_panic_message_truncates_long_input_with_marker() {
let raw = "A".repeat(1024);
let out = truncate_panic_message(raw, 256);
assert!(out.starts_with(&"A".repeat(256)));
assert!(out.ends_with("…[truncated]"), "got: {out}");
assert_eq!(out.len(), 256 + "…[truncated]".len());
}
#[test]
fn truncate_panic_message_respects_utf8_char_boundaries() {
let raw = "✨".repeat(100);
let out = truncate_panic_message(raw, 256);
assert!(out.ends_with("…[truncated]"));
}
#[tokio::test]
async fn run_with_panic_guard_truncates_oversized_panic_message() {
let result: Result<(), String> = run_with_panic_guard(async {
panic!("{}", "S".repeat(10_000));
})
.await;
let msg = result.expect_err("must catch");
assert!(
msg.len() < 10_000,
"panic message not truncated: len={}",
msg.len()
);
assert!(
msg.ends_with("…[truncated]"),
"missing truncation marker: {msg}"
);
}
}
#[cfg(test)]
mod entra_pipeline_tests {
use super::*;
use crate::entra::test_support::{token, RecordingFakeTokenSource};
use sqlx::Row;
fn parse_database_url(url: &str) -> Option<(String, u16, String, String, String)> {
let stripped = url
.strip_prefix("postgres://")
.or_else(|| url.strip_prefix("postgresql://"))?;
let (creds, rest) = stripped.split_once('@')?;
let (user, password) = creds.split_once(':')?;
let (hostport, db_with_query) = rest.split_once('/')?;
let (host, port_str) = hostport
.split_once(':')
.map(|(h, p)| (h, p))
.unwrap_or((hostport, "5432"));
let port: u16 = port_str.parse().ok()?;
let db = db_with_query.split('?').next()?;
Some((
host.to_string(),
port,
db.to_string(),
user.to_string(),
password.to_string(),
))
}
fn pg_connection_or_skip() -> Option<(String, u16, String, String, String)> {
dotenvy::dotenv().ok();
let url = match std::env::var("DATABASE_URL") {
Ok(u) => u,
Err(_) => {
eprintln!("DATABASE_URL not set; skipping Entra pipeline integration test");
return None;
}
};
match parse_database_url(&url) {
Some(parts) => Some(parts),
None => {
eprintln!("DATABASE_URL not parseable; skipping: {url}");
None
}
}
}
fn unique_schema() -> String {
let id = uuid::Uuid::new_v4().to_string();
format!("entra_inj_{}", &id[id.len() - 8..])
}
async fn wrong_password_is_rejected(host: &str, port: u16, db: &str, user: &str) -> bool {
let result = PgPoolOptions::new()
.max_connections(1)
.connect_with(
PgConnectOptions::new()
.host(host)
.port(port)
.database(db)
.username(user)
.password("definitely-wrong-password")
.ssl_mode(PgSslMode::Disable),
)
.await;
match result {
Ok(pool) => {
pool.close().await;
false
}
Err(err) => {
let msg = format!("{err:#}");
assert!(
msg.to_lowercase().contains("password")
|| msg.contains("28P01")
|| msg.contains("28000"),
"expected authentication failure, got: {msg}"
);
true
}
}
}
async fn drop_schema(pool: &PgPool, schema: &str) {
let stmt = format!("DROP SCHEMA IF EXISTS \"{schema}\" CASCADE");
if let Err(e) = sqlx::query(&stmt).execute(pool).await {
eprintln!("warning: failed to drop schema {schema}: {e}");
}
}
#[tokio::test]
async fn pipeline_with_injected_token_authenticates_and_runs_migrations() {
let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
return;
};
let token_source: Arc<dyn TokenSource> =
RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
let schema = unique_schema();
let provider = PostgresProvider::new_with_entra_with_token_source(
&host,
port,
&db,
&user,
Some(&schema),
EntraAuthOptions::new(),
token_source,
PgSslMode::Disable,
MigrationPolicy::ApplyAll,
)
.await
.expect("Entra pipeline must succeed against local PG with correct token");
let row = sqlx::query(&format!(
"SELECT to_regclass('{}.instances')::text AS r",
schema
))
.fetch_one(provider.pool())
.await
.expect("smoke query must succeed");
let regclass: Option<String> = row.get("r");
assert!(
regclass.is_some(),
"expected migrations to create {}.instances",
schema
);
drop_schema(provider.pool(), &schema).await;
}
#[tokio::test]
async fn pipeline_with_wrong_token_fails_before_migrations() {
let Some((host, port, db, user, _password)) = pg_connection_or_skip() else {
return;
};
if !wrong_password_is_rejected(&host, port, &db, &user).await {
eprintln!(
"local PostgreSQL accepts wrong passwords; skipping negative Entra pipeline test"
);
return;
}
let token_source: Arc<dyn TokenSource> =
RecordingFakeTokenSource::with_tokens(vec![token("definitely-wrong-password", 3600)]);
let schema = unique_schema();
let result = PostgresProvider::new_with_entra_with_token_source(
&host,
port,
&db,
&user,
Some(&schema),
EntraAuthOptions::new(),
token_source,
PgSslMode::Disable,
MigrationPolicy::ApplyAll,
)
.await;
let err = match result {
Ok(_) => panic!("wrong token must fail pool construction, but provider was built"),
Err(e) => e,
};
let msg = format!("{err:#}");
assert!(
msg.to_lowercase().contains("password")
|| msg.contains("28P01")
|| msg.contains("28000"),
"expected authentication failure, got: {msg}"
);
}
#[tokio::test]
async fn pipeline_default_constructor_path_with_injected_token() {
let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
return;
};
let schema = unique_schema();
let token_source: Arc<dyn TokenSource> =
RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
let provider = PostgresProvider::new_with_entra_with_token_source(
&host,
port,
&db,
&user,
Some(&schema),
EntraAuthOptions::new().refresh_interval(Duration::from_secs(60 * 60)),
token_source,
PgSslMode::Disable,
MigrationPolicy::ApplyAll,
)
.await
.expect("default-constructor variant must succeed");
assert_eq!(provider.schema_name(), schema);
drop_schema(provider.pool(), &schema).await;
}
}