rain_engine_store_sqlite/
lib.rs1use async_trait::async_trait;
4use rain_engine_core::{
5 EngineOutcome, MemoryError, MemoryStore, NewSessionRecord, PendingApprovalRecord, RecordPage,
6 RecordPageQuery, SessionListQuery, SessionRecord, SessionRecordKind, SessionSnapshot,
7 SessionSummary, StoredSessionRecord,
8};
9use serde_json::from_str;
10use sqlx::{Row, SqlitePool};
11
12#[derive(Clone)]
13pub struct SqliteMemoryStore {
14 pool: SqlitePool,
15}
16
17impl SqliteMemoryStore {
18 pub async fn connect(database_url: &str) -> Result<Self, MemoryError> {
19 let pool = SqlitePool::connect(database_url)
20 .await
21 .map_err(|err| MemoryError::new(err.to_string()))?;
22 sqlx::query(
23 r#"
24 CREATE TABLE IF NOT EXISTS session_records (
25 sequence_no INTEGER PRIMARY KEY AUTOINCREMENT,
26 session_id TEXT NOT NULL,
27 occurred_at_ms INTEGER NOT NULL,
28 record_kind TEXT NOT NULL,
29 trigger_id TEXT,
30 idempotency_key TEXT,
31 payload_json TEXT NOT NULL
32 )
33 "#,
34 )
35 .execute(&pool)
36 .await
37 .map_err(|err| MemoryError::new(err.to_string()))?;
38 sqlx::query(
39 "CREATE INDEX IF NOT EXISTS idx_session_records_session_id ON session_records(session_id)",
40 )
41 .execute(&pool)
42 .await
43 .map_err(|err| MemoryError::new(err.to_string()))?;
44
45 sqlx::query(
46 r#"
47 CREATE TABLE IF NOT EXISTS skills (
48 name TEXT PRIMARY KEY,
49 manifest_json TEXT NOT NULL,
50 wasm_bytes BLOB NOT NULL
51 )
52 "#,
53 )
54 .execute(&pool)
55 .await
56 .map_err(|err| MemoryError::new(err.to_string()))?;
57
58 Ok(Self { pool })
59 }
60
61 pub fn pool(&self) -> &SqlitePool {
62 &self.pool
63 }
64}
65
66#[async_trait]
67impl MemoryStore for SqliteMemoryStore {
68 async fn append_record(
69 &self,
70 record: NewSessionRecord,
71 ) -> Result<StoredSessionRecord, MemoryError> {
72 let payload_json = serde_json::to_string(&record.record)
73 .map_err(|err| MemoryError::new(err.to_string()))?;
74 let sequence_no: i64 = sqlx::query_scalar(
75 r#"
76 INSERT INTO session_records (session_id, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json)
77 VALUES (?, ?, ?, ?, ?, ?)
78 RETURNING sequence_no
79 "#,
80 )
81 .bind(&record.session_id)
82 .bind(record.occurred_at_ms)
83 .bind(record.record_kind.as_str())
84 .bind(&record.trigger_id)
85 .bind(&record.idempotency_key)
86 .bind(payload_json)
87 .fetch_one(&self.pool)
88 .await
89 .map_err(|err| MemoryError::new(err.to_string()))?;
90
91 Ok(StoredSessionRecord {
92 session_id: record.session_id,
93 sequence_no,
94 occurred_at_ms: record.occurred_at_ms,
95 record_kind: record.record_kind,
96 trigger_id: record.trigger_id,
97 idempotency_key: record.idempotency_key,
98 record: record.record,
99 })
100 }
101
102 async fn load_session(&self, session_id: &str) -> Result<SessionSnapshot, MemoryError> {
103 let rows = sqlx::query(
104 r#"
105 SELECT sequence_no, payload_json
106 FROM session_records
107 WHERE session_id = ?
108 ORDER BY sequence_no ASC
109 "#,
110 )
111 .bind(session_id)
112 .fetch_all(&self.pool)
113 .await
114 .map_err(|err| MemoryError::new(err.to_string()))?;
115
116 let mut records = Vec::with_capacity(rows.len());
117 let mut last_sequence_no = None;
118 let mut latest_outcome = None;
119 for row in rows {
120 last_sequence_no = Some(row.get::<i64, _>("sequence_no"));
121 let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
122 .map_err(|err| MemoryError::new(err.to_string()))?;
123 if let SessionRecord::Outcome(outcome) = &record {
124 latest_outcome = Some(outcome.clone());
125 }
126 records.push(record);
127 }
128 Ok(SessionSnapshot {
129 session_id: session_id.to_string(),
130 records,
131 last_sequence_no,
132 latest_outcome,
133 })
134 }
135
136 async fn list_sessions(
137 &self,
138 query: SessionListQuery,
139 ) -> Result<Vec<SessionSummary>, MemoryError> {
140 let rows = sqlx::query(
141 r#"
142 SELECT session_id, MIN(occurred_at_ms) AS first_recorded_at_ms, MAX(occurred_at_ms) AS last_recorded_at_ms, COUNT(*) AS record_count
143 FROM session_records
144 WHERE (? IS NULL OR occurred_at_ms >= ?)
145 AND (? IS NULL OR occurred_at_ms <= ?)
146 GROUP BY session_id
147 ORDER BY session_id
148 LIMIT ? OFFSET ?
149 "#,
150 )
151 .bind(query.since_ms)
152 .bind(query.since_ms)
153 .bind(query.until_ms)
154 .bind(query.until_ms)
155 .bind(query.limit as i64)
156 .bind(query.offset as i64)
157 .fetch_all(&self.pool)
158 .await
159 .map_err(|err| MemoryError::new(err.to_string()))?;
160
161 Ok(rows
162 .into_iter()
163 .map(|row| SessionSummary {
164 session_id: row.get("session_id"),
165 first_recorded_at_ms: row.get("first_recorded_at_ms"),
166 last_recorded_at_ms: row.get("last_recorded_at_ms"),
167 record_count: row.get::<i64, _>("record_count") as usize,
168 })
169 .collect())
170 }
171
172 async fn list_records(&self, query: RecordPageQuery) -> Result<RecordPage, MemoryError> {
173 let rows = sqlx::query(
174 r#"
175 SELECT sequence_no, occurred_at_ms, record_kind, trigger_id, idempotency_key, payload_json
176 FROM session_records
177 WHERE session_id = ?
178 AND (? IS NULL OR occurred_at_ms >= ?)
179 AND (? IS NULL OR occurred_at_ms <= ?)
180 ORDER BY sequence_no
181 LIMIT ? OFFSET ?
182 "#,
183 )
184 .bind(&query.session_id)
185 .bind(query.since_ms)
186 .bind(query.since_ms)
187 .bind(query.until_ms)
188 .bind(query.until_ms)
189 .bind(query.limit as i64)
190 .bind(query.offset as i64)
191 .fetch_all(&self.pool)
192 .await
193 .map_err(|err| MemoryError::new(err.to_string()))?;
194
195 let mut records = Vec::with_capacity(rows.len());
196 for row in rows {
197 let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
198 .map_err(|err| MemoryError::new(err.to_string()))?;
199 records.push(StoredSessionRecord {
200 session_id: query.session_id.clone(),
201 sequence_no: row.get("sequence_no"),
202 occurred_at_ms: row.get("occurred_at_ms"),
203 record_kind: SessionRecordKind::parse(row.get::<&str, _>("record_kind"))
204 .ok_or_else(|| MemoryError::new("unknown record kind"))?,
205 trigger_id: row.get("trigger_id"),
206 idempotency_key: row.get("idempotency_key"),
207 record,
208 });
209 }
210 let next_offset = (records.len() == query.limit).then_some(query.offset + records.len());
211 Ok(RecordPage {
212 session_id: query.session_id,
213 records,
214 next_offset,
215 })
216 }
217
218 async fn find_outcome_by_idempotency_key(
219 &self,
220 session_id: &str,
221 idempotency_key: &str,
222 ) -> Result<Option<EngineOutcome>, MemoryError> {
223 let row = sqlx::query(
224 r#"
225 SELECT payload_json
226 FROM session_records
227 WHERE session_id = ?
228 AND idempotency_key = ?
229 AND record_kind = ?
230 ORDER BY sequence_no DESC
231 LIMIT 1
232 "#,
233 )
234 .bind(session_id)
235 .bind(idempotency_key)
236 .bind(SessionRecordKind::Outcome.as_str())
237 .fetch_optional(&self.pool)
238 .await
239 .map_err(|err| MemoryError::new(err.to_string()))?;
240
241 match row {
242 Some(row) => {
243 let record: SessionRecord = from_str(row.get::<&str, _>("payload_json"))
244 .map_err(|err| MemoryError::new(err.to_string()))?;
245 match record {
246 SessionRecord::Outcome(outcome) => {
247 Ok(Some(EngineOutcome::from_record(outcome)))
248 }
249 _ => Ok(None),
250 }
251 }
252 None => Ok(None),
253 }
254 }
255
256 async fn find_pending_approval_by_resume_token(
257 &self,
258 session_id: &str,
259 resume_token: &str,
260 ) -> Result<Option<PendingApprovalRecord>, MemoryError> {
261 let snapshot = <Self as MemoryStore>::load_session(self, session_id).await?;
262 let mut pending = None::<PendingApprovalRecord>;
263 for record in snapshot.records {
264 match record {
265 SessionRecord::PendingApproval(record)
266 if record.resume_token.as_str() == resume_token =>
267 {
268 pending = Some(record);
269 }
270 SessionRecord::ApprovalResolution(record)
271 if record.resume_token.as_str() == resume_token =>
272 {
273 pending = None;
274 }
275 _ => {}
276 }
277 }
278 Ok(pending)
279 }
280}
281
282#[async_trait]
283impl rain_engine_core::SkillStore for SqliteMemoryStore {
284 async fn store_skill(
285 &self,
286 manifest: rain_engine_core::SkillManifest,
287 wasm_bytes: Vec<u8>,
288 ) -> Result<(), String> {
289 let manifest_json = serde_json::to_string(&manifest)
290 .map_err(|err| format!("Manifest serialization failed: {err}"))?;
291
292 sqlx::query(
293 r#"
294 INSERT INTO skills (name, manifest_json, wasm_bytes)
295 VALUES (?, ?, ?)
296 ON CONFLICT(name) DO UPDATE SET
297 manifest_json = excluded.manifest_json,
298 wasm_bytes = excluded.wasm_bytes
299 "#,
300 )
301 .bind(&manifest.name)
302 .bind(manifest_json)
303 .bind(wasm_bytes)
304 .execute(&self.pool)
305 .await
306 .map_err(|err| format!("Skill storage failed: {err}"))?;
307
308 Ok(())
309 }
310
311 async fn list_skills(&self) -> Result<Vec<(rain_engine_core::SkillManifest, Vec<u8>)>, String> {
312 let rows = sqlx::query(
313 r#"
314 SELECT manifest_json, wasm_bytes FROM skills
315 "#,
316 )
317 .fetch_all(&self.pool)
318 .await
319 .map_err(|err| format!("Skill retrieval failed: {err}"))?;
320
321 let mut skills = Vec::with_capacity(rows.len());
322 for row in rows {
323 let manifest_json: &str = row.get("manifest_json");
324 let manifest: rain_engine_core::SkillManifest = serde_json::from_str(manifest_json)
325 .map_err(|err| format!("Manifest deserialization failed: {err}"))?;
326 let wasm_bytes: Vec<u8> = row.get("wasm_bytes");
327 skills.push((manifest, wasm_bytes));
328 }
329 Ok(skills)
330 }
331
332 async fn remove_skill(&self, name: &str) -> Result<(), String> {
333 sqlx::query(
334 r#"
335 DELETE FROM skills WHERE name = ?
336 "#,
337 )
338 .bind(name)
339 .execute(&self.pool)
340 .await
341 .map_err(|err| format!("Skill removal failed: {err}"))?;
342
343 Ok(())
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use rain_engine_core::{
351 AdvanceRequest, AgentAction, AgentEngine, AgentTrigger, ContinueRequest, EngineOutcome,
352 MockLlmProvider, ProcessRequest,
353 };
354 use std::sync::Arc;
355
356 #[tokio::test]
357 async fn sqlite_store_replays_in_order() {
358 let store = Arc::new(
359 SqliteMemoryStore::connect("sqlite::memory:")
360 .await
361 .expect("sqlite store"),
362 );
363 let llm = Arc::new(MockLlmProvider::scripted(vec![AgentAction::Respond {
364 content: "ok".to_string(),
365 }]));
366 let engine = AgentEngine::new(llm, store.clone());
367
368 run_until_terminal(
369 &engine,
370 ProcessRequest::new(
371 "sqlite-session",
372 AgentTrigger::Message {
373 user_id: "u".to_string(),
374 content: "hello".to_string(),
375 attachments: Vec::new(),
376 },
377 ),
378 )
379 .await
380 .expect("outcome");
381
382 let snapshot = store
383 .load_session("sqlite-session")
384 .await
385 .expect("snapshot");
386 assert!(matches!(
387 snapshot.records.first(),
388 Some(SessionRecord::Trigger(_))
389 ));
390 assert!(
391 snapshot
392 .records
393 .iter()
394 .any(|record| matches!(record, SessionRecord::Outcome(_)))
395 );
396 }
397
398 async fn run_until_terminal(
399 engine: &AgentEngine,
400 request: ProcessRequest,
401 ) -> Result<EngineOutcome, rain_engine_core::EngineError> {
402 let mut next = AdvanceRequest::Trigger(request.clone());
403 loop {
404 let result = engine.advance(next).await?;
405 if let Some(outcome) = result.outcome {
406 return Ok(outcome);
407 }
408 next = AdvanceRequest::Continue(ContinueRequest {
409 session_id: request.session_id.clone(),
410 granted_scopes: request.granted_scopes.clone(),
411 policy: request.policy.clone(),
412 provider: request.provider.clone(),
413 cancellation: request.cancellation.clone(),
414 });
415 }
416 }
417}