rain-engine-store-pg 0.1.0

Postgres ledger store for RainEngine
Documentation
//! Postgres ledger store for RainEngine sessions.
//!
//! Records are stored append-only as typed JSONB payloads with sequence numbers
//! used for deterministic replay and projection.

use async_trait::async_trait;
use rain_engine_core::{
    EngineOutcome, MemoryError, MemoryStore, NewSessionRecord, PendingApprovalRecord, RecordPage,
    RecordPageQuery, SessionListQuery, SessionRecord, SessionRecordKind, SessionSnapshot,
    SessionSummary, StoredSessionRecord,
};
use sqlx::{PgPool, Row};

#[derive(Clone)]
pub struct PgMemoryStore {
    pool: PgPool,
}

impl PgMemoryStore {
    pub async fn connect(database_url: &str) -> Result<Self, MemoryError> {
        let pool = PgPool::connect(database_url)
            .await
            .map_err(|err| MemoryError::new(err.to_string()))?;
        sqlx::query(
            r#"
            CREATE TABLE IF NOT EXISTS session_records (
                sequence_no BIGSERIAL PRIMARY KEY,
                session_id TEXT NOT NULL,
                occurred_at_ms BIGINT NOT NULL,
                record_kind TEXT NOT NULL,
                trigger_id TEXT,
                idempotency_key TEXT,
                payload_json JSONB NOT NULL
            )
            "#,
        )
        .execute(&pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;
        sqlx::query(
            "CREATE INDEX IF NOT EXISTS idx_session_records_session_id ON session_records(session_id)",
        )
        .execute(&pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;
        Ok(Self { pool })
    }

    pub fn connect_lazy(database_url: &str) -> Result<Self, MemoryError> {
        let pool =
            PgPool::connect_lazy(database_url).map_err(|err| MemoryError::new(err.to_string()))?;
        Ok(Self { pool })
    }
}

pub type PostgresMemoryStore = PgMemoryStore;

#[async_trait]
impl MemoryStore for PgMemoryStore {
    async fn append_record(
        &self,
        record: NewSessionRecord,
    ) -> Result<StoredSessionRecord, MemoryError> {
        let payload_json = serde_json::to_value(&record.record)
            .map_err(|err| MemoryError::new(err.to_string()))?;
        let sequence_no: i64 = sqlx::query_scalar(
            r#"
            INSERT INTO session_records (session_id, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json)
            VALUES ($1, $2, $3, $4, $5, $6)
            RETURNING sequence_no
            "#,
        )
        .bind(&record.session_id)
        .bind(record.occurred_at_ms)
        .bind(record.record_kind.as_str())
        .bind(&record.trigger_id)
        .bind(&record.idempotency_key)
        .bind(payload_json)
        .fetch_one(&self.pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;

        Ok(StoredSessionRecord {
            session_id: record.session_id,
            sequence_no,
            occurred_at_ms: record.occurred_at_ms,
            record_kind: record.record_kind,
            trigger_id: record.trigger_id,
            idempotency_key: record.idempotency_key,
            record: record.record,
        })
    }

    async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
        let rows = sqlx::query(
            r#"
            SELECT sequence_no, payload_json
            FROM session_records
            WHERE session_id = $1
            ORDER BY sequence_no ASC
            "#,
        )
        .bind(session_id)
        .fetch_all(&self.pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;
        let mut records = Vec::with_capacity(rows.len());
        let mut last_sequence_no = None;
        let mut latest_outcome = None;
        for row in rows {
            last_sequence_no = Some(row.get::<i64, _>("sequence_no"));
            let value: serde_json::Value = row.get("payload_json");
            let record: SessionRecord =
                serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
            if let SessionRecord::Outcome(outcome) = &record {
                latest_outcome = Some(outcome.clone());
            }
            records.push(record);
        }
        Ok(SessionSnapshot {
            session_id: session_id.to_string(),
            records,
            last_sequence_no,
            latest_outcome,
        })
    }

    async fn list_sessions(
        &self,
        query: SessionListQuery,
    ) -> Result<Vec<SessionSummary>, MemoryError> {
        let rows = sqlx::query(
            r#"
            SELECT session_id, MIN(occurred_at_ms) AS first_recorded_at_ms, MAX(occurred_at_ms) AS last_recorded_at_ms, COUNT(*) AS record_count
            FROM session_records
            WHERE ($1::BIGINT IS NULL OR occurred_at_ms >= $1)
              AND ($2::BIGINT IS NULL OR occurred_at_ms <= $2)
            GROUP BY session_id
            ORDER BY session_id
            OFFSET $3 LIMIT $4
            "#,
        )
        .bind(query.since_ms)
        .bind(query.until_ms)
        .bind(query.offset as i64)
        .bind(query.limit as i64)
        .fetch_all(&self.pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;
        Ok(rows
            .into_iter()
            .map(|row| SessionSummary {
                session_id: row.get("session_id"),
                first_recorded_at_ms: row.get("first_recorded_at_ms"),
                last_recorded_at_ms: row.get("last_recorded_at_ms"),
                record_count: row.get::<i64, _>("record_count") as usize,
            })
            .collect())
    }

    async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
        let rows = sqlx::query(
            r#"
            SELECT sequence_no, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json
            FROM session_records
            WHERE session_id = $1
              AND ($2::BIGINT IS NULL OR occurred_at_ms >= $2)
              AND ($3::BIGINT IS NULL OR occurred_at_ms <= $3)
            ORDER BY sequence_no
            OFFSET $4 LIMIT $5
            "#,
        )
        .bind(&query.session_id)
        .bind(query.since_ms)
        .bind(query.until_ms)
        .bind(query.offset as i64)
        .bind(query.limit as i64)
        .fetch_all(&self.pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;
        let mut records = Vec::with_capacity(rows.len());
        for row in rows {
            let value: serde_json::Value = row.get("payload_json");
            let record: SessionRecord =
                serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
            records.push(StoredSessionRecord {
                session_id: query.session_id.clone(),
                sequence_no: row.get("sequence_no"),
                occurred_at_ms: row.get("occurred_at_ms"),
                record_kind: SessionRecordKind::parse(row.get::<&str, _>("record_kind"))
                    .ok_or_else(|| MemoryError::new("unknown record kind"))?,
                trigger_id: row.get("trigger_id"),
                idempotency_key: row.get("idempotency_key"),
                record,
            });
        }
        let next_offset = (records.len() == query.limit).then_some(query.offset + records.len());
        Ok(RecordPage {
            session_id: query.session_id,
            records,
            next_offset,
        })
    }

    async fn find_outcome_by_idempotency_key(
        &self,
        session_id: &str,
        idempotency_key: &str,
    ) -> Result<Option<EngineOutcome>, MemoryError> {
        let row = sqlx::query(
            r#"
            SELECT payload_json
            FROM session_records
            WHERE session_id = $1
              AND idempotency_key = $2
              AND record_kind = $3
            ORDER BY sequence_no DESC
            LIMIT 1
            "#,
        )
        .bind(session_id)
        .bind(idempotency_key)
        .bind(SessionRecordKind::Outcome.as_str())
        .fetch_optional(&self.pool)
        .await
        .map_err(|err| MemoryError::new(err.to_string()))?;

        match row {
            Some(row) => {
                let value: serde_json::Value = row.get("payload_json");
                let record: SessionRecord = serde_json::from_value(value)
                    .map_err(|err| MemoryError::new(err.to_string()))?;
                match record {
                    SessionRecord::Outcome(outcome) => {
                        Ok(Some(EngineOutcome::from_record(outcome)))
                    }
                    _ => Ok(None),
                }
            }
            None => Ok(None),
        }
    }

    async fn find_pending_approval_by_resume_token(
        &self,
        session_id: &str,
        resume_token: &str,
    ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
        let snapshot = <Self as MemoryStore>::load_session(self, session_id).await?;
        let mut pending = None::<PendingApprovalRecord>;
        for record in snapshot.records {
            match record {
                SessionRecord::PendingApproval(record)
                    if record.resume_token.as_str() == resume_token =>
                {
                    pending = Some(record);
                }
                SessionRecord::ApprovalResolution(record)
                    if record.resume_token.as_str() == resume_token =>
                {
                    pending = None;
                }
                _ => {}
            }
        }
        Ok(pending)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn lazy_connection_validates_configuration_shape() {
        let store = PgMemoryStore::connect_lazy("postgres://postgres:postgres@localhost/test")
            .expect("lazy pool");
        let _ = store;
    }
}