Skip to main content

rain_engine_store_pg/
lib.rs

1//! Postgres ledger store for RainEngine sessions.
2//!
3//! Records are stored append-only as typed JSONB payloads with sequence numbers
4//! used for deterministic replay and projection.
5
6use async_trait::async_trait;
7use rain_engine_core::{
8    EngineOutcome, MemoryError, MemoryStore, NewSessionRecord, PendingApprovalRecord, RecordPage,
9    RecordPageQuery, SessionListQuery, SessionRecord, SessionRecordKind, SessionSnapshot,
10    SessionSummary, StoredSessionRecord,
11};
12use sqlx::{PgPool, Row};
13
14#[derive(Clone)]
15pub struct PgMemoryStore {
16    pool: PgPool,
17}
18
19impl PgMemoryStore {
20    pub async fn connect(database_url: &str) -> Result<Self, MemoryError> {
21        let pool = PgPool::connect(database_url)
22            .await
23            .map_err(|err| MemoryError::new(err.to_string()))?;
24        sqlx::query(
25            r#"
26            CREATE TABLE IF NOT EXISTS session_records (
27                sequence_no BIGSERIAL PRIMARY KEY,
28                session_id TEXT NOT NULL,
29                occurred_at_ms BIGINT NOT NULL,
30                record_kind TEXT NOT NULL,
31                trigger_id TEXT,
32                idempotency_key TEXT,
33                payload_json JSONB NOT NULL
34            )
35            "#,
36        )
37        .execute(&pool)
38        .await
39        .map_err(|err| MemoryError::new(err.to_string()))?;
40        sqlx::query(
41            "CREATE INDEX IF NOT EXISTS idx_session_records_session_id ON session_records(session_id)",
42        )
43        .execute(&pool)
44        .await
45        .map_err(|err| MemoryError::new(err.to_string()))?;
46        Ok(Self { pool })
47    }
48
49    pub fn connect_lazy(database_url: &str) -> Result<Self, MemoryError> {
50        let pool =
51            PgPool::connect_lazy(database_url).map_err(|err| MemoryError::new(err.to_string()))?;
52        Ok(Self { pool })
53    }
54}
55
56pub type PostgresMemoryStore = PgMemoryStore;
57
58#[async_trait]
59impl MemoryStore for PgMemoryStore {
60    async fn append_record(
61        &self,
62        record: NewSessionRecord,
63    ) -> Result<StoredSessionRecord, MemoryError> {
64        let payload_json = serde_json::to_value(&record.record)
65            .map_err(|err| MemoryError::new(err.to_string()))?;
66        let sequence_no: i64 = sqlx::query_scalar(
67            r#"
68            INSERT INTO session_records (session_id, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json)
69            VALUES ($1, $2, $3, $4, $5, $6)
70            RETURNING sequence_no
71            "#,
72        )
73        .bind(&record.session_id)
74        .bind(record.occurred_at_ms)
75        .bind(record.record_kind.as_str())
76        .bind(&record.trigger_id)
77        .bind(&record.idempotency_key)
78        .bind(payload_json)
79        .fetch_one(&self.pool)
80        .await
81        .map_err(|err| MemoryError::new(err.to_string()))?;
82
83        Ok(StoredSessionRecord {
84            session_id: record.session_id,
85            sequence_no,
86            occurred_at_ms: record.occurred_at_ms,
87            record_kind: record.record_kind,
88            trigger_id: record.trigger_id,
89            idempotency_key: record.idempotency_key,
90            record: record.record,
91        })
92    }
93
94    async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
95        let rows = sqlx::query(
96            r#"
97            SELECT sequence_no, payload_json
98            FROM session_records
99            WHERE session_id = $1
100            ORDER BY sequence_no ASC
101            "#,
102        )
103        .bind(session_id)
104        .fetch_all(&self.pool)
105        .await
106        .map_err(|err| MemoryError::new(err.to_string()))?;
107        let mut records = Vec::with_capacity(rows.len());
108        let mut last_sequence_no = None;
109        let mut latest_outcome = None;
110        for row in rows {
111            last_sequence_no = Some(row.get::<i64, _>("sequence_no"));
112            let value: serde_json::Value = row.get("payload_json");
113            let record: SessionRecord =
114                serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
115            if let SessionRecord::Outcome(outcome) = &record {
116                latest_outcome = Some(outcome.clone());
117            }
118            records.push(record);
119        }
120        Ok(SessionSnapshot {
121            session_id: session_id.to_string(),
122            records,
123            last_sequence_no,
124            latest_outcome,
125        })
126    }
127
128    async fn list_sessions(
129        &self,
130        query: SessionListQuery,
131    ) -> Result<Vec<SessionSummary>, MemoryError> {
132        let rows = sqlx::query(
133            r#"
134            SELECT session_id, MIN(occurred_at_ms) AS first_recorded_at_ms, MAX(occurred_at_ms) AS last_recorded_at_ms, COUNT(*) AS record_count
135            FROM session_records
136            WHERE ($1::BIGINT IS NULL OR occurred_at_ms >= $1)
137              AND ($2::BIGINT IS NULL OR occurred_at_ms <= $2)
138            GROUP BY session_id
139            ORDER BY session_id
140            OFFSET $3 LIMIT $4
141            "#,
142        )
143        .bind(query.since_ms)
144        .bind(query.until_ms)
145        .bind(query.offset as i64)
146        .bind(query.limit as i64)
147        .fetch_all(&self.pool)
148        .await
149        .map_err(|err| MemoryError::new(err.to_string()))?;
150        Ok(rows
151            .into_iter()
152            .map(|row| SessionSummary {
153                session_id: row.get("session_id"),
154                first_recorded_at_ms: row.get("first_recorded_at_ms"),
155                last_recorded_at_ms: row.get("last_recorded_at_ms"),
156                record_count: row.get::<i64, _>("record_count") as usize,
157            })
158            .collect())
159    }
160
161    async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
162        let rows = sqlx::query(
163            r#"
164            SELECT sequence_no, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json
165            FROM session_records
166            WHERE session_id = $1
167              AND ($2::BIGINT IS NULL OR occurred_at_ms >= $2)
168              AND ($3::BIGINT IS NULL OR occurred_at_ms <= $3)
169            ORDER BY sequence_no
170            OFFSET $4 LIMIT $5
171            "#,
172        )
173        .bind(&query.session_id)
174        .bind(query.since_ms)
175        .bind(query.until_ms)
176        .bind(query.offset as i64)
177        .bind(query.limit as i64)
178        .fetch_all(&self.pool)
179        .await
180        .map_err(|err| MemoryError::new(err.to_string()))?;
181        let mut records = Vec::with_capacity(rows.len());
182        for row in rows {
183            let value: serde_json::Value = row.get("payload_json");
184            let record: SessionRecord =
185                serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
186            records.push(StoredSessionRecord {
187                session_id: query.session_id.clone(),
188                sequence_no: row.get("sequence_no"),
189                occurred_at_ms: row.get("occurred_at_ms"),
190                record_kind: SessionRecordKind::parse(row.get::<&str, _>("record_kind"))
191                    .ok_or_else(|| MemoryError::new("unknown record kind"))?,
192                trigger_id: row.get("trigger_id"),
193                idempotency_key: row.get("idempotency_key"),
194                record,
195            });
196        }
197        let next_offset = (records.len() == query.limit).then_some(query.offset + records.len());
198        Ok(RecordPage {
199            session_id: query.session_id,
200            records,
201            next_offset,
202        })
203    }
204
205    async fn find_outcome_by_idempotency_key(
206        &self,
207        session_id: &str,
208        idempotency_key: &str,
209    ) -> Result<Option<EngineOutcome>, MemoryError> {
210        let row = sqlx::query(
211            r#"
212            SELECT payload_json
213            FROM session_records
214            WHERE session_id = $1
215              AND idempotency_key = $2
216              AND record_kind = $3
217            ORDER BY sequence_no DESC
218            LIMIT 1
219            "#,
220        )
221        .bind(session_id)
222        .bind(idempotency_key)
223        .bind(SessionRecordKind::Outcome.as_str())
224        .fetch_optional(&self.pool)
225        .await
226        .map_err(|err| MemoryError::new(err.to_string()))?;
227
228        match row {
229            Some(row) => {
230                let value: serde_json::Value = row.get("payload_json");
231                let record: SessionRecord = serde_json::from_value(value)
232                    .map_err(|err| MemoryError::new(err.to_string()))?;
233                match record {
234                    SessionRecord::Outcome(outcome) => {
235                        Ok(Some(EngineOutcome::from_record(outcome)))
236                    }
237                    _ => Ok(None),
238                }
239            }
240            None => Ok(None),
241        }
242    }
243
244    async fn find_pending_approval_by_resume_token(
245        &self,
246        session_id: &str,
247        resume_token: &str,
248    ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
249        let snapshot = <Self as MemoryStore>::load_session(self, session_id).await?;
250        let mut pending = None::<PendingApprovalRecord>;
251        for record in snapshot.records {
252            match record {
253                SessionRecord::PendingApproval(record)
254                    if record.resume_token.as_str() == resume_token =>
255                {
256                    pending = Some(record);
257                }
258                SessionRecord::ApprovalResolution(record)
259                    if record.resume_token.as_str() == resume_token =>
260                {
261                    pending = None;
262                }
263                _ => {}
264            }
265        }
266        Ok(pending)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[tokio::test]
275    async fn lazy_connection_validates_configuration_shape() {
276        let store = PgMemoryStore::connect_lazy("postgres://postgres:postgres@localhost/test")
277            .expect("lazy pool");
278        let _ = store;
279    }
280}