1use async_trait::async_trait;
7use rain_engine_core::{
8 EngineOutcome, MemoryError, MemoryStore, NewSessionRecord, PendingApprovalRecord, RecordPage,
9 RecordPageQuery, SessionListQuery, SessionRecord, SessionRecordKind, SessionSnapshot,
10 SessionSummary, StoredSessionRecord,
11};
12use sqlx::{PgPool, Row};
13
14#[derive(Clone)]
15pub struct PgMemoryStore {
16 pool: PgPool,
17}
18
19impl PgMemoryStore {
20 pub async fn connect(database_url: &str) -> Result<Self, MemoryError> {
21 let pool = PgPool::connect(database_url)
22 .await
23 .map_err(|err| MemoryError::new(err.to_string()))?;
24 sqlx::query(
25 r#"
26 CREATE TABLE IF NOT EXISTS session_records (
27 sequence_no BIGSERIAL PRIMARY KEY,
28 session_id TEXT NOT NULL,
29 occurred_at_ms BIGINT NOT NULL,
30 record_kind TEXT NOT NULL,
31 trigger_id TEXT,
32 idempotency_key TEXT,
33 payload_json JSONB NOT NULL
34 )
35 "#,
36 )
37 .execute(&pool)
38 .await
39 .map_err(|err| MemoryError::new(err.to_string()))?;
40 sqlx::query(
41 "CREATE INDEX IF NOT EXISTS idx_session_records_session_id ON session_records(session_id)",
42 )
43 .execute(&pool)
44 .await
45 .map_err(|err| MemoryError::new(err.to_string()))?;
46 Ok(Self { pool })
47 }
48
49 pub fn connect_lazy(database_url: &str) -> Result<Self, MemoryError> {
50 let pool =
51 PgPool::connect_lazy(database_url).map_err(|err| MemoryError::new(err.to_string()))?;
52 Ok(Self { pool })
53 }
54}
55
56pub type PostgresMemoryStore = PgMemoryStore;
57
58#[async_trait]
59impl MemoryStore for PgMemoryStore {
60 async fn append_record(
61 &self,
62 record: NewSessionRecord,
63 ) -> Result<StoredSessionRecord, MemoryError> {
64 let payload_json = serde_json::to_value(&record.record)
65 .map_err(|err| MemoryError::new(err.to_string()))?;
66 let sequence_no: i64 = sqlx::query_scalar(
67 r#"
68 INSERT INTO session_records (session_id, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json)
69 VALUES ($1, $2, $3, $4, $5, $6)
70 RETURNING sequence_no
71 "#,
72 )
73 .bind(&record.session_id)
74 .bind(record.occurred_at_ms)
75 .bind(record.record_kind.as_str())
76 .bind(&record.trigger_id)
77 .bind(&record.idempotency_key)
78 .bind(payload_json)
79 .fetch_one(&self.pool)
80 .await
81 .map_err(|err| MemoryError::new(err.to_string()))?;
82
83 Ok(StoredSessionRecord {
84 session_id: record.session_id,
85 sequence_no,
86 occurred_at_ms: record.occurred_at_ms,
87 record_kind: record.record_kind,
88 trigger_id: record.trigger_id,
89 idempotency_key: record.idempotency_key,
90 record: record.record,
91 })
92 }
93
94 async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
95 let rows = sqlx::query(
96 r#"
97 SELECT sequence_no, payload_json
98 FROM session_records
99 WHERE session_id = $1
100 ORDER BY sequence_no ASC
101 "#,
102 )
103 .bind(session_id)
104 .fetch_all(&self.pool)
105 .await
106 .map_err(|err| MemoryError::new(err.to_string()))?;
107 let mut records = Vec::with_capacity(rows.len());
108 let mut last_sequence_no = None;
109 let mut latest_outcome = None;
110 for row in rows {
111 last_sequence_no = Some(row.get::<i64, _>("sequence_no"));
112 let value: serde_json::Value = row.get("payload_json");
113 let record: SessionRecord =
114 serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
115 if let SessionRecord::Outcome(outcome) = &record {
116 latest_outcome = Some(outcome.clone());
117 }
118 records.push(record);
119 }
120 Ok(SessionSnapshot {
121 session_id: session_id.to_string(),
122 records,
123 last_sequence_no,
124 latest_outcome,
125 })
126 }
127
128 async fn list_sessions(
129 &self,
130 query: SessionListQuery,
131 ) -> Result<Vec<SessionSummary>, MemoryError> {
132 let rows = sqlx::query(
133 r#"
134 SELECT session_id, MIN(occurred_at_ms) AS first_recorded_at_ms, MAX(occurred_at_ms) AS last_recorded_at_ms, COUNT(*) AS record_count
135 FROM session_records
136 WHERE ($1::BIGINT IS NULL OR occurred_at_ms >= $1)
137 AND ($2::BIGINT IS NULL OR occurred_at_ms <= $2)
138 GROUP BY session_id
139 ORDER BY session_id
140 OFFSET $3 LIMIT $4
141 "#,
142 )
143 .bind(query.since_ms)
144 .bind(query.until_ms)
145 .bind(query.offset as i64)
146 .bind(query.limit as i64)
147 .fetch_all(&self.pool)
148 .await
149 .map_err(|err| MemoryError::new(err.to_string()))?;
150 Ok(rows
151 .into_iter()
152 .map(|row| SessionSummary {
153 session_id: row.get("session_id"),
154 first_recorded_at_ms: row.get("first_recorded_at_ms"),
155 last_recorded_at_ms: row.get("last_recorded_at_ms"),
156 record_count: row.get::<i64, _>("record_count") as usize,
157 })
158 .collect())
159 }
160
161 async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
162 let rows = sqlx::query(
163 r#"
164 SELECT sequence_no, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json
165 FROM session_records
166 WHERE session_id = $1
167 AND ($2::BIGINT IS NULL OR occurred_at_ms >= $2)
168 AND ($3::BIGINT IS NULL OR occurred_at_ms <= $3)
169 ORDER BY sequence_no
170 OFFSET $4 LIMIT $5
171 "#,
172 )
173 .bind(&query.session_id)
174 .bind(query.since_ms)
175 .bind(query.until_ms)
176 .bind(query.offset as i64)
177 .bind(query.limit as i64)
178 .fetch_all(&self.pool)
179 .await
180 .map_err(|err| MemoryError::new(err.to_string()))?;
181 let mut records = Vec::with_capacity(rows.len());
182 for row in rows {
183 let value: serde_json::Value = row.get("payload_json");
184 let record: SessionRecord =
185 serde_json::from_value(value).map_err(|err| MemoryError::new(err.to_string()))?;
186 records.push(StoredSessionRecord {
187 session_id: query.session_id.clone(),
188 sequence_no: row.get("sequence_no"),
189 occurred_at_ms: row.get("occurred_at_ms"),
190 record_kind: SessionRecordKind::parse(row.get::<&str, _>("record_kind"))
191 .ok_or_else(|| MemoryError::new("unknown record kind"))?,
192 trigger_id: row.get("trigger_id"),
193 idempotency_key: row.get("idempotency_key"),
194 record,
195 });
196 }
197 let next_offset = (records.len() == query.limit).then_some(query.offset + records.len());
198 Ok(RecordPage {
199 session_id: query.session_id,
200 records,
201 next_offset,
202 })
203 }
204
205 async fn find_outcome_by_idempotency_key(
206 &self,
207 session_id: &str,
208 idempotency_key: &str,
209 ) -> Result<Option<EngineOutcome>, MemoryError> {
210 let row = sqlx::query(
211 r#"
212 SELECT payload_json
213 FROM session_records
214 WHERE session_id = $1
215 AND idempotency_key = $2
216 AND record_kind = $3
217 ORDER BY sequence_no DESC
218 LIMIT 1
219 "#,
220 )
221 .bind(session_id)
222 .bind(idempotency_key)
223 .bind(SessionRecordKind::Outcome.as_str())
224 .fetch_optional(&self.pool)
225 .await
226 .map_err(|err| MemoryError::new(err.to_string()))?;
227
228 match row {
229 Some(row) => {
230 let value: serde_json::Value = row.get("payload_json");
231 let record: SessionRecord = serde_json::from_value(value)
232 .map_err(|err| MemoryError::new(err.to_string()))?;
233 match record {
234 SessionRecord::Outcome(outcome) => {
235 Ok(Some(EngineOutcome::from_record(outcome)))
236 }
237 _ => Ok(None),
238 }
239 }
240 None => Ok(None),
241 }
242 }
243
244 async fn find_pending_approval_by_resume_token(
245 &self,
246 session_id: &str,
247 resume_token: &str,
248 ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
249 let snapshot = <Self as MemoryStore>::load_session(self, session_id).await?;
250 let mut pending = None::<PendingApprovalRecord>;
251 for record in snapshot.records {
252 match record {
253 SessionRecord::PendingApproval(record)
254 if record.resume_token.as_str() == resume_token =>
255 {
256 pending = Some(record);
257 }
258 SessionRecord::ApprovalResolution(record)
259 if record.resume_token.as_str() == resume_token =>
260 {
261 pending = None;
262 }
263 _ => {}
264 }
265 }
266 Ok(pending)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[tokio::test]
275 async fn lazy_connection_validates_configuration_shape() {
276 let store = PgMemoryStore::connect_lazy("postgres://postgres:postgres@localhost/test")
277 .expect("lazy pool");
278 let _ = store;
279 }
280}