1use anyhow::{Context, Result};
2use chrono::{TimeZone, Utc};
3use duroxide::providers::{
4 DeleteInstanceResult, DispatcherCapabilityFilter, ExecutionInfo, ExecutionMetadata,
5 InstanceFilter, InstanceInfo, OrchestrationItem, Provider, ProviderAdmin, ProviderError,
6 PruneOptions, PruneResult, QueueDepths, ScheduledActivityIdentifier, SessionFetchConfig,
7 SystemMetrics, TagFilter, WorkItem,
8};
9use duroxide::{Event, EventKind, SystemStats};
10use sqlx::postgres::{PgConnectOptions, PgSslMode};
11use sqlx::{postgres::PgPoolOptions, Error as SqlxError, PgPool};
12use std::sync::Arc;
13use std::time::Duration;
14use std::time::{SystemTime, UNIX_EPOCH};
15use tokio::task::AbortHandle;
16use tokio::time::sleep;
17use tracing::{debug, error, instrument, warn};
18
19use crate::entra::{EntraAuthOptions, TokenSource};
20use crate::migrations::MigrationRunner;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub(crate) enum SqlStateClass {
76 Retryable,
77 Permanent,
78}
79
80pub(crate) fn classify_pg_sqlstate(code: Option<&str>, is_entra: bool) -> SqlStateClass {
88 match code {
89 Some("40P01") => SqlStateClass::Retryable, Some("28000") | Some("28P01") if is_entra => SqlStateClass::Retryable, Some("40001") => SqlStateClass::Permanent, Some("23505") => SqlStateClass::Permanent, Some("23503") => SqlStateClass::Permanent, Some("0A000") => SqlStateClass::Retryable, _ => SqlStateClass::Permanent,
96 }
97}
98
99pub struct PostgresProvider {
100 pool: Arc<PgPool>,
101 schema_name: String,
102 is_entra: bool,
107 _refresh_task: Option<AbortOnDropHandle>,
108}
109
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
116pub enum MigrationPolicy {
117 #[default]
122 ApplyAll,
123 VerifyOnly,
131}
132
133#[derive(Debug, Clone)]
151#[non_exhaustive]
152pub struct ProviderConfig {
153 pub connection: ConnectionConfig,
155 pub schema_name: Option<String>,
157 pub migration_policy: MigrationPolicy,
159}
160
161impl ProviderConfig {
162 pub fn url(database_url: impl Into<String>) -> Self {
166 Self {
167 connection: ConnectionConfig::Url(database_url.into()),
168 schema_name: None,
169 migration_policy: MigrationPolicy::default(),
170 }
171 }
172
173 pub fn entra(
178 host: impl Into<String>,
179 port: u16,
180 database: impl Into<String>,
181 user: impl Into<String>,
182 options: EntraAuthOptions,
183 ) -> Self {
184 Self {
185 connection: ConnectionConfig::Entra {
186 host: host.into(),
187 port,
188 database: database.into(),
189 user: user.into(),
190 options,
191 },
192 schema_name: None,
193 migration_policy: MigrationPolicy::default(),
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
203#[non_exhaustive]
204pub enum ConnectionConfig {
205 Url(String),
207 Entra {
209 host: String,
210 port: u16,
211 database: String,
212 user: String,
213 options: EntraAuthOptions,
214 },
215}
216
217fn validate_schema_name(schema_name: &str) -> Result<()> {
232 let mut chars = schema_name.chars();
233 let Some(first) = chars.next() else {
234 anyhow::bail!("Invalid schema_name '': must match [A-Za-z_][A-Za-z0-9_]*");
235 };
236 if !(first == '_' || first.is_ascii_alphabetic()) {
237 anyhow::bail!("Invalid schema_name '{schema_name}': must match [A-Za-z_][A-Za-z0-9_]*");
238 }
239 for ch in chars {
240 if !(ch == '_' || ch.is_ascii_alphanumeric()) {
241 anyhow::bail!("Invalid schema_name '{schema_name}': must match [A-Za-z_][A-Za-z0-9_]*");
242 }
243 }
244 Ok(())
245}
246
247struct AbortOnDropHandle(AbortHandle);
251
252impl Drop for AbortOnDropHandle {
253 fn drop(&mut self) {
254 self.0.abort();
255 }
256}
257
258impl PostgresProvider {
259 pub async fn new(database_url: &str) -> Result<Self> {
264 Self::new_with_config(ProviderConfig::url(database_url)).await
265 }
266
267 pub async fn new_with_schema(database_url: &str, schema_name: Option<&str>) -> Result<Self> {
272 let mut config = ProviderConfig::url(database_url);
273 config.schema_name = schema_name.map(str::to_string);
274 Self::new_with_config(config).await
275 }
276
277 pub async fn new_with_config(config: ProviderConfig) -> Result<Self> {
283 let ProviderConfig {
284 connection,
285 schema_name,
286 migration_policy,
287 } = config;
288
289 if let Some(ref s) = schema_name {
290 validate_schema_name(s)?;
291 }
292
293 match connection {
294 ConnectionConfig::Url(database_url) => {
295 Self::new_from_url(&database_url, schema_name.as_deref(), migration_policy).await
296 }
297 ConnectionConfig::Entra {
298 host,
299 port,
300 database,
301 user,
302 options,
303 } => {
304 let token_source = options.default_token_source().context(
305 "Entra credential resolution failed: could not build the default credential chain",
306 )?;
307 Self::new_with_entra_with_token_source(
308 &host,
309 port,
310 &database,
311 &user,
312 schema_name.as_deref(),
313 options,
314 token_source,
315 PgSslMode::VerifyFull,
316 migration_policy,
317 )
318 .await
319 }
320 }
321 }
322
323 async fn new_from_url(
324 database_url: &str,
325 schema_name: Option<&str>,
326 migration_policy: MigrationPolicy,
327 ) -> Result<Self> {
328 let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
329 .ok()
330 .and_then(|s| s.parse::<u32>().ok())
331 .unwrap_or(10);
332
333 let pool = PgPoolOptions::new()
334 .max_connections(max_connections)
335 .min_connections(1)
336 .acquire_timeout(std::time::Duration::from_secs(30))
337 .connect(database_url)
338 .await?;
339
340 let schema_name = schema_name.unwrap_or("public").to_string();
341
342 let provider = Self {
343 pool: Arc::new(pool),
344 schema_name: schema_name.clone(),
345 is_entra: false,
346 _refresh_task: None,
347 };
348
349 let migration_runner = MigrationRunner::new(provider.pool.clone(), schema_name);
350 match migration_policy {
351 MigrationPolicy::ApplyAll => migration_runner.migrate().await?,
352 MigrationPolicy::VerifyOnly => migration_runner.verify().await?,
353 }
354
355 Ok(provider)
356 }
357
358 #[deprecated(
385 since = "0.1.34",
386 note = "use `PostgresProvider::new_with_config(ProviderConfig::entra(...))` instead"
387 )]
388 pub async fn new_with_entra(
389 host: &str,
390 port: u16,
391 database: &str,
392 user: &str,
393 options: EntraAuthOptions,
394 ) -> Result<Self> {
395 Self::new_with_config(ProviderConfig::entra(host, port, database, user, options)).await
396 }
397
398 #[deprecated(
401 since = "0.1.34",
402 note = "use `PostgresProvider::new_with_config(ProviderConfig::entra(...))` with `schema_name` set instead"
403 )]
404 #[instrument(
405 skip(options),
406 fields(host = %host, port = %port, database = %database, user = %user, schema = ?schema_name),
407 target = "duroxide::providers::postgres",
408 )]
409 pub async fn new_with_schema_and_entra(
410 host: &str,
411 port: u16,
412 database: &str,
413 user: &str,
414 schema_name: Option<&str>,
415 options: EntraAuthOptions,
416 ) -> Result<Self> {
417 let mut config = ProviderConfig::entra(host, port, database, user, options);
418 config.schema_name = schema_name.map(str::to_string);
419 Self::new_with_config(config).await
420 }
421
422 #[allow(clippy::too_many_arguments)]
432 pub(crate) async fn new_with_entra_with_token_source(
433 host: &str,
434 port: u16,
435 database: &str,
436 user: &str,
437 schema_name: Option<&str>,
438 options: EntraAuthOptions,
439 token_source: Arc<dyn TokenSource>,
440 ssl_mode: PgSslMode,
441 migration_policy: MigrationPolicy,
442 ) -> Result<Self> {
443 let audience = options.audience_str().to_string();
444 let token = token_source
445 .fetch_token(&[audience.as_str()])
446 .await
447 .context(
448 "Entra credential resolution failed: could not acquire an initial access token",
449 )?;
450
451 let base_options = build_entra_connect_options(host, port, database, user, ssl_mode);
452
453 let pool = PgPoolOptions::new()
454 .max_connections(options.max_connections_value())
455 .min_connections(1)
456 .acquire_timeout(options.acquire_timeout_value())
457 .connect_with(base_options.clone().password(&token.secret))
458 .await?;
459
460 let pool = Arc::new(pool);
461 let schema_name = schema_name.unwrap_or("public").to_string();
462
463 let migration_runner = MigrationRunner::new(pool.clone(), schema_name.clone());
464 match migration_policy {
465 MigrationPolicy::ApplyAll => migration_runner.migrate().await?,
466 MigrationPolicy::VerifyOnly => migration_runner.verify().await?,
467 }
468
469 let refresh_handle = spawn_token_refresh_task(
470 pool.clone(),
471 token_source,
472 base_options,
473 audience,
474 options.refresh_interval_value(),
475 token.expires_at,
476 );
477
478 Ok(Self {
479 pool,
480 schema_name,
481 is_entra: true,
482 _refresh_task: Some(AbortOnDropHandle(refresh_handle)),
483 })
484 }
485
486 #[deprecated(
487 since = "0.1.34",
488 note = "schema initialization is now run automatically by every constructor; this shim will be removed in a future release"
489 )]
490 #[instrument(skip(self), target = "duroxide::providers::postgres")]
491 pub async fn initialize_schema(&self) -> Result<()> {
492 let migration_runner = MigrationRunner::new(self.pool.clone(), self.schema_name.clone());
493 migration_runner.migrate().await?;
494 Ok(())
495 }
496
497 fn now_millis() -> i64 {
499 SystemTime::now()
500 .duration_since(UNIX_EPOCH)
501 .unwrap()
502 .as_millis() as i64
503 }
504
505 fn table_name(&self, table: &str) -> String {
507 format!("{}.{}", self.schema_name, table)
508 }
509
510 pub fn pool(&self) -> &PgPool {
512 &self.pool
513 }
514
515 pub fn schema_name(&self) -> &str {
517 &self.schema_name
518 }
519
520 fn sqlx_to_provider_error(&self, operation: &str, e: SqlxError) -> ProviderError {
529 match e {
530 SqlxError::Database(ref db_err) => {
531 let code_opt = db_err.code();
532 let code = code_opt.as_deref();
533 match classify_pg_sqlstate(code, self.is_entra) {
534 SqlStateClass::Retryable => ProviderError::retryable(
535 operation,
536 match code {
537 Some("40P01") => format!("Deadlock detected: {e}"),
538 Some("28000") | Some("28P01") => {
539 format!("Authentication error (likely token rotation): {e}")
540 }
541 Some("0A000") => format!("Cached plan invalidated: {e}"),
542 _ => format!("Retryable database error: {e}"),
543 },
544 ),
545 SqlStateClass::Permanent => ProviderError::permanent(
546 operation,
547 match code {
548 Some("40001") => format!("Serialization failure: {e}"),
549 Some("23505") => format!("Duplicate detected: {e}"),
550 Some("23503") => format!("Foreign key violation: {e}"),
551 _ => format!("Database error: {e}"),
552 },
553 ),
554 }
555 }
556 SqlxError::PoolClosed | SqlxError::PoolTimedOut => {
557 ProviderError::retryable(operation, format!("Connection pool error: {e}"))
558 }
559 SqlxError::Io(_) => ProviderError::retryable(operation, format!("I/O error: {e}")),
560 _ => ProviderError::permanent(operation, format!("Unexpected error: {e}")),
561 }
562 }
563
564 fn tag_filter_to_sql(filter: &TagFilter) -> (&'static str, Vec<String>) {
566 match filter {
567 TagFilter::DefaultOnly => ("default_only", vec![]),
568 TagFilter::Tags(set) => {
569 let mut tags: Vec<String> = set.iter().cloned().collect();
570 tags.sort();
571 ("tags", tags)
572 }
573 TagFilter::DefaultAnd(set) => {
574 let mut tags: Vec<String> = set.iter().cloned().collect();
575 tags.sort();
576 ("default_and", tags)
577 }
578 TagFilter::Any => ("any", vec![]),
579 TagFilter::None => ("none", vec![]),
580 }
581 }
582
583 pub async fn cleanup_schema(&self) -> Result<()> {
588 const MAX_RETRIES: u32 = 5;
589 const BASE_RETRY_DELAY_MS: u64 = 50;
590
591 for attempt in 0..=MAX_RETRIES {
592 let cleanup_result = async {
593 sqlx::query(&format!("SELECT {}.cleanup_schema()", self.schema_name))
595 .execute(&*self.pool)
596 .await?;
597
598 if self.schema_name != "public" {
601 sqlx::query(&format!(
602 "DROP SCHEMA IF EXISTS {} CASCADE",
603 self.schema_name
604 ))
605 .execute(&*self.pool)
606 .await?;
607 } else {
608 }
611
612 Ok::<(), SqlxError>(())
613 }
614 .await;
615
616 match cleanup_result {
617 Ok(()) => return Ok(()),
618 Err(SqlxError::Database(db_err)) if db_err.code().as_deref() == Some("40P01") => {
619 if attempt < MAX_RETRIES {
620 warn!(
621 target = "duroxide::providers::postgres",
622 schema = %self.schema_name,
623 attempt = attempt + 1,
624 "Deadlock during cleanup_schema, retrying"
625 );
626 sleep(Duration::from_millis(
627 BASE_RETRY_DELAY_MS * (attempt as u64 + 1),
628 ))
629 .await;
630 continue;
631 }
632 return Err(anyhow::anyhow!(db_err.to_string()));
633 }
634 Err(e) => return Err(anyhow::anyhow!(e.to_string())),
635 }
636 }
637
638 Ok(())
639 }
640}
641
642pub(crate) fn build_entra_connect_options(
651 host: &str,
652 port: u16,
653 database: &str,
654 user: &str,
655 ssl_mode: PgSslMode,
656) -> PgConnectOptions {
657 PgConnectOptions::new()
658 .host(host)
659 .port(port)
660 .database(database)
661 .username(user)
662 .ssl_mode(ssl_mode)
663}
664
665const ENTRA_REFRESH_MIN_INTERVAL: Duration = Duration::from_secs(30);
668
669pub(crate) const ENTRA_REFRESH_SAFETY_MARGIN: Duration = Duration::from_secs(5 * 60);
672
673const ENTRA_PANIC_MSG_TRUNCATION_LIMIT: usize = 256;
678
679async fn run_with_panic_guard<Fut, T>(fut: Fut) -> Result<T, String>
691where
692 Fut: std::future::Future<Output = T>,
693{
694 use futures_util::FutureExt;
695 use std::panic::AssertUnwindSafe;
696
697 AssertUnwindSafe(fut).catch_unwind().await.map_err(|panic| {
698 let raw = if let Some(s) = panic.downcast_ref::<&'static str>() {
699 (*s).to_string()
700 } else if let Some(s) = panic.downcast_ref::<String>() {
701 s.clone()
702 } else {
703 "<non-string panic payload>".to_string()
704 };
705 truncate_panic_message(raw, ENTRA_PANIC_MSG_TRUNCATION_LIMIT)
706 })
707}
708
709fn truncate_panic_message(s: String, limit: usize) -> String {
713 if s.len() <= limit {
714 return s;
715 }
716 let mut cut = limit;
719 while cut > 0 && !s.is_char_boundary(cut) {
720 cut -= 1;
721 }
722 let mut out = String::with_capacity(cut + 16);
723 out.push_str(&s[..cut]);
724 out.push_str("…[truncated]");
725 out
726}
727
728fn spawn_token_refresh_task(
747 pool: Arc<PgPool>,
748 token_source: Arc<dyn TokenSource>,
749 base_options: PgConnectOptions,
750 audience: String,
751 refresh_interval_ceiling: Duration,
752 initial_expires_at: SystemTime,
753) -> AbortHandle {
754 let handle = tokio::spawn(async move {
755 let mut next_expires_at = initial_expires_at;
769 let mut sleep_duration = compute_next_refresh_sleep(
770 refresh_interval_ceiling,
771 next_expires_at,
772 SystemTime::now(),
773 );
774 loop {
775 debug!(
776 target: "duroxide::providers::postgres",
777 sleep_secs = sleep_duration.as_secs(),
778 "Entra refresh task sleeping",
779 );
780 sleep(sleep_duration).await;
781
782 let result = run_with_panic_guard(refresh_loop_iteration(
783 &pool,
784 token_source.as_ref(),
785 &base_options,
786 &audience,
787 &mut next_expires_at,
788 ))
789 .await;
790
791 if let Err(panic_msg) = &result {
792 error!(
793 target: "duroxide::providers::postgres",
794 panic = %panic_msg,
795 "Entra refresh task body panicked; continuing with bounded backoff",
796 );
797 }
798
799 sleep_duration = next_sleep_after_iteration(
800 &result,
801 refresh_interval_ceiling,
802 next_expires_at,
803 SystemTime::now(),
804 );
805 }
806 });
807 handle.abort_handle()
808}
809
810fn next_sleep_after_iteration(
823 result: &Result<Result<(), ()>, String>,
824 refresh_interval_ceiling: Duration,
825 next_expires_at: SystemTime,
826 now: SystemTime,
827) -> Duration {
828 match result {
829 Ok(Ok(())) => compute_next_refresh_sleep(refresh_interval_ceiling, next_expires_at, now),
830 Ok(Err(())) | Err(_) => ENTRA_REFRESH_MIN_INTERVAL,
831 }
832}
833
834async fn refresh_loop_iteration(
842 pool: &Arc<PgPool>,
843 token_source: &dyn TokenSource,
844 base_options: &PgConnectOptions,
845 audience: &str,
846 next_expires_at: &mut SystemTime,
847) -> Result<(), ()> {
848 match token_source.fetch_token(&[audience]).await {
849 Ok(token) => {
850 let new_options = base_options.clone().password(&token.secret);
851 pool.set_connect_options(new_options);
852 *next_expires_at = token.expires_at;
853 debug!(
854 target: "duroxide::providers::postgres",
855 "Entra token refreshed and applied to pool",
856 );
857 Ok(())
858 }
859 Err(e) => {
860 warn!(
861 target: "duroxide::providers::postgres",
862 error = %e,
863 "Entra token refresh failed; will retry after bounded backoff",
864 );
865 Err(())
866 }
867 }
868}
869
870fn compute_next_refresh_sleep(
877 ceiling: Duration,
878 expires_at: SystemTime,
879 now: SystemTime,
880) -> Duration {
881 let until_expiry = expires_at.duration_since(now).unwrap_or(Duration::ZERO);
882
883 let expiry_driven = until_expiry
884 .checked_sub(ENTRA_REFRESH_SAFETY_MARGIN)
885 .unwrap_or(Duration::ZERO);
886
887 let expiry_driven = expiry_driven.max(ENTRA_REFRESH_MIN_INTERVAL);
888
889 ceiling.min(expiry_driven).max(ENTRA_REFRESH_MIN_INTERVAL)
892}
893
894#[async_trait::async_trait]
895impl Provider for PostgresProvider {
896 fn name(&self) -> &str {
897 "duroxide-pg"
898 }
899
900 fn version(&self) -> &str {
901 env!("CARGO_PKG_VERSION")
902 }
903
904 #[instrument(skip(self), target = "duroxide::providers::postgres")]
905 async fn fetch_orchestration_item(
906 &self,
907 lock_timeout: Duration,
908 _poll_timeout: Duration,
909 filter: Option<&DispatcherCapabilityFilter>,
910 ) -> Result<Option<(OrchestrationItem, String, u32)>, ProviderError> {
911 let start = std::time::Instant::now();
912
913 const MAX_RETRIES: u32 = 3;
914 const RETRY_DELAY_MS: u64 = 50;
915
916 let lock_timeout_ms = lock_timeout.as_millis() as i64;
918 let mut _last_error: Option<ProviderError> = None;
919
920 let (min_packed, max_packed) = if let Some(f) = filter {
922 if let Some(range) = f.supported_duroxide_versions.first() {
923 let min = range.min.major as i64 * 1_000_000
924 + range.min.minor as i64 * 1_000
925 + range.min.patch as i64;
926 let max = range.max.major as i64 * 1_000_000
927 + range.max.minor as i64 * 1_000
928 + range.max.patch as i64;
929 (Some(min), Some(max))
930 } else {
931 return Ok(None);
933 }
934 } else {
935 (None, None)
936 };
937
938 for attempt in 0..=MAX_RETRIES {
939 let now_ms = Self::now_millis();
940
941 let result: Result<
942 Option<(
943 String,
944 String,
945 String,
946 i64,
947 serde_json::Value,
948 serde_json::Value,
949 String,
950 i32,
951 serde_json::Value,
952 )>,
953 SqlxError,
954 > = sqlx::query_as(&format!(
955 "SELECT * FROM {}.fetch_orchestration_item($1, $2, $3, $4)",
956 self.schema_name
957 ))
958 .bind(now_ms)
959 .bind(lock_timeout_ms)
960 .bind(min_packed)
961 .bind(max_packed)
962 .fetch_optional(&*self.pool)
963 .await;
964
965 let row = match result {
966 Ok(r) => r,
967 Err(e) => {
968 let provider_err = self.sqlx_to_provider_error("fetch_orchestration_item", e);
969 if provider_err.is_retryable() && attempt < MAX_RETRIES {
970 warn!(
971 target = "duroxide::providers::postgres",
972 operation = "fetch_orchestration_item",
973 attempt = attempt + 1,
974 error = %provider_err,
975 "Retryable error, will retry"
976 );
977 _last_error = Some(provider_err);
978 sleep(std::time::Duration::from_millis(
979 RETRY_DELAY_MS * (attempt as u64 + 1),
980 ))
981 .await;
982 continue;
983 }
984 return Err(provider_err);
985 }
986 };
987
988 if let Some((
989 instance_id,
990 orchestration_name,
991 orchestration_version,
992 execution_id,
993 history_json,
994 messages_json,
995 lock_token,
996 attempt_count,
997 kv_snapshot_json,
998 )) = row
999 {
1000 let (history, history_error) =
1001 match serde_json::from_value::<Vec<Event>>(history_json) {
1002 Ok(h) => (h, None),
1003 Err(e) => {
1004 let error_msg = format!("Failed to deserialize history: {e}");
1005 warn!(
1006 target = "duroxide::providers::postgres",
1007 instance = %instance_id,
1008 error = %error_msg,
1009 "History deserialization failed, returning item with history_error"
1010 );
1011 (vec![], Some(error_msg))
1012 }
1013 };
1014
1015 let messages: Vec<WorkItem> =
1016 serde_json::from_value(messages_json).map_err(|e| {
1017 ProviderError::permanent(
1018 "fetch_orchestration_item",
1019 format!("Failed to deserialize messages: {e}"),
1020 )
1021 })?;
1022 let kv_snapshot: std::collections::HashMap<String, duroxide::providers::KvEntry> = {
1023 let raw: std::collections::HashMap<String, serde_json::Value> =
1024 serde_json::from_value(kv_snapshot_json).unwrap_or_default();
1025 raw.into_iter()
1026 .filter_map(|(k, v)| {
1027 let value = v.get("value")?.as_str()?.to_string();
1028 let last_updated_at_ms =
1029 v.get("last_updated_at_ms")?.as_u64().unwrap_or(0);
1030 Some((
1031 k,
1032 duroxide::providers::KvEntry {
1033 value,
1034 last_updated_at_ms,
1035 },
1036 ))
1037 })
1038 .collect()
1039 };
1040
1041 let duration_ms = start.elapsed().as_millis() as u64;
1042 debug!(
1043 target = "duroxide::providers::postgres",
1044 operation = "fetch_orchestration_item",
1045 instance_id = %instance_id,
1046 execution_id = execution_id,
1047 message_count = messages.len(),
1048 history_count = history.len(),
1049 attempt_count = attempt_count,
1050 duration_ms = duration_ms,
1051 attempts = attempt + 1,
1052 "Fetched orchestration item via stored procedure"
1053 );
1054
1055 if orchestration_name == "Unknown"
1061 && history.is_empty()
1062 && messages
1063 .iter()
1064 .all(|m| matches!(m, WorkItem::QueueMessage { .. }))
1065 {
1066 let message_count = messages.len();
1067 tracing::warn!(
1068 target = "duroxide::providers::postgres",
1069 instance = %instance_id,
1070 message_count,
1071 "Dropping orphan queue messages — events enqueued before orchestration started are not supported"
1072 );
1073 self.ack_orchestration_item(
1074 &lock_token,
1075 execution_id as u64,
1076 vec![],
1077 vec![],
1078 vec![],
1079 ExecutionMetadata::default(),
1080 vec![],
1081 )
1082 .await?;
1083 return Ok(None);
1084 }
1085
1086 return Ok(Some((
1087 OrchestrationItem {
1088 instance: instance_id,
1089 orchestration_name,
1090 execution_id: execution_id as u64,
1091 version: orchestration_version,
1092 history,
1093 messages,
1094 history_error,
1095 kv_snapshot,
1096 },
1097 lock_token,
1098 attempt_count as u32,
1099 )));
1100 }
1101
1102 return Ok(None);
1105 }
1106
1107 Ok(None)
1108 }
1109 #[instrument(skip(self), fields(lock_token = %lock_token, execution_id = execution_id), target = "duroxide::providers::postgres")]
1110 async fn ack_orchestration_item(
1111 &self,
1112 lock_token: &str,
1113 execution_id: u64,
1114 history_delta: Vec<Event>,
1115 worker_items: Vec<WorkItem>,
1116 orchestrator_items: Vec<WorkItem>,
1117 metadata: ExecutionMetadata,
1118 cancelled_activities: Vec<ScheduledActivityIdentifier>,
1119 ) -> Result<(), ProviderError> {
1120 let start = std::time::Instant::now();
1121
1122 const MAX_RETRIES: u32 = 3;
1123 const RETRY_DELAY_MS: u64 = 50;
1124
1125 let mut history_delta_payload = Vec::with_capacity(history_delta.len());
1126 for event in &history_delta {
1127 if event.event_id() == 0 {
1128 return Err(ProviderError::permanent(
1129 "ack_orchestration_item",
1130 "event_id must be set by runtime",
1131 ));
1132 }
1133
1134 let event_json = serde_json::to_string(event).map_err(|e| {
1135 ProviderError::permanent(
1136 "ack_orchestration_item",
1137 format!("Failed to serialize event: {e}"),
1138 )
1139 })?;
1140
1141 let event_type = format!("{event:?}")
1142 .split('{')
1143 .next()
1144 .unwrap_or("Unknown")
1145 .trim()
1146 .to_string();
1147
1148 history_delta_payload.push(serde_json::json!({
1149 "event_id": event.event_id(),
1150 "event_type": event_type,
1151 "event_data": event_json,
1152 }));
1153 }
1154
1155 let history_delta_json = serde_json::Value::Array(history_delta_payload);
1156
1157 let worker_items_json = serde_json::to_value(&worker_items).map_err(|e| {
1158 ProviderError::permanent(
1159 "ack_orchestration_item",
1160 format!("Failed to serialize worker items: {e}"),
1161 )
1162 })?;
1163
1164 let orchestrator_items_json = serde_json::to_value(&orchestrator_items).map_err(|e| {
1165 ProviderError::permanent(
1166 "ack_orchestration_item",
1167 format!("Failed to serialize orchestrator items: {e}"),
1168 )
1169 })?;
1170
1171 let (custom_status_action, custom_status_value): (Option<&str>, Option<&str>) = {
1173 let mut last_status: Option<&Option<String>> = None;
1174 for event in &history_delta {
1175 if let EventKind::CustomStatusUpdated { ref status } = event.kind {
1176 last_status = Some(status);
1177 }
1178 }
1179 match last_status {
1180 Some(Some(s)) => (Some("set"), Some(s.as_str())),
1181 Some(None) => (Some("clear"), None),
1182 None => (None, None),
1183 }
1184 };
1185
1186 let kv_mutations: Vec<serde_json::Value> = history_delta
1187 .iter()
1188 .filter_map(|event| match &event.kind {
1189 EventKind::KeyValueSet {
1190 key,
1191 value,
1192 last_updated_at_ms,
1193 } => Some(serde_json::json!({
1194 "action": "set",
1195 "key": key,
1196 "value": value,
1197 "last_updated_at_ms": last_updated_at_ms,
1198 })),
1199 EventKind::KeyValueCleared { key } => Some(serde_json::json!({
1200 "action": "clear_key",
1201 "key": key,
1202 })),
1203 EventKind::KeyValuesCleared => Some(serde_json::json!({
1204 "action": "clear_all",
1205 })),
1206 _ => None,
1207 })
1208 .collect();
1209
1210 let metadata_json = serde_json::json!({
1211 "orchestration_name": metadata.orchestration_name,
1212 "orchestration_version": metadata.orchestration_version,
1213 "status": metadata.status,
1214 "output": metadata.output,
1215 "parent_instance_id": metadata.parent_instance_id,
1216 "pinned_duroxide_version": metadata.pinned_duroxide_version.as_ref().map(|v| {
1217 serde_json::json!({
1218 "major": v.major,
1219 "minor": v.minor,
1220 "patch": v.patch,
1221 })
1222 }),
1223 "custom_status_action": custom_status_action,
1224 "custom_status_value": custom_status_value,
1225 "kv_mutations": kv_mutations,
1226 });
1227
1228 let cancelled_activities_json: Vec<serde_json::Value> = cancelled_activities
1230 .iter()
1231 .map(|a| {
1232 serde_json::json!({
1233 "instance": a.instance,
1234 "execution_id": a.execution_id,
1235 "activity_id": a.activity_id,
1236 })
1237 })
1238 .collect();
1239 let cancelled_activities_json = serde_json::Value::Array(cancelled_activities_json);
1240
1241 for attempt in 0..=MAX_RETRIES {
1242 let now_ms = Self::now_millis();
1243 let result = sqlx::query(&format!(
1244 "SELECT {}.ack_orchestration_item($1, $2, $3, $4, $5, $6, $7, $8)",
1245 self.schema_name
1246 ))
1247 .bind(lock_token)
1248 .bind(now_ms)
1249 .bind(execution_id as i64)
1250 .bind(&history_delta_json)
1251 .bind(&worker_items_json)
1252 .bind(&orchestrator_items_json)
1253 .bind(&metadata_json)
1254 .bind(&cancelled_activities_json)
1255 .execute(&*self.pool)
1256 .await;
1257
1258 match result {
1259 Ok(_) => {
1260 let duration_ms = start.elapsed().as_millis() as u64;
1261 debug!(
1262 target = "duroxide::providers::postgres",
1263 operation = "ack_orchestration_item",
1264 execution_id = execution_id,
1265 history_count = history_delta.len(),
1266 worker_items_count = worker_items.len(),
1267 orchestrator_items_count = orchestrator_items.len(),
1268 cancelled_activities_count = cancelled_activities.len(),
1269 duration_ms = duration_ms,
1270 attempts = attempt + 1,
1271 "Acknowledged orchestration item via stored procedure"
1272 );
1273 return Ok(());
1274 }
1275 Err(e) => {
1276 if let SqlxError::Database(db_err) = &e {
1278 if db_err.message().contains("Invalid lock token") {
1279 return Err(ProviderError::permanent(
1280 "ack_orchestration_item",
1281 "Invalid lock token",
1282 ));
1283 }
1284 } else if e.to_string().contains("Invalid lock token") {
1285 return Err(ProviderError::permanent(
1286 "ack_orchestration_item",
1287 "Invalid lock token",
1288 ));
1289 }
1290
1291 let provider_err = self.sqlx_to_provider_error("ack_orchestration_item", e);
1292 if provider_err.is_retryable() && attempt < MAX_RETRIES {
1293 warn!(
1294 target = "duroxide::providers::postgres",
1295 operation = "ack_orchestration_item",
1296 attempt = attempt + 1,
1297 error = %provider_err,
1298 "Retryable error, will retry"
1299 );
1300 sleep(std::time::Duration::from_millis(
1301 RETRY_DELAY_MS * (attempt as u64 + 1),
1302 ))
1303 .await;
1304 continue;
1305 }
1306 return Err(provider_err);
1307 }
1308 }
1309 }
1310
1311 Ok(())
1313 }
1314 #[instrument(skip(self), fields(lock_token = %lock_token), target = "duroxide::providers::postgres")]
1315 async fn abandon_orchestration_item(
1316 &self,
1317 lock_token: &str,
1318 delay: Option<Duration>,
1319 ignore_attempt: bool,
1320 ) -> Result<(), ProviderError> {
1321 let start = std::time::Instant::now();
1322 let now_ms = Self::now_millis();
1323 let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
1324
1325 let instance_id = match sqlx::query_scalar::<_, String>(&format!(
1326 "SELECT {}.abandon_orchestration_item($1, $2, $3, $4)",
1327 self.schema_name
1328 ))
1329 .bind(lock_token)
1330 .bind(now_ms)
1331 .bind(delay_param)
1332 .bind(ignore_attempt)
1333 .fetch_one(&*self.pool)
1334 .await
1335 {
1336 Ok(instance_id) => instance_id,
1337 Err(e) => {
1338 if let SqlxError::Database(db_err) = &e {
1339 if db_err.message().contains("Invalid lock token") {
1340 return Err(ProviderError::permanent(
1341 "abandon_orchestration_item",
1342 "Invalid lock token",
1343 ));
1344 }
1345 } else if e.to_string().contains("Invalid lock token") {
1346 return Err(ProviderError::permanent(
1347 "abandon_orchestration_item",
1348 "Invalid lock token",
1349 ));
1350 }
1351
1352 return Err(self.sqlx_to_provider_error("abandon_orchestration_item", e));
1353 }
1354 };
1355
1356 let duration_ms = start.elapsed().as_millis() as u64;
1357 debug!(
1358 target = "duroxide::providers::postgres",
1359 operation = "abandon_orchestration_item",
1360 instance_id = %instance_id,
1361 delay_ms = delay.map(|d| d.as_millis() as u64),
1362 ignore_attempt = ignore_attempt,
1363 duration_ms = duration_ms,
1364 "Abandoned orchestration item via stored procedure"
1365 );
1366
1367 Ok(())
1368 }
1369
1370 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1371 async fn read(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
1372 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
1373 "SELECT out_event_data FROM {}.fetch_history($1)",
1374 self.schema_name
1375 ))
1376 .bind(instance)
1377 .fetch_all(&*self.pool)
1378 .await
1379 .map_err(|e| self.sqlx_to_provider_error("read", e))?;
1380
1381 event_data_rows
1382 .into_iter()
1383 .map(|event_data| {
1384 serde_json::from_str::<Event>(&event_data).map_err(|e| {
1385 ProviderError::permanent("read", format!("Failed to deserialize event: {e}"))
1386 })
1387 })
1388 .collect()
1389 }
1390
1391 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
1392 async fn append_with_execution(
1393 &self,
1394 instance: &str,
1395 execution_id: u64,
1396 new_events: Vec<Event>,
1397 ) -> Result<(), ProviderError> {
1398 if new_events.is_empty() {
1399 return Ok(());
1400 }
1401
1402 let mut events_payload = Vec::with_capacity(new_events.len());
1403 for event in &new_events {
1404 if event.event_id() == 0 {
1405 error!(
1406 target = "duroxide::providers::postgres",
1407 operation = "append_with_execution",
1408 error_type = "validation_error",
1409 instance_id = %instance,
1410 execution_id = execution_id,
1411 "event_id must be set by runtime"
1412 );
1413 return Err(ProviderError::permanent(
1414 "append_with_execution",
1415 "event_id must be set by runtime",
1416 ));
1417 }
1418
1419 let event_json = serde_json::to_string(event).map_err(|e| {
1420 ProviderError::permanent(
1421 "append_with_execution",
1422 format!("Failed to serialize event: {e}"),
1423 )
1424 })?;
1425
1426 let event_type = format!("{event:?}")
1427 .split('{')
1428 .next()
1429 .unwrap_or("Unknown")
1430 .trim()
1431 .to_string();
1432
1433 events_payload.push(serde_json::json!({
1434 "event_id": event.event_id(),
1435 "event_type": event_type,
1436 "event_data": event_json,
1437 }));
1438 }
1439
1440 let events_json = serde_json::Value::Array(events_payload);
1441
1442 sqlx::query(&format!(
1443 "SELECT {}.append_history($1, $2, $3)",
1444 self.schema_name
1445 ))
1446 .bind(instance)
1447 .bind(execution_id as i64)
1448 .bind(events_json)
1449 .execute(&*self.pool)
1450 .await
1451 .map_err(|e| self.sqlx_to_provider_error("append_with_execution", e))?;
1452
1453 debug!(
1454 target = "duroxide::providers::postgres",
1455 operation = "append_with_execution",
1456 instance_id = %instance,
1457 execution_id = execution_id,
1458 event_count = new_events.len(),
1459 "Appended history events via stored procedure"
1460 );
1461
1462 Ok(())
1463 }
1464
1465 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1466 async fn enqueue_for_worker(&self, item: WorkItem) -> Result<(), ProviderError> {
1467 let work_item = serde_json::to_string(&item).map_err(|e| {
1468 ProviderError::permanent(
1469 "enqueue_worker_work",
1470 format!("Failed to serialize work item: {e}"),
1471 )
1472 })?;
1473
1474 let now_ms = Self::now_millis();
1475
1476 let (instance_id, execution_id, activity_id, session_id, tag) = match &item {
1478 WorkItem::ActivityExecute {
1479 instance,
1480 execution_id,
1481 id,
1482 session_id,
1483 tag,
1484 ..
1485 } => (
1486 Some(instance.clone()),
1487 Some(*execution_id as i64),
1488 Some(*id as i64),
1489 session_id.clone(),
1490 tag.clone(),
1491 ),
1492 _ => (None, None, None, None, None),
1493 };
1494
1495 sqlx::query(&format!(
1496 "SELECT {}.enqueue_worker_work($1, $2, $3, $4, $5, $6, $7)",
1497 self.schema_name
1498 ))
1499 .bind(work_item)
1500 .bind(now_ms)
1501 .bind(&instance_id)
1502 .bind(execution_id)
1503 .bind(activity_id)
1504 .bind(&session_id)
1505 .bind(&tag)
1506 .execute(&*self.pool)
1507 .await
1508 .map_err(|e| {
1509 error!(
1510 target = "duroxide::providers::postgres",
1511 operation = "enqueue_worker_work",
1512 error_type = "database_error",
1513 error = %e,
1514 "Failed to enqueue worker work"
1515 );
1516 self.sqlx_to_provider_error("enqueue_worker_work", e)
1517 })?;
1518
1519 Ok(())
1520 }
1521
1522 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1523 async fn fetch_work_item(
1524 &self,
1525 lock_timeout: Duration,
1526 _poll_timeout: Duration,
1527 session: Option<&SessionFetchConfig>,
1528 tag_filter: &TagFilter,
1529 ) -> Result<Option<(WorkItem, String, u32)>, ProviderError> {
1530 if matches!(tag_filter, TagFilter::None) {
1532 return Ok(None);
1533 }
1534
1535 let start = std::time::Instant::now();
1536
1537 let lock_timeout_ms = lock_timeout.as_millis() as i64;
1539
1540 let (owner_id, session_lock_timeout_ms): (Option<&str>, Option<i64>) = match session {
1542 Some(config) => (
1543 Some(&config.owner_id),
1544 Some(config.lock_timeout.as_millis() as i64),
1545 ),
1546 None => (None, None),
1547 };
1548
1549 let (tag_mode, tag_names) = Self::tag_filter_to_sql(tag_filter);
1551
1552 let row = match sqlx::query_as::<_, (String, String, i32)>(&format!(
1553 "SELECT * FROM {}.fetch_work_item($1, $2, $3, $4, $5, $6)",
1554 self.schema_name
1555 ))
1556 .bind(Self::now_millis())
1557 .bind(lock_timeout_ms)
1558 .bind(owner_id)
1559 .bind(session_lock_timeout_ms)
1560 .bind(&tag_names)
1561 .bind(tag_mode)
1562 .fetch_optional(&*self.pool)
1563 .await
1564 {
1565 Ok(row) => row,
1566 Err(e) => {
1567 return Err(self.sqlx_to_provider_error("fetch_work_item", e));
1568 }
1569 };
1570
1571 let (work_item_json, lock_token, attempt_count) = match row {
1572 Some(row) => row,
1573 None => return Ok(None),
1574 };
1575
1576 let work_item: WorkItem = serde_json::from_str(&work_item_json).map_err(|e| {
1577 ProviderError::permanent(
1578 "fetch_work_item",
1579 format!("Failed to deserialize worker item: {e}"),
1580 )
1581 })?;
1582
1583 let duration_ms = start.elapsed().as_millis() as u64;
1584
1585 let instance_id = match &work_item {
1587 WorkItem::ActivityExecute { instance, .. } => instance.as_str(),
1588 WorkItem::ActivityCompleted { instance, .. } => instance.as_str(),
1589 WorkItem::ActivityFailed { instance, .. } => instance.as_str(),
1590 WorkItem::StartOrchestration { instance, .. } => instance.as_str(),
1591 WorkItem::TimerFired { instance, .. } => instance.as_str(),
1592 WorkItem::ExternalRaised { instance, .. } => instance.as_str(),
1593 WorkItem::CancelInstance { instance, .. } => instance.as_str(),
1594 WorkItem::ContinueAsNew { instance, .. } => instance.as_str(),
1595 WorkItem::SubOrchCompleted {
1596 parent_instance, ..
1597 } => parent_instance.as_str(),
1598 WorkItem::SubOrchFailed {
1599 parent_instance, ..
1600 } => parent_instance.as_str(),
1601 WorkItem::QueueMessage { instance, .. } => instance.as_str(),
1602 };
1603
1604 debug!(
1605 target = "duroxide::providers::postgres",
1606 operation = "fetch_work_item",
1607 instance_id = %instance_id,
1608 attempt_count = attempt_count,
1609 duration_ms = duration_ms,
1610 "Fetched activity work item via stored procedure"
1611 );
1612
1613 Ok(Some((work_item, lock_token, attempt_count as u32)))
1614 }
1615
1616 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1617 async fn ack_work_item(
1618 &self,
1619 token: &str,
1620 completion: Option<WorkItem>,
1621 ) -> Result<(), ProviderError> {
1622 let start = std::time::Instant::now();
1623
1624 let Some(completion) = completion else {
1626 let now_ms = Self::now_millis();
1627 sqlx::query(&format!(
1629 "SELECT {}.ack_worker($1, NULL, NULL, $2)",
1630 self.schema_name
1631 ))
1632 .bind(token)
1633 .bind(now_ms)
1634 .execute(&*self.pool)
1635 .await
1636 .map_err(|e| {
1637 if e.to_string().contains("Worker queue item not found") {
1638 ProviderError::permanent(
1639 "ack_worker",
1640 "Worker queue item not found or already processed",
1641 )
1642 } else {
1643 self.sqlx_to_provider_error("ack_worker", e)
1644 }
1645 })?;
1646
1647 let duration_ms = start.elapsed().as_millis() as u64;
1648 debug!(
1649 target = "duroxide::providers::postgres",
1650 operation = "ack_worker",
1651 token = %token,
1652 duration_ms = duration_ms,
1653 "Acknowledged worker without completion (cancelled)"
1654 );
1655 return Ok(());
1656 };
1657
1658 let instance_id = match &completion {
1660 WorkItem::ActivityCompleted { instance, .. }
1661 | WorkItem::ActivityFailed { instance, .. } => instance,
1662 _ => {
1663 error!(
1664 target = "duroxide::providers::postgres",
1665 operation = "ack_worker",
1666 error_type = "invalid_completion_type",
1667 "Invalid completion work item type"
1668 );
1669 return Err(ProviderError::permanent(
1670 "ack_worker",
1671 "Invalid completion work item type",
1672 ));
1673 }
1674 };
1675
1676 let completion_json = serde_json::to_string(&completion).map_err(|e| {
1677 ProviderError::permanent("ack_worker", format!("Failed to serialize completion: {e}"))
1678 })?;
1679
1680 let now_ms = Self::now_millis();
1681
1682 sqlx::query(&format!(
1684 "SELECT {}.ack_worker($1, $2, $3, $4)",
1685 self.schema_name
1686 ))
1687 .bind(token)
1688 .bind(instance_id)
1689 .bind(completion_json)
1690 .bind(now_ms)
1691 .execute(&*self.pool)
1692 .await
1693 .map_err(|e| {
1694 if e.to_string().contains("Worker queue item not found") {
1695 error!(
1696 target = "duroxide::providers::postgres",
1697 operation = "ack_worker",
1698 error_type = "worker_item_not_found",
1699 token = %token,
1700 "Worker queue item not found or already processed"
1701 );
1702 ProviderError::permanent(
1703 "ack_worker",
1704 "Worker queue item not found or already processed",
1705 )
1706 } else {
1707 self.sqlx_to_provider_error("ack_worker", e)
1708 }
1709 })?;
1710
1711 let duration_ms = start.elapsed().as_millis() as u64;
1712 debug!(
1713 target = "duroxide::providers::postgres",
1714 operation = "ack_worker",
1715 instance_id = %instance_id,
1716 duration_ms = duration_ms,
1717 "Acknowledged worker and enqueued completion"
1718 );
1719
1720 Ok(())
1721 }
1722
1723 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1724 async fn renew_work_item_lock(
1725 &self,
1726 token: &str,
1727 extend_for: Duration,
1728 ) -> Result<(), ProviderError> {
1729 let start = std::time::Instant::now();
1730
1731 let now_ms = Self::now_millis();
1733
1734 let extend_secs = extend_for.as_secs() as i64;
1736
1737 match sqlx::query(&format!(
1738 "SELECT {}.renew_work_item_lock($1, $2, $3)",
1739 self.schema_name
1740 ))
1741 .bind(token)
1742 .bind(now_ms)
1743 .bind(extend_secs)
1744 .execute(&*self.pool)
1745 .await
1746 {
1747 Ok(_) => {
1748 let duration_ms = start.elapsed().as_millis() as u64;
1749 debug!(
1750 target = "duroxide::providers::postgres",
1751 operation = "renew_work_item_lock",
1752 token = %token,
1753 extend_for_secs = extend_secs,
1754 duration_ms = duration_ms,
1755 "Work item lock renewed successfully"
1756 );
1757 Ok(())
1758 }
1759 Err(e) => {
1760 if let SqlxError::Database(db_err) = &e {
1761 if db_err.message().contains("Lock token invalid") {
1762 return Err(ProviderError::permanent(
1763 "renew_work_item_lock",
1764 "Lock token invalid, expired, or already acked",
1765 ));
1766 }
1767 } else if e.to_string().contains("Lock token invalid") {
1768 return Err(ProviderError::permanent(
1769 "renew_work_item_lock",
1770 "Lock token invalid, expired, or already acked",
1771 ));
1772 }
1773
1774 Err(self.sqlx_to_provider_error("renew_work_item_lock", e))
1775 }
1776 }
1777 }
1778
1779 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1780 async fn abandon_work_item(
1781 &self,
1782 token: &str,
1783 delay: Option<Duration>,
1784 ignore_attempt: bool,
1785 ) -> Result<(), ProviderError> {
1786 let start = std::time::Instant::now();
1787 let now_ms = Self::now_millis();
1788 let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
1789
1790 match sqlx::query(&format!(
1791 "SELECT {}.abandon_work_item($1, $2, $3, $4)",
1792 self.schema_name
1793 ))
1794 .bind(token)
1795 .bind(now_ms)
1796 .bind(delay_param)
1797 .bind(ignore_attempt)
1798 .execute(&*self.pool)
1799 .await
1800 {
1801 Ok(_) => {
1802 let duration_ms = start.elapsed().as_millis() as u64;
1803 debug!(
1804 target = "duroxide::providers::postgres",
1805 operation = "abandon_work_item",
1806 token = %token,
1807 delay_ms = delay.map(|d| d.as_millis() as u64),
1808 ignore_attempt = ignore_attempt,
1809 duration_ms = duration_ms,
1810 "Abandoned work item via stored procedure"
1811 );
1812 Ok(())
1813 }
1814 Err(e) => {
1815 if let SqlxError::Database(db_err) = &e {
1816 if db_err.message().contains("Invalid lock token")
1817 || db_err.message().contains("already acked")
1818 {
1819 return Err(ProviderError::permanent(
1820 "abandon_work_item",
1821 "Invalid lock token or already acked",
1822 ));
1823 }
1824 } else if e.to_string().contains("Invalid lock token")
1825 || e.to_string().contains("already acked")
1826 {
1827 return Err(ProviderError::permanent(
1828 "abandon_work_item",
1829 "Invalid lock token or already acked",
1830 ));
1831 }
1832
1833 Err(self.sqlx_to_provider_error("abandon_work_item", e))
1834 }
1835 }
1836 }
1837
1838 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1839 async fn renew_orchestration_item_lock(
1840 &self,
1841 token: &str,
1842 extend_for: Duration,
1843 ) -> Result<(), ProviderError> {
1844 let start = std::time::Instant::now();
1845
1846 let now_ms = Self::now_millis();
1848
1849 let extend_secs = extend_for.as_secs() as i64;
1851
1852 match sqlx::query(&format!(
1853 "SELECT {}.renew_orchestration_item_lock($1, $2, $3)",
1854 self.schema_name
1855 ))
1856 .bind(token)
1857 .bind(now_ms)
1858 .bind(extend_secs)
1859 .execute(&*self.pool)
1860 .await
1861 {
1862 Ok(_) => {
1863 let duration_ms = start.elapsed().as_millis() as u64;
1864 debug!(
1865 target = "duroxide::providers::postgres",
1866 operation = "renew_orchestration_item_lock",
1867 token = %token,
1868 extend_for_secs = extend_secs,
1869 duration_ms = duration_ms,
1870 "Orchestration item lock renewed successfully"
1871 );
1872 Ok(())
1873 }
1874 Err(e) => {
1875 if let SqlxError::Database(db_err) = &e {
1876 if db_err.message().contains("Lock token invalid")
1877 || db_err.message().contains("expired")
1878 || db_err.message().contains("already released")
1879 {
1880 return Err(ProviderError::permanent(
1881 "renew_orchestration_item_lock",
1882 "Lock token invalid, expired, or already released",
1883 ));
1884 }
1885 } else if e.to_string().contains("Lock token invalid")
1886 || e.to_string().contains("expired")
1887 || e.to_string().contains("already released")
1888 {
1889 return Err(ProviderError::permanent(
1890 "renew_orchestration_item_lock",
1891 "Lock token invalid, expired, or already released",
1892 ));
1893 }
1894
1895 Err(self.sqlx_to_provider_error("renew_orchestration_item_lock", e))
1896 }
1897 }
1898 }
1899
1900 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1901 async fn enqueue_for_orchestrator(
1902 &self,
1903 item: WorkItem,
1904 delay: Option<Duration>,
1905 ) -> Result<(), ProviderError> {
1906 let work_item = serde_json::to_string(&item).map_err(|e| {
1907 ProviderError::permanent(
1908 "enqueue_orchestrator_work",
1909 format!("Failed to serialize work item: {e}"),
1910 )
1911 })?;
1912
1913 let instance_id = match &item {
1915 WorkItem::StartOrchestration { instance, .. }
1916 | WorkItem::ActivityCompleted { instance, .. }
1917 | WorkItem::ActivityFailed { instance, .. }
1918 | WorkItem::TimerFired { instance, .. }
1919 | WorkItem::ExternalRaised { instance, .. }
1920 | WorkItem::CancelInstance { instance, .. }
1921 | WorkItem::ContinueAsNew { instance, .. }
1922 | WorkItem::QueueMessage { instance, .. } => instance,
1923 WorkItem::SubOrchCompleted {
1924 parent_instance, ..
1925 }
1926 | WorkItem::SubOrchFailed {
1927 parent_instance, ..
1928 } => parent_instance,
1929 WorkItem::ActivityExecute { .. } => {
1930 return Err(ProviderError::permanent(
1931 "enqueue_orchestrator_work",
1932 "ActivityExecute should go to worker queue, not orchestrator queue",
1933 ));
1934 }
1935 };
1936
1937 let now_ms = Self::now_millis();
1939
1940 let visible_at_ms = if let WorkItem::TimerFired { fire_at_ms, .. } = &item {
1941 if *fire_at_ms > 0 {
1942 if let Some(delay) = delay {
1944 std::cmp::max(*fire_at_ms, now_ms as u64 + delay.as_millis() as u64)
1945 } else {
1946 *fire_at_ms
1947 }
1948 } else {
1949 delay
1951 .map(|d| now_ms as u64 + d.as_millis() as u64)
1952 .unwrap_or(now_ms as u64)
1953 }
1954 } else {
1955 delay
1957 .map(|d| now_ms as u64 + d.as_millis() as u64)
1958 .unwrap_or(now_ms as u64)
1959 };
1960
1961 let visible_at = Utc
1962 .timestamp_millis_opt(visible_at_ms as i64)
1963 .single()
1964 .ok_or_else(|| {
1965 ProviderError::permanent(
1966 "enqueue_orchestrator_work",
1967 "Invalid visible_at timestamp",
1968 )
1969 })?;
1970
1971 sqlx::query(&format!(
1976 "SELECT {}.enqueue_orchestrator_work($1, $2, $3, $4, $5, $6)",
1977 self.schema_name
1978 ))
1979 .bind(instance_id)
1980 .bind(&work_item)
1981 .bind(visible_at)
1982 .bind::<Option<String>>(None) .bind::<Option<String>>(None) .bind::<Option<i64>>(None) .execute(&*self.pool)
1986 .await
1987 .map_err(|e| {
1988 error!(
1989 target = "duroxide::providers::postgres",
1990 operation = "enqueue_orchestrator_work",
1991 error_type = "database_error",
1992 error = %e,
1993 instance_id = %instance_id,
1994 "Failed to enqueue orchestrator work"
1995 );
1996 self.sqlx_to_provider_error("enqueue_orchestrator_work", e)
1997 })?;
1998
1999 debug!(
2000 target = "duroxide::providers::postgres",
2001 operation = "enqueue_orchestrator_work",
2002 instance_id = %instance_id,
2003 delay_ms = delay.map(|d| d.as_millis() as u64),
2004 "Enqueued orchestrator work"
2005 );
2006
2007 Ok(())
2008 }
2009
2010 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2011 async fn read_with_execution(
2012 &self,
2013 instance: &str,
2014 execution_id: u64,
2015 ) -> Result<Vec<Event>, ProviderError> {
2016 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
2017 "SELECT event_data FROM {} WHERE instance_id = $1 AND execution_id = $2 ORDER BY event_id",
2018 self.table_name("history")
2019 ))
2020 .bind(instance)
2021 .bind(execution_id as i64)
2022 .fetch_all(&*self.pool)
2023 .await
2024 .map_err(|e| self.sqlx_to_provider_error("read_with_execution", e))?;
2025
2026 event_data_rows
2027 .into_iter()
2028 .map(|event_data| {
2029 serde_json::from_str::<Event>(&event_data).map_err(|e| {
2030 ProviderError::permanent(
2031 "read_with_execution",
2032 format!("Failed to deserialize event: {e}"),
2033 )
2034 })
2035 })
2036 .collect()
2037 }
2038
2039 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2040 async fn renew_session_lock(
2041 &self,
2042 owner_ids: &[&str],
2043 extend_for: Duration,
2044 idle_timeout: Duration,
2045 ) -> Result<usize, ProviderError> {
2046 if owner_ids.is_empty() {
2047 return Ok(0);
2048 }
2049
2050 let now_ms = Self::now_millis();
2051 let extend_ms = extend_for.as_millis() as i64;
2052 let idle_timeout_ms = idle_timeout.as_millis() as i64;
2053 let owner_ids_vec: Vec<&str> = owner_ids.to_vec();
2054
2055 let result = sqlx::query_scalar::<_, i64>(&format!(
2056 "SELECT {}.renew_session_lock($1, $2, $3, $4)",
2057 self.schema_name
2058 ))
2059 .bind(&owner_ids_vec)
2060 .bind(now_ms)
2061 .bind(extend_ms)
2062 .bind(idle_timeout_ms)
2063 .fetch_one(&*self.pool)
2064 .await
2065 .map_err(|e| self.sqlx_to_provider_error("renew_session_lock", e))?;
2066
2067 debug!(
2068 target = "duroxide::providers::postgres",
2069 operation = "renew_session_lock",
2070 owner_count = owner_ids.len(),
2071 sessions_renewed = result,
2072 "Session locks renewed"
2073 );
2074
2075 Ok(result as usize)
2076 }
2077
2078 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2079 async fn cleanup_orphaned_sessions(
2080 &self,
2081 _idle_timeout: Duration,
2082 ) -> Result<usize, ProviderError> {
2083 let now_ms = Self::now_millis();
2084
2085 let result = sqlx::query_scalar::<_, i64>(&format!(
2086 "SELECT {}.cleanup_orphaned_sessions($1)",
2087 self.schema_name
2088 ))
2089 .bind(now_ms)
2090 .fetch_one(&*self.pool)
2091 .await
2092 .map_err(|e| self.sqlx_to_provider_error("cleanup_orphaned_sessions", e))?;
2093
2094 debug!(
2095 target = "duroxide::providers::postgres",
2096 operation = "cleanup_orphaned_sessions",
2097 sessions_cleaned = result,
2098 "Orphaned sessions cleaned up"
2099 );
2100
2101 Ok(result as usize)
2102 }
2103
2104 fn as_management_capability(&self) -> Option<&dyn ProviderAdmin> {
2105 Some(self)
2106 }
2107
2108 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2109 async fn get_custom_status(
2110 &self,
2111 instance: &str,
2112 last_seen_version: u64,
2113 ) -> Result<Option<(Option<String>, u64)>, ProviderError> {
2114 let row = sqlx::query_as::<_, (Option<String>, i64)>(&format!(
2115 "SELECT * FROM {}.get_custom_status($1, $2)",
2116 self.schema_name
2117 ))
2118 .bind(instance)
2119 .bind(last_seen_version as i64)
2120 .fetch_optional(&*self.pool)
2121 .await
2122 .map_err(|e| self.sqlx_to_provider_error("get_custom_status", e))?;
2123
2124 match row {
2125 Some((custom_status, version)) => Ok(Some((custom_status, version as u64))),
2126 None => Ok(None),
2127 }
2128 }
2129
2130 async fn get_kv_value(
2131 &self,
2132 instance_id: &str,
2133 key: &str,
2134 ) -> Result<Option<String>, ProviderError> {
2135 let row: Option<(Option<String>, bool)> = sqlx::query_as(&format!(
2136 "SELECT * FROM {}.get_kv_value($1, $2)",
2137 self.schema_name
2138 ))
2139 .bind(instance_id)
2140 .bind(key)
2141 .fetch_optional(&*self.pool)
2142 .await
2143 .map_err(|e| self.sqlx_to_provider_error("get_kv_value", e))?;
2144
2145 Ok(row.and_then(|(value, found)| if found { value } else { None }))
2146 }
2147
2148 async fn get_kv_all_values(
2149 &self,
2150 instance_id: &str,
2151 ) -> Result<std::collections::HashMap<String, String>, ProviderError> {
2152 let rows: Vec<(String, String)> = sqlx::query_as(&format!(
2153 "SELECT * FROM {}.get_kv_all_values($1)",
2154 self.schema_name
2155 ))
2156 .bind(instance_id)
2157 .fetch_all(&*self.pool)
2158 .await
2159 .map_err(|e| self.sqlx_to_provider_error("get_kv_all_values", e))?;
2160
2161 Ok(rows.into_iter().collect())
2162 }
2163
2164 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2165 async fn get_instance_stats(
2166 &self,
2167 instance: &str,
2168 ) -> Result<Option<SystemStats>, ProviderError> {
2169 let row: Option<(bool, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
2170 "SELECT * FROM {}.get_instance_stats($1)",
2171 self.schema_name
2172 ))
2173 .bind(instance)
2174 .fetch_optional(&*self.pool)
2175 .await
2176 .map_err(|e| self.sqlx_to_provider_error("get_instance_stats", e))?;
2177
2178 match row {
2179 Some((
2180 true,
2181 history_event_count,
2182 history_size_bytes,
2183 queue_pending_count,
2184 kv_user_key_count,
2185 kv_total_value_bytes,
2186 )) => Ok(Some(SystemStats {
2187 history_event_count: history_event_count as u64,
2188 history_size_bytes: history_size_bytes as u64,
2189 queue_pending_count: queue_pending_count as u64,
2190 kv_user_key_count: kv_user_key_count as u64,
2191 kv_total_value_bytes: kv_total_value_bytes as u64,
2192 })),
2193 _ => Ok(None),
2194 }
2195 }
2196}
2197
2198#[async_trait::async_trait]
2199impl ProviderAdmin for PostgresProvider {
2200 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2201 async fn list_instances(&self) -> Result<Vec<String>, ProviderError> {
2202 sqlx::query_scalar(&format!(
2203 "SELECT instance_id FROM {}.list_instances()",
2204 self.schema_name
2205 ))
2206 .fetch_all(&*self.pool)
2207 .await
2208 .map_err(|e| self.sqlx_to_provider_error("list_instances", e))
2209 }
2210
2211 #[instrument(skip(self), fields(status = %status), target = "duroxide::providers::postgres")]
2212 async fn list_instances_by_status(&self, status: &str) -> Result<Vec<String>, ProviderError> {
2213 sqlx::query_scalar(&format!(
2214 "SELECT instance_id FROM {}.list_instances_by_status($1)",
2215 self.schema_name
2216 ))
2217 .bind(status)
2218 .fetch_all(&*self.pool)
2219 .await
2220 .map_err(|e| self.sqlx_to_provider_error("list_instances_by_status", e))
2221 }
2222
2223 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2224 async fn list_executions(&self, instance: &str) -> Result<Vec<u64>, ProviderError> {
2225 let execution_ids: Vec<i64> = sqlx::query_scalar(&format!(
2226 "SELECT execution_id FROM {}.list_executions($1)",
2227 self.schema_name
2228 ))
2229 .bind(instance)
2230 .fetch_all(&*self.pool)
2231 .await
2232 .map_err(|e| self.sqlx_to_provider_error("list_executions", e))?;
2233
2234 Ok(execution_ids.into_iter().map(|id| id as u64).collect())
2235 }
2236
2237 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
2238 async fn read_history_with_execution_id(
2239 &self,
2240 instance: &str,
2241 execution_id: u64,
2242 ) -> Result<Vec<Event>, ProviderError> {
2243 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
2244 "SELECT out_event_data FROM {}.fetch_history_with_execution($1, $2)",
2245 self.schema_name
2246 ))
2247 .bind(instance)
2248 .bind(execution_id as i64)
2249 .fetch_all(&*self.pool)
2250 .await
2251 .map_err(|e| self.sqlx_to_provider_error("read_execution", e))?;
2252
2253 event_data_rows
2254 .into_iter()
2255 .map(|event_data| {
2256 serde_json::from_str::<Event>(&event_data).map_err(|e| {
2257 ProviderError::permanent(
2258 "read_history_with_execution_id",
2259 format!("Failed to deserialize event: {e}"),
2260 )
2261 })
2262 })
2263 .collect()
2264 }
2265
2266 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2267 async fn read_history(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
2268 let execution_id = self.latest_execution_id(instance).await?;
2269 self.read_history_with_execution_id(instance, execution_id)
2270 .await
2271 }
2272
2273 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2274 async fn latest_execution_id(&self, instance: &str) -> Result<u64, ProviderError> {
2275 sqlx::query_scalar(&format!(
2276 "SELECT {}.latest_execution_id($1)",
2277 self.schema_name
2278 ))
2279 .bind(instance)
2280 .fetch_optional(&*self.pool)
2281 .await
2282 .map_err(|e| self.sqlx_to_provider_error("latest_execution_id", e))?
2283 .map(|id: i64| id as u64)
2284 .ok_or_else(|| ProviderError::permanent("latest_execution_id", "Instance not found"))
2285 }
2286
2287 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2288 async fn get_instance_info(&self, instance: &str) -> Result<InstanceInfo, ProviderError> {
2289 let row: Option<(
2290 String,
2291 String,
2292 String,
2293 i64,
2294 chrono::DateTime<Utc>,
2295 Option<chrono::DateTime<Utc>>,
2296 Option<String>,
2297 Option<String>,
2298 Option<String>,
2299 )> = sqlx::query_as(&format!(
2300 "SELECT * FROM {}.get_instance_info($1)",
2301 self.schema_name
2302 ))
2303 .bind(instance)
2304 .fetch_optional(&*self.pool)
2305 .await
2306 .map_err(|e| self.sqlx_to_provider_error("get_instance_info", e))?;
2307
2308 let (
2309 instance_id,
2310 orchestration_name,
2311 orchestration_version,
2312 current_execution_id,
2313 created_at,
2314 updated_at,
2315 status,
2316 output,
2317 parent_instance_id,
2318 ) =
2319 row.ok_or_else(|| ProviderError::permanent("get_instance_info", "Instance not found"))?;
2320
2321 Ok(InstanceInfo {
2322 instance_id,
2323 orchestration_name,
2324 orchestration_version,
2325 current_execution_id: current_execution_id as u64,
2326 status: status.unwrap_or_else(|| "Running".to_string()),
2327 output,
2328 created_at: created_at.timestamp_millis() as u64,
2329 updated_at: updated_at
2330 .map(|dt| dt.timestamp_millis() as u64)
2331 .unwrap_or(created_at.timestamp_millis() as u64),
2332 parent_instance_id,
2333 })
2334 }
2335
2336 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
2337 async fn get_execution_info(
2338 &self,
2339 instance: &str,
2340 execution_id: u64,
2341 ) -> Result<ExecutionInfo, ProviderError> {
2342 let row: Option<(
2343 i64,
2344 String,
2345 Option<String>,
2346 chrono::DateTime<Utc>,
2347 Option<chrono::DateTime<Utc>>,
2348 i64,
2349 )> = sqlx::query_as(&format!(
2350 "SELECT * FROM {}.get_execution_info($1, $2)",
2351 self.schema_name
2352 ))
2353 .bind(instance)
2354 .bind(execution_id as i64)
2355 .fetch_optional(&*self.pool)
2356 .await
2357 .map_err(|e| self.sqlx_to_provider_error("get_execution_info", e))?;
2358
2359 let (exec_id, status, output, started_at, completed_at, event_count) = row
2360 .ok_or_else(|| ProviderError::permanent("get_execution_info", "Execution not found"))?;
2361
2362 Ok(ExecutionInfo {
2363 execution_id: exec_id as u64,
2364 status,
2365 output,
2366 started_at: started_at.timestamp_millis() as u64,
2367 completed_at: completed_at.map(|dt| dt.timestamp_millis() as u64),
2368 event_count: event_count as usize,
2369 })
2370 }
2371
2372 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2373 async fn get_system_metrics(&self) -> Result<SystemMetrics, ProviderError> {
2374 let row: Option<(i64, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
2375 "SELECT * FROM {}.get_system_metrics()",
2376 self.schema_name
2377 ))
2378 .fetch_optional(&*self.pool)
2379 .await
2380 .map_err(|e| self.sqlx_to_provider_error("get_system_metrics", e))?;
2381
2382 let (
2383 total_instances,
2384 total_executions,
2385 running_instances,
2386 completed_instances,
2387 failed_instances,
2388 total_events,
2389 ) = row.ok_or_else(|| {
2390 ProviderError::permanent("get_system_metrics", "Failed to get system metrics")
2391 })?;
2392
2393 Ok(SystemMetrics {
2394 total_instances: total_instances as u64,
2395 total_executions: total_executions as u64,
2396 running_instances: running_instances as u64,
2397 completed_instances: completed_instances as u64,
2398 failed_instances: failed_instances as u64,
2399 total_events: total_events as u64,
2400 })
2401 }
2402
2403 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2404 async fn get_queue_depths(&self) -> Result<QueueDepths, ProviderError> {
2405 let now_ms = Self::now_millis();
2406
2407 let row: Option<(i64, i64)> = sqlx::query_as(&format!(
2408 "SELECT * FROM {}.get_queue_depths($1)",
2409 self.schema_name
2410 ))
2411 .bind(now_ms)
2412 .fetch_optional(&*self.pool)
2413 .await
2414 .map_err(|e| self.sqlx_to_provider_error("get_queue_depths", e))?;
2415
2416 let (orchestrator_queue, worker_queue) = row.ok_or_else(|| {
2417 ProviderError::permanent("get_queue_depths", "Failed to get queue depths")
2418 })?;
2419
2420 Ok(QueueDepths {
2421 orchestrator_queue: orchestrator_queue as usize,
2422 worker_queue: worker_queue as usize,
2423 timer_queue: 0, })
2425 }
2426
2427 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2430 async fn list_children(&self, instance_id: &str) -> Result<Vec<String>, ProviderError> {
2431 sqlx::query_scalar(&format!(
2432 "SELECT child_instance_id FROM {}.list_children($1)",
2433 self.schema_name
2434 ))
2435 .bind(instance_id)
2436 .fetch_all(&*self.pool)
2437 .await
2438 .map_err(|e| self.sqlx_to_provider_error("list_children", e))
2439 }
2440
2441 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2442 async fn get_parent_id(&self, instance_id: &str) -> Result<Option<String>, ProviderError> {
2443 let result: Result<Option<String>, _> =
2446 sqlx::query_scalar(&format!("SELECT {}.get_parent_id($1)", self.schema_name))
2447 .bind(instance_id)
2448 .fetch_one(&*self.pool)
2449 .await;
2450
2451 match result {
2452 Ok(parent_id) => Ok(parent_id),
2453 Err(e) => {
2454 let err_str = e.to_string();
2455 if err_str.contains("Instance not found") {
2456 Err(ProviderError::permanent(
2457 "get_parent_id",
2458 format!("Instance not found: {}", instance_id),
2459 ))
2460 } else {
2461 Err(self.sqlx_to_provider_error("get_parent_id", e))
2462 }
2463 }
2464 }
2465 }
2466
2467 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2470 async fn delete_instances_atomic(
2471 &self,
2472 ids: &[String],
2473 force: bool,
2474 ) -> Result<DeleteInstanceResult, ProviderError> {
2475 if ids.is_empty() {
2476 return Ok(DeleteInstanceResult::default());
2477 }
2478
2479 let row: Option<(i64, i64, i64, i64)> = sqlx::query_as(&format!(
2480 "SELECT * FROM {}.delete_instances_atomic($1, $2)",
2481 self.schema_name
2482 ))
2483 .bind(ids)
2484 .bind(force)
2485 .fetch_optional(&*self.pool)
2486 .await
2487 .map_err(|e| {
2488 let err_str = e.to_string();
2489 if err_str.contains("is Running") {
2490 ProviderError::permanent("delete_instances_atomic", err_str)
2491 } else if err_str.contains("Orphan detected") {
2492 ProviderError::permanent("delete_instances_atomic", err_str)
2493 } else {
2494 self.sqlx_to_provider_error("delete_instances_atomic", e)
2495 }
2496 })?;
2497
2498 let (instances_deleted, executions_deleted, events_deleted, queue_messages_deleted) =
2499 row.unwrap_or((0, 0, 0, 0));
2500
2501 debug!(
2502 target = "duroxide::providers::postgres",
2503 operation = "delete_instances_atomic",
2504 instances_deleted = instances_deleted,
2505 executions_deleted = executions_deleted,
2506 events_deleted = events_deleted,
2507 queue_messages_deleted = queue_messages_deleted,
2508 "Deleted instances atomically"
2509 );
2510
2511 Ok(DeleteInstanceResult {
2512 instances_deleted: instances_deleted as u64,
2513 executions_deleted: executions_deleted as u64,
2514 events_deleted: events_deleted as u64,
2515 queue_messages_deleted: queue_messages_deleted as u64,
2516 })
2517 }
2518
2519 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2520 async fn delete_instance_bulk(
2521 &self,
2522 filter: InstanceFilter,
2523 ) -> Result<DeleteInstanceResult, ProviderError> {
2524 let mut sql = format!(
2526 r#"
2527 SELECT i.instance_id
2528 FROM {}.instances i
2529 LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
2530 AND i.current_execution_id = e.execution_id
2531 WHERE i.parent_instance_id IS NULL
2532 AND e.status IN ('Completed', 'Failed', 'ContinuedAsNew')
2533 "#,
2534 self.schema_name, self.schema_name
2535 );
2536
2537 if let Some(ref ids) = filter.instance_ids {
2539 if ids.is_empty() {
2540 return Ok(DeleteInstanceResult::default());
2541 }
2542 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
2543 sql.push_str(&format!(
2544 " AND i.instance_id IN ({})",
2545 placeholders.join(", ")
2546 ));
2547 }
2548
2549 if filter.completed_before.is_some() {
2551 let param_num = filter
2552 .instance_ids
2553 .as_ref()
2554 .map(|ids| ids.len())
2555 .unwrap_or(0)
2556 + 1;
2557 sql.push_str(&format!(
2558 " AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
2559 param_num
2560 ));
2561 }
2562
2563 let limit = filter.limit.unwrap_or(1000);
2565 let limit_param_num = filter
2566 .instance_ids
2567 .as_ref()
2568 .map(|ids| ids.len())
2569 .unwrap_or(0)
2570 + if filter.completed_before.is_some() {
2571 1
2572 } else {
2573 0
2574 }
2575 + 1;
2576 sql.push_str(&format!(" LIMIT ${}", limit_param_num));
2577
2578 let mut query = sqlx::query_scalar::<_, String>(&sql);
2580 if let Some(ref ids) = filter.instance_ids {
2581 for id in ids {
2582 query = query.bind(id);
2583 }
2584 }
2585 if let Some(completed_before) = filter.completed_before {
2586 query = query.bind(completed_before as i64);
2587 }
2588 query = query.bind(limit as i64);
2589
2590 let instance_ids: Vec<String> = query
2591 .fetch_all(&*self.pool)
2592 .await
2593 .map_err(|e| self.sqlx_to_provider_error("delete_instance_bulk", e))?;
2594
2595 if instance_ids.is_empty() {
2596 return Ok(DeleteInstanceResult::default());
2597 }
2598
2599 let mut result = DeleteInstanceResult::default();
2601
2602 for instance_id in &instance_ids {
2603 let tree = self.get_instance_tree(instance_id).await?;
2605
2606 let delete_result = self.delete_instances_atomic(&tree.all_ids, true).await?;
2608 result.instances_deleted += delete_result.instances_deleted;
2609 result.executions_deleted += delete_result.executions_deleted;
2610 result.events_deleted += delete_result.events_deleted;
2611 result.queue_messages_deleted += delete_result.queue_messages_deleted;
2612 }
2613
2614 debug!(
2615 target = "duroxide::providers::postgres",
2616 operation = "delete_instance_bulk",
2617 instances_deleted = result.instances_deleted,
2618 executions_deleted = result.executions_deleted,
2619 events_deleted = result.events_deleted,
2620 queue_messages_deleted = result.queue_messages_deleted,
2621 "Bulk deleted instances"
2622 );
2623
2624 Ok(result)
2625 }
2626
2627 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2630 async fn prune_executions(
2631 &self,
2632 instance_id: &str,
2633 options: PruneOptions,
2634 ) -> Result<PruneResult, ProviderError> {
2635 let keep_last: Option<i32> = options.keep_last.map(|v| v as i32);
2636 let completed_before_ms: Option<i64> = options.completed_before.map(|v| v as i64);
2637
2638 let row: Option<(i64, i64, i64)> = sqlx::query_as(&format!(
2639 "SELECT * FROM {}.prune_executions($1, $2, $3)",
2640 self.schema_name
2641 ))
2642 .bind(instance_id)
2643 .bind(keep_last)
2644 .bind(completed_before_ms)
2645 .fetch_optional(&*self.pool)
2646 .await
2647 .map_err(|e| self.sqlx_to_provider_error("prune_executions", e))?;
2648
2649 let (instances_processed, executions_deleted, events_deleted) = row.unwrap_or((0, 0, 0));
2650
2651 debug!(
2652 target = "duroxide::providers::postgres",
2653 operation = "prune_executions",
2654 instance_id = %instance_id,
2655 instances_processed = instances_processed,
2656 executions_deleted = executions_deleted,
2657 events_deleted = events_deleted,
2658 "Pruned executions"
2659 );
2660
2661 Ok(PruneResult {
2662 instances_processed: instances_processed as u64,
2663 executions_deleted: executions_deleted as u64,
2664 events_deleted: events_deleted as u64,
2665 })
2666 }
2667
2668 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2669 async fn prune_executions_bulk(
2670 &self,
2671 filter: InstanceFilter,
2672 options: PruneOptions,
2673 ) -> Result<PruneResult, ProviderError> {
2674 let mut sql = format!(
2679 r#"
2680 SELECT i.instance_id
2681 FROM {}.instances i
2682 LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
2683 AND i.current_execution_id = e.execution_id
2684 WHERE 1=1
2685 "#,
2686 self.schema_name, self.schema_name
2687 );
2688
2689 if let Some(ref ids) = filter.instance_ids {
2691 if ids.is_empty() {
2692 return Ok(PruneResult::default());
2693 }
2694 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
2695 sql.push_str(&format!(
2696 " AND i.instance_id IN ({})",
2697 placeholders.join(", ")
2698 ));
2699 }
2700
2701 if filter.completed_before.is_some() {
2703 let param_num = filter
2704 .instance_ids
2705 .as_ref()
2706 .map(|ids| ids.len())
2707 .unwrap_or(0)
2708 + 1;
2709 sql.push_str(&format!(
2710 " AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
2711 param_num
2712 ));
2713 }
2714
2715 let limit = filter.limit.unwrap_or(1000);
2717 let limit_param_num = filter
2718 .instance_ids
2719 .as_ref()
2720 .map(|ids| ids.len())
2721 .unwrap_or(0)
2722 + if filter.completed_before.is_some() {
2723 1
2724 } else {
2725 0
2726 }
2727 + 1;
2728 sql.push_str(&format!(" LIMIT ${}", limit_param_num));
2729
2730 let mut query = sqlx::query_scalar::<_, String>(&sql);
2732 if let Some(ref ids) = filter.instance_ids {
2733 for id in ids {
2734 query = query.bind(id);
2735 }
2736 }
2737 if let Some(completed_before) = filter.completed_before {
2738 query = query.bind(completed_before as i64);
2739 }
2740 query = query.bind(limit as i64);
2741
2742 let instance_ids: Vec<String> = query
2743 .fetch_all(&*self.pool)
2744 .await
2745 .map_err(|e| self.sqlx_to_provider_error("prune_executions_bulk", e))?;
2746
2747 let mut result = PruneResult::default();
2749
2750 for instance_id in &instance_ids {
2751 let single_result = self.prune_executions(instance_id, options.clone()).await?;
2752 result.instances_processed += single_result.instances_processed;
2753 result.executions_deleted += single_result.executions_deleted;
2754 result.events_deleted += single_result.events_deleted;
2755 }
2756
2757 debug!(
2758 target = "duroxide::providers::postgres",
2759 operation = "prune_executions_bulk",
2760 instances_processed = result.instances_processed,
2761 executions_deleted = result.executions_deleted,
2762 events_deleted = result.events_deleted,
2763 "Bulk pruned executions"
2764 );
2765
2766 Ok(result)
2767 }
2768}
2769
2770#[cfg(test)]
2771mod tests {
2772 use super::*;
2773 use crate::entra::test_support::{token, RecordingFakeTokenSource};
2774
2775 #[test]
2776 fn build_entra_connect_options_uses_verify_full() {
2777 let opts =
2778 build_entra_connect_options("h.example.com", 5432, "db", "u", PgSslMode::VerifyFull);
2779 assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull));
2780 assert_eq!(opts.get_host(), "h.example.com");
2781 assert_eq!(opts.get_port(), 5432);
2782 assert_eq!(opts.get_database(), Some("db"));
2783 assert_eq!(opts.get_username(), "u");
2784 }
2785
2786 #[test]
2787 fn compute_next_refresh_sleep_is_capped_by_ceiling() {
2788 let now = SystemTime::now();
2790 let expires = now + Duration::from_secs(24 * 3600);
2791 let sleep = compute_next_refresh_sleep(Duration::from_secs(5 * 60), expires, now);
2792 assert_eq!(sleep, Duration::from_secs(5 * 60));
2793 }
2794
2795 #[test]
2796 fn compute_next_refresh_sleep_drives_from_expiry() {
2797 let now = SystemTime::now();
2799 let expires = now + Duration::from_secs(6 * 60);
2800 let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
2801 assert!(sleep <= Duration::from_secs(60), "got {sleep:?}");
2802 assert!(sleep >= ENTRA_REFRESH_MIN_INTERVAL, "got {sleep:?}");
2803 }
2804
2805 #[test]
2806 fn compute_next_refresh_sleep_floors_at_min_interval() {
2807 let now = SystemTime::now();
2809 let expires = now + Duration::from_secs(60); let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
2811 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2812 }
2813
2814 #[tokio::test]
2815 async fn recording_token_source_returns_distinct_tokens_in_script_order() {
2816 let fake = RecordingFakeTokenSource::with_tokens(vec![
2823 token("token-A", 3600),
2824 token("token-B", 3600),
2825 token("token-C", 3600),
2826 token("token-D", 3600),
2827 token("token-E", 3600),
2828 token("token-F", 3600),
2829 ]);
2830 let token_source: Arc<dyn TokenSource> = fake.clone();
2831
2832 let base_options =
2836 build_entra_connect_options("127.0.0.1", 5432, "db", "u", PgSslMode::VerifyFull);
2837 let pool: Arc<PgPool> = Arc::new(
2838 PgPoolOptions::new()
2839 .max_connections(1)
2840 .connect_lazy_with(base_options.clone().password("placeholder")),
2841 );
2842
2843 let initial_expires_at = SystemTime::now() + Duration::from_secs(3600);
2844
2845 let _ = pool;
2860 let _ = initial_expires_at;
2861
2862 let t1 = token_source.fetch_token(&["aud"]).await.unwrap();
2863 let t2 = token_source.fetch_token(&["aud"]).await.unwrap();
2864 let t3 = token_source.fetch_token(&["aud"]).await.unwrap();
2865 assert_ne!(t1.secret, t2.secret);
2866 assert_ne!(t2.secret, t3.secret);
2867 assert_eq!(fake.call_count(), 3);
2868 }
2869
2870 #[tokio::test]
2871 async fn audience_override_is_passed_to_token_source() {
2872 let fake = RecordingFakeTokenSource::with_tokens(vec![token("t", 3600)]);
2873 let source: Arc<dyn TokenSource> = fake.clone();
2874 let opts =
2875 crate::entra::EntraAuthOptions::new().audience("https://custom.example/.default");
2876 let _t = source.fetch_token(&[opts.audience_str()]).await.unwrap();
2877 let scopes = fake.recorded_scopes();
2878 assert_eq!(scopes.len(), 1);
2879 assert_eq!(
2880 scopes[0],
2881 vec!["https://custom.example/.default".to_string()]
2882 );
2883 }
2884
2885 #[tokio::test]
2886 async fn missing_credential_surfaces_descriptive_error() {
2887 let fake = RecordingFakeTokenSource::always_failing("no credential available");
2888 let source: Arc<dyn TokenSource> = fake;
2889 let result: anyhow::Result<crate::entra::EntraToken> = source.fetch_token(&["aud"]).await;
2890 let err = result.expect_err("should fail");
2891 let msg = format!("{err:#}");
2892 assert!(msg.contains("no credential available"), "got: {msg}");
2893 }
2894
2895 #[test]
2896 fn next_sleep_after_iteration_uses_expiry_schedule_on_success() {
2897 let now = SystemTime::now();
2898 let expires = now + Duration::from_secs(3600);
2899 let result: Result<Result<(), ()>, String> = Ok(Ok(()));
2900 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2901 let expected = compute_next_refresh_sleep(Duration::from_secs(20 * 60), expires, now);
2903 assert_eq!(sleep, expected);
2904 assert_eq!(sleep, Duration::from_secs(20 * 60));
2905 }
2906
2907 #[test]
2908 fn next_sleep_after_iteration_returns_min_interval_on_fetch_failure() {
2909 let now = SystemTime::now();
2913 let expires = now + Duration::from_secs(3600);
2916 let result: Result<Result<(), ()>, String> = Ok(Err(()));
2917 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2918 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2919 }
2920
2921 #[test]
2922 fn next_sleep_after_iteration_returns_min_interval_on_panic() {
2923 let now = SystemTime::now();
2924 let expires = now + Duration::from_secs(3600);
2925 let result: Result<Result<(), ()>, String> = Err("simulated panic".to_string());
2926 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2927 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2928 }
2929
2930 #[test]
2931 fn compute_next_refresh_sleep_floors_when_ceiling_is_tiny() {
2932 let now = SystemTime::now();
2935 let expires = now + Duration::from_secs(3600);
2936 let sleep = compute_next_refresh_sleep(Duration::from_secs(1), expires, now);
2937 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2938 }
2939
2940 #[test]
2941 fn entra_token_debug_redacts_secret() {
2942 use crate::entra::test_support::token;
2943 let t = token("super-secret-bearer-string", 3600);
2944 let debug = format!("{t:?}");
2945 assert!(
2946 !debug.contains("super-secret-bearer-string"),
2947 "leaked: {debug}"
2948 );
2949 assert!(
2950 debug.contains("<redacted>"),
2951 "expected redaction marker: {debug}"
2952 );
2953 }
2954
2955 #[test]
2956 fn classify_pg_sqlstate_gates_28xxx_on_is_entra() {
2957 use crate::provider::{classify_pg_sqlstate, SqlStateClass};
2958
2959 assert_eq!(
2961 classify_pg_sqlstate(Some("28000"), true),
2962 SqlStateClass::Retryable
2963 );
2964 assert_eq!(
2965 classify_pg_sqlstate(Some("28P01"), true),
2966 SqlStateClass::Retryable
2967 );
2968
2969 assert_eq!(
2971 classify_pg_sqlstate(Some("28000"), false),
2972 SqlStateClass::Permanent
2973 );
2974 assert_eq!(
2975 classify_pg_sqlstate(Some("28P01"), false),
2976 SqlStateClass::Permanent
2977 );
2978
2979 assert_eq!(
2981 classify_pg_sqlstate(Some("40P01"), true),
2982 SqlStateClass::Retryable
2983 );
2984 assert_eq!(
2985 classify_pg_sqlstate(Some("40P01"), false),
2986 SqlStateClass::Retryable
2987 );
2988 assert_eq!(
2989 classify_pg_sqlstate(Some("23505"), true),
2990 SqlStateClass::Permanent
2991 );
2992 assert_eq!(
2993 classify_pg_sqlstate(Some("23505"), false),
2994 SqlStateClass::Permanent
2995 );
2996 assert_eq!(
2997 classify_pg_sqlstate(Some("0A000"), true),
2998 SqlStateClass::Retryable
2999 );
3000 assert_eq!(classify_pg_sqlstate(None, true), SqlStateClass::Permanent);
3001 }
3002
3003 #[tokio::test]
3004 async fn run_with_panic_guard_catches_string_panic_and_continues() {
3005 let result: Result<(), String> = run_with_panic_guard(async { panic!("boom") }).await;
3006 let msg = result.expect_err("must catch the panic");
3007 assert!(msg.contains("boom"), "got: {msg}");
3008 }
3009
3010 #[tokio::test]
3011 async fn run_with_panic_guard_returns_ok_when_future_completes() {
3012 let result: Result<i32, String> = run_with_panic_guard(async { 42 }).await;
3013 assert_eq!(result.unwrap(), 42);
3014 }
3015
3016 #[tokio::test]
3017 async fn run_with_panic_guard_handles_non_string_panic_payload() {
3018 let result: Result<(), String> =
3020 run_with_panic_guard(async { std::panic::panic_any(42_i32) }).await;
3021 let msg = result.expect_err("must catch");
3022 assert!(msg.contains("non-string panic payload"), "got: {msg}");
3023 }
3024
3025 #[test]
3028 fn truncate_panic_message_passes_through_short_input() {
3029 let s = "short message".to_string();
3030 assert_eq!(truncate_panic_message(s.clone(), 256), s);
3031 }
3032
3033 #[test]
3034 fn truncate_panic_message_truncates_long_input_with_marker() {
3035 let raw = "A".repeat(1024);
3036 let out = truncate_panic_message(raw, 256);
3037 assert!(out.starts_with(&"A".repeat(256)));
3038 assert!(out.ends_with("…[truncated]"), "got: {out}");
3039 assert_eq!(out.len(), 256 + "…[truncated]".len());
3041 }
3042
3043 #[test]
3044 fn truncate_panic_message_respects_utf8_char_boundaries() {
3045 let raw = "✨".repeat(100);
3049 let out = truncate_panic_message(raw, 256);
3050 assert!(out.ends_with("…[truncated]"));
3053 }
3054
3055 #[tokio::test]
3056 async fn run_with_panic_guard_truncates_oversized_panic_message() {
3057 let result: Result<(), String> = run_with_panic_guard(async {
3060 panic!("{}", "S".repeat(10_000));
3061 })
3062 .await;
3063 let msg = result.expect_err("must catch");
3064 assert!(
3065 msg.len() < 10_000,
3066 "panic message not truncated: len={}",
3067 msg.len()
3068 );
3069 assert!(
3070 msg.ends_with("…[truncated]"),
3071 "missing truncation marker: {msg}"
3072 );
3073 }
3074}
3075
3076#[cfg(test)]
3101mod entra_pipeline_tests {
3102 use super::*;
3103 use crate::entra::test_support::{token, RecordingFakeTokenSource};
3104 use sqlx::Row;
3105
3106 fn parse_database_url(url: &str) -> Option<(String, u16, String, String, String)> {
3112 let stripped = url
3113 .strip_prefix("postgres://")
3114 .or_else(|| url.strip_prefix("postgresql://"))?;
3115 let (creds, rest) = stripped.split_once('@')?;
3116 let (user, password) = creds.split_once(':')?;
3117 let (hostport, db_with_query) = rest.split_once('/')?;
3118 let (host, port_str) = hostport
3119 .split_once(':')
3120 .map(|(h, p)| (h, p))
3121 .unwrap_or((hostport, "5432"));
3122 let port: u16 = port_str.parse().ok()?;
3123 let db = db_with_query.split('?').next()?;
3124 Some((
3125 host.to_string(),
3126 port,
3127 db.to_string(),
3128 user.to_string(),
3129 password.to_string(),
3130 ))
3131 }
3132
3133 fn pg_connection_or_skip() -> Option<(String, u16, String, String, String)> {
3136 dotenvy::dotenv().ok();
3137 let url = match std::env::var("DATABASE_URL") {
3138 Ok(u) => u,
3139 Err(_) => {
3140 eprintln!("DATABASE_URL not set; skipping Entra pipeline integration test");
3141 return None;
3142 }
3143 };
3144 match parse_database_url(&url) {
3145 Some(parts) => Some(parts),
3146 None => {
3147 eprintln!("DATABASE_URL not parseable; skipping: {url}");
3148 None
3149 }
3150 }
3151 }
3152
3153 fn unique_schema() -> String {
3154 let id = uuid::Uuid::new_v4().to_string();
3155 format!("entra_inj_{}", &id[id.len() - 8..])
3156 }
3157
3158 async fn wrong_password_is_rejected(host: &str, port: u16, db: &str, user: &str) -> bool {
3159 let result = PgPoolOptions::new()
3160 .max_connections(1)
3161 .connect_with(
3162 PgConnectOptions::new()
3163 .host(host)
3164 .port(port)
3165 .database(db)
3166 .username(user)
3167 .password("definitely-wrong-password")
3168 .ssl_mode(PgSslMode::Disable),
3169 )
3170 .await;
3171
3172 match result {
3173 Ok(pool) => {
3174 pool.close().await;
3175 false
3176 }
3177 Err(err) => {
3178 let msg = format!("{err:#}");
3179 assert!(
3180 msg.to_lowercase().contains("password")
3181 || msg.contains("28P01")
3182 || msg.contains("28000"),
3183 "expected authentication failure, got: {msg}"
3184 );
3185 true
3186 }
3187 }
3188 }
3189
3190 async fn drop_schema(pool: &PgPool, schema: &str) {
3193 let stmt = format!("DROP SCHEMA IF EXISTS \"{schema}\" CASCADE");
3194 if let Err(e) = sqlx::query(&stmt).execute(pool).await {
3195 eprintln!("warning: failed to drop schema {schema}: {e}");
3196 }
3197 }
3198
3199 #[tokio::test]
3200 async fn pipeline_with_injected_token_authenticates_and_runs_migrations() {
3201 let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
3202 return;
3203 };
3204
3205 let token_source: Arc<dyn TokenSource> =
3206 RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
3207 let schema = unique_schema();
3208
3209 let provider = PostgresProvider::new_with_entra_with_token_source(
3210 &host,
3211 port,
3212 &db,
3213 &user,
3214 Some(&schema),
3215 EntraAuthOptions::new(),
3216 token_source,
3217 PgSslMode::Disable,
3218 MigrationPolicy::ApplyAll,
3219 )
3220 .await
3221 .expect("Entra pipeline must succeed against local PG with correct token");
3222
3223 let row = sqlx::query(&format!(
3225 "SELECT to_regclass('{}.instances')::text AS r",
3226 schema
3227 ))
3228 .fetch_one(provider.pool())
3229 .await
3230 .expect("smoke query must succeed");
3231 let regclass: Option<String> = row.get("r");
3232 assert!(
3233 regclass.is_some(),
3234 "expected migrations to create {}.instances",
3235 schema
3236 );
3237
3238 drop_schema(provider.pool(), &schema).await;
3239 }
3240
3241 #[tokio::test]
3242 async fn pipeline_with_wrong_token_fails_before_migrations() {
3243 let Some((host, port, db, user, _password)) = pg_connection_or_skip() else {
3244 return;
3245 };
3246
3247 if !wrong_password_is_rejected(&host, port, &db, &user).await {
3248 eprintln!(
3249 "local PostgreSQL accepts wrong passwords; skipping negative Entra pipeline test"
3250 );
3251 return;
3252 }
3253
3254 let token_source: Arc<dyn TokenSource> =
3255 RecordingFakeTokenSource::with_tokens(vec![token("definitely-wrong-password", 3600)]);
3256 let schema = unique_schema();
3257
3258 let result = PostgresProvider::new_with_entra_with_token_source(
3259 &host,
3260 port,
3261 &db,
3262 &user,
3263 Some(&schema),
3264 EntraAuthOptions::new(),
3265 token_source,
3266 PgSslMode::Disable,
3267 MigrationPolicy::ApplyAll,
3268 )
3269 .await;
3270
3271 let err = match result {
3272 Ok(_) => panic!("wrong token must fail pool construction, but provider was built"),
3273 Err(e) => e,
3274 };
3275 let msg = format!("{err:#}");
3276 assert!(
3280 msg.to_lowercase().contains("password")
3281 || msg.contains("28P01")
3282 || msg.contains("28000"),
3283 "expected authentication failure, got: {msg}"
3284 );
3285 }
3286
3287 #[tokio::test]
3288 async fn pipeline_default_constructor_path_with_injected_token() {
3289 let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
3293 return;
3294 };
3295
3296 let schema = unique_schema();
3304 let token_source: Arc<dyn TokenSource> =
3305 RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
3306
3307 let provider = PostgresProvider::new_with_entra_with_token_source(
3308 &host,
3309 port,
3310 &db,
3311 &user,
3312 Some(&schema),
3313 EntraAuthOptions::new().refresh_interval(Duration::from_secs(60 * 60)),
3314 token_source,
3315 PgSslMode::Disable,
3316 MigrationPolicy::ApplyAll,
3317 )
3318 .await
3319 .expect("default-constructor variant must succeed");
3320 assert_eq!(provider.schema_name(), schema);
3321
3322 drop_schema(provider.pool(), &schema).await;
3323 }
3324}