1use anyhow::{Context, Result};
2use chrono::{TimeZone, Utc};
3use duroxide::providers::{
4 DeleteInstanceResult, DispatcherCapabilityFilter, ExecutionInfo, ExecutionMetadata,
5 InstanceFilter, InstanceInfo, OrchestrationItem, Provider, ProviderAdmin, ProviderError,
6 PruneOptions, PruneResult, QueueDepths, ScheduledActivityIdentifier, SessionFetchConfig,
7 SystemMetrics, TagFilter, WorkItem,
8};
9use duroxide::{Event, EventKind, SystemStats};
10use sqlx::postgres::{PgConnectOptions, PgSslMode};
11use sqlx::{postgres::PgPoolOptions, Error as SqlxError, PgPool};
12use std::sync::Arc;
13use std::time::Duration;
14use std::time::{SystemTime, UNIX_EPOCH};
15use tokio::task::AbortHandle;
16use tokio::time::sleep;
17use tracing::{debug, error, instrument, warn};
18
19use crate::entra::{EntraAuthOptions, TokenSource};
20use crate::migrations::MigrationRunner;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub(crate) enum SqlStateClass {
76 Retryable,
77 Permanent,
78}
79
80pub(crate) fn classify_pg_sqlstate(code: Option<&str>, is_entra: bool) -> SqlStateClass {
88 match code {
89 Some("40P01") => SqlStateClass::Retryable, Some("28000") | Some("28P01") if is_entra => SqlStateClass::Retryable, Some("40001") => SqlStateClass::Permanent, Some("23505") => SqlStateClass::Permanent, Some("23503") => SqlStateClass::Permanent, Some("0A000") => SqlStateClass::Retryable, _ => SqlStateClass::Permanent,
96 }
97}
98
99pub struct PostgresProvider {
100 pool: Arc<PgPool>,
101 schema_name: String,
102 is_entra: bool,
107 _refresh_task: Option<AbortOnDropHandle>,
108}
109
110struct AbortOnDropHandle(AbortHandle);
114
115impl Drop for AbortOnDropHandle {
116 fn drop(&mut self) {
117 self.0.abort();
118 }
119}
120
121impl PostgresProvider {
122 pub async fn new(database_url: &str) -> Result<Self> {
123 Self::new_with_schema(database_url, None).await
124 }
125
126 pub async fn new_with_schema(database_url: &str, schema_name: Option<&str>) -> Result<Self> {
127 let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
128 .ok()
129 .and_then(|s| s.parse::<u32>().ok())
130 .unwrap_or(10);
131
132 let pool = PgPoolOptions::new()
133 .max_connections(max_connections)
134 .min_connections(1)
135 .acquire_timeout(std::time::Duration::from_secs(30))
136 .connect(database_url)
137 .await?;
138
139 let schema_name = schema_name.unwrap_or("public").to_string();
140
141 let provider = Self {
142 pool: Arc::new(pool),
143 schema_name: schema_name.clone(),
144 is_entra: false,
145 _refresh_task: None,
146 };
147
148 let migration_runner = MigrationRunner::new(provider.pool.clone(), schema_name.clone());
150 migration_runner.migrate().await?;
151
152 Ok(provider)
153 }
154
155 pub async fn new_with_entra(
182 host: &str,
183 port: u16,
184 database: &str,
185 user: &str,
186 options: EntraAuthOptions,
187 ) -> Result<Self> {
188 Self::new_with_schema_and_entra(host, port, database, user, None, options).await
189 }
190
191 #[instrument(
194 skip(options),
195 fields(host = %host, port = %port, database = %database, user = %user, schema = ?schema_name),
196 target = "duroxide::providers::postgres",
197 )]
198 pub async fn new_with_schema_and_entra(
199 host: &str,
200 port: u16,
201 database: &str,
202 user: &str,
203 schema_name: Option<&str>,
204 options: EntraAuthOptions,
205 ) -> Result<Self> {
206 let token_source = options.default_token_source().context(
207 "Entra credential resolution failed: could not build the default credential chain",
208 )?;
209
210 Self::new_with_entra_with_token_source(
211 host,
212 port,
213 database,
214 user,
215 schema_name,
216 options,
217 token_source,
218 PgSslMode::VerifyFull,
219 )
220 .await
221 }
222
223 pub(crate) async fn new_with_entra_with_token_source(
233 host: &str,
234 port: u16,
235 database: &str,
236 user: &str,
237 schema_name: Option<&str>,
238 options: EntraAuthOptions,
239 token_source: Arc<dyn TokenSource>,
240 ssl_mode: PgSslMode,
241 ) -> Result<Self> {
242 let audience = options.audience_str().to_string();
243 let token = token_source
244 .fetch_token(&[audience.as_str()])
245 .await
246 .context(
247 "Entra credential resolution failed: could not acquire an initial access token",
248 )?;
249
250 let base_options = build_entra_connect_options(host, port, database, user, ssl_mode);
251
252 let pool = PgPoolOptions::new()
253 .max_connections(options.max_connections_value())
254 .min_connections(1)
255 .acquire_timeout(options.acquire_timeout_value())
256 .connect_with(base_options.clone().password(&token.secret))
257 .await?;
258
259 let pool = Arc::new(pool);
260 let schema_name = schema_name.unwrap_or("public").to_string();
261
262 let migration_runner = MigrationRunner::new(pool.clone(), schema_name.clone());
263 migration_runner.migrate().await?;
264
265 let refresh_handle = spawn_token_refresh_task(
266 pool.clone(),
267 token_source,
268 base_options,
269 audience,
270 options.refresh_interval_value(),
271 token.expires_at,
272 );
273
274 Ok(Self {
275 pool,
276 schema_name,
277 is_entra: true,
278 _refresh_task: Some(AbortOnDropHandle(refresh_handle)),
279 })
280 }
281
282 #[instrument(skip(self), target = "duroxide::providers::postgres")]
283 pub async fn initialize_schema(&self) -> Result<()> {
284 let migration_runner = MigrationRunner::new(self.pool.clone(), self.schema_name.clone());
287 migration_runner.migrate().await?;
288 Ok(())
289 }
290
291 fn now_millis() -> i64 {
293 SystemTime::now()
294 .duration_since(UNIX_EPOCH)
295 .unwrap()
296 .as_millis() as i64
297 }
298
299 fn table_name(&self, table: &str) -> String {
301 format!("{}.{}", self.schema_name, table)
302 }
303
304 pub fn pool(&self) -> &PgPool {
306 &self.pool
307 }
308
309 pub fn schema_name(&self) -> &str {
311 &self.schema_name
312 }
313
314 fn sqlx_to_provider_error(&self, operation: &str, e: SqlxError) -> ProviderError {
323 match e {
324 SqlxError::Database(ref db_err) => {
325 let code_opt = db_err.code();
326 let code = code_opt.as_deref();
327 match classify_pg_sqlstate(code, self.is_entra) {
328 SqlStateClass::Retryable => ProviderError::retryable(
329 operation,
330 match code {
331 Some("40P01") => format!("Deadlock detected: {e}"),
332 Some("28000") | Some("28P01") => {
333 format!("Authentication error (likely token rotation): {e}")
334 }
335 Some("0A000") => format!("Cached plan invalidated: {e}"),
336 _ => format!("Retryable database error: {e}"),
337 },
338 ),
339 SqlStateClass::Permanent => ProviderError::permanent(
340 operation,
341 match code {
342 Some("40001") => format!("Serialization failure: {e}"),
343 Some("23505") => format!("Duplicate detected: {e}"),
344 Some("23503") => format!("Foreign key violation: {e}"),
345 _ => format!("Database error: {e}"),
346 },
347 ),
348 }
349 }
350 SqlxError::PoolClosed | SqlxError::PoolTimedOut => {
351 ProviderError::retryable(operation, format!("Connection pool error: {e}"))
352 }
353 SqlxError::Io(_) => ProviderError::retryable(operation, format!("I/O error: {e}")),
354 _ => ProviderError::permanent(operation, format!("Unexpected error: {e}")),
355 }
356 }
357
358 fn tag_filter_to_sql(filter: &TagFilter) -> (&'static str, Vec<String>) {
360 match filter {
361 TagFilter::DefaultOnly => ("default_only", vec![]),
362 TagFilter::Tags(set) => {
363 let mut tags: Vec<String> = set.iter().cloned().collect();
364 tags.sort();
365 ("tags", tags)
366 }
367 TagFilter::DefaultAnd(set) => {
368 let mut tags: Vec<String> = set.iter().cloned().collect();
369 tags.sort();
370 ("default_and", tags)
371 }
372 TagFilter::Any => ("any", vec![]),
373 TagFilter::None => ("none", vec![]),
374 }
375 }
376
377 pub async fn cleanup_schema(&self) -> Result<()> {
382 const MAX_RETRIES: u32 = 5;
383 const BASE_RETRY_DELAY_MS: u64 = 50;
384
385 for attempt in 0..=MAX_RETRIES {
386 let cleanup_result = async {
387 sqlx::query(&format!("SELECT {}.cleanup_schema()", self.schema_name))
389 .execute(&*self.pool)
390 .await?;
391
392 if self.schema_name != "public" {
395 sqlx::query(&format!(
396 "DROP SCHEMA IF EXISTS {} CASCADE",
397 self.schema_name
398 ))
399 .execute(&*self.pool)
400 .await?;
401 } else {
402 }
405
406 Ok::<(), SqlxError>(())
407 }
408 .await;
409
410 match cleanup_result {
411 Ok(()) => return Ok(()),
412 Err(SqlxError::Database(db_err)) if db_err.code().as_deref() == Some("40P01") => {
413 if attempt < MAX_RETRIES {
414 warn!(
415 target = "duroxide::providers::postgres",
416 schema = %self.schema_name,
417 attempt = attempt + 1,
418 "Deadlock during cleanup_schema, retrying"
419 );
420 sleep(Duration::from_millis(
421 BASE_RETRY_DELAY_MS * (attempt as u64 + 1),
422 ))
423 .await;
424 continue;
425 }
426 return Err(anyhow::anyhow!(db_err.to_string()));
427 }
428 Err(e) => return Err(anyhow::anyhow!(e.to_string())),
429 }
430 }
431
432 Ok(())
433 }
434}
435
436pub(crate) fn build_entra_connect_options(
445 host: &str,
446 port: u16,
447 database: &str,
448 user: &str,
449 ssl_mode: PgSslMode,
450) -> PgConnectOptions {
451 PgConnectOptions::new()
452 .host(host)
453 .port(port)
454 .database(database)
455 .username(user)
456 .ssl_mode(ssl_mode)
457}
458
459const ENTRA_REFRESH_MIN_INTERVAL: Duration = Duration::from_secs(30);
462
463pub(crate) const ENTRA_REFRESH_SAFETY_MARGIN: Duration = Duration::from_secs(5 * 60);
466
467const ENTRA_PANIC_MSG_TRUNCATION_LIMIT: usize = 256;
472
473async fn run_with_panic_guard<Fut, T>(fut: Fut) -> Result<T, String>
485where
486 Fut: std::future::Future<Output = T>,
487{
488 use futures_util::FutureExt;
489 use std::panic::AssertUnwindSafe;
490
491 AssertUnwindSafe(fut).catch_unwind().await.map_err(|panic| {
492 let raw = if let Some(s) = panic.downcast_ref::<&'static str>() {
493 (*s).to_string()
494 } else if let Some(s) = panic.downcast_ref::<String>() {
495 s.clone()
496 } else {
497 "<non-string panic payload>".to_string()
498 };
499 truncate_panic_message(raw, ENTRA_PANIC_MSG_TRUNCATION_LIMIT)
500 })
501}
502
503fn truncate_panic_message(s: String, limit: usize) -> String {
507 if s.len() <= limit {
508 return s;
509 }
510 let mut cut = limit;
513 while cut > 0 && !s.is_char_boundary(cut) {
514 cut -= 1;
515 }
516 let mut out = String::with_capacity(cut + 16);
517 out.push_str(&s[..cut]);
518 out.push_str("…[truncated]");
519 out
520}
521
522fn spawn_token_refresh_task(
541 pool: Arc<PgPool>,
542 token_source: Arc<dyn TokenSource>,
543 base_options: PgConnectOptions,
544 audience: String,
545 refresh_interval_ceiling: Duration,
546 initial_expires_at: SystemTime,
547) -> AbortHandle {
548 let handle = tokio::spawn(async move {
549 let mut next_expires_at = initial_expires_at;
563 let mut sleep_duration = compute_next_refresh_sleep(
564 refresh_interval_ceiling,
565 next_expires_at,
566 SystemTime::now(),
567 );
568 loop {
569 debug!(
570 target: "duroxide::providers::postgres",
571 sleep_secs = sleep_duration.as_secs(),
572 "Entra refresh task sleeping",
573 );
574 sleep(sleep_duration).await;
575
576 let result = run_with_panic_guard(refresh_loop_iteration(
577 &pool,
578 token_source.as_ref(),
579 &base_options,
580 &audience,
581 &mut next_expires_at,
582 ))
583 .await;
584
585 if let Err(panic_msg) = &result {
586 error!(
587 target: "duroxide::providers::postgres",
588 panic = %panic_msg,
589 "Entra refresh task body panicked; continuing with bounded backoff",
590 );
591 }
592
593 sleep_duration = next_sleep_after_iteration(
594 &result,
595 refresh_interval_ceiling,
596 next_expires_at,
597 SystemTime::now(),
598 );
599 }
600 });
601 handle.abort_handle()
602}
603
604fn next_sleep_after_iteration(
617 result: &Result<Result<(), ()>, String>,
618 refresh_interval_ceiling: Duration,
619 next_expires_at: SystemTime,
620 now: SystemTime,
621) -> Duration {
622 match result {
623 Ok(Ok(())) => compute_next_refresh_sleep(refresh_interval_ceiling, next_expires_at, now),
624 Ok(Err(())) | Err(_) => ENTRA_REFRESH_MIN_INTERVAL,
625 }
626}
627
628async fn refresh_loop_iteration(
636 pool: &Arc<PgPool>,
637 token_source: &dyn TokenSource,
638 base_options: &PgConnectOptions,
639 audience: &str,
640 next_expires_at: &mut SystemTime,
641) -> Result<(), ()> {
642 match token_source.fetch_token(&[audience]).await {
643 Ok(token) => {
644 let new_options = base_options.clone().password(&token.secret);
645 pool.set_connect_options(new_options);
646 *next_expires_at = token.expires_at;
647 debug!(
648 target: "duroxide::providers::postgres",
649 "Entra token refreshed and applied to pool",
650 );
651 Ok(())
652 }
653 Err(e) => {
654 warn!(
655 target: "duroxide::providers::postgres",
656 error = %e,
657 "Entra token refresh failed; will retry after bounded backoff",
658 );
659 Err(())
660 }
661 }
662}
663
664fn compute_next_refresh_sleep(
671 ceiling: Duration,
672 expires_at: SystemTime,
673 now: SystemTime,
674) -> Duration {
675 let until_expiry = expires_at.duration_since(now).unwrap_or(Duration::ZERO);
676
677 let expiry_driven = until_expiry
678 .checked_sub(ENTRA_REFRESH_SAFETY_MARGIN)
679 .unwrap_or(Duration::ZERO);
680
681 let expiry_driven = expiry_driven.max(ENTRA_REFRESH_MIN_INTERVAL);
682
683 ceiling.min(expiry_driven).max(ENTRA_REFRESH_MIN_INTERVAL)
686}
687
688#[async_trait::async_trait]
689impl Provider for PostgresProvider {
690 fn name(&self) -> &str {
691 "duroxide-pg"
692 }
693
694 fn version(&self) -> &str {
695 env!("CARGO_PKG_VERSION")
696 }
697
698 #[instrument(skip(self), target = "duroxide::providers::postgres")]
699 async fn fetch_orchestration_item(
700 &self,
701 lock_timeout: Duration,
702 _poll_timeout: Duration,
703 filter: Option<&DispatcherCapabilityFilter>,
704 ) -> Result<Option<(OrchestrationItem, String, u32)>, ProviderError> {
705 let start = std::time::Instant::now();
706
707 const MAX_RETRIES: u32 = 3;
708 const RETRY_DELAY_MS: u64 = 50;
709
710 let lock_timeout_ms = lock_timeout.as_millis() as i64;
712 let mut _last_error: Option<ProviderError> = None;
713
714 let (min_packed, max_packed) = if let Some(f) = filter {
716 if let Some(range) = f.supported_duroxide_versions.first() {
717 let min = range.min.major as i64 * 1_000_000
718 + range.min.minor as i64 * 1_000
719 + range.min.patch as i64;
720 let max = range.max.major as i64 * 1_000_000
721 + range.max.minor as i64 * 1_000
722 + range.max.patch as i64;
723 (Some(min), Some(max))
724 } else {
725 return Ok(None);
727 }
728 } else {
729 (None, None)
730 };
731
732 for attempt in 0..=MAX_RETRIES {
733 let now_ms = Self::now_millis();
734
735 let result: Result<
736 Option<(
737 String,
738 String,
739 String,
740 i64,
741 serde_json::Value,
742 serde_json::Value,
743 String,
744 i32,
745 serde_json::Value,
746 )>,
747 SqlxError,
748 > = sqlx::query_as(&format!(
749 "SELECT * FROM {}.fetch_orchestration_item($1, $2, $3, $4)",
750 self.schema_name
751 ))
752 .bind(now_ms)
753 .bind(lock_timeout_ms)
754 .bind(min_packed)
755 .bind(max_packed)
756 .fetch_optional(&*self.pool)
757 .await;
758
759 let row = match result {
760 Ok(r) => r,
761 Err(e) => {
762 let provider_err = self.sqlx_to_provider_error("fetch_orchestration_item", e);
763 if provider_err.is_retryable() && attempt < MAX_RETRIES {
764 warn!(
765 target = "duroxide::providers::postgres",
766 operation = "fetch_orchestration_item",
767 attempt = attempt + 1,
768 error = %provider_err,
769 "Retryable error, will retry"
770 );
771 _last_error = Some(provider_err);
772 sleep(std::time::Duration::from_millis(
773 RETRY_DELAY_MS * (attempt as u64 + 1),
774 ))
775 .await;
776 continue;
777 }
778 return Err(provider_err);
779 }
780 };
781
782 if let Some((
783 instance_id,
784 orchestration_name,
785 orchestration_version,
786 execution_id,
787 history_json,
788 messages_json,
789 lock_token,
790 attempt_count,
791 kv_snapshot_json,
792 )) = row
793 {
794 let (history, history_error) =
795 match serde_json::from_value::<Vec<Event>>(history_json) {
796 Ok(h) => (h, None),
797 Err(e) => {
798 let error_msg = format!("Failed to deserialize history: {e}");
799 warn!(
800 target = "duroxide::providers::postgres",
801 instance = %instance_id,
802 error = %error_msg,
803 "History deserialization failed, returning item with history_error"
804 );
805 (vec![], Some(error_msg))
806 }
807 };
808
809 let messages: Vec<WorkItem> =
810 serde_json::from_value(messages_json).map_err(|e| {
811 ProviderError::permanent(
812 "fetch_orchestration_item",
813 format!("Failed to deserialize messages: {e}"),
814 )
815 })?;
816 let kv_snapshot: std::collections::HashMap<String, duroxide::providers::KvEntry> = {
817 let raw: std::collections::HashMap<String, serde_json::Value> =
818 serde_json::from_value(kv_snapshot_json).unwrap_or_default();
819 raw.into_iter()
820 .filter_map(|(k, v)| {
821 let value = v.get("value")?.as_str()?.to_string();
822 let last_updated_at_ms =
823 v.get("last_updated_at_ms")?.as_u64().unwrap_or(0);
824 Some((
825 k,
826 duroxide::providers::KvEntry {
827 value,
828 last_updated_at_ms,
829 },
830 ))
831 })
832 .collect()
833 };
834
835 let duration_ms = start.elapsed().as_millis() as u64;
836 debug!(
837 target = "duroxide::providers::postgres",
838 operation = "fetch_orchestration_item",
839 instance_id = %instance_id,
840 execution_id = execution_id,
841 message_count = messages.len(),
842 history_count = history.len(),
843 attempt_count = attempt_count,
844 duration_ms = duration_ms,
845 attempts = attempt + 1,
846 "Fetched orchestration item via stored procedure"
847 );
848
849 if orchestration_name == "Unknown"
855 && history.is_empty()
856 && messages
857 .iter()
858 .all(|m| matches!(m, WorkItem::QueueMessage { .. }))
859 {
860 let message_count = messages.len();
861 tracing::warn!(
862 target = "duroxide::providers::postgres",
863 instance = %instance_id,
864 message_count,
865 "Dropping orphan queue messages — events enqueued before orchestration started are not supported"
866 );
867 self.ack_orchestration_item(
868 &lock_token,
869 execution_id as u64,
870 vec![],
871 vec![],
872 vec![],
873 ExecutionMetadata::default(),
874 vec![],
875 )
876 .await?;
877 return Ok(None);
878 }
879
880 return Ok(Some((
881 OrchestrationItem {
882 instance: instance_id,
883 orchestration_name,
884 execution_id: execution_id as u64,
885 version: orchestration_version,
886 history,
887 messages,
888 history_error,
889 kv_snapshot,
890 },
891 lock_token,
892 attempt_count as u32,
893 )));
894 }
895
896 return Ok(None);
899 }
900
901 Ok(None)
902 }
903 #[instrument(skip(self), fields(lock_token = %lock_token, execution_id = execution_id), target = "duroxide::providers::postgres")]
904 async fn ack_orchestration_item(
905 &self,
906 lock_token: &str,
907 execution_id: u64,
908 history_delta: Vec<Event>,
909 worker_items: Vec<WorkItem>,
910 orchestrator_items: Vec<WorkItem>,
911 metadata: ExecutionMetadata,
912 cancelled_activities: Vec<ScheduledActivityIdentifier>,
913 ) -> Result<(), ProviderError> {
914 let start = std::time::Instant::now();
915
916 const MAX_RETRIES: u32 = 3;
917 const RETRY_DELAY_MS: u64 = 50;
918
919 let mut history_delta_payload = Vec::with_capacity(history_delta.len());
920 for event in &history_delta {
921 if event.event_id() == 0 {
922 return Err(ProviderError::permanent(
923 "ack_orchestration_item",
924 "event_id must be set by runtime",
925 ));
926 }
927
928 let event_json = serde_json::to_string(event).map_err(|e| {
929 ProviderError::permanent(
930 "ack_orchestration_item",
931 format!("Failed to serialize event: {e}"),
932 )
933 })?;
934
935 let event_type = format!("{event:?}")
936 .split('{')
937 .next()
938 .unwrap_or("Unknown")
939 .trim()
940 .to_string();
941
942 history_delta_payload.push(serde_json::json!({
943 "event_id": event.event_id(),
944 "event_type": event_type,
945 "event_data": event_json,
946 }));
947 }
948
949 let history_delta_json = serde_json::Value::Array(history_delta_payload);
950
951 let worker_items_json = serde_json::to_value(&worker_items).map_err(|e| {
952 ProviderError::permanent(
953 "ack_orchestration_item",
954 format!("Failed to serialize worker items: {e}"),
955 )
956 })?;
957
958 let orchestrator_items_json = serde_json::to_value(&orchestrator_items).map_err(|e| {
959 ProviderError::permanent(
960 "ack_orchestration_item",
961 format!("Failed to serialize orchestrator items: {e}"),
962 )
963 })?;
964
965 let (custom_status_action, custom_status_value): (Option<&str>, Option<&str>) = {
967 let mut last_status: Option<&Option<String>> = None;
968 for event in &history_delta {
969 if let EventKind::CustomStatusUpdated { ref status } = event.kind {
970 last_status = Some(status);
971 }
972 }
973 match last_status {
974 Some(Some(s)) => (Some("set"), Some(s.as_str())),
975 Some(None) => (Some("clear"), None),
976 None => (None, None),
977 }
978 };
979
980 let kv_mutations: Vec<serde_json::Value> = history_delta
981 .iter()
982 .filter_map(|event| match &event.kind {
983 EventKind::KeyValueSet {
984 key,
985 value,
986 last_updated_at_ms,
987 } => Some(serde_json::json!({
988 "action": "set",
989 "key": key,
990 "value": value,
991 "last_updated_at_ms": last_updated_at_ms,
992 })),
993 EventKind::KeyValueCleared { key } => Some(serde_json::json!({
994 "action": "clear_key",
995 "key": key,
996 })),
997 EventKind::KeyValuesCleared => Some(serde_json::json!({
998 "action": "clear_all",
999 })),
1000 _ => None,
1001 })
1002 .collect();
1003
1004 let metadata_json = serde_json::json!({
1005 "orchestration_name": metadata.orchestration_name,
1006 "orchestration_version": metadata.orchestration_version,
1007 "status": metadata.status,
1008 "output": metadata.output,
1009 "parent_instance_id": metadata.parent_instance_id,
1010 "pinned_duroxide_version": metadata.pinned_duroxide_version.as_ref().map(|v| {
1011 serde_json::json!({
1012 "major": v.major,
1013 "minor": v.minor,
1014 "patch": v.patch,
1015 })
1016 }),
1017 "custom_status_action": custom_status_action,
1018 "custom_status_value": custom_status_value,
1019 "kv_mutations": kv_mutations,
1020 });
1021
1022 let cancelled_activities_json: Vec<serde_json::Value> = cancelled_activities
1024 .iter()
1025 .map(|a| {
1026 serde_json::json!({
1027 "instance": a.instance,
1028 "execution_id": a.execution_id,
1029 "activity_id": a.activity_id,
1030 })
1031 })
1032 .collect();
1033 let cancelled_activities_json = serde_json::Value::Array(cancelled_activities_json);
1034
1035 for attempt in 0..=MAX_RETRIES {
1036 let now_ms = Self::now_millis();
1037 let result = sqlx::query(&format!(
1038 "SELECT {}.ack_orchestration_item($1, $2, $3, $4, $5, $6, $7, $8)",
1039 self.schema_name
1040 ))
1041 .bind(lock_token)
1042 .bind(now_ms)
1043 .bind(execution_id as i64)
1044 .bind(&history_delta_json)
1045 .bind(&worker_items_json)
1046 .bind(&orchestrator_items_json)
1047 .bind(&metadata_json)
1048 .bind(&cancelled_activities_json)
1049 .execute(&*self.pool)
1050 .await;
1051
1052 match result {
1053 Ok(_) => {
1054 let duration_ms = start.elapsed().as_millis() as u64;
1055 debug!(
1056 target = "duroxide::providers::postgres",
1057 operation = "ack_orchestration_item",
1058 execution_id = execution_id,
1059 history_count = history_delta.len(),
1060 worker_items_count = worker_items.len(),
1061 orchestrator_items_count = orchestrator_items.len(),
1062 cancelled_activities_count = cancelled_activities.len(),
1063 duration_ms = duration_ms,
1064 attempts = attempt + 1,
1065 "Acknowledged orchestration item via stored procedure"
1066 );
1067 return Ok(());
1068 }
1069 Err(e) => {
1070 if let SqlxError::Database(db_err) = &e {
1072 if db_err.message().contains("Invalid lock token") {
1073 return Err(ProviderError::permanent(
1074 "ack_orchestration_item",
1075 "Invalid lock token",
1076 ));
1077 }
1078 } else if e.to_string().contains("Invalid lock token") {
1079 return Err(ProviderError::permanent(
1080 "ack_orchestration_item",
1081 "Invalid lock token",
1082 ));
1083 }
1084
1085 let provider_err = self.sqlx_to_provider_error("ack_orchestration_item", e);
1086 if provider_err.is_retryable() && attempt < MAX_RETRIES {
1087 warn!(
1088 target = "duroxide::providers::postgres",
1089 operation = "ack_orchestration_item",
1090 attempt = attempt + 1,
1091 error = %provider_err,
1092 "Retryable error, will retry"
1093 );
1094 sleep(std::time::Duration::from_millis(
1095 RETRY_DELAY_MS * (attempt as u64 + 1),
1096 ))
1097 .await;
1098 continue;
1099 }
1100 return Err(provider_err);
1101 }
1102 }
1103 }
1104
1105 Ok(())
1107 }
1108 #[instrument(skip(self), fields(lock_token = %lock_token), target = "duroxide::providers::postgres")]
1109 async fn abandon_orchestration_item(
1110 &self,
1111 lock_token: &str,
1112 delay: Option<Duration>,
1113 ignore_attempt: bool,
1114 ) -> Result<(), ProviderError> {
1115 let start = std::time::Instant::now();
1116 let now_ms = Self::now_millis();
1117 let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
1118
1119 let instance_id = match sqlx::query_scalar::<_, String>(&format!(
1120 "SELECT {}.abandon_orchestration_item($1, $2, $3, $4)",
1121 self.schema_name
1122 ))
1123 .bind(lock_token)
1124 .bind(now_ms)
1125 .bind(delay_param)
1126 .bind(ignore_attempt)
1127 .fetch_one(&*self.pool)
1128 .await
1129 {
1130 Ok(instance_id) => instance_id,
1131 Err(e) => {
1132 if let SqlxError::Database(db_err) = &e {
1133 if db_err.message().contains("Invalid lock token") {
1134 return Err(ProviderError::permanent(
1135 "abandon_orchestration_item",
1136 "Invalid lock token",
1137 ));
1138 }
1139 } else if e.to_string().contains("Invalid lock token") {
1140 return Err(ProviderError::permanent(
1141 "abandon_orchestration_item",
1142 "Invalid lock token",
1143 ));
1144 }
1145
1146 return Err(self.sqlx_to_provider_error("abandon_orchestration_item", e));
1147 }
1148 };
1149
1150 let duration_ms = start.elapsed().as_millis() as u64;
1151 debug!(
1152 target = "duroxide::providers::postgres",
1153 operation = "abandon_orchestration_item",
1154 instance_id = %instance_id,
1155 delay_ms = delay.map(|d| d.as_millis() as u64),
1156 ignore_attempt = ignore_attempt,
1157 duration_ms = duration_ms,
1158 "Abandoned orchestration item via stored procedure"
1159 );
1160
1161 Ok(())
1162 }
1163
1164 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1165 async fn read(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
1166 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
1167 "SELECT out_event_data FROM {}.fetch_history($1)",
1168 self.schema_name
1169 ))
1170 .bind(instance)
1171 .fetch_all(&*self.pool)
1172 .await
1173 .map_err(|e| self.sqlx_to_provider_error("read", e))?;
1174
1175 event_data_rows
1176 .into_iter()
1177 .map(|event_data| {
1178 serde_json::from_str::<Event>(&event_data).map_err(|e| {
1179 ProviderError::permanent("read", format!("Failed to deserialize event: {e}"))
1180 })
1181 })
1182 .collect()
1183 }
1184
1185 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
1186 async fn append_with_execution(
1187 &self,
1188 instance: &str,
1189 execution_id: u64,
1190 new_events: Vec<Event>,
1191 ) -> Result<(), ProviderError> {
1192 if new_events.is_empty() {
1193 return Ok(());
1194 }
1195
1196 let mut events_payload = Vec::with_capacity(new_events.len());
1197 for event in &new_events {
1198 if event.event_id() == 0 {
1199 error!(
1200 target = "duroxide::providers::postgres",
1201 operation = "append_with_execution",
1202 error_type = "validation_error",
1203 instance_id = %instance,
1204 execution_id = execution_id,
1205 "event_id must be set by runtime"
1206 );
1207 return Err(ProviderError::permanent(
1208 "append_with_execution",
1209 "event_id must be set by runtime",
1210 ));
1211 }
1212
1213 let event_json = serde_json::to_string(event).map_err(|e| {
1214 ProviderError::permanent(
1215 "append_with_execution",
1216 format!("Failed to serialize event: {e}"),
1217 )
1218 })?;
1219
1220 let event_type = format!("{event:?}")
1221 .split('{')
1222 .next()
1223 .unwrap_or("Unknown")
1224 .trim()
1225 .to_string();
1226
1227 events_payload.push(serde_json::json!({
1228 "event_id": event.event_id(),
1229 "event_type": event_type,
1230 "event_data": event_json,
1231 }));
1232 }
1233
1234 let events_json = serde_json::Value::Array(events_payload);
1235
1236 sqlx::query(&format!(
1237 "SELECT {}.append_history($1, $2, $3)",
1238 self.schema_name
1239 ))
1240 .bind(instance)
1241 .bind(execution_id as i64)
1242 .bind(events_json)
1243 .execute(&*self.pool)
1244 .await
1245 .map_err(|e| self.sqlx_to_provider_error("append_with_execution", e))?;
1246
1247 debug!(
1248 target = "duroxide::providers::postgres",
1249 operation = "append_with_execution",
1250 instance_id = %instance,
1251 execution_id = execution_id,
1252 event_count = new_events.len(),
1253 "Appended history events via stored procedure"
1254 );
1255
1256 Ok(())
1257 }
1258
1259 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1260 async fn enqueue_for_worker(&self, item: WorkItem) -> Result<(), ProviderError> {
1261 let work_item = serde_json::to_string(&item).map_err(|e| {
1262 ProviderError::permanent(
1263 "enqueue_worker_work",
1264 format!("Failed to serialize work item: {e}"),
1265 )
1266 })?;
1267
1268 let now_ms = Self::now_millis();
1269
1270 let (instance_id, execution_id, activity_id, session_id, tag) = match &item {
1272 WorkItem::ActivityExecute {
1273 instance,
1274 execution_id,
1275 id,
1276 session_id,
1277 tag,
1278 ..
1279 } => (
1280 Some(instance.clone()),
1281 Some(*execution_id as i64),
1282 Some(*id as i64),
1283 session_id.clone(),
1284 tag.clone(),
1285 ),
1286 _ => (None, None, None, None, None),
1287 };
1288
1289 sqlx::query(&format!(
1290 "SELECT {}.enqueue_worker_work($1, $2, $3, $4, $5, $6, $7)",
1291 self.schema_name
1292 ))
1293 .bind(work_item)
1294 .bind(now_ms)
1295 .bind(&instance_id)
1296 .bind(execution_id)
1297 .bind(activity_id)
1298 .bind(&session_id)
1299 .bind(&tag)
1300 .execute(&*self.pool)
1301 .await
1302 .map_err(|e| {
1303 error!(
1304 target = "duroxide::providers::postgres",
1305 operation = "enqueue_worker_work",
1306 error_type = "database_error",
1307 error = %e,
1308 "Failed to enqueue worker work"
1309 );
1310 self.sqlx_to_provider_error("enqueue_worker_work", e)
1311 })?;
1312
1313 Ok(())
1314 }
1315
1316 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1317 async fn fetch_work_item(
1318 &self,
1319 lock_timeout: Duration,
1320 _poll_timeout: Duration,
1321 session: Option<&SessionFetchConfig>,
1322 tag_filter: &TagFilter,
1323 ) -> Result<Option<(WorkItem, String, u32)>, ProviderError> {
1324 if matches!(tag_filter, TagFilter::None) {
1326 return Ok(None);
1327 }
1328
1329 let start = std::time::Instant::now();
1330
1331 let lock_timeout_ms = lock_timeout.as_millis() as i64;
1333
1334 let (owner_id, session_lock_timeout_ms): (Option<&str>, Option<i64>) = match session {
1336 Some(config) => (
1337 Some(&config.owner_id),
1338 Some(config.lock_timeout.as_millis() as i64),
1339 ),
1340 None => (None, None),
1341 };
1342
1343 let (tag_mode, tag_names) = Self::tag_filter_to_sql(tag_filter);
1345
1346 let row = match sqlx::query_as::<_, (String, String, i32)>(&format!(
1347 "SELECT * FROM {}.fetch_work_item($1, $2, $3, $4, $5, $6)",
1348 self.schema_name
1349 ))
1350 .bind(Self::now_millis())
1351 .bind(lock_timeout_ms)
1352 .bind(owner_id)
1353 .bind(session_lock_timeout_ms)
1354 .bind(&tag_names)
1355 .bind(tag_mode)
1356 .fetch_optional(&*self.pool)
1357 .await
1358 {
1359 Ok(row) => row,
1360 Err(e) => {
1361 return Err(self.sqlx_to_provider_error("fetch_work_item", e));
1362 }
1363 };
1364
1365 let (work_item_json, lock_token, attempt_count) = match row {
1366 Some(row) => row,
1367 None => return Ok(None),
1368 };
1369
1370 let work_item: WorkItem = serde_json::from_str(&work_item_json).map_err(|e| {
1371 ProviderError::permanent(
1372 "fetch_work_item",
1373 format!("Failed to deserialize worker item: {e}"),
1374 )
1375 })?;
1376
1377 let duration_ms = start.elapsed().as_millis() as u64;
1378
1379 let instance_id = match &work_item {
1381 WorkItem::ActivityExecute { instance, .. } => instance.as_str(),
1382 WorkItem::ActivityCompleted { instance, .. } => instance.as_str(),
1383 WorkItem::ActivityFailed { instance, .. } => instance.as_str(),
1384 WorkItem::StartOrchestration { instance, .. } => instance.as_str(),
1385 WorkItem::TimerFired { instance, .. } => instance.as_str(),
1386 WorkItem::ExternalRaised { instance, .. } => instance.as_str(),
1387 WorkItem::CancelInstance { instance, .. } => instance.as_str(),
1388 WorkItem::ContinueAsNew { instance, .. } => instance.as_str(),
1389 WorkItem::SubOrchCompleted {
1390 parent_instance, ..
1391 } => parent_instance.as_str(),
1392 WorkItem::SubOrchFailed {
1393 parent_instance, ..
1394 } => parent_instance.as_str(),
1395 WorkItem::QueueMessage { instance, .. } => instance.as_str(),
1396 };
1397
1398 debug!(
1399 target = "duroxide::providers::postgres",
1400 operation = "fetch_work_item",
1401 instance_id = %instance_id,
1402 attempt_count = attempt_count,
1403 duration_ms = duration_ms,
1404 "Fetched activity work item via stored procedure"
1405 );
1406
1407 Ok(Some((work_item, lock_token, attempt_count as u32)))
1408 }
1409
1410 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1411 async fn ack_work_item(
1412 &self,
1413 token: &str,
1414 completion: Option<WorkItem>,
1415 ) -> Result<(), ProviderError> {
1416 let start = std::time::Instant::now();
1417
1418 let Some(completion) = completion else {
1420 let now_ms = Self::now_millis();
1421 sqlx::query(&format!(
1423 "SELECT {}.ack_worker($1, NULL, NULL, $2)",
1424 self.schema_name
1425 ))
1426 .bind(token)
1427 .bind(now_ms)
1428 .execute(&*self.pool)
1429 .await
1430 .map_err(|e| {
1431 if e.to_string().contains("Worker queue item not found") {
1432 ProviderError::permanent(
1433 "ack_worker",
1434 "Worker queue item not found or already processed",
1435 )
1436 } else {
1437 self.sqlx_to_provider_error("ack_worker", e)
1438 }
1439 })?;
1440
1441 let duration_ms = start.elapsed().as_millis() as u64;
1442 debug!(
1443 target = "duroxide::providers::postgres",
1444 operation = "ack_worker",
1445 token = %token,
1446 duration_ms = duration_ms,
1447 "Acknowledged worker without completion (cancelled)"
1448 );
1449 return Ok(());
1450 };
1451
1452 let instance_id = match &completion {
1454 WorkItem::ActivityCompleted { instance, .. }
1455 | WorkItem::ActivityFailed { instance, .. } => instance,
1456 _ => {
1457 error!(
1458 target = "duroxide::providers::postgres",
1459 operation = "ack_worker",
1460 error_type = "invalid_completion_type",
1461 "Invalid completion work item type"
1462 );
1463 return Err(ProviderError::permanent(
1464 "ack_worker",
1465 "Invalid completion work item type",
1466 ));
1467 }
1468 };
1469
1470 let completion_json = serde_json::to_string(&completion).map_err(|e| {
1471 ProviderError::permanent("ack_worker", format!("Failed to serialize completion: {e}"))
1472 })?;
1473
1474 let now_ms = Self::now_millis();
1475
1476 sqlx::query(&format!(
1478 "SELECT {}.ack_worker($1, $2, $3, $4)",
1479 self.schema_name
1480 ))
1481 .bind(token)
1482 .bind(instance_id)
1483 .bind(completion_json)
1484 .bind(now_ms)
1485 .execute(&*self.pool)
1486 .await
1487 .map_err(|e| {
1488 if e.to_string().contains("Worker queue item not found") {
1489 error!(
1490 target = "duroxide::providers::postgres",
1491 operation = "ack_worker",
1492 error_type = "worker_item_not_found",
1493 token = %token,
1494 "Worker queue item not found or already processed"
1495 );
1496 ProviderError::permanent(
1497 "ack_worker",
1498 "Worker queue item not found or already processed",
1499 )
1500 } else {
1501 self.sqlx_to_provider_error("ack_worker", e)
1502 }
1503 })?;
1504
1505 let duration_ms = start.elapsed().as_millis() as u64;
1506 debug!(
1507 target = "duroxide::providers::postgres",
1508 operation = "ack_worker",
1509 instance_id = %instance_id,
1510 duration_ms = duration_ms,
1511 "Acknowledged worker and enqueued completion"
1512 );
1513
1514 Ok(())
1515 }
1516
1517 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1518 async fn renew_work_item_lock(
1519 &self,
1520 token: &str,
1521 extend_for: Duration,
1522 ) -> Result<(), ProviderError> {
1523 let start = std::time::Instant::now();
1524
1525 let now_ms = Self::now_millis();
1527
1528 let extend_secs = extend_for.as_secs() as i64;
1530
1531 match sqlx::query(&format!(
1532 "SELECT {}.renew_work_item_lock($1, $2, $3)",
1533 self.schema_name
1534 ))
1535 .bind(token)
1536 .bind(now_ms)
1537 .bind(extend_secs)
1538 .execute(&*self.pool)
1539 .await
1540 {
1541 Ok(_) => {
1542 let duration_ms = start.elapsed().as_millis() as u64;
1543 debug!(
1544 target = "duroxide::providers::postgres",
1545 operation = "renew_work_item_lock",
1546 token = %token,
1547 extend_for_secs = extend_secs,
1548 duration_ms = duration_ms,
1549 "Work item lock renewed successfully"
1550 );
1551 Ok(())
1552 }
1553 Err(e) => {
1554 if let SqlxError::Database(db_err) = &e {
1555 if db_err.message().contains("Lock token invalid") {
1556 return Err(ProviderError::permanent(
1557 "renew_work_item_lock",
1558 "Lock token invalid, expired, or already acked",
1559 ));
1560 }
1561 } else if e.to_string().contains("Lock token invalid") {
1562 return Err(ProviderError::permanent(
1563 "renew_work_item_lock",
1564 "Lock token invalid, expired, or already acked",
1565 ));
1566 }
1567
1568 Err(self.sqlx_to_provider_error("renew_work_item_lock", e))
1569 }
1570 }
1571 }
1572
1573 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1574 async fn abandon_work_item(
1575 &self,
1576 token: &str,
1577 delay: Option<Duration>,
1578 ignore_attempt: bool,
1579 ) -> Result<(), ProviderError> {
1580 let start = std::time::Instant::now();
1581 let now_ms = Self::now_millis();
1582 let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
1583
1584 match sqlx::query(&format!(
1585 "SELECT {}.abandon_work_item($1, $2, $3, $4)",
1586 self.schema_name
1587 ))
1588 .bind(token)
1589 .bind(now_ms)
1590 .bind(delay_param)
1591 .bind(ignore_attempt)
1592 .execute(&*self.pool)
1593 .await
1594 {
1595 Ok(_) => {
1596 let duration_ms = start.elapsed().as_millis() as u64;
1597 debug!(
1598 target = "duroxide::providers::postgres",
1599 operation = "abandon_work_item",
1600 token = %token,
1601 delay_ms = delay.map(|d| d.as_millis() as u64),
1602 ignore_attempt = ignore_attempt,
1603 duration_ms = duration_ms,
1604 "Abandoned work item via stored procedure"
1605 );
1606 Ok(())
1607 }
1608 Err(e) => {
1609 if let SqlxError::Database(db_err) = &e {
1610 if db_err.message().contains("Invalid lock token")
1611 || db_err.message().contains("already acked")
1612 {
1613 return Err(ProviderError::permanent(
1614 "abandon_work_item",
1615 "Invalid lock token or already acked",
1616 ));
1617 }
1618 } else if e.to_string().contains("Invalid lock token")
1619 || e.to_string().contains("already acked")
1620 {
1621 return Err(ProviderError::permanent(
1622 "abandon_work_item",
1623 "Invalid lock token or already acked",
1624 ));
1625 }
1626
1627 Err(self.sqlx_to_provider_error("abandon_work_item", e))
1628 }
1629 }
1630 }
1631
1632 #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
1633 async fn renew_orchestration_item_lock(
1634 &self,
1635 token: &str,
1636 extend_for: Duration,
1637 ) -> Result<(), ProviderError> {
1638 let start = std::time::Instant::now();
1639
1640 let now_ms = Self::now_millis();
1642
1643 let extend_secs = extend_for.as_secs() as i64;
1645
1646 match sqlx::query(&format!(
1647 "SELECT {}.renew_orchestration_item_lock($1, $2, $3)",
1648 self.schema_name
1649 ))
1650 .bind(token)
1651 .bind(now_ms)
1652 .bind(extend_secs)
1653 .execute(&*self.pool)
1654 .await
1655 {
1656 Ok(_) => {
1657 let duration_ms = start.elapsed().as_millis() as u64;
1658 debug!(
1659 target = "duroxide::providers::postgres",
1660 operation = "renew_orchestration_item_lock",
1661 token = %token,
1662 extend_for_secs = extend_secs,
1663 duration_ms = duration_ms,
1664 "Orchestration item lock renewed successfully"
1665 );
1666 Ok(())
1667 }
1668 Err(e) => {
1669 if let SqlxError::Database(db_err) = &e {
1670 if db_err.message().contains("Lock token invalid")
1671 || db_err.message().contains("expired")
1672 || db_err.message().contains("already released")
1673 {
1674 return Err(ProviderError::permanent(
1675 "renew_orchestration_item_lock",
1676 "Lock token invalid, expired, or already released",
1677 ));
1678 }
1679 } else if e.to_string().contains("Lock token invalid")
1680 || e.to_string().contains("expired")
1681 || e.to_string().contains("already released")
1682 {
1683 return Err(ProviderError::permanent(
1684 "renew_orchestration_item_lock",
1685 "Lock token invalid, expired, or already released",
1686 ));
1687 }
1688
1689 Err(self.sqlx_to_provider_error("renew_orchestration_item_lock", e))
1690 }
1691 }
1692 }
1693
1694 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1695 async fn enqueue_for_orchestrator(
1696 &self,
1697 item: WorkItem,
1698 delay: Option<Duration>,
1699 ) -> Result<(), ProviderError> {
1700 let work_item = serde_json::to_string(&item).map_err(|e| {
1701 ProviderError::permanent(
1702 "enqueue_orchestrator_work",
1703 format!("Failed to serialize work item: {e}"),
1704 )
1705 })?;
1706
1707 let instance_id = match &item {
1709 WorkItem::StartOrchestration { instance, .. }
1710 | WorkItem::ActivityCompleted { instance, .. }
1711 | WorkItem::ActivityFailed { instance, .. }
1712 | WorkItem::TimerFired { instance, .. }
1713 | WorkItem::ExternalRaised { instance, .. }
1714 | WorkItem::CancelInstance { instance, .. }
1715 | WorkItem::ContinueAsNew { instance, .. }
1716 | WorkItem::QueueMessage { instance, .. } => instance,
1717 WorkItem::SubOrchCompleted {
1718 parent_instance, ..
1719 }
1720 | WorkItem::SubOrchFailed {
1721 parent_instance, ..
1722 } => parent_instance,
1723 WorkItem::ActivityExecute { .. } => {
1724 return Err(ProviderError::permanent(
1725 "enqueue_orchestrator_work",
1726 "ActivityExecute should go to worker queue, not orchestrator queue",
1727 ));
1728 }
1729 };
1730
1731 let now_ms = Self::now_millis();
1733
1734 let visible_at_ms = if let WorkItem::TimerFired { fire_at_ms, .. } = &item {
1735 if *fire_at_ms > 0 {
1736 if let Some(delay) = delay {
1738 std::cmp::max(*fire_at_ms, now_ms as u64 + delay.as_millis() as u64)
1739 } else {
1740 *fire_at_ms
1741 }
1742 } else {
1743 delay
1745 .map(|d| now_ms as u64 + d.as_millis() as u64)
1746 .unwrap_or(now_ms as u64)
1747 }
1748 } else {
1749 delay
1751 .map(|d| now_ms as u64 + d.as_millis() as u64)
1752 .unwrap_or(now_ms as u64)
1753 };
1754
1755 let visible_at = Utc
1756 .timestamp_millis_opt(visible_at_ms as i64)
1757 .single()
1758 .ok_or_else(|| {
1759 ProviderError::permanent(
1760 "enqueue_orchestrator_work",
1761 "Invalid visible_at timestamp",
1762 )
1763 })?;
1764
1765 sqlx::query(&format!(
1770 "SELECT {}.enqueue_orchestrator_work($1, $2, $3, $4, $5, $6)",
1771 self.schema_name
1772 ))
1773 .bind(instance_id)
1774 .bind(&work_item)
1775 .bind(visible_at)
1776 .bind::<Option<String>>(None) .bind::<Option<String>>(None) .bind::<Option<i64>>(None) .execute(&*self.pool)
1780 .await
1781 .map_err(|e| {
1782 error!(
1783 target = "duroxide::providers::postgres",
1784 operation = "enqueue_orchestrator_work",
1785 error_type = "database_error",
1786 error = %e,
1787 instance_id = %instance_id,
1788 "Failed to enqueue orchestrator work"
1789 );
1790 self.sqlx_to_provider_error("enqueue_orchestrator_work", e)
1791 })?;
1792
1793 debug!(
1794 target = "duroxide::providers::postgres",
1795 operation = "enqueue_orchestrator_work",
1796 instance_id = %instance_id,
1797 delay_ms = delay.map(|d| d.as_millis() as u64),
1798 "Enqueued orchestrator work"
1799 );
1800
1801 Ok(())
1802 }
1803
1804 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1805 async fn read_with_execution(
1806 &self,
1807 instance: &str,
1808 execution_id: u64,
1809 ) -> Result<Vec<Event>, ProviderError> {
1810 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
1811 "SELECT event_data FROM {} WHERE instance_id = $1 AND execution_id = $2 ORDER BY event_id",
1812 self.table_name("history")
1813 ))
1814 .bind(instance)
1815 .bind(execution_id as i64)
1816 .fetch_all(&*self.pool)
1817 .await
1818 .map_err(|e| self.sqlx_to_provider_error("read_with_execution", e))?;
1819
1820 event_data_rows
1821 .into_iter()
1822 .map(|event_data| {
1823 serde_json::from_str::<Event>(&event_data).map_err(|e| {
1824 ProviderError::permanent(
1825 "read_with_execution",
1826 format!("Failed to deserialize event: {e}"),
1827 )
1828 })
1829 })
1830 .collect()
1831 }
1832
1833 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1834 async fn renew_session_lock(
1835 &self,
1836 owner_ids: &[&str],
1837 extend_for: Duration,
1838 idle_timeout: Duration,
1839 ) -> Result<usize, ProviderError> {
1840 if owner_ids.is_empty() {
1841 return Ok(0);
1842 }
1843
1844 let now_ms = Self::now_millis();
1845 let extend_ms = extend_for.as_millis() as i64;
1846 let idle_timeout_ms = idle_timeout.as_millis() as i64;
1847 let owner_ids_vec: Vec<&str> = owner_ids.to_vec();
1848
1849 let result = sqlx::query_scalar::<_, i64>(&format!(
1850 "SELECT {}.renew_session_lock($1, $2, $3, $4)",
1851 self.schema_name
1852 ))
1853 .bind(&owner_ids_vec)
1854 .bind(now_ms)
1855 .bind(extend_ms)
1856 .bind(idle_timeout_ms)
1857 .fetch_one(&*self.pool)
1858 .await
1859 .map_err(|e| self.sqlx_to_provider_error("renew_session_lock", e))?;
1860
1861 debug!(
1862 target = "duroxide::providers::postgres",
1863 operation = "renew_session_lock",
1864 owner_count = owner_ids.len(),
1865 sessions_renewed = result,
1866 "Session locks renewed"
1867 );
1868
1869 Ok(result as usize)
1870 }
1871
1872 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1873 async fn cleanup_orphaned_sessions(
1874 &self,
1875 _idle_timeout: Duration,
1876 ) -> Result<usize, ProviderError> {
1877 let now_ms = Self::now_millis();
1878
1879 let result = sqlx::query_scalar::<_, i64>(&format!(
1880 "SELECT {}.cleanup_orphaned_sessions($1)",
1881 self.schema_name
1882 ))
1883 .bind(now_ms)
1884 .fetch_one(&*self.pool)
1885 .await
1886 .map_err(|e| self.sqlx_to_provider_error("cleanup_orphaned_sessions", e))?;
1887
1888 debug!(
1889 target = "duroxide::providers::postgres",
1890 operation = "cleanup_orphaned_sessions",
1891 sessions_cleaned = result,
1892 "Orphaned sessions cleaned up"
1893 );
1894
1895 Ok(result as usize)
1896 }
1897
1898 fn as_management_capability(&self) -> Option<&dyn ProviderAdmin> {
1899 Some(self)
1900 }
1901
1902 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1903 async fn get_custom_status(
1904 &self,
1905 instance: &str,
1906 last_seen_version: u64,
1907 ) -> Result<Option<(Option<String>, u64)>, ProviderError> {
1908 let row = sqlx::query_as::<_, (Option<String>, i64)>(&format!(
1909 "SELECT * FROM {}.get_custom_status($1, $2)",
1910 self.schema_name
1911 ))
1912 .bind(instance)
1913 .bind(last_seen_version as i64)
1914 .fetch_optional(&*self.pool)
1915 .await
1916 .map_err(|e| self.sqlx_to_provider_error("get_custom_status", e))?;
1917
1918 match row {
1919 Some((custom_status, version)) => Ok(Some((custom_status, version as u64))),
1920 None => Ok(None),
1921 }
1922 }
1923
1924 async fn get_kv_value(
1925 &self,
1926 instance_id: &str,
1927 key: &str,
1928 ) -> Result<Option<String>, ProviderError> {
1929 let row: Option<(Option<String>, bool)> = sqlx::query_as(&format!(
1930 "SELECT * FROM {}.get_kv_value($1, $2)",
1931 self.schema_name
1932 ))
1933 .bind(instance_id)
1934 .bind(key)
1935 .fetch_optional(&*self.pool)
1936 .await
1937 .map_err(|e| self.sqlx_to_provider_error("get_kv_value", e))?;
1938
1939 Ok(row.and_then(|(value, found)| if found { value } else { None }))
1940 }
1941
1942 async fn get_kv_all_values(
1943 &self,
1944 instance_id: &str,
1945 ) -> Result<std::collections::HashMap<String, String>, ProviderError> {
1946 let rows: Vec<(String, String)> = sqlx::query_as(&format!(
1947 "SELECT * FROM {}.get_kv_all_values($1)",
1948 self.schema_name
1949 ))
1950 .bind(instance_id)
1951 .fetch_all(&*self.pool)
1952 .await
1953 .map_err(|e| self.sqlx_to_provider_error("get_kv_all_values", e))?;
1954
1955 Ok(rows.into_iter().collect())
1956 }
1957
1958 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1959 async fn get_instance_stats(
1960 &self,
1961 instance: &str,
1962 ) -> Result<Option<SystemStats>, ProviderError> {
1963 let row: Option<(bool, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
1964 "SELECT * FROM {}.get_instance_stats($1)",
1965 self.schema_name
1966 ))
1967 .bind(instance)
1968 .fetch_optional(&*self.pool)
1969 .await
1970 .map_err(|e| self.sqlx_to_provider_error("get_instance_stats", e))?;
1971
1972 match row {
1973 Some((
1974 true,
1975 history_event_count,
1976 history_size_bytes,
1977 queue_pending_count,
1978 kv_user_key_count,
1979 kv_total_value_bytes,
1980 )) => Ok(Some(SystemStats {
1981 history_event_count: history_event_count as u64,
1982 history_size_bytes: history_size_bytes as u64,
1983 queue_pending_count: queue_pending_count as u64,
1984 kv_user_key_count: kv_user_key_count as u64,
1985 kv_total_value_bytes: kv_total_value_bytes as u64,
1986 })),
1987 _ => Ok(None),
1988 }
1989 }
1990}
1991
1992#[async_trait::async_trait]
1993impl ProviderAdmin for PostgresProvider {
1994 #[instrument(skip(self), target = "duroxide::providers::postgres")]
1995 async fn list_instances(&self) -> Result<Vec<String>, ProviderError> {
1996 sqlx::query_scalar(&format!(
1997 "SELECT instance_id FROM {}.list_instances()",
1998 self.schema_name
1999 ))
2000 .fetch_all(&*self.pool)
2001 .await
2002 .map_err(|e| self.sqlx_to_provider_error("list_instances", e))
2003 }
2004
2005 #[instrument(skip(self), fields(status = %status), target = "duroxide::providers::postgres")]
2006 async fn list_instances_by_status(&self, status: &str) -> Result<Vec<String>, ProviderError> {
2007 sqlx::query_scalar(&format!(
2008 "SELECT instance_id FROM {}.list_instances_by_status($1)",
2009 self.schema_name
2010 ))
2011 .bind(status)
2012 .fetch_all(&*self.pool)
2013 .await
2014 .map_err(|e| self.sqlx_to_provider_error("list_instances_by_status", e))
2015 }
2016
2017 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2018 async fn list_executions(&self, instance: &str) -> Result<Vec<u64>, ProviderError> {
2019 let execution_ids: Vec<i64> = sqlx::query_scalar(&format!(
2020 "SELECT execution_id FROM {}.list_executions($1)",
2021 self.schema_name
2022 ))
2023 .bind(instance)
2024 .fetch_all(&*self.pool)
2025 .await
2026 .map_err(|e| self.sqlx_to_provider_error("list_executions", e))?;
2027
2028 Ok(execution_ids.into_iter().map(|id| id as u64).collect())
2029 }
2030
2031 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
2032 async fn read_history_with_execution_id(
2033 &self,
2034 instance: &str,
2035 execution_id: u64,
2036 ) -> Result<Vec<Event>, ProviderError> {
2037 let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
2038 "SELECT out_event_data FROM {}.fetch_history_with_execution($1, $2)",
2039 self.schema_name
2040 ))
2041 .bind(instance)
2042 .bind(execution_id as i64)
2043 .fetch_all(&*self.pool)
2044 .await
2045 .map_err(|e| self.sqlx_to_provider_error("read_execution", e))?;
2046
2047 event_data_rows
2048 .into_iter()
2049 .map(|event_data| {
2050 serde_json::from_str::<Event>(&event_data).map_err(|e| {
2051 ProviderError::permanent(
2052 "read_history_with_execution_id",
2053 format!("Failed to deserialize event: {e}"),
2054 )
2055 })
2056 })
2057 .collect()
2058 }
2059
2060 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2061 async fn read_history(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
2062 let execution_id = self.latest_execution_id(instance).await?;
2063 self.read_history_with_execution_id(instance, execution_id)
2064 .await
2065 }
2066
2067 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2068 async fn latest_execution_id(&self, instance: &str) -> Result<u64, ProviderError> {
2069 sqlx::query_scalar(&format!(
2070 "SELECT {}.latest_execution_id($1)",
2071 self.schema_name
2072 ))
2073 .bind(instance)
2074 .fetch_optional(&*self.pool)
2075 .await
2076 .map_err(|e| self.sqlx_to_provider_error("latest_execution_id", e))?
2077 .map(|id: i64| id as u64)
2078 .ok_or_else(|| ProviderError::permanent("latest_execution_id", "Instance not found"))
2079 }
2080
2081 #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
2082 async fn get_instance_info(&self, instance: &str) -> Result<InstanceInfo, ProviderError> {
2083 let row: Option<(
2084 String,
2085 String,
2086 String,
2087 i64,
2088 chrono::DateTime<Utc>,
2089 Option<chrono::DateTime<Utc>>,
2090 Option<String>,
2091 Option<String>,
2092 Option<String>,
2093 )> = sqlx::query_as(&format!(
2094 "SELECT * FROM {}.get_instance_info($1)",
2095 self.schema_name
2096 ))
2097 .bind(instance)
2098 .fetch_optional(&*self.pool)
2099 .await
2100 .map_err(|e| self.sqlx_to_provider_error("get_instance_info", e))?;
2101
2102 let (
2103 instance_id,
2104 orchestration_name,
2105 orchestration_version,
2106 current_execution_id,
2107 created_at,
2108 updated_at,
2109 status,
2110 output,
2111 parent_instance_id,
2112 ) =
2113 row.ok_or_else(|| ProviderError::permanent("get_instance_info", "Instance not found"))?;
2114
2115 Ok(InstanceInfo {
2116 instance_id,
2117 orchestration_name,
2118 orchestration_version,
2119 current_execution_id: current_execution_id as u64,
2120 status: status.unwrap_or_else(|| "Running".to_string()),
2121 output,
2122 created_at: created_at.timestamp_millis() as u64,
2123 updated_at: updated_at
2124 .map(|dt| dt.timestamp_millis() as u64)
2125 .unwrap_or(created_at.timestamp_millis() as u64),
2126 parent_instance_id,
2127 })
2128 }
2129
2130 #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
2131 async fn get_execution_info(
2132 &self,
2133 instance: &str,
2134 execution_id: u64,
2135 ) -> Result<ExecutionInfo, ProviderError> {
2136 let row: Option<(
2137 i64,
2138 String,
2139 Option<String>,
2140 chrono::DateTime<Utc>,
2141 Option<chrono::DateTime<Utc>>,
2142 i64,
2143 )> = sqlx::query_as(&format!(
2144 "SELECT * FROM {}.get_execution_info($1, $2)",
2145 self.schema_name
2146 ))
2147 .bind(instance)
2148 .bind(execution_id as i64)
2149 .fetch_optional(&*self.pool)
2150 .await
2151 .map_err(|e| self.sqlx_to_provider_error("get_execution_info", e))?;
2152
2153 let (exec_id, status, output, started_at, completed_at, event_count) = row
2154 .ok_or_else(|| ProviderError::permanent("get_execution_info", "Execution not found"))?;
2155
2156 Ok(ExecutionInfo {
2157 execution_id: exec_id as u64,
2158 status,
2159 output,
2160 started_at: started_at.timestamp_millis() as u64,
2161 completed_at: completed_at.map(|dt| dt.timestamp_millis() as u64),
2162 event_count: event_count as usize,
2163 })
2164 }
2165
2166 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2167 async fn get_system_metrics(&self) -> Result<SystemMetrics, ProviderError> {
2168 let row: Option<(i64, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
2169 "SELECT * FROM {}.get_system_metrics()",
2170 self.schema_name
2171 ))
2172 .fetch_optional(&*self.pool)
2173 .await
2174 .map_err(|e| self.sqlx_to_provider_error("get_system_metrics", e))?;
2175
2176 let (
2177 total_instances,
2178 total_executions,
2179 running_instances,
2180 completed_instances,
2181 failed_instances,
2182 total_events,
2183 ) = row.ok_or_else(|| {
2184 ProviderError::permanent("get_system_metrics", "Failed to get system metrics")
2185 })?;
2186
2187 Ok(SystemMetrics {
2188 total_instances: total_instances as u64,
2189 total_executions: total_executions as u64,
2190 running_instances: running_instances as u64,
2191 completed_instances: completed_instances as u64,
2192 failed_instances: failed_instances as u64,
2193 total_events: total_events as u64,
2194 })
2195 }
2196
2197 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2198 async fn get_queue_depths(&self) -> Result<QueueDepths, ProviderError> {
2199 let now_ms = Self::now_millis();
2200
2201 let row: Option<(i64, i64)> = sqlx::query_as(&format!(
2202 "SELECT * FROM {}.get_queue_depths($1)",
2203 self.schema_name
2204 ))
2205 .bind(now_ms)
2206 .fetch_optional(&*self.pool)
2207 .await
2208 .map_err(|e| self.sqlx_to_provider_error("get_queue_depths", e))?;
2209
2210 let (orchestrator_queue, worker_queue) = row.ok_or_else(|| {
2211 ProviderError::permanent("get_queue_depths", "Failed to get queue depths")
2212 })?;
2213
2214 Ok(QueueDepths {
2215 orchestrator_queue: orchestrator_queue as usize,
2216 worker_queue: worker_queue as usize,
2217 timer_queue: 0, })
2219 }
2220
2221 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2224 async fn list_children(&self, instance_id: &str) -> Result<Vec<String>, ProviderError> {
2225 sqlx::query_scalar(&format!(
2226 "SELECT child_instance_id FROM {}.list_children($1)",
2227 self.schema_name
2228 ))
2229 .bind(instance_id)
2230 .fetch_all(&*self.pool)
2231 .await
2232 .map_err(|e| self.sqlx_to_provider_error("list_children", e))
2233 }
2234
2235 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2236 async fn get_parent_id(&self, instance_id: &str) -> Result<Option<String>, ProviderError> {
2237 let result: Result<Option<String>, _> =
2240 sqlx::query_scalar(&format!("SELECT {}.get_parent_id($1)", self.schema_name))
2241 .bind(instance_id)
2242 .fetch_one(&*self.pool)
2243 .await;
2244
2245 match result {
2246 Ok(parent_id) => Ok(parent_id),
2247 Err(e) => {
2248 let err_str = e.to_string();
2249 if err_str.contains("Instance not found") {
2250 Err(ProviderError::permanent(
2251 "get_parent_id",
2252 format!("Instance not found: {}", instance_id),
2253 ))
2254 } else {
2255 Err(self.sqlx_to_provider_error("get_parent_id", e))
2256 }
2257 }
2258 }
2259 }
2260
2261 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2264 async fn delete_instances_atomic(
2265 &self,
2266 ids: &[String],
2267 force: bool,
2268 ) -> Result<DeleteInstanceResult, ProviderError> {
2269 if ids.is_empty() {
2270 return Ok(DeleteInstanceResult::default());
2271 }
2272
2273 let row: Option<(i64, i64, i64, i64)> = sqlx::query_as(&format!(
2274 "SELECT * FROM {}.delete_instances_atomic($1, $2)",
2275 self.schema_name
2276 ))
2277 .bind(ids)
2278 .bind(force)
2279 .fetch_optional(&*self.pool)
2280 .await
2281 .map_err(|e| {
2282 let err_str = e.to_string();
2283 if err_str.contains("is Running") {
2284 ProviderError::permanent("delete_instances_atomic", err_str)
2285 } else if err_str.contains("Orphan detected") {
2286 ProviderError::permanent("delete_instances_atomic", err_str)
2287 } else {
2288 self.sqlx_to_provider_error("delete_instances_atomic", e)
2289 }
2290 })?;
2291
2292 let (instances_deleted, executions_deleted, events_deleted, queue_messages_deleted) =
2293 row.unwrap_or((0, 0, 0, 0));
2294
2295 debug!(
2296 target = "duroxide::providers::postgres",
2297 operation = "delete_instances_atomic",
2298 instances_deleted = instances_deleted,
2299 executions_deleted = executions_deleted,
2300 events_deleted = events_deleted,
2301 queue_messages_deleted = queue_messages_deleted,
2302 "Deleted instances atomically"
2303 );
2304
2305 Ok(DeleteInstanceResult {
2306 instances_deleted: instances_deleted as u64,
2307 executions_deleted: executions_deleted as u64,
2308 events_deleted: events_deleted as u64,
2309 queue_messages_deleted: queue_messages_deleted as u64,
2310 })
2311 }
2312
2313 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2314 async fn delete_instance_bulk(
2315 &self,
2316 filter: InstanceFilter,
2317 ) -> Result<DeleteInstanceResult, ProviderError> {
2318 let mut sql = format!(
2320 r#"
2321 SELECT i.instance_id
2322 FROM {}.instances i
2323 LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
2324 AND i.current_execution_id = e.execution_id
2325 WHERE i.parent_instance_id IS NULL
2326 AND e.status IN ('Completed', 'Failed', 'ContinuedAsNew')
2327 "#,
2328 self.schema_name, self.schema_name
2329 );
2330
2331 if let Some(ref ids) = filter.instance_ids {
2333 if ids.is_empty() {
2334 return Ok(DeleteInstanceResult::default());
2335 }
2336 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
2337 sql.push_str(&format!(
2338 " AND i.instance_id IN ({})",
2339 placeholders.join(", ")
2340 ));
2341 }
2342
2343 if filter.completed_before.is_some() {
2345 let param_num = filter
2346 .instance_ids
2347 .as_ref()
2348 .map(|ids| ids.len())
2349 .unwrap_or(0)
2350 + 1;
2351 sql.push_str(&format!(
2352 " AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
2353 param_num
2354 ));
2355 }
2356
2357 let limit = filter.limit.unwrap_or(1000);
2359 let limit_param_num = filter
2360 .instance_ids
2361 .as_ref()
2362 .map(|ids| ids.len())
2363 .unwrap_or(0)
2364 + if filter.completed_before.is_some() {
2365 1
2366 } else {
2367 0
2368 }
2369 + 1;
2370 sql.push_str(&format!(" LIMIT ${}", limit_param_num));
2371
2372 let mut query = sqlx::query_scalar::<_, String>(&sql);
2374 if let Some(ref ids) = filter.instance_ids {
2375 for id in ids {
2376 query = query.bind(id);
2377 }
2378 }
2379 if let Some(completed_before) = filter.completed_before {
2380 query = query.bind(completed_before as i64);
2381 }
2382 query = query.bind(limit as i64);
2383
2384 let instance_ids: Vec<String> = query
2385 .fetch_all(&*self.pool)
2386 .await
2387 .map_err(|e| self.sqlx_to_provider_error("delete_instance_bulk", e))?;
2388
2389 if instance_ids.is_empty() {
2390 return Ok(DeleteInstanceResult::default());
2391 }
2392
2393 let mut result = DeleteInstanceResult::default();
2395
2396 for instance_id in &instance_ids {
2397 let tree = self.get_instance_tree(instance_id).await?;
2399
2400 let delete_result = self.delete_instances_atomic(&tree.all_ids, true).await?;
2402 result.instances_deleted += delete_result.instances_deleted;
2403 result.executions_deleted += delete_result.executions_deleted;
2404 result.events_deleted += delete_result.events_deleted;
2405 result.queue_messages_deleted += delete_result.queue_messages_deleted;
2406 }
2407
2408 debug!(
2409 target = "duroxide::providers::postgres",
2410 operation = "delete_instance_bulk",
2411 instances_deleted = result.instances_deleted,
2412 executions_deleted = result.executions_deleted,
2413 events_deleted = result.events_deleted,
2414 queue_messages_deleted = result.queue_messages_deleted,
2415 "Bulk deleted instances"
2416 );
2417
2418 Ok(result)
2419 }
2420
2421 #[instrument(skip(self), fields(instance = %instance_id), target = "duroxide::providers::postgres")]
2424 async fn prune_executions(
2425 &self,
2426 instance_id: &str,
2427 options: PruneOptions,
2428 ) -> Result<PruneResult, ProviderError> {
2429 let keep_last: Option<i32> = options.keep_last.map(|v| v as i32);
2430 let completed_before_ms: Option<i64> = options.completed_before.map(|v| v as i64);
2431
2432 let row: Option<(i64, i64, i64)> = sqlx::query_as(&format!(
2433 "SELECT * FROM {}.prune_executions($1, $2, $3)",
2434 self.schema_name
2435 ))
2436 .bind(instance_id)
2437 .bind(keep_last)
2438 .bind(completed_before_ms)
2439 .fetch_optional(&*self.pool)
2440 .await
2441 .map_err(|e| self.sqlx_to_provider_error("prune_executions", e))?;
2442
2443 let (instances_processed, executions_deleted, events_deleted) = row.unwrap_or((0, 0, 0));
2444
2445 debug!(
2446 target = "duroxide::providers::postgres",
2447 operation = "prune_executions",
2448 instance_id = %instance_id,
2449 instances_processed = instances_processed,
2450 executions_deleted = executions_deleted,
2451 events_deleted = events_deleted,
2452 "Pruned executions"
2453 );
2454
2455 Ok(PruneResult {
2456 instances_processed: instances_processed as u64,
2457 executions_deleted: executions_deleted as u64,
2458 events_deleted: events_deleted as u64,
2459 })
2460 }
2461
2462 #[instrument(skip(self), target = "duroxide::providers::postgres")]
2463 async fn prune_executions_bulk(
2464 &self,
2465 filter: InstanceFilter,
2466 options: PruneOptions,
2467 ) -> Result<PruneResult, ProviderError> {
2468 let mut sql = format!(
2473 r#"
2474 SELECT i.instance_id
2475 FROM {}.instances i
2476 LEFT JOIN {}.executions e ON i.instance_id = e.instance_id
2477 AND i.current_execution_id = e.execution_id
2478 WHERE 1=1
2479 "#,
2480 self.schema_name, self.schema_name
2481 );
2482
2483 if let Some(ref ids) = filter.instance_ids {
2485 if ids.is_empty() {
2486 return Ok(PruneResult::default());
2487 }
2488 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
2489 sql.push_str(&format!(
2490 " AND i.instance_id IN ({})",
2491 placeholders.join(", ")
2492 ));
2493 }
2494
2495 if filter.completed_before.is_some() {
2497 let param_num = filter
2498 .instance_ids
2499 .as_ref()
2500 .map(|ids| ids.len())
2501 .unwrap_or(0)
2502 + 1;
2503 sql.push_str(&format!(
2504 " AND e.completed_at < TO_TIMESTAMP(${} / 1000.0)",
2505 param_num
2506 ));
2507 }
2508
2509 let limit = filter.limit.unwrap_or(1000);
2511 let limit_param_num = filter
2512 .instance_ids
2513 .as_ref()
2514 .map(|ids| ids.len())
2515 .unwrap_or(0)
2516 + if filter.completed_before.is_some() {
2517 1
2518 } else {
2519 0
2520 }
2521 + 1;
2522 sql.push_str(&format!(" LIMIT ${}", limit_param_num));
2523
2524 let mut query = sqlx::query_scalar::<_, String>(&sql);
2526 if let Some(ref ids) = filter.instance_ids {
2527 for id in ids {
2528 query = query.bind(id);
2529 }
2530 }
2531 if let Some(completed_before) = filter.completed_before {
2532 query = query.bind(completed_before as i64);
2533 }
2534 query = query.bind(limit as i64);
2535
2536 let instance_ids: Vec<String> = query
2537 .fetch_all(&*self.pool)
2538 .await
2539 .map_err(|e| self.sqlx_to_provider_error("prune_executions_bulk", e))?;
2540
2541 let mut result = PruneResult::default();
2543
2544 for instance_id in &instance_ids {
2545 let single_result = self.prune_executions(instance_id, options.clone()).await?;
2546 result.instances_processed += single_result.instances_processed;
2547 result.executions_deleted += single_result.executions_deleted;
2548 result.events_deleted += single_result.events_deleted;
2549 }
2550
2551 debug!(
2552 target = "duroxide::providers::postgres",
2553 operation = "prune_executions_bulk",
2554 instances_processed = result.instances_processed,
2555 executions_deleted = result.executions_deleted,
2556 events_deleted = result.events_deleted,
2557 "Bulk pruned executions"
2558 );
2559
2560 Ok(result)
2561 }
2562}
2563
2564#[cfg(test)]
2565mod tests {
2566 use super::*;
2567 use crate::entra::test_support::{token, RecordingFakeTokenSource};
2568
2569 #[test]
2570 fn build_entra_connect_options_uses_verify_full() {
2571 let opts =
2572 build_entra_connect_options("h.example.com", 5432, "db", "u", PgSslMode::VerifyFull);
2573 assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull));
2574 assert_eq!(opts.get_host(), "h.example.com");
2575 assert_eq!(opts.get_port(), 5432);
2576 assert_eq!(opts.get_database(), Some("db"));
2577 assert_eq!(opts.get_username(), "u");
2578 }
2579
2580 #[test]
2581 fn compute_next_refresh_sleep_is_capped_by_ceiling() {
2582 let now = SystemTime::now();
2584 let expires = now + Duration::from_secs(24 * 3600);
2585 let sleep = compute_next_refresh_sleep(Duration::from_secs(5 * 60), expires, now);
2586 assert_eq!(sleep, Duration::from_secs(5 * 60));
2587 }
2588
2589 #[test]
2590 fn compute_next_refresh_sleep_drives_from_expiry() {
2591 let now = SystemTime::now();
2593 let expires = now + Duration::from_secs(6 * 60);
2594 let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
2595 assert!(sleep <= Duration::from_secs(60), "got {sleep:?}");
2596 assert!(sleep >= ENTRA_REFRESH_MIN_INTERVAL, "got {sleep:?}");
2597 }
2598
2599 #[test]
2600 fn compute_next_refresh_sleep_floors_at_min_interval() {
2601 let now = SystemTime::now();
2603 let expires = now + Duration::from_secs(60); let sleep = compute_next_refresh_sleep(Duration::from_secs(3600), expires, now);
2605 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2606 }
2607
2608 #[tokio::test]
2609 async fn recording_token_source_returns_distinct_tokens_in_script_order() {
2610 let fake = RecordingFakeTokenSource::with_tokens(vec![
2617 token("token-A", 3600),
2618 token("token-B", 3600),
2619 token("token-C", 3600),
2620 token("token-D", 3600),
2621 token("token-E", 3600),
2622 token("token-F", 3600),
2623 ]);
2624 let token_source: Arc<dyn TokenSource> = fake.clone();
2625
2626 let base_options =
2630 build_entra_connect_options("127.0.0.1", 5432, "db", "u", PgSslMode::VerifyFull);
2631 let pool: Arc<PgPool> = Arc::new(
2632 PgPoolOptions::new()
2633 .max_connections(1)
2634 .connect_lazy_with(base_options.clone().password("placeholder")),
2635 );
2636
2637 let initial_expires_at = SystemTime::now() + Duration::from_secs(3600);
2638
2639 let _ = pool;
2654 let _ = initial_expires_at;
2655
2656 let t1 = token_source.fetch_token(&["aud"]).await.unwrap();
2657 let t2 = token_source.fetch_token(&["aud"]).await.unwrap();
2658 let t3 = token_source.fetch_token(&["aud"]).await.unwrap();
2659 assert_ne!(t1.secret, t2.secret);
2660 assert_ne!(t2.secret, t3.secret);
2661 assert_eq!(fake.call_count(), 3);
2662 }
2663
2664 #[tokio::test]
2665 async fn audience_override_is_passed_to_token_source() {
2666 let fake = RecordingFakeTokenSource::with_tokens(vec![token("t", 3600)]);
2667 let source: Arc<dyn TokenSource> = fake.clone();
2668 let opts =
2669 crate::entra::EntraAuthOptions::new().audience("https://custom.example/.default");
2670 let _t = source.fetch_token(&[opts.audience_str()]).await.unwrap();
2671 let scopes = fake.recorded_scopes();
2672 assert_eq!(scopes.len(), 1);
2673 assert_eq!(
2674 scopes[0],
2675 vec!["https://custom.example/.default".to_string()]
2676 );
2677 }
2678
2679 #[tokio::test]
2680 async fn missing_credential_surfaces_descriptive_error() {
2681 let fake = RecordingFakeTokenSource::always_failing("no credential available");
2682 let source: Arc<dyn TokenSource> = fake;
2683 let result: anyhow::Result<crate::entra::EntraToken> = source.fetch_token(&["aud"]).await;
2684 let err = result.expect_err("should fail");
2685 let msg = format!("{err:#}");
2686 assert!(msg.contains("no credential available"), "got: {msg}");
2687 }
2688
2689 #[test]
2690 fn next_sleep_after_iteration_uses_expiry_schedule_on_success() {
2691 let now = SystemTime::now();
2692 let expires = now + Duration::from_secs(3600);
2693 let result: Result<Result<(), ()>, String> = Ok(Ok(()));
2694 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2695 let expected = compute_next_refresh_sleep(Duration::from_secs(20 * 60), expires, now);
2697 assert_eq!(sleep, expected);
2698 assert_eq!(sleep, Duration::from_secs(20 * 60));
2699 }
2700
2701 #[test]
2702 fn next_sleep_after_iteration_returns_min_interval_on_fetch_failure() {
2703 let now = SystemTime::now();
2707 let expires = now + Duration::from_secs(3600);
2710 let result: Result<Result<(), ()>, String> = Ok(Err(()));
2711 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2712 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2713 }
2714
2715 #[test]
2716 fn next_sleep_after_iteration_returns_min_interval_on_panic() {
2717 let now = SystemTime::now();
2718 let expires = now + Duration::from_secs(3600);
2719 let result: Result<Result<(), ()>, String> = Err("simulated panic".to_string());
2720 let sleep = next_sleep_after_iteration(&result, Duration::from_secs(20 * 60), expires, now);
2721 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2722 }
2723
2724 #[test]
2725 fn compute_next_refresh_sleep_floors_when_ceiling_is_tiny() {
2726 let now = SystemTime::now();
2729 let expires = now + Duration::from_secs(3600);
2730 let sleep = compute_next_refresh_sleep(Duration::from_secs(1), expires, now);
2731 assert_eq!(sleep, ENTRA_REFRESH_MIN_INTERVAL);
2732 }
2733
2734 #[test]
2735 fn entra_token_debug_redacts_secret() {
2736 use crate::entra::test_support::token;
2737 let t = token("super-secret-bearer-string", 3600);
2738 let debug = format!("{t:?}");
2739 assert!(
2740 !debug.contains("super-secret-bearer-string"),
2741 "leaked: {debug}"
2742 );
2743 assert!(
2744 debug.contains("<redacted>"),
2745 "expected redaction marker: {debug}"
2746 );
2747 }
2748
2749 #[test]
2750 fn classify_pg_sqlstate_gates_28xxx_on_is_entra() {
2751 use crate::provider::{classify_pg_sqlstate, SqlStateClass};
2752
2753 assert_eq!(
2755 classify_pg_sqlstate(Some("28000"), true),
2756 SqlStateClass::Retryable
2757 );
2758 assert_eq!(
2759 classify_pg_sqlstate(Some("28P01"), true),
2760 SqlStateClass::Retryable
2761 );
2762
2763 assert_eq!(
2765 classify_pg_sqlstate(Some("28000"), false),
2766 SqlStateClass::Permanent
2767 );
2768 assert_eq!(
2769 classify_pg_sqlstate(Some("28P01"), false),
2770 SqlStateClass::Permanent
2771 );
2772
2773 assert_eq!(
2775 classify_pg_sqlstate(Some("40P01"), true),
2776 SqlStateClass::Retryable
2777 );
2778 assert_eq!(
2779 classify_pg_sqlstate(Some("40P01"), false),
2780 SqlStateClass::Retryable
2781 );
2782 assert_eq!(
2783 classify_pg_sqlstate(Some("23505"), true),
2784 SqlStateClass::Permanent
2785 );
2786 assert_eq!(
2787 classify_pg_sqlstate(Some("23505"), false),
2788 SqlStateClass::Permanent
2789 );
2790 assert_eq!(
2791 classify_pg_sqlstate(Some("0A000"), true),
2792 SqlStateClass::Retryable
2793 );
2794 assert_eq!(classify_pg_sqlstate(None, true), SqlStateClass::Permanent);
2795 }
2796
2797 #[tokio::test]
2798 async fn run_with_panic_guard_catches_string_panic_and_continues() {
2799 let result: Result<(), String> = run_with_panic_guard(async { panic!("boom") }).await;
2800 let msg = result.expect_err("must catch the panic");
2801 assert!(msg.contains("boom"), "got: {msg}");
2802 }
2803
2804 #[tokio::test]
2805 async fn run_with_panic_guard_returns_ok_when_future_completes() {
2806 let result: Result<i32, String> = run_with_panic_guard(async { 42 }).await;
2807 assert_eq!(result.unwrap(), 42);
2808 }
2809
2810 #[tokio::test]
2811 async fn run_with_panic_guard_handles_non_string_panic_payload() {
2812 let result: Result<(), String> =
2814 run_with_panic_guard(async { std::panic::panic_any(42_i32) }).await;
2815 let msg = result.expect_err("must catch");
2816 assert!(msg.contains("non-string panic payload"), "got: {msg}");
2817 }
2818
2819 #[test]
2822 fn truncate_panic_message_passes_through_short_input() {
2823 let s = "short message".to_string();
2824 assert_eq!(truncate_panic_message(s.clone(), 256), s);
2825 }
2826
2827 #[test]
2828 fn truncate_panic_message_truncates_long_input_with_marker() {
2829 let raw = "A".repeat(1024);
2830 let out = truncate_panic_message(raw, 256);
2831 assert!(out.starts_with(&"A".repeat(256)));
2832 assert!(out.ends_with("…[truncated]"), "got: {out}");
2833 assert_eq!(out.len(), 256 + "…[truncated]".len());
2835 }
2836
2837 #[test]
2838 fn truncate_panic_message_respects_utf8_char_boundaries() {
2839 let raw = "✨".repeat(100);
2843 let out = truncate_panic_message(raw, 256);
2844 assert!(out.ends_with("…[truncated]"));
2847 }
2848
2849 #[tokio::test]
2850 async fn run_with_panic_guard_truncates_oversized_panic_message() {
2851 let result: Result<(), String> = run_with_panic_guard(async {
2854 panic!("{}", "S".repeat(10_000));
2855 })
2856 .await;
2857 let msg = result.expect_err("must catch");
2858 assert!(
2859 msg.len() < 10_000,
2860 "panic message not truncated: len={}",
2861 msg.len()
2862 );
2863 assert!(
2864 msg.ends_with("…[truncated]"),
2865 "missing truncation marker: {msg}"
2866 );
2867 }
2868}
2869
2870#[cfg(test)]
2895mod entra_pipeline_tests {
2896 use super::*;
2897 use crate::entra::test_support::{token, RecordingFakeTokenSource};
2898 use sqlx::Row;
2899
2900 fn parse_database_url(url: &str) -> Option<(String, u16, String, String, String)> {
2906 let stripped = url
2907 .strip_prefix("postgres://")
2908 .or_else(|| url.strip_prefix("postgresql://"))?;
2909 let (creds, rest) = stripped.split_once('@')?;
2910 let (user, password) = creds.split_once(':')?;
2911 let (hostport, db_with_query) = rest.split_once('/')?;
2912 let (host, port_str) = hostport
2913 .split_once(':')
2914 .map(|(h, p)| (h, p))
2915 .unwrap_or((hostport, "5432"));
2916 let port: u16 = port_str.parse().ok()?;
2917 let db = db_with_query.split('?').next()?;
2918 Some((
2919 host.to_string(),
2920 port,
2921 db.to_string(),
2922 user.to_string(),
2923 password.to_string(),
2924 ))
2925 }
2926
2927 fn pg_connection_or_skip() -> Option<(String, u16, String, String, String)> {
2930 dotenvy::dotenv().ok();
2931 let url = match std::env::var("DATABASE_URL") {
2932 Ok(u) => u,
2933 Err(_) => {
2934 eprintln!("DATABASE_URL not set; skipping Entra pipeline integration test");
2935 return None;
2936 }
2937 };
2938 match parse_database_url(&url) {
2939 Some(parts) => Some(parts),
2940 None => {
2941 eprintln!("DATABASE_URL not parseable; skipping: {url}");
2942 None
2943 }
2944 }
2945 }
2946
2947 fn unique_schema() -> String {
2948 let id = uuid::Uuid::new_v4().to_string();
2949 format!("entra_inj_{}", &id[id.len() - 8..])
2950 }
2951
2952 async fn drop_schema(pool: &PgPool, schema: &str) {
2955 let stmt = format!("DROP SCHEMA IF EXISTS \"{schema}\" CASCADE");
2956 if let Err(e) = sqlx::query(&stmt).execute(pool).await {
2957 eprintln!("warning: failed to drop schema {schema}: {e}");
2958 }
2959 }
2960
2961 #[tokio::test]
2962 async fn pipeline_with_injected_token_authenticates_and_runs_migrations() {
2963 let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
2964 return;
2965 };
2966
2967 let token_source: Arc<dyn TokenSource> =
2968 RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
2969 let schema = unique_schema();
2970
2971 let provider = PostgresProvider::new_with_entra_with_token_source(
2972 &host,
2973 port,
2974 &db,
2975 &user,
2976 Some(&schema),
2977 EntraAuthOptions::new(),
2978 token_source,
2979 PgSslMode::Disable,
2980 )
2981 .await
2982 .expect("Entra pipeline must succeed against local PG with correct token");
2983
2984 let row = sqlx::query(&format!(
2986 "SELECT to_regclass('{}.instances')::text AS r",
2987 schema
2988 ))
2989 .fetch_one(provider.pool())
2990 .await
2991 .expect("smoke query must succeed");
2992 let regclass: Option<String> = row.get("r");
2993 assert!(
2994 regclass.is_some(),
2995 "expected migrations to create {}.instances",
2996 schema
2997 );
2998
2999 drop_schema(provider.pool(), &schema).await;
3000 }
3001
3002 #[tokio::test]
3003 async fn pipeline_with_wrong_token_fails_before_migrations() {
3004 let Some((host, port, db, user, _password)) = pg_connection_or_skip() else {
3005 return;
3006 };
3007
3008 let token_source: Arc<dyn TokenSource> =
3009 RecordingFakeTokenSource::with_tokens(vec![token("definitely-wrong-password", 3600)]);
3010 let schema = unique_schema();
3011
3012 let result = PostgresProvider::new_with_entra_with_token_source(
3013 &host,
3014 port,
3015 &db,
3016 &user,
3017 Some(&schema),
3018 EntraAuthOptions::new(),
3019 token_source,
3020 PgSslMode::Disable,
3021 )
3022 .await;
3023
3024 let err = match result {
3025 Ok(_) => panic!("wrong token must fail pool construction, but provider was built"),
3026 Err(e) => e,
3027 };
3028 let msg = format!("{err:#}");
3029 assert!(
3033 msg.to_lowercase().contains("password")
3034 || msg.contains("28P01")
3035 || msg.contains("28000"),
3036 "expected authentication failure, got: {msg}"
3037 );
3038 }
3039
3040 #[tokio::test]
3041 async fn pipeline_default_constructor_path_with_injected_token() {
3042 let Some((host, port, db, user, password)) = pg_connection_or_skip() else {
3046 return;
3047 };
3048
3049 let schema = unique_schema();
3057 let token_source: Arc<dyn TokenSource> =
3058 RecordingFakeTokenSource::with_tokens(vec![token(&password, 3600)]);
3059
3060 let provider = PostgresProvider::new_with_entra_with_token_source(
3061 &host,
3062 port,
3063 &db,
3064 &user,
3065 Some(&schema),
3066 EntraAuthOptions::new().refresh_interval(Duration::from_secs(60 * 60)),
3067 token_source,
3068 PgSslMode::Disable,
3069 )
3070 .await
3071 .expect("default-constructor variant must succeed");
3072 assert_eq!(provider.schema_name(), schema);
3073
3074 drop_schema(provider.pool(), &schema).await;
3075 }
3076}