duroxide_pg/
provider.rs

1use anyhow::Result;
2use chrono::{TimeZone, Utc};
3use duroxide::providers::{
4    ExecutionInfo, ExecutionMetadata, InstanceInfo, OrchestrationItem, Provider, ProviderAdmin,
5    ProviderError, QueueDepths, SystemMetrics, WorkItem,
6};
7use duroxide::Event;
8use sqlx::{postgres::PgPoolOptions, Error as SqlxError, PgPool};
9use std::sync::Arc;
10use std::time::Duration;
11use std::time::{SystemTime, UNIX_EPOCH};
12use tokio::time::sleep;
13use tracing::{debug, error, instrument, warn};
14
15use crate::migrations::MigrationRunner;
16
17/// PostgreSQL-based provider for Duroxide durable orchestrations.
18///
19/// Implements the [`Provider`] and [`ProviderAdmin`] traits from Duroxide,
20/// storing orchestration state, history, and work queues in PostgreSQL.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use duroxide_pg::PostgresProvider;
26///
27/// # async fn example() -> anyhow::Result<()> {
28/// // Connect using DATABASE_URL or explicit connection string
29/// let provider = PostgresProvider::new("postgres://localhost/mydb").await?;
30///
31/// // Or use a custom schema for isolation
32/// let provider = PostgresProvider::new_with_schema(
33///     "postgres://localhost/mydb",
34///     Some("my_app"),
35/// ).await?;
36/// # Ok(())
37/// # }
38/// ```
39pub struct PostgresProvider {
40    pool: Arc<PgPool>,
41    schema_name: String,
42}
43
44impl PostgresProvider {
45    pub async fn new(database_url: &str) -> Result<Self> {
46        Self::new_with_schema(database_url, None).await
47    }
48
49    pub async fn new_with_schema(database_url: &str, schema_name: Option<&str>) -> Result<Self> {
50        let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
51            .ok()
52            .and_then(|s| s.parse::<u32>().ok())
53            .unwrap_or(10);
54
55        let pool = PgPoolOptions::new()
56            .max_connections(max_connections)
57            .min_connections(1)
58            .acquire_timeout(std::time::Duration::from_secs(30))
59            .connect(database_url)
60            .await?;
61
62        let schema_name = schema_name.unwrap_or("public").to_string();
63
64        let provider = Self {
65            pool: Arc::new(pool),
66            schema_name: schema_name.clone(),
67        };
68
69        // Run migrations to initialize schema
70        let migration_runner = MigrationRunner::new(provider.pool.clone(), schema_name.clone());
71        migration_runner.migrate().await?;
72
73        Ok(provider)
74    }
75
76    #[instrument(skip(self), target = "duroxide::providers::postgres")]
77    pub async fn initialize_schema(&self) -> Result<()> {
78        // Schema initialization is now handled by migrations
79        // This method is kept for backward compatibility but delegates to migrations
80        let migration_runner = MigrationRunner::new(self.pool.clone(), self.schema_name.clone());
81        migration_runner.migrate().await?;
82        Ok(())
83    }
84
85    /// Get current timestamp in milliseconds (Unix epoch)
86    fn now_millis() -> i64 {
87        SystemTime::now()
88            .duration_since(UNIX_EPOCH)
89            .unwrap()
90            .as_millis() as i64
91    }
92
93    /// Get schema-qualified table name
94    fn table_name(&self, table: &str) -> String {
95        format!("{}.{}", self.schema_name, table)
96    }
97
98    /// Get the database pool (for testing)
99    pub fn pool(&self) -> &PgPool {
100        &self.pool
101    }
102
103    /// Get the schema name (for testing)
104    pub fn schema_name(&self) -> &str {
105        &self.schema_name
106    }
107
108    /// Convert sqlx::Error to ProviderError with proper classification
109    fn sqlx_to_provider_error(operation: &str, e: SqlxError) -> ProviderError {
110        match e {
111            SqlxError::Database(ref db_err) => {
112                // PostgreSQL error codes
113                let code_opt = db_err.code();
114                let code = code_opt.as_deref();
115                if code == Some("40P01") {
116                    // Deadlock detected
117                    ProviderError::retryable(operation, format!("Deadlock detected: {e}"))
118                } else if code == Some("40001") {
119                    // Serialization failure - permanent error (transaction conflict, not transient)
120                    ProviderError::permanent(operation, format!("Serialization failure: {e}"))
121                } else if code == Some("23505") {
122                    // Unique constraint violation (duplicate event)
123                    ProviderError::permanent(operation, format!("Duplicate detected: {e}"))
124                } else if code == Some("23503") {
125                    // Foreign key constraint violation
126                    ProviderError::permanent(operation, format!("Foreign key violation: {e}"))
127                } else {
128                    ProviderError::permanent(operation, format!("Database error: {e}"))
129                }
130            }
131            SqlxError::PoolClosed | SqlxError::PoolTimedOut => {
132                ProviderError::retryable(operation, format!("Connection pool error: {e}"))
133            }
134            SqlxError::Io(_) => ProviderError::retryable(operation, format!("I/O error: {e}")),
135            _ => ProviderError::permanent(operation, format!("Unexpected error: {e}")),
136        }
137    }
138
139    /// Clean up schema after tests (drops all tables and optionally the schema)
140    ///
141    /// **SAFETY**: Never drops the "public" schema itself, only tables within it.
142    /// Only drops the schema if it's a custom schema (not "public").
143    pub async fn cleanup_schema(&self) -> Result<()> {
144        // Call the stored procedure to drop all tables
145        sqlx::query(&format!("SELECT {}.cleanup_schema()", self.schema_name))
146            .execute(&*self.pool)
147            .await?;
148
149        // SAFETY: Never drop the "public" schema - it's a PostgreSQL system schema
150        // Only drop custom schemas created for testing
151        if self.schema_name != "public" {
152            sqlx::query(&format!(
153                "DROP SCHEMA IF EXISTS {} CASCADE",
154                self.schema_name
155            ))
156            .execute(&*self.pool)
157            .await?;
158        } else {
159            // Explicit safeguard: we only drop tables from public schema, never the schema itself
160            // This ensures we don't accidentally drop the default PostgreSQL schema
161        }
162
163        Ok(())
164    }
165}
166
167#[async_trait::async_trait]
168impl Provider for PostgresProvider {
169    #[instrument(skip(self), target = "duroxide::providers::postgres")]
170    async fn fetch_orchestration_item(
171        &self,
172        lock_timeout: Duration,
173    ) -> Result<Option<OrchestrationItem>, ProviderError> {
174        let start = std::time::Instant::now();
175
176        const MAX_RETRIES: u32 = 3;
177        const RETRY_DELAY_MS: u64 = 50;
178
179        // Convert Duration to milliseconds
180        let lock_timeout_ms = lock_timeout.as_millis() as i64;
181        let mut _last_error: Option<ProviderError> = None;
182
183        for attempt in 0..=MAX_RETRIES {
184            let now_ms = Self::now_millis();
185
186            let result: Result<
187                Option<(
188                    String,
189                    String,
190                    String,
191                    i64,
192                    serde_json::Value,
193                    serde_json::Value,
194                    String,
195                )>,
196                SqlxError,
197            > = sqlx::query_as(&format!(
198                "SELECT * FROM {}.fetch_orchestration_item($1, $2)",
199                self.schema_name
200            ))
201            .bind(now_ms)
202            .bind(lock_timeout_ms)
203            .fetch_optional(&*self.pool)
204            .await;
205
206            let row = match result {
207                Ok(r) => r,
208                Err(e) => {
209                    let provider_err = Self::sqlx_to_provider_error("fetch_orchestration_item", e);
210                    if provider_err.is_retryable() && attempt < MAX_RETRIES {
211                        warn!(
212                            target = "duroxide::providers::postgres",
213                            operation = "fetch_orchestration_item",
214                            attempt = attempt + 1,
215                            error = %provider_err,
216                            "Retryable error, will retry"
217                        );
218                        _last_error = Some(provider_err);
219                        sleep(std::time::Duration::from_millis(
220                            RETRY_DELAY_MS * (attempt as u64 + 1),
221                        ))
222                        .await;
223                        continue;
224                    }
225                    return Err(provider_err);
226                }
227            };
228
229            if let Some((
230                instance_id,
231                orchestration_name,
232                orchestration_version,
233                execution_id,
234                history_json,
235                messages_json,
236                lock_token,
237            )) = row
238            {
239                let history: Vec<Event> = serde_json::from_value(history_json).map_err(|e| {
240                    ProviderError::permanent(
241                        "fetch_orchestration_item",
242                        format!("Failed to deserialize history: {e}"),
243                    )
244                })?;
245
246                let messages: Vec<WorkItem> =
247                    serde_json::from_value(messages_json).map_err(|e| {
248                        ProviderError::permanent(
249                            "fetch_orchestration_item",
250                            format!("Failed to deserialize messages: {e}"),
251                        )
252                    })?;
253
254                let duration_ms = start.elapsed().as_millis() as u64;
255                debug!(
256                    target = "duroxide::providers::postgres",
257                    operation = "fetch_orchestration_item",
258                    instance_id = %instance_id,
259                    execution_id = execution_id,
260                    message_count = messages.len(),
261                    history_count = history.len(),
262                    duration_ms = duration_ms,
263                    attempts = attempt + 1,
264                    "Fetched orchestration item via stored procedure"
265                );
266
267                return Ok(Some(OrchestrationItem {
268                    instance: instance_id,
269                    orchestration_name,
270                    execution_id: execution_id as u64,
271                    version: orchestration_version,
272                    history,
273                    messages,
274                    lock_token,
275                }));
276            }
277
278            if attempt < MAX_RETRIES {
279                sleep(std::time::Duration::from_millis(RETRY_DELAY_MS)).await;
280            }
281        }
282
283        Ok(None)
284    }
285    #[instrument(skip(self), fields(lock_token = %lock_token, execution_id = execution_id), target = "duroxide::providers::postgres")]
286    async fn ack_orchestration_item(
287        &self,
288        lock_token: &str,
289        execution_id: u64,
290        history_delta: Vec<Event>,
291        worker_items: Vec<WorkItem>,
292        orchestrator_items: Vec<WorkItem>,
293        metadata: ExecutionMetadata,
294    ) -> Result<(), ProviderError> {
295        let start = std::time::Instant::now();
296
297        const MAX_RETRIES: u32 = 3;
298        const RETRY_DELAY_MS: u64 = 50;
299
300        let mut history_delta_payload = Vec::with_capacity(history_delta.len());
301        for event in &history_delta {
302            if event.event_id() == 0 {
303                return Err(ProviderError::permanent(
304                    "ack_orchestration_item",
305                    "event_id must be set by runtime",
306                ));
307            }
308
309            let event_json = serde_json::to_string(event).map_err(|e| {
310                ProviderError::permanent(
311                    "ack_orchestration_item",
312                    format!("Failed to serialize event: {e}"),
313                )
314            })?;
315
316            let event_type = format!("{event:?}")
317                .split('{')
318                .next()
319                .unwrap_or("Unknown")
320                .trim()
321                .to_string();
322
323            history_delta_payload.push(serde_json::json!({
324                "event_id": event.event_id(),
325                "event_type": event_type,
326                "event_data": event_json,
327            }));
328        }
329
330        let history_delta_json = serde_json::Value::Array(history_delta_payload);
331
332        let worker_items_json = serde_json::to_value(&worker_items).map_err(|e| {
333            ProviderError::permanent(
334                "ack_orchestration_item",
335                format!("Failed to serialize worker items: {e}"),
336            )
337        })?;
338
339        let orchestrator_items_json = serde_json::to_value(&orchestrator_items).map_err(|e| {
340            ProviderError::permanent(
341                "ack_orchestration_item",
342                format!("Failed to serialize orchestrator items: {e}"),
343            )
344        })?;
345
346        let metadata_json = serde_json::json!({
347            "orchestration_name": metadata.orchestration_name,
348            "orchestration_version": metadata.orchestration_version,
349            "status": metadata.status,
350            "output": metadata.output,
351        });
352
353        for attempt in 0..=MAX_RETRIES {
354            let result = sqlx::query(&format!(
355                "SELECT {}.ack_orchestration_item($1, $2, $3, $4, $5, $6)",
356                self.schema_name
357            ))
358            .bind(lock_token)
359            .bind(execution_id as i64)
360            .bind(&history_delta_json)
361            .bind(&worker_items_json)
362            .bind(&orchestrator_items_json)
363            .bind(&metadata_json)
364            .execute(&*self.pool)
365            .await;
366
367            match result {
368                Ok(_) => {
369                    let duration_ms = start.elapsed().as_millis() as u64;
370                    debug!(
371                        target = "duroxide::providers::postgres",
372                        operation = "ack_orchestration_item",
373                        execution_id = execution_id,
374                        history_count = history_delta.len(),
375                        worker_items_count = worker_items.len(),
376                        orchestrator_items_count = orchestrator_items.len(),
377                        duration_ms = duration_ms,
378                        attempts = attempt + 1,
379                        "Acknowledged orchestration item via stored procedure"
380                    );
381                    return Ok(());
382                }
383                Err(e) => {
384                    // Check for permanent errors first
385                    if let SqlxError::Database(db_err) = &e {
386                        if db_err.message().contains("Invalid lock token") {
387                            return Err(ProviderError::permanent(
388                                "ack_orchestration_item",
389                                "Invalid lock token",
390                            ));
391                        }
392                    } else if e.to_string().contains("Invalid lock token") {
393                        return Err(ProviderError::permanent(
394                            "ack_orchestration_item",
395                            "Invalid lock token",
396                        ));
397                    }
398
399                    let provider_err = Self::sqlx_to_provider_error("ack_orchestration_item", e);
400                    if provider_err.is_retryable() && attempt < MAX_RETRIES {
401                        warn!(
402                            target = "duroxide::providers::postgres",
403                            operation = "ack_orchestration_item",
404                            attempt = attempt + 1,
405                            error = %provider_err,
406                            "Retryable error, will retry"
407                        );
408                        sleep(std::time::Duration::from_millis(
409                            RETRY_DELAY_MS * (attempt as u64 + 1),
410                        ))
411                        .await;
412                        continue;
413                    }
414                    return Err(provider_err);
415                }
416            }
417        }
418
419        // Should never reach here, but just in case
420        Ok(())
421    }
422    #[instrument(skip(self), fields(lock_token = %lock_token), target = "duroxide::providers::postgres")]
423    async fn abandon_orchestration_item(
424        &self,
425        lock_token: &str,
426        delay: Option<Duration>,
427    ) -> Result<(), ProviderError> {
428        let start = std::time::Instant::now();
429        let delay_param: Option<i64> = delay.map(|d| d.as_millis() as i64);
430
431        let instance_id = match sqlx::query_scalar::<_, String>(&format!(
432            "SELECT {}.abandon_orchestration_item($1, $2)",
433            self.schema_name
434        ))
435        .bind(lock_token)
436        .bind(delay_param)
437        .fetch_one(&*self.pool)
438        .await
439        {
440            Ok(instance_id) => instance_id,
441            Err(e) => {
442                if let SqlxError::Database(db_err) = &e {
443                    if db_err.message().contains("Invalid lock token") {
444                        return Err(ProviderError::permanent(
445                            "abandon_orchestration_item",
446                            "Invalid lock token",
447                        ));
448                    }
449                } else if e.to_string().contains("Invalid lock token") {
450                    return Err(ProviderError::permanent(
451                        "abandon_orchestration_item",
452                        "Invalid lock token",
453                    ));
454                }
455
456                return Err(Self::sqlx_to_provider_error(
457                    "abandon_orchestration_item",
458                    e,
459                ));
460            }
461        };
462
463        let duration_ms = start.elapsed().as_millis() as u64;
464        debug!(
465            target = "duroxide::providers::postgres",
466            operation = "abandon_orchestration_item",
467            instance_id = %instance_id,
468            delay_ms = delay.map(|d| d.as_millis() as u64),
469            duration_ms = duration_ms,
470            "Abandoned orchestration item via stored procedure"
471        );
472
473        Ok(())
474    }
475
476    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
477    async fn read(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
478        let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
479            "SELECT out_event_data FROM {}.fetch_history($1)",
480            self.schema_name
481        ))
482        .bind(instance)
483        .fetch_all(&*self.pool)
484        .await
485        .map_err(|e| Self::sqlx_to_provider_error("read", e))?;
486
487        Ok(event_data_rows
488            .into_iter()
489            .filter_map(|event_data| serde_json::from_str::<Event>(&event_data).ok())
490            .collect())
491    }
492
493    #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
494    async fn append_with_execution(
495        &self,
496        instance: &str,
497        execution_id: u64,
498        new_events: Vec<Event>,
499    ) -> Result<(), ProviderError> {
500        if new_events.is_empty() {
501            return Ok(());
502        }
503
504        let mut events_payload = Vec::with_capacity(new_events.len());
505        for event in &new_events {
506            if event.event_id() == 0 {
507                error!(
508                    target = "duroxide::providers::postgres",
509                    operation = "append_with_execution",
510                    error_type = "validation_error",
511                    instance_id = %instance,
512                    execution_id = execution_id,
513                    "event_id must be set by runtime"
514                );
515                return Err(ProviderError::permanent(
516                    "append_with_execution",
517                    "event_id must be set by runtime",
518                ));
519            }
520
521            let event_json = serde_json::to_string(event).map_err(|e| {
522                ProviderError::permanent(
523                    "append_with_execution",
524                    format!("Failed to serialize event: {e}"),
525                )
526            })?;
527
528            let event_type = format!("{event:?}")
529                .split('{')
530                .next()
531                .unwrap_or("Unknown")
532                .trim()
533                .to_string();
534
535            events_payload.push(serde_json::json!({
536                "event_id": event.event_id(),
537                "event_type": event_type,
538                "event_data": event_json,
539            }));
540        }
541
542        let events_json = serde_json::Value::Array(events_payload);
543
544        sqlx::query(&format!(
545            "SELECT {}.append_history($1, $2, $3)",
546            self.schema_name
547        ))
548        .bind(instance)
549        .bind(execution_id as i64)
550        .bind(events_json)
551        .execute(&*self.pool)
552        .await
553        .map_err(|e| Self::sqlx_to_provider_error("append_with_execution", e))?;
554
555        debug!(
556            target = "duroxide::providers::postgres",
557            operation = "append_with_execution",
558            instance_id = %instance,
559            execution_id = execution_id,
560            event_count = new_events.len(),
561            "Appended history events via stored procedure"
562        );
563
564        Ok(())
565    }
566
567    #[instrument(skip(self), target = "duroxide::providers::postgres")]
568    async fn enqueue_for_worker(&self, item: WorkItem) -> Result<(), ProviderError> {
569        let work_item = serde_json::to_string(&item).map_err(|e| {
570            ProviderError::permanent(
571                "enqueue_worker_work",
572                format!("Failed to serialize work item: {e}"),
573            )
574        })?;
575
576        sqlx::query(&format!(
577            "SELECT {}.enqueue_worker_work($1)",
578            self.schema_name
579        ))
580        .bind(work_item)
581        .execute(&*self.pool)
582        .await
583        .map_err(|e| {
584            error!(
585                target = "duroxide::providers::postgres",
586                operation = "enqueue_worker_work",
587                error_type = "database_error",
588                error = %e,
589                "Failed to enqueue worker work"
590            );
591            Self::sqlx_to_provider_error("enqueue_worker_work", e)
592        })?;
593
594        Ok(())
595    }
596
597    #[instrument(skip(self), target = "duroxide::providers::postgres")]
598    async fn fetch_work_item(
599        &self,
600        lock_timeout: Duration,
601    ) -> Result<Option<(WorkItem, String)>, ProviderError> {
602        let start = std::time::Instant::now();
603
604        // Convert Duration to milliseconds
605        let lock_timeout_ms = lock_timeout.as_millis() as i64;
606
607        let row = match sqlx::query_as::<_, (String, String)>(&format!(
608            "SELECT * FROM {}.fetch_work_item($1, $2)",
609            self.schema_name
610        ))
611        .bind(Self::now_millis())
612        .bind(lock_timeout_ms)
613        .fetch_optional(&*self.pool)
614        .await
615        {
616            Ok(row) => row,
617            Err(e) => {
618                return Err(Self::sqlx_to_provider_error("fetch_work_item", e));
619            }
620        };
621
622        let (work_item_json, lock_token) = match row {
623            Some(row) => row,
624            None => return Ok(None),
625        };
626
627        let work_item: WorkItem = serde_json::from_str(&work_item_json).map_err(|e| {
628            ProviderError::permanent(
629                "fetch_work_item",
630                format!("Failed to deserialize worker item: {e}"),
631            )
632        })?;
633
634        let duration_ms = start.elapsed().as_millis() as u64;
635
636        // Extract instance for logging - different work item types have different structures
637        let instance_id = match &work_item {
638            WorkItem::ActivityExecute { instance, .. } => instance.as_str(),
639            WorkItem::ActivityCompleted { instance, .. } => instance.as_str(),
640            WorkItem::ActivityFailed { instance, .. } => instance.as_str(),
641            WorkItem::StartOrchestration { instance, .. } => instance.as_str(),
642            WorkItem::TimerFired { instance, .. } => instance.as_str(),
643            WorkItem::ExternalRaised { instance, .. } => instance.as_str(),
644            WorkItem::CancelInstance { instance, .. } => instance.as_str(),
645            WorkItem::ContinueAsNew { instance, .. } => instance.as_str(),
646            WorkItem::SubOrchCompleted {
647                parent_instance, ..
648            } => parent_instance.as_str(),
649            WorkItem::SubOrchFailed {
650                parent_instance, ..
651            } => parent_instance.as_str(),
652        };
653
654        debug!(
655            target = "duroxide::providers::postgres",
656            operation = "fetch_work_item",
657            instance_id = %instance_id,
658            duration_ms = duration_ms,
659            "Fetched activity work item via stored procedure"
660        );
661
662        Ok(Some((work_item, lock_token)))
663    }
664
665    #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
666    async fn ack_work_item(&self, token: &str, completion: WorkItem) -> Result<(), ProviderError> {
667        let start = std::time::Instant::now();
668
669        // Extract instance ID from completion WorkItem
670        let instance_id = match &completion {
671            WorkItem::ActivityCompleted { instance, .. }
672            | WorkItem::ActivityFailed { instance, .. } => instance,
673            _ => {
674                error!(
675                    target = "duroxide::providers::postgres",
676                    operation = "ack_worker",
677                    error_type = "invalid_completion_type",
678                    "Invalid completion work item type"
679                );
680                return Err(ProviderError::permanent(
681                    "ack_worker",
682                    "Invalid completion work item type",
683                ));
684            }
685        };
686
687        let completion_json = serde_json::to_string(&completion).map_err(|e| {
688            ProviderError::permanent(
689                "ack_worker",
690                format!("Failed to serialize completion: {e}"),
691            )
692        })?;
693
694        // Call stored procedure to atomically delete worker item and enqueue completion
695        sqlx::query(&format!(
696            "SELECT {}.ack_worker($1, $2, $3)",
697            self.schema_name
698        ))
699        .bind(token)
700        .bind(instance_id)
701        .bind(completion_json)
702        .execute(&*self.pool)
703        .await
704        .map_err(|e| {
705            if e.to_string().contains("Worker queue item not found") {
706                error!(
707                    target = "duroxide::providers::postgres",
708                    operation = "ack_worker",
709                    error_type = "worker_item_not_found",
710                    token = %token,
711                    "Worker queue item not found or already processed"
712                );
713                ProviderError::permanent(
714                    "ack_worker",
715                    "Worker queue item not found or already processed",
716                )
717            } else {
718                Self::sqlx_to_provider_error("ack_worker", e)
719            }
720        })?;
721
722        let duration_ms = start.elapsed().as_millis() as u64;
723        debug!(
724            target = "duroxide::providers::postgres",
725            operation = "ack_worker",
726            instance_id = %instance_id,
727            duration_ms = duration_ms,
728            "Acknowledged worker and enqueued completion"
729        );
730
731        Ok(())
732    }
733
734    #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")]
735    async fn renew_work_item_lock(
736        &self,
737        token: &str,
738        extend_for: Duration,
739    ) -> Result<(), ProviderError> {
740        let start = std::time::Instant::now();
741
742        // Get current time from application for consistent time reference
743        let now_ms = Self::now_millis();
744
745        // Convert Duration to seconds for the stored procedure
746        let extend_secs = extend_for.as_secs() as i64;
747
748        match sqlx::query(&format!(
749            "SELECT {}.renew_work_item_lock($1, $2, $3)",
750            self.schema_name
751        ))
752        .bind(token)
753        .bind(now_ms)
754        .bind(extend_secs)
755        .execute(&*self.pool)
756        .await
757        {
758            Ok(_) => {
759                let duration_ms = start.elapsed().as_millis() as u64;
760                debug!(
761                    target = "duroxide::providers::postgres",
762                    operation = "renew_work_item_lock",
763                    token = %token,
764                    extend_for_secs = extend_secs,
765                    duration_ms = duration_ms,
766                    "Work item lock renewed successfully"
767                );
768                Ok(())
769            }
770            Err(e) => {
771                if let SqlxError::Database(db_err) = &e {
772                    if db_err.message().contains("Lock token invalid") {
773                        return Err(ProviderError::permanent(
774                            "renew_work_item_lock",
775                            "Lock token invalid, expired, or already acked",
776                        ));
777                    }
778                } else if e.to_string().contains("Lock token invalid") {
779                    return Err(ProviderError::permanent(
780                        "renew_work_item_lock",
781                        "Lock token invalid, expired, or already acked",
782                    ));
783                }
784
785                Err(Self::sqlx_to_provider_error("renew_work_item_lock", e))
786            }
787        }
788    }
789
790    #[instrument(skip(self), target = "duroxide::providers::postgres")]
791    async fn enqueue_for_orchestrator(
792        &self,
793        item: WorkItem,
794        delay: Option<Duration>,
795    ) -> Result<(), ProviderError> {
796        let work_item = serde_json::to_string(&item).map_err(|e| {
797            ProviderError::permanent(
798                "enqueue_orchestrator_work",
799                format!("Failed to serialize work item: {e}"),
800            )
801        })?;
802
803        // Extract instance ID from WorkItem enum
804        let instance_id = match &item {
805            WorkItem::StartOrchestration { instance, .. }
806            | WorkItem::ActivityCompleted { instance, .. }
807            | WorkItem::ActivityFailed { instance, .. }
808            | WorkItem::TimerFired { instance, .. }
809            | WorkItem::ExternalRaised { instance, .. }
810            | WorkItem::CancelInstance { instance, .. }
811            | WorkItem::ContinueAsNew { instance, .. } => instance,
812            WorkItem::SubOrchCompleted {
813                parent_instance, ..
814            }
815            | WorkItem::SubOrchFailed {
816                parent_instance, ..
817            } => parent_instance,
818            WorkItem::ActivityExecute { .. } => {
819                return Err(ProviderError::permanent(
820                    "enqueue_orchestrator_work",
821                    "ActivityExecute should go to worker queue, not orchestrator queue",
822                ));
823            }
824        };
825
826        // Determine visible_at: use max of fire_at_ms (for TimerFired) and delay
827        let now_ms = Self::now_millis();
828
829        let visible_at_ms = if let WorkItem::TimerFired { fire_at_ms, .. } = &item {
830            if *fire_at_ms > 0 {
831                // Take max of fire_at_ms and delay (if provided)
832                if let Some(delay) = delay {
833                    std::cmp::max(*fire_at_ms, now_ms as u64 + delay.as_millis() as u64)
834                } else {
835                    *fire_at_ms
836                }
837            } else {
838                // fire_at_ms is 0, use delay or NOW()
839                delay
840                    .map(|d| now_ms as u64 + d.as_millis() as u64)
841                    .unwrap_or(now_ms as u64)
842            }
843        } else {
844            // Non-timer item: use delay or NOW()
845            delay
846                .map(|d| now_ms as u64 + d.as_millis() as u64)
847                .unwrap_or(now_ms as u64)
848        };
849
850        let visible_at = Utc
851            .timestamp_millis_opt(visible_at_ms as i64)
852            .single()
853            .ok_or_else(|| {
854                ProviderError::permanent(
855                    "enqueue_orchestrator_work",
856                    "Invalid visible_at timestamp",
857                )
858            })?;
859
860        // ⚠️ CRITICAL: DO NOT extract orchestration metadata - instance creation happens via ack_orchestration_item metadata
861        // Pass NULL for orchestration_name, orchestration_version, execution_id parameters
862
863        // Call stored procedure to enqueue work
864        sqlx::query(&format!(
865            "SELECT {}.enqueue_orchestrator_work($1, $2, $3, $4, $5, $6)",
866            self.schema_name
867        ))
868        .bind(instance_id)
869        .bind(&work_item)
870        .bind(visible_at)
871        .bind::<Option<String>>(None) // orchestration_name - NULL
872        .bind::<Option<String>>(None) // orchestration_version - NULL
873        .bind::<Option<i64>>(None) // execution_id - NULL
874        .execute(&*self.pool)
875        .await
876        .map_err(|e| {
877            error!(
878                target = "duroxide::providers::postgres",
879                operation = "enqueue_orchestrator_work",
880                error_type = "database_error",
881                error = %e,
882                instance_id = %instance_id,
883                "Failed to enqueue orchestrator work"
884            );
885            Self::sqlx_to_provider_error("enqueue_orchestrator_work", e)
886        })?;
887
888        debug!(
889            target = "duroxide::providers::postgres",
890            operation = "enqueue_orchestrator_work",
891            instance_id = %instance_id,
892            delay_ms = delay.map(|d| d.as_millis() as u64),
893            "Enqueued orchestrator work"
894        );
895
896        Ok(())
897    }
898
899    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
900    async fn read_with_execution(
901        &self,
902        instance: &str,
903        execution_id: u64,
904    ) -> Result<Vec<Event>, ProviderError> {
905        let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
906            "SELECT event_data FROM {} WHERE instance_id = $1 AND execution_id = $2 ORDER BY event_id",
907            self.table_name("history")
908        ))
909        .bind(instance)
910        .bind(execution_id as i64)
911        .fetch_all(&*self.pool)
912        .await
913        .ok()
914        .unwrap_or_default();
915
916        Ok(event_data_rows
917            .into_iter()
918            .filter_map(|event_data| serde_json::from_str::<Event>(&event_data).ok())
919            .collect())
920    }
921
922    fn as_management_capability(&self) -> Option<&dyn ProviderAdmin> {
923        Some(self)
924    }
925}
926
927#[async_trait::async_trait]
928impl ProviderAdmin for PostgresProvider {
929    #[instrument(skip(self), target = "duroxide::providers::postgres")]
930    async fn list_instances(&self) -> Result<Vec<String>, ProviderError> {
931        sqlx::query_scalar(&format!(
932            "SELECT instance_id FROM {}.list_instances()",
933            self.schema_name
934        ))
935        .fetch_all(&*self.pool)
936        .await
937        .map_err(|e| Self::sqlx_to_provider_error("list_instances", e))
938    }
939
940    #[instrument(skip(self), fields(status = %status), target = "duroxide::providers::postgres")]
941    async fn list_instances_by_status(&self, status: &str) -> Result<Vec<String>, ProviderError> {
942        sqlx::query_scalar(&format!(
943            "SELECT instance_id FROM {}.list_instances_by_status($1)",
944            self.schema_name
945        ))
946        .bind(status)
947        .fetch_all(&*self.pool)
948        .await
949        .map_err(|e| Self::sqlx_to_provider_error("list_instances_by_status", e))
950    }
951
952    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
953    async fn list_executions(&self, instance: &str) -> Result<Vec<u64>, ProviderError> {
954        let execution_ids: Vec<i64> = sqlx::query_scalar(&format!(
955            "SELECT execution_id FROM {}.list_executions($1)",
956            self.schema_name
957        ))
958        .bind(instance)
959        .fetch_all(&*self.pool)
960        .await
961        .map_err(|e| Self::sqlx_to_provider_error("list_executions", e))?;
962
963        Ok(execution_ids.into_iter().map(|id| id as u64).collect())
964    }
965
966    #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
967    async fn read_history_with_execution_id(
968        &self,
969        instance: &str,
970        execution_id: u64,
971    ) -> Result<Vec<Event>, ProviderError> {
972        let event_data_rows: Vec<String> = sqlx::query_scalar(&format!(
973            "SELECT out_event_data FROM {}.fetch_history_with_execution($1, $2)",
974            self.schema_name
975        ))
976        .bind(instance)
977        .bind(execution_id as i64)
978        .fetch_all(&*self.pool)
979        .await
980        .map_err(|e| Self::sqlx_to_provider_error("read_execution", e))?;
981
982        event_data_rows
983            .into_iter()
984            .filter_map(|event_data| serde_json::from_str::<Event>(&event_data).ok())
985            .collect::<Vec<Event>>()
986            .into_iter()
987            .map(Ok)
988            .collect()
989    }
990
991    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
992    async fn read_history(&self, instance: &str) -> Result<Vec<Event>, ProviderError> {
993        let execution_id = self.latest_execution_id(instance).await?;
994        self.read_history_with_execution_id(instance, execution_id)
995            .await
996    }
997
998    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
999    async fn latest_execution_id(&self, instance: &str) -> Result<u64, ProviderError> {
1000        sqlx::query_scalar(&format!(
1001            "SELECT {}.latest_execution_id($1)",
1002            self.schema_name
1003        ))
1004        .bind(instance)
1005        .fetch_optional(&*self.pool)
1006        .await
1007        .map_err(|e| Self::sqlx_to_provider_error("latest_execution_id", e))?
1008        .map(|id: i64| id as u64)
1009        .ok_or_else(|| ProviderError::permanent("latest_execution_id", "Instance not found"))
1010    }
1011
1012    #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")]
1013    async fn get_instance_info(&self, instance: &str) -> Result<InstanceInfo, ProviderError> {
1014        let row: Option<(
1015            String,
1016            String,
1017            String,
1018            i64,
1019            chrono::DateTime<Utc>,
1020            Option<chrono::DateTime<Utc>>,
1021            Option<String>,
1022            Option<String>,
1023        )> = sqlx::query_as(&format!(
1024            "SELECT * FROM {}.get_instance_info($1)",
1025            self.schema_name
1026        ))
1027        .bind(instance)
1028        .fetch_optional(&*self.pool)
1029        .await
1030        .map_err(|e| Self::sqlx_to_provider_error("get_instance_info", e))?;
1031
1032        let (
1033            instance_id,
1034            orchestration_name,
1035            orchestration_version,
1036            current_execution_id,
1037            created_at,
1038            updated_at,
1039            status,
1040            output,
1041        ) =
1042            row.ok_or_else(|| ProviderError::permanent("get_instance_info", "Instance not found"))?;
1043
1044        Ok(InstanceInfo {
1045            instance_id,
1046            orchestration_name,
1047            orchestration_version,
1048            current_execution_id: current_execution_id as u64,
1049            status: status.unwrap_or_else(|| "Running".to_string()),
1050            output,
1051            created_at: created_at.timestamp_millis() as u64,
1052            updated_at: updated_at
1053                .map(|dt| dt.timestamp_millis() as u64)
1054                .unwrap_or(created_at.timestamp_millis() as u64),
1055        })
1056    }
1057
1058    #[instrument(skip(self), fields(instance = %instance, execution_id = execution_id), target = "duroxide::providers::postgres")]
1059    async fn get_execution_info(
1060        &self,
1061        instance: &str,
1062        execution_id: u64,
1063    ) -> Result<ExecutionInfo, ProviderError> {
1064        let row: Option<(
1065            i64,
1066            String,
1067            Option<String>,
1068            chrono::DateTime<Utc>,
1069            Option<chrono::DateTime<Utc>>,
1070            i64,
1071        )> = sqlx::query_as(&format!(
1072            "SELECT * FROM {}.get_execution_info($1, $2)",
1073            self.schema_name
1074        ))
1075        .bind(instance)
1076        .bind(execution_id as i64)
1077        .fetch_optional(&*self.pool)
1078        .await
1079        .map_err(|e| Self::sqlx_to_provider_error("get_execution_info", e))?;
1080
1081        let (exec_id, status, output, started_at, completed_at, event_count) = row
1082            .ok_or_else(|| ProviderError::permanent("get_execution_info", "Execution not found"))?;
1083
1084        Ok(ExecutionInfo {
1085            execution_id: exec_id as u64,
1086            status,
1087            output,
1088            started_at: started_at.timestamp_millis() as u64,
1089            completed_at: completed_at.map(|dt| dt.timestamp_millis() as u64),
1090            event_count: event_count as usize,
1091        })
1092    }
1093
1094    #[instrument(skip(self), target = "duroxide::providers::postgres")]
1095    async fn get_system_metrics(&self) -> Result<SystemMetrics, ProviderError> {
1096        let row: Option<(i64, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!(
1097            "SELECT * FROM {}.get_system_metrics()",
1098            self.schema_name
1099        ))
1100        .fetch_optional(&*self.pool)
1101        .await
1102        .map_err(|e| Self::sqlx_to_provider_error("get_system_metrics", e))?;
1103
1104        let (
1105            total_instances,
1106            total_executions,
1107            running_instances,
1108            completed_instances,
1109            failed_instances,
1110            total_events,
1111        ) = row.ok_or_else(|| {
1112            ProviderError::permanent("get_system_metrics", "Failed to get system metrics")
1113        })?;
1114
1115        Ok(SystemMetrics {
1116            total_instances: total_instances as u64,
1117            total_executions: total_executions as u64,
1118            running_instances: running_instances as u64,
1119            completed_instances: completed_instances as u64,
1120            failed_instances: failed_instances as u64,
1121            total_events: total_events as u64,
1122        })
1123    }
1124
1125    #[instrument(skip(self), target = "duroxide::providers::postgres")]
1126    async fn get_queue_depths(&self) -> Result<QueueDepths, ProviderError> {
1127        let now_ms = Self::now_millis();
1128
1129        let row: Option<(i64, i64)> = sqlx::query_as(&format!(
1130            "SELECT * FROM {}.get_queue_depths($1)",
1131            self.schema_name
1132        ))
1133        .bind(now_ms)
1134        .fetch_optional(&*self.pool)
1135        .await
1136        .map_err(|e| Self::sqlx_to_provider_error("get_queue_depths", e))?;
1137
1138        let (orchestrator_queue, worker_queue) = row.ok_or_else(|| {
1139            ProviderError::permanent("get_queue_depths", "Failed to get queue depths")
1140        })?;
1141
1142        Ok(QueueDepths {
1143            orchestrator_queue: orchestrator_queue as usize,
1144            worker_queue: worker_queue as usize,
1145            timer_queue: 0, // Timers are in orchestrator queue with delayed visibility
1146        })
1147    }
1148}