eventuali_core/store/
sqlite.rs

1use crate::{
2    store::{traits::EventStoreBackend, EventStoreConfig},
3    Event, EventData, EventMetadata, AggregateId, AggregateVersion, Result, EventualiError,
4};
5use async_trait::async_trait;
6use base64::{Engine as _, engine::general_purpose};
7use chrono::{DateTime, Utc};
8use serde_json;
9use sqlx::{sqlite::{SqlitePool, SqliteConnectOptions, SqliteJournalMode}, Row};
10use std::str::FromStr;
11use uuid::Uuid;
12
13pub struct SQLiteBackend {
14    pool: SqlitePool,
15    table_name: String,
16}
17
18impl SQLiteBackend {
19    pub async fn new(config: &EventStoreConfig) -> Result<Self> {
20        match config {
21            EventStoreConfig::SQLite {
22                database_path,
23                max_connections,
24                table_name,
25            } => {
26                let pool = if database_path == ":memory:" {
27                    // For in-memory databases, use the simple connection string
28                    sqlx::sqlite::SqlitePoolOptions::new()
29                        .max_connections(max_connections.unwrap_or(10))
30                        .connect("sqlite://:memory:")
31                        .await?
32                } else {
33                    // For file-based SQLite, use SqliteConnectOptions with create_if_missing
34                    let path = std::path::Path::new(database_path);
35                    let full_path = if path.is_absolute() {
36                        database_path.clone()
37                    } else {
38                        // Convert relative path to absolute path
39                        std::env::current_dir()
40                            .map_err(|e| EventualiError::Configuration(format!("Cannot get current directory: {e}")))?
41                            .join(path)
42                            .to_string_lossy()
43                            .to_string()
44                    };
45                    
46                    // Create parent directories if they don't exist
47                    let db_path = std::path::Path::new(&full_path);
48                    if let Some(parent) = db_path.parent() {
49                        if !parent.exists() {
50                            std::fs::create_dir_all(parent)
51                                .map_err(|e| EventualiError::Configuration(format!("Cannot create directory {}: {}", parent.display(), e)))?;
52                        }
53                    }
54                    
55                    
56                    // Use SqliteConnectOptions for proper file database creation
57                    let connect_options = SqliteConnectOptions::from_str(&full_path)
58                        .map_err(|e| EventualiError::Configuration(format!("Invalid SQLite path {full_path}: {e}")))?
59                        .create_if_missing(true)
60                        .journal_mode(SqliteJournalMode::Wal);
61                    
62                    sqlx::sqlite::SqlitePoolOptions::new()
63                        .max_connections(max_connections.unwrap_or(10))
64                        .connect_with(connect_options)
65                        .await?
66                };
67
68                let table_name = table_name
69                    .as_deref()
70                    .unwrap_or("events")
71                    .to_string();
72
73                let backend = Self { pool, table_name };
74                Ok(backend)
75            }
76            _ => Err(EventualiError::Configuration(
77                "Invalid configuration for SQLite backend".to_string(),
78            )),
79        }
80    }
81
82    async fn create_tables(&self) -> Result<()> {
83        // Enable foreign keys (WAL mode is set in connection options)
84        sqlx::query("PRAGMA foreign_keys = ON")
85            .execute(&self.pool)
86            .await?;
87
88        let create_events_table = format!(
89            r#"
90            CREATE TABLE IF NOT EXISTS {} (
91                id TEXT PRIMARY KEY,
92                aggregate_id TEXT NOT NULL,
93                aggregate_type TEXT NOT NULL,
94                event_type TEXT NOT NULL,
95                event_version INTEGER NOT NULL,
96                aggregate_version INTEGER NOT NULL,
97                event_data TEXT NOT NULL,
98                event_data_type TEXT NOT NULL DEFAULT 'json',
99                metadata TEXT NOT NULL,
100                timestamp TEXT NOT NULL,
101                UNIQUE(aggregate_id, aggregate_version)
102            );
103            
104            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_id ON {} (aggregate_id);
105            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_type ON {} (aggregate_type);
106            CREATE INDEX IF NOT EXISTS idx_{}_timestamp ON {} (timestamp);
107            "#,
108            self.table_name,
109            self.table_name, self.table_name,
110            self.table_name, self.table_name,
111            self.table_name, self.table_name
112        );
113
114        sqlx::query(&create_events_table)
115            .execute(&self.pool)
116            .await?;
117
118        Ok(())
119    }
120}
121
122#[async_trait]
123impl EventStoreBackend for SQLiteBackend {
124    async fn initialize(&mut self) -> Result<()> {
125        self.create_tables().await
126    }
127
128    async fn save_events(&self, events: Vec<Event>) -> Result<()> {
129        if events.is_empty() {
130            return Ok(());
131        }
132
133        let mut tx = self.pool.begin().await?;
134
135        for event in events {
136            let (event_data_text, event_data_type) = match &event.data {
137                EventData::Json(value) => (serde_json::to_string(value)?, "json"),
138                EventData::Protobuf(bytes) => {
139                    // Store protobuf as base64 for SQLite
140                    let base64_data = general_purpose::STANDARD.encode(bytes);
141                    (base64_data, "protobuf")
142                }
143            };
144
145            let metadata_text = serde_json::to_string(&event.metadata)?;
146            let timestamp_text = event.timestamp.to_rfc3339();
147
148            let query = format!(
149                r#"
150                INSERT INTO {} (
151                    id, aggregate_id, aggregate_type, event_type, event_version,
152                    aggregate_version, event_data, event_data_type, metadata, timestamp
153                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
154                "#,
155                self.table_name
156            );
157
158            sqlx::query(&query)
159                .bind(event.id.to_string())
160                .bind(&event.aggregate_id)
161                .bind(&event.aggregate_type)
162                .bind(&event.event_type)
163                .bind(event.event_version)
164                .bind(event.aggregate_version)
165                .bind(&event_data_text)
166                .bind(event_data_type)
167                .bind(&metadata_text)
168                .bind(&timestamp_text)
169                .execute(&mut *tx)
170                .await
171                .map_err(|e| match e {
172                    sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
173                        EventualiError::OptimisticConcurrency {
174                            expected: event.aggregate_version,
175                            actual: event.aggregate_version - 1,
176                        }
177                    }
178                    _ => EventualiError::Database(e),
179                })?;
180        }
181
182        tx.commit().await?;
183        Ok(())
184    }
185
186    async fn load_events(
187        &self,
188        aggregate_id: &AggregateId,
189        from_version: Option<AggregateVersion>,
190    ) -> Result<Vec<Event>> {
191        let query = match from_version {
192            Some(_version) => format!(
193                r#"
194                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
195                       aggregate_version, event_data, event_data_type, metadata, timestamp
196                FROM {} 
197                WHERE aggregate_id = ? AND aggregate_version > ?
198                ORDER BY aggregate_version ASC
199                "#,
200                self.table_name
201            ),
202            None => format!(
203                r#"
204                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
205                       aggregate_version, event_data, event_data_type, metadata, timestamp
206                FROM {} 
207                WHERE aggregate_id = ?
208                ORDER BY aggregate_version ASC
209                "#,
210                self.table_name
211            ),
212        };
213
214        let rows = if let Some(version) = from_version {
215            sqlx::query(&query)
216                .bind(aggregate_id)
217                .bind(version)
218                .fetch_all(&self.pool)
219                .await?
220        } else {
221            sqlx::query(&query)
222                .bind(aggregate_id)
223                .fetch_all(&self.pool)
224                .await?
225        };
226
227        let mut events = Vec::new();
228        for row in rows {
229            let event = self.row_to_event(row)?;
230            events.push(event);
231        }
232
233        Ok(events)
234    }
235
236    async fn load_events_by_type(
237        &self,
238        aggregate_type: &str,
239        from_version: Option<AggregateVersion>,
240    ) -> Result<Vec<Event>> {
241        let query = match from_version {
242            Some(_version) => format!(
243                r#"
244                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
245                       aggregate_version, event_data, event_data_type, metadata, timestamp
246                FROM {} 
247                WHERE aggregate_type = ? AND aggregate_version > ?
248                ORDER BY timestamp ASC
249                "#,
250                self.table_name
251            ),
252            None => format!(
253                r#"
254                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
255                       aggregate_version, event_data, event_data_type, metadata, timestamp
256                FROM {} 
257                WHERE aggregate_type = ?
258                ORDER BY timestamp ASC
259                "#,
260                self.table_name
261            ),
262        };
263
264        let rows = if let Some(version) = from_version {
265            sqlx::query(&query)
266                .bind(aggregate_type)
267                .bind(version)
268                .fetch_all(&self.pool)
269                .await?
270        } else {
271            sqlx::query(&query)
272                .bind(aggregate_type)
273                .fetch_all(&self.pool)
274                .await?
275        };
276
277        let mut events = Vec::new();
278        for row in rows {
279            let event = self.row_to_event(row)?;
280            events.push(event);
281        }
282
283        Ok(events)
284    }
285
286    async fn get_aggregate_version(&self, aggregate_id: &AggregateId) -> Result<Option<AggregateVersion>> {
287        let query = format!(
288            "SELECT MAX(aggregate_version) FROM {} WHERE aggregate_id = ?",
289            self.table_name
290        );
291
292        let row = sqlx::query(&query)
293            .bind(aggregate_id)
294            .fetch_optional(&self.pool)
295            .await?;
296
297        if let Some(row) = row {
298            let version: Option<i64> = row.try_get(0)?;
299            Ok(version)
300        } else {
301            Ok(None)
302        }
303    }
304}
305
306impl SQLiteBackend {
307    fn row_to_event(&self, row: sqlx::sqlite::SqliteRow) -> Result<Event> {
308        let id_str: String = row.try_get("id")?;
309        let id = Uuid::parse_str(&id_str)
310            .map_err(|_| EventualiError::InvalidEventData("Invalid UUID format".to_string()))?;
311        
312        let aggregate_id: String = row.try_get("aggregate_id")?;
313        let aggregate_type: String = row.try_get("aggregate_type")?;
314        let event_type: String = row.try_get("event_type")?;
315        let event_version: i32 = row.try_get("event_version")?;
316        let aggregate_version: i64 = row.try_get("aggregate_version")?;
317        let event_data_text: String = row.try_get("event_data")?;
318        let event_data_type: String = row.try_get("event_data_type")?;
319        let metadata_text: String = row.try_get("metadata")?;
320        let timestamp_text: String = row.try_get("timestamp")?;
321
322        let event_data = match event_data_type.as_str() {
323            "json" => {
324                let json_value: serde_json::Value = serde_json::from_str(&event_data_text)?;
325                EventData::Json(json_value)
326            }
327            "protobuf" => {
328                let bytes = general_purpose::STANDARD.decode(&event_data_text).map_err(|_| {
329                    EventualiError::InvalidEventData("Invalid base64 protobuf data".to_string())
330                })?;
331                EventData::Protobuf(bytes)
332            }
333            _ => {
334                return Err(EventualiError::InvalidEventData(format!(
335                    "Unknown event data type: {event_data_type}"
336                )))
337            }
338        };
339
340        let metadata: EventMetadata = serde_json::from_str(&metadata_text)?;
341        let timestamp: DateTime<Utc> = DateTime::parse_from_rfc3339(&timestamp_text)
342            .map_err(|_| EventualiError::InvalidEventData("Invalid timestamp format".to_string()))?
343            .with_timezone(&Utc);
344
345        Ok(Event {
346            id,
347            aggregate_id,
348            aggregate_type,
349            event_type,
350            event_version,
351            aggregate_version,
352            data: event_data,
353            metadata,
354            timestamp,
355        })
356    }
357}