1use super::{AggregateSnapshot, SnapshotStore, SnapshotConfig, SnapshotCompression};
2use crate::{AggregateId, AggregateVersion, Result, EventualiError};
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use serde_json;
6use sqlx::{sqlite::SqlitePool, Row};
7use uuid::Uuid;
8
9pub struct SqliteSnapshotStore {
10 pool: SqlitePool,
11 table_name: String,
12}
13
14impl SqliteSnapshotStore {
15 pub fn new(pool: SqlitePool, table_name: Option<String>) -> Self {
16 Self {
17 pool,
18 table_name: table_name.unwrap_or_else(|| "aggregate_snapshots".to_string()),
19 }
20 }
21
22 pub async fn initialize(&self) -> Result<()> {
23 let create_table = format!(
24 r#"
25 CREATE TABLE IF NOT EXISTS {} (
26 snapshot_id TEXT PRIMARY KEY,
27 aggregate_id TEXT NOT NULL,
28 aggregate_type TEXT NOT NULL,
29 aggregate_version INTEGER NOT NULL,
30 state_data BLOB NOT NULL,
31 compression TEXT NOT NULL,
32 metadata TEXT NOT NULL,
33 created_at TEXT NOT NULL,
34 UNIQUE(aggregate_id, aggregate_version)
35 );
36
37 CREATE INDEX IF NOT EXISTS idx_{}_aggregate_id ON {} (aggregate_id);
38 CREATE INDEX IF NOT EXISTS idx_{}_aggregate_type ON {} (aggregate_type);
39 CREATE INDEX IF NOT EXISTS idx_{}_created_at ON {} (created_at);
40 CREATE INDEX IF NOT EXISTS idx_{}_aggregate_version ON {} (aggregate_id, aggregate_version DESC);
41 "#,
42 self.table_name,
43 self.table_name, self.table_name,
44 self.table_name, self.table_name,
45 self.table_name, self.table_name,
46 self.table_name, self.table_name
47 );
48
49 sqlx::query(&create_table)
50 .execute(&self.pool)
51 .await?;
52
53 Ok(())
54 }
55}
56
57#[async_trait]
58impl SnapshotStore for SqliteSnapshotStore {
59 async fn save_snapshot(&self, snapshot: AggregateSnapshot) -> Result<()> {
60 let compression_str = match snapshot.compression {
61 SnapshotCompression::None => "none",
62 SnapshotCompression::Gzip => "gzip",
63 SnapshotCompression::Lz4 => "lz4",
64 };
65
66 let metadata_json = serde_json::to_string(&snapshot.metadata)?;
67
68 let query = format!(
69 r#"
70 INSERT INTO {} (
71 snapshot_id, aggregate_id, aggregate_type, aggregate_version,
72 state_data, compression, metadata, created_at
73 ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
74 "#,
75 self.table_name
76 );
77
78 sqlx::query(&query)
79 .bind(snapshot.snapshot_id.to_string())
80 .bind(&snapshot.aggregate_id)
81 .bind(&snapshot.aggregate_type)
82 .bind(snapshot.aggregate_version)
83 .bind(&snapshot.state_data)
84 .bind(compression_str)
85 .bind(&metadata_json)
86 .bind(snapshot.created_at.to_rfc3339())
87 .execute(&self.pool)
88 .await
89 .map_err(|e| match e {
90 sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
91 EventualiError::Configuration(format!(
92 "Snapshot already exists for aggregate {} at version {}",
93 snapshot.aggregate_id, snapshot.aggregate_version
94 ))
95 }
96 _ => EventualiError::Database(e),
97 })?;
98
99 Ok(())
100 }
101
102 async fn load_latest_snapshot(&self, aggregate_id: &AggregateId) -> Result<Option<AggregateSnapshot>> {
103 let query = format!(
104 r#"
105 SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
106 state_data, compression, metadata, created_at
107 FROM {}
108 WHERE aggregate_id = ?
109 ORDER BY aggregate_version DESC
110 LIMIT 1
111 "#,
112 self.table_name
113 );
114
115 let row = sqlx::query(&query)
116 .bind(aggregate_id)
117 .fetch_optional(&self.pool)
118 .await?;
119
120 if let Some(row) = row {
121 Ok(Some(self.row_to_snapshot(row)?))
122 } else {
123 Ok(None)
124 }
125 }
126
127 async fn load_snapshot(&self, snapshot_id: Uuid) -> Result<Option<AggregateSnapshot>> {
128 let query = format!(
129 r#"
130 SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
131 state_data, compression, metadata, created_at
132 FROM {}
133 WHERE snapshot_id = ?
134 "#,
135 self.table_name
136 );
137
138 let row = sqlx::query(&query)
139 .bind(snapshot_id.to_string())
140 .fetch_optional(&self.pool)
141 .await?;
142
143 if let Some(row) = row {
144 Ok(Some(self.row_to_snapshot(row)?))
145 } else {
146 Ok(None)
147 }
148 }
149
150 async fn list_snapshots(&self, aggregate_id: &AggregateId) -> Result<Vec<AggregateSnapshot>> {
151 let query = format!(
152 r#"
153 SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
154 state_data, compression, metadata, created_at
155 FROM {}
156 WHERE aggregate_id = ?
157 ORDER BY aggregate_version DESC
158 "#,
159 self.table_name
160 );
161
162 let rows = sqlx::query(&query)
163 .bind(aggregate_id)
164 .fetch_all(&self.pool)
165 .await?;
166
167 let mut snapshots = Vec::new();
168 for row in rows {
169 snapshots.push(self.row_to_snapshot(row)?);
170 }
171
172 Ok(snapshots)
173 }
174
175 async fn delete_snapshot(&self, snapshot_id: Uuid) -> Result<()> {
176 let query = format!("DELETE FROM {} WHERE snapshot_id = ?", self.table_name);
177
178 sqlx::query(&query)
179 .bind(snapshot_id.to_string())
180 .execute(&self.pool)
181 .await?;
182
183 Ok(())
184 }
185
186 async fn cleanup_old_snapshots(&self, config: &SnapshotConfig) -> Result<u64> {
187 if !config.auto_cleanup {
188 return Ok(0);
189 }
190
191 let cutoff_time = Utc::now() - chrono::Duration::hours(config.max_snapshot_age_hours as i64);
192
193 let query = format!(
194 "DELETE FROM {} WHERE created_at < ?",
195 self.table_name
196 );
197
198 let result = sqlx::query(&query)
199 .bind(cutoff_time.to_rfc3339())
200 .execute(&self.pool)
201 .await?;
202
203 Ok(result.rows_affected())
204 }
205
206 async fn should_take_snapshot(
207 &self,
208 aggregate_id: &AggregateId,
209 current_version: AggregateVersion,
210 config: &SnapshotConfig,
211 ) -> Result<bool> {
212 if current_version % config.snapshot_frequency != 0 {
214 return Ok(false);
215 }
216
217 let query = format!(
219 "SELECT COUNT(*) FROM {} WHERE aggregate_id = ? AND aggregate_version = ?",
220 self.table_name
221 );
222
223 let row = sqlx::query(&query)
224 .bind(aggregate_id)
225 .bind(current_version)
226 .fetch_one(&self.pool)
227 .await?;
228
229 let count: i64 = row.try_get(0)?;
230 Ok(count == 0)
231 }
232}
233
234impl SqliteSnapshotStore {
235 fn row_to_snapshot(&self, row: sqlx::sqlite::SqliteRow) -> Result<AggregateSnapshot> {
236 let snapshot_id_str: String = row.try_get("snapshot_id")?;
237 let snapshot_id = Uuid::parse_str(&snapshot_id_str)
238 .map_err(|_| EventualiError::InvalidEventData("Invalid snapshot UUID format".to_string()))?;
239
240 let aggregate_id: String = row.try_get("aggregate_id")?;
241 let aggregate_type: String = row.try_get("aggregate_type")?;
242 let aggregate_version: i64 = row.try_get("aggregate_version")?;
243 let state_data: Vec<u8> = row.try_get("state_data")?;
244 let compression_str: String = row.try_get("compression")?;
245 let metadata_json: String = row.try_get("metadata")?;
246 let created_at_str: String = row.try_get("created_at")?;
247
248 let compression = match compression_str.as_str() {
249 "none" => SnapshotCompression::None,
250 "gzip" => SnapshotCompression::Gzip,
251 "lz4" => SnapshotCompression::Lz4,
252 _ => return Err(EventualiError::InvalidEventData(format!(
253 "Unknown compression type: {compression_str}"
254 ))),
255 };
256
257 let metadata = serde_json::from_str(&metadata_json)?;
258
259 let created_at: DateTime<Utc> = DateTime::parse_from_rfc3339(&created_at_str)
260 .map_err(|_| EventualiError::InvalidEventData("Invalid timestamp format".to_string()))?
261 .with_timezone(&Utc);
262
263 Ok(AggregateSnapshot {
264 snapshot_id,
265 aggregate_id,
266 aggregate_type,
267 aggregate_version,
268 state_data,
269 compression,
270 metadata,
271 created_at,
272 })
273 }
274}