Skip to main content

rain_engine_store_sqlite/
lib.rs

1//! SQLite ledger store for local RainEngine development and tests.
2
3use async_trait::async_trait;
4use rain_engine_core::{
5    EngineOutcome, MemoryError, MemoryStore, NewSessionRecord, PendingApprovalRecord, RecordPage,
6    RecordPageQuery, SessionListQuery, SessionRecord, SessionRecordKind, SessionSnapshot,
7    SessionSummary, StoredSessionRecord,
8};
9use serde_json::from_str;
10use sqlx::{Row, SqlitePool};
11
12#[derive(Clone)]
13pub struct SqliteMemoryStore {
14    pool: SqlitePool,
15}
16
17impl SqliteMemoryStore {
18    pub async fn connect(database_url: &str) -> Result<Self, MemoryError> {
19        let pool = SqlitePool::connect(database_url)
20            .await
21            .map_err(|err| MemoryError::new(err.to_string()))?;
22        sqlx::query(
23            r#"
24            CREATE TABLE IF NOT EXISTS session_records (
25                sequence_no INTEGER PRIMARY KEY AUTOINCREMENT,
26                session_id TEXT NOT NULL,
27                occurred_at_ms INTEGER NOT NULL,
28                record_kind TEXT NOT NULL,
29                trigger_id TEXT,
30                idempotency_key TEXT,
31                payload_json TEXT NOT NULL
32            )
33            "#,
34        )
35        .execute(&pool)
36        .await
37        .map_err(|err| MemoryError::new(err.to_string()))?;
38        sqlx::query(
39            "CREATE INDEX IF NOT EXISTS idx_session_records_session_id ON session_records(session_id)",
40        )
41        .execute(&pool)
42        .await
43        .map_err(|err| MemoryError::new(err.to_string()))?;
44
45        sqlx::query(
46            r#"
47            CREATE TABLE IF NOT EXISTS skills (
48                name TEXT PRIMARY KEY,
49                manifest_json TEXT NOT NULL,
50                wasm_bytes BLOB NOT NULL
51            )
52            "#,
53        )
54        .execute(&pool)
55        .await
56        .map_err(|err| MemoryError::new(err.to_string()))?;
57
58        Ok(Self { pool })
59    }
60
61    pub fn pool(&self) -> &SqlitePool {
62        &self.pool
63    }
64}
65
66#[async_trait]
67impl MemoryStore for SqliteMemoryStore {
68    async fn append_record(
69        &self,
70        record: NewSessionRecord,
71    ) -> Result<StoredSessionRecord, MemoryError> {
72        let payload_json = serde_json::to_string(&record.record)
73            .map_err(|err| MemoryError::new(err.to_string()))?;
74        let sequence_no: i64 = sqlx::query_scalar(
75            r#"
76            INSERT INTO session_records (session_id, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json)
77            VALUES (?, ?, ?, ?, ?, ?)
78            RETURNING sequence_no
79            "#,
80        )
81        .bind(&record.session_id)
82        .bind(record.occurred_at_ms)
83        .bind(record.record_kind.as_str())
84        .bind(&record.trigger_id)
85        .bind(&record.idempotency_key)
86        .bind(payload_json)
87        .fetch_one(&self.pool)
88        .await
89        .map_err(|err| MemoryError::new(err.to_string()))?;
90
91        Ok(StoredSessionRecord {
92            session_id: record.session_id,
93            sequence_no,
94            occurred_at_ms: record.occurred_at_ms,
95            record_kind: record.record_kind,
96            trigger_id: record.trigger_id,
97            idempotency_key: record.idempotency_key,
98            record: record.record,
99        })
100    }
101
102    async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
103        let rows = sqlx::query(
104            r#"
105            SELECT sequence_no, payload_json
106            FROM session_records
107            WHERE session_id = ?
108            ORDER BY sequence_no ASC
109            "#,
110        )
111        .bind(session_id)
112        .fetch_all(&self.pool)
113        .await
114        .map_err(|err| MemoryError::new(err.to_string()))?;
115
116        let mut records = Vec::with_capacity(rows.len());
117        let mut last_sequence_no = None;
118        let mut latest_outcome = None;
119        for row in rows {
120            last_sequence_no = Some(row.get::<i64, _>("sequence_no"));
121            let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
122                .map_err(|err| MemoryError::new(err.to_string()))?;
123            if let SessionRecord::Outcome(outcome) = &record {
124                latest_outcome = Some(outcome.clone());
125            }
126            records.push(record);
127        }
128        Ok(SessionSnapshot {
129            session_id: session_id.to_string(),
130            records,
131            last_sequence_no,
132            latest_outcome,
133        })
134    }
135
136    async fn list_sessions(
137        &self,
138        query: SessionListQuery,
139    ) -> Result<Vec<SessionSummary>, MemoryError> {
140        let rows = sqlx::query(
141            r#"
142            SELECT session_id, MIN(occurred_at_ms) AS first_recorded_at_ms, MAX(occurred_at_ms) AS last_recorded_at_ms, COUNT(*) AS record_count
143            FROM session_records
144            WHERE (? IS NULL OR occurred_at_ms >= ?)
145              AND (? IS NULL OR occurred_at_ms <= ?)
146            GROUP BY session_id
147            ORDER BY session_id
148            LIMIT ? OFFSET ?
149            "#,
150        )
151        .bind(query.since_ms)
152        .bind(query.since_ms)
153        .bind(query.until_ms)
154        .bind(query.until_ms)
155        .bind(query.limit as i64)
156        .bind(query.offset as i64)
157        .fetch_all(&self.pool)
158        .await
159        .map_err(|err| MemoryError::new(err.to_string()))?;
160
161        Ok(rows
162            .into_iter()
163            .map(|row| SessionSummary {
164                session_id: row.get("session_id"),
165                first_recorded_at_ms: row.get("first_recorded_at_ms"),
166                last_recorded_at_ms: row.get("last_recorded_at_ms"),
167                record_count: row.get::<i64, _>("record_count") as usize,
168            })
169            .collect())
170    }
171
172    async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
173        let rows = sqlx::query(
174            r#"
175            SELECT sequence_no, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json
176            FROM session_records
177            WHERE session_id = ?
178              AND (? IS NULL OR occurred_at_ms >= ?)
179              AND (? IS NULL OR occurred_at_ms <= ?)
180            ORDER BY sequence_no
181            LIMIT ? OFFSET ?
182            "#,
183        )
184        .bind(&query.session_id)
185        .bind(query.since_ms)
186        .bind(query.since_ms)
187        .bind(query.until_ms)
188        .bind(query.until_ms)
189        .bind(query.limit as i64)
190        .bind(query.offset as i64)
191        .fetch_all(&self.pool)
192        .await
193        .map_err(|err| MemoryError::new(err.to_string()))?;
194
195        let mut records = Vec::with_capacity(rows.len());
196        for row in rows {
197            let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
198                .map_err(|err| MemoryError::new(err.to_string()))?;
199            records.push(StoredSessionRecord {
200                session_id: query.session_id.clone(),
201                sequence_no: row.get("sequence_no"),
202                occurred_at_ms: row.get("occurred_at_ms"),
203                record_kind: SessionRecordKind::parse(row.get::<&str, _>("record_kind"))
204                    .ok_or_else(|| MemoryError::new("unknown record kind"))?,
205                trigger_id: row.get("trigger_id"),
206                idempotency_key: row.get("idempotency_key"),
207                record,
208            });
209        }
210        let next_offset = (records.len() == query.limit).then_some(query.offset + records.len());
211        Ok(RecordPage {
212            session_id: query.session_id,
213            records,
214            next_offset,
215        })
216    }
217
218    async fn find_outcome_by_idempotency_key(
219        &self,
220        session_id: &str,
221        idempotency_key: &str,
222    ) -> Result<Option<EngineOutcome>, MemoryError> {
223        let row = sqlx::query(
224            r#"
225            SELECT payload_json
226            FROM session_records
227            WHERE session_id = ?
228              AND idempotency_key = ?
229              AND record_kind = ?
230            ORDER BY sequence_no DESC
231            LIMIT 1
232            "#,
233        )
234        .bind(session_id)
235        .bind(idempotency_key)
236        .bind(SessionRecordKind::Outcome.as_str())
237        .fetch_optional(&self.pool)
238        .await
239        .map_err(|err| MemoryError::new(err.to_string()))?;
240
241        match row {
242            Some(row) => {
243                let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
244                    .map_err(|err| MemoryError::new(err.to_string()))?;
245                match record {
246                    SessionRecord::Outcome(outcome) => {
247                        Ok(Some(EngineOutcome::from_record(outcome)))
248                    }
249                    _ => Ok(None),
250                }
251            }
252            None => Ok(None),
253        }
254    }
255
256    async fn find_pending_approval_by_resume_token(
257        &self,
258        session_id: &str,
259        resume_token: &str,
260    ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
261        let snapshot = <Self as MemoryStore>::load_session(self, session_id).await?;
262        let mut pending = None::<PendingApprovalRecord>;
263        for record in snapshot.records {
264            match record {
265                SessionRecord::PendingApproval(record)
266                    if record.resume_token.as_str() == resume_token =>
267                {
268                    pending = Some(record);
269                }
270                SessionRecord::ApprovalResolution(record)
271                    if record.resume_token.as_str() == resume_token =>
272                {
273                    pending = None;
274                }
275                _ => {}
276            }
277        }
278        Ok(pending)
279    }
280}
281
282#[async_trait]
283impl rain_engine_core::SkillStore for SqliteMemoryStore {
284    async fn store_skill(
285        &self,
286        manifest: rain_engine_core::SkillManifest,
287        wasm_bytes: Vec<u8>,
288    ) -> Result<(), String> {
289        let manifest_json = serde_json::to_string(&manifest)
290            .map_err(|err| format!("Manifest serialization failed: {err}"))?;
291
292        sqlx::query(
293            r#"
294            INSERT INTO skills (name, manifest_json, wasm_bytes)
295            VALUES (?, ?, ?)
296            ON CONFLICT(name) DO UPDATE SET
297                manifest_json = excluded.manifest_json,
298                wasm_bytes = excluded.wasm_bytes
299            "#,
300        )
301        .bind(&manifest.name)
302        .bind(manifest_json)
303        .bind(wasm_bytes)
304        .execute(&self.pool)
305        .await
306        .map_err(|err| format!("Skill storage failed: {err}"))?;
307
308        Ok(())
309    }
310
311    async fn list_skills(&self) -> Result<Vec<(rain_engine_core::SkillManifest, Vec<u8>)>, String> {
312        let rows = sqlx::query(
313            r#"
314            SELECT manifest_json, wasm_bytes FROM skills
315            "#,
316        )
317        .fetch_all(&self.pool)
318        .await
319        .map_err(|err| format!("Skill retrieval failed: {err}"))?;
320
321        let mut skills = Vec::with_capacity(rows.len());
322        for row in rows {
323            let manifest_json: &str = row.get("manifest_json");
324            let manifest: rain_engine_core::SkillManifest = serde_json::from_str(manifest_json)
325                .map_err(|err| format!("Manifest deserialization failed: {err}"))?;
326            let wasm_bytes: Vec<u8> = row.get("wasm_bytes");
327            skills.push((manifest, wasm_bytes));
328        }
329        Ok(skills)
330    }
331
332    async fn remove_skill(&self, name: &str) -> Result<(), String> {
333        sqlx::query(
334            r#"
335            DELETE FROM skills WHERE name = ?
336            "#,
337        )
338        .bind(name)
339        .execute(&self.pool)
340        .await
341        .map_err(|err| format!("Skill removal failed: {err}"))?;
342
343        Ok(())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use rain_engine_core::{
351        AdvanceRequest, AgentAction, AgentEngine, AgentTrigger, ContinueRequest, EngineOutcome,
352        MockLlmProvider, ProcessRequest,
353    };
354    use std::sync::Arc;
355
356    #[tokio::test]
357    async fn sqlite_store_replays_in_order() {
358        let store = Arc::new(
359            SqliteMemoryStore::connect("sqlite::memory:")
360                .await
361                .expect("sqlite store"),
362        );
363        let llm = Arc::new(MockLlmProvider::scripted(vec![AgentAction::Respond {
364            content: "ok".to_string(),
365        }]));
366        let engine = AgentEngine::new(llm, store.clone());
367
368        run_until_terminal(
369            &engine,
370            ProcessRequest::new(
371                "sqlite-session",
372                AgentTrigger::Message {
373                    user_id: "u".to_string(),
374                    content: "hello".to_string(),
375                    attachments: Vec::new(),
376                },
377            ),
378        )
379        .await
380        .expect("outcome");
381
382        let snapshot = store
383            .load_session("sqlite-session")
384            .await
385            .expect("snapshot");
386        assert!(matches!(
387            snapshot.records.first(),
388            Some(SessionRecord::Trigger(_))
389        ));
390        assert!(
391            snapshot
392                .records
393                .iter()
394                .any(|record| matches!(record, SessionRecord::Outcome(_)))
395        );
396    }
397
398    async fn run_until_terminal(
399        engine: &AgentEngine,
400        request: ProcessRequest,
401    ) -> Result<EngineOutcome, rain_engine_core::EngineError> {
402        let mut next = AdvanceRequest::Trigger(request.clone());
403        loop {
404            let result = engine.advance(next).await?;
405            if let Some(outcome) = result.outcome {
406                return Ok(outcome);
407            }
408            next = AdvanceRequest::Continue(ContinueRequest {
409                session_id: request.session_id.clone(),
410                granted_scopes: request.granted_scopes.clone(),
411                policy: request.policy.clone(),
412                provider: request.provider.clone(),
413                cancellation: request.cancellation.clone(),
414            });
415        }
416    }
417}