eventuali_core/snapshot/
sqlite_store.rs

1use super::{AggregateSnapshot, SnapshotStore, SnapshotConfig, SnapshotCompression};
2use crate::{AggregateId, AggregateVersion, Result, EventualiError};
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use serde_json;
6use sqlx::{sqlite::SqlitePool, Row};
7use uuid::Uuid;
8
9pub struct SqliteSnapshotStore {
10    pool: SqlitePool,
11    table_name: String,
12}
13
14impl SqliteSnapshotStore {
15    pub fn new(pool: SqlitePool, table_name: Option<String>) -> Self {
16        Self {
17            pool,
18            table_name: table_name.unwrap_or_else(|| "aggregate_snapshots".to_string()),
19        }
20    }
21
22    pub async fn initialize(&self) -> Result<()> {
23        let create_table = format!(
24            r#"
25            CREATE TABLE IF NOT EXISTS {} (
26                snapshot_id TEXT PRIMARY KEY,
27                aggregate_id TEXT NOT NULL,
28                aggregate_type TEXT NOT NULL,
29                aggregate_version INTEGER NOT NULL,
30                state_data BLOB NOT NULL,
31                compression TEXT NOT NULL,
32                metadata TEXT NOT NULL,
33                created_at TEXT NOT NULL,
34                UNIQUE(aggregate_id, aggregate_version)
35            );
36            
37            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_id ON {} (aggregate_id);
38            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_type ON {} (aggregate_type);
39            CREATE INDEX IF NOT EXISTS idx_{}_created_at ON {} (created_at);
40            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_version ON {} (aggregate_id, aggregate_version DESC);
41            "#,
42            self.table_name,
43            self.table_name, self.table_name,
44            self.table_name, self.table_name,
45            self.table_name, self.table_name,
46            self.table_name, self.table_name
47        );
48
49        sqlx::query(&create_table)
50            .execute(&self.pool)
51            .await?;
52
53        Ok(())
54    }
55}
56
57#[async_trait]
58impl SnapshotStore for SqliteSnapshotStore {
59    async fn save_snapshot(&self, snapshot: AggregateSnapshot) -> Result<()> {
60        let compression_str = match snapshot.compression {
61            SnapshotCompression::None => "none",
62            SnapshotCompression::Gzip => "gzip",
63            SnapshotCompression::Lz4 => "lz4",
64        };
65
66        let metadata_json = serde_json::to_string(&snapshot.metadata)?;
67
68        let query = format!(
69            r#"
70            INSERT INTO {} (
71                snapshot_id, aggregate_id, aggregate_type, aggregate_version,
72                state_data, compression, metadata, created_at
73            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
74            "#,
75            self.table_name
76        );
77
78        sqlx::query(&query)
79            .bind(snapshot.snapshot_id.to_string())
80            .bind(&snapshot.aggregate_id)
81            .bind(&snapshot.aggregate_type)
82            .bind(snapshot.aggregate_version)
83            .bind(&snapshot.state_data)
84            .bind(compression_str)
85            .bind(&metadata_json)
86            .bind(snapshot.created_at.to_rfc3339())
87            .execute(&self.pool)
88            .await
89            .map_err(|e| match e {
90                sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
91                    EventualiError::Configuration(format!(
92                        "Snapshot already exists for aggregate {} at version {}",
93                        snapshot.aggregate_id, snapshot.aggregate_version
94                    ))
95                }
96                _ => EventualiError::Database(e),
97            })?;
98
99        Ok(())
100    }
101
102    async fn load_latest_snapshot(&self, aggregate_id: &AggregateId) -> Result<Option<AggregateSnapshot>> {
103        let query = format!(
104            r#"
105            SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
106                   state_data, compression, metadata, created_at
107            FROM {}
108            WHERE aggregate_id = ?
109            ORDER BY aggregate_version DESC
110            LIMIT 1
111            "#,
112            self.table_name
113        );
114
115        let row = sqlx::query(&query)
116            .bind(aggregate_id)
117            .fetch_optional(&self.pool)
118            .await?;
119
120        if let Some(row) = row {
121            Ok(Some(self.row_to_snapshot(row)?))
122        } else {
123            Ok(None)
124        }
125    }
126
127    async fn load_snapshot(&self, snapshot_id: Uuid) -> Result<Option<AggregateSnapshot>> {
128        let query = format!(
129            r#"
130            SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
131                   state_data, compression, metadata, created_at
132            FROM {}
133            WHERE snapshot_id = ?
134            "#,
135            self.table_name
136        );
137
138        let row = sqlx::query(&query)
139            .bind(snapshot_id.to_string())
140            .fetch_optional(&self.pool)
141            .await?;
142
143        if let Some(row) = row {
144            Ok(Some(self.row_to_snapshot(row)?))
145        } else {
146            Ok(None)
147        }
148    }
149
150    async fn list_snapshots(&self, aggregate_id: &AggregateId) -> Result<Vec<AggregateSnapshot>> {
151        let query = format!(
152            r#"
153            SELECT snapshot_id, aggregate_id, aggregate_type, aggregate_version,
154                   state_data, compression, metadata, created_at
155            FROM {}
156            WHERE aggregate_id = ?
157            ORDER BY aggregate_version DESC
158            "#,
159            self.table_name
160        );
161
162        let rows = sqlx::query(&query)
163            .bind(aggregate_id)
164            .fetch_all(&self.pool)
165            .await?;
166
167        let mut snapshots = Vec::new();
168        for row in rows {
169            snapshots.push(self.row_to_snapshot(row)?);
170        }
171
172        Ok(snapshots)
173    }
174
175    async fn delete_snapshot(&self, snapshot_id: Uuid) -> Result<()> {
176        let query = format!("DELETE FROM {} WHERE snapshot_id = ?", self.table_name);
177
178        sqlx::query(&query)
179            .bind(snapshot_id.to_string())
180            .execute(&self.pool)
181            .await?;
182
183        Ok(())
184    }
185
186    async fn cleanup_old_snapshots(&self, config: &SnapshotConfig) -> Result<u64> {
187        if !config.auto_cleanup {
188            return Ok(0);
189        }
190
191        let cutoff_time = Utc::now() - chrono::Duration::hours(config.max_snapshot_age_hours as i64);
192
193        let query = format!(
194            "DELETE FROM {} WHERE created_at < ?",
195            self.table_name
196        );
197
198        let result = sqlx::query(&query)
199            .bind(cutoff_time.to_rfc3339())
200            .execute(&self.pool)
201            .await?;
202
203        Ok(result.rows_affected())
204    }
205
206    async fn should_take_snapshot(
207        &self,
208        aggregate_id: &AggregateId,
209        current_version: AggregateVersion,
210        config: &SnapshotConfig,
211    ) -> Result<bool> {
212        // Check if we should take a snapshot based on frequency
213        if current_version % config.snapshot_frequency != 0 {
214            return Ok(false);
215        }
216
217        // Check if we already have a snapshot at this version
218        let query = format!(
219            "SELECT COUNT(*) FROM {} WHERE aggregate_id = ? AND aggregate_version = ?",
220            self.table_name
221        );
222
223        let row = sqlx::query(&query)
224            .bind(aggregate_id)
225            .bind(current_version)
226            .fetch_one(&self.pool)
227            .await?;
228
229        let count: i64 = row.try_get(0)?;
230        Ok(count == 0)
231    }
232}
233
234impl SqliteSnapshotStore {
235    fn row_to_snapshot(&self, row: sqlx::sqlite::SqliteRow) -> Result<AggregateSnapshot> {
236        let snapshot_id_str: String = row.try_get("snapshot_id")?;
237        let snapshot_id = Uuid::parse_str(&snapshot_id_str)
238            .map_err(|_| EventualiError::InvalidEventData("Invalid snapshot UUID format".to_string()))?;
239
240        let aggregate_id: String = row.try_get("aggregate_id")?;
241        let aggregate_type: String = row.try_get("aggregate_type")?;
242        let aggregate_version: i64 = row.try_get("aggregate_version")?;
243        let state_data: Vec<u8> = row.try_get("state_data")?;
244        let compression_str: String = row.try_get("compression")?;
245        let metadata_json: String = row.try_get("metadata")?;
246        let created_at_str: String = row.try_get("created_at")?;
247
248        let compression = match compression_str.as_str() {
249            "none" => SnapshotCompression::None,
250            "gzip" => SnapshotCompression::Gzip,
251            "lz4" => SnapshotCompression::Lz4,
252            _ => return Err(EventualiError::InvalidEventData(format!(
253                "Unknown compression type: {compression_str}"
254            ))),
255        };
256
257        let metadata = serde_json::from_str(&metadata_json)?;
258
259        let created_at: DateTime<Utc> = DateTime::parse_from_rfc3339(&created_at_str)
260            .map_err(|_| EventualiError::InvalidEventData("Invalid timestamp format".to_string()))?
261            .with_timezone(&Utc);
262
263        Ok(AggregateSnapshot {
264            snapshot_id,
265            aggregate_id,
266            aggregate_type,
267            aggregate_version,
268            state_data,
269            compression,
270            metadata,
271            created_at,
272        })
273    }
274}