eventuali_core/store/
postgres.rs

1use crate::{
2    store::{traits::EventStoreBackend, EventStoreConfig},
3    Event, EventData, EventMetadata, AggregateId, AggregateVersion, Result, EventualiError,
4};
5use async_trait::async_trait;
6use base64::{Engine as _, engine::general_purpose};
7use chrono::{DateTime, Utc};
8use serde_json;
9use sqlx::{postgres::PgPool, Row};
10use uuid::Uuid;
11
12pub struct PostgreSQLBackend {
13    pool: PgPool,
14    table_name: String,
15}
16
17impl PostgreSQLBackend {
18    pub async fn new(config: &EventStoreConfig) -> Result<Self> {
19        match config {
20            EventStoreConfig::PostgreSQL {
21                connection_string,
22                max_connections,
23                table_name,
24            } => {
25                let pool = sqlx::postgres::PgPoolOptions::new()
26                    .max_connections(max_connections.unwrap_or(10))
27                    .connect(connection_string)
28                    .await?;
29
30                let table_name = table_name
31                    .as_deref()
32                    .unwrap_or("events")
33                    .to_string();
34
35                let backend = Self { pool, table_name };
36                Ok(backend)
37            }
38            _ => Err(EventualiError::Configuration(
39                "Invalid configuration for PostgreSQL backend".to_string(),
40            )),
41        }
42    }
43
44    async fn create_tables(&self) -> Result<()> {
45        let create_events_table = format!(
46            r#"
47            CREATE TABLE IF NOT EXISTS {} (
48                id UUID PRIMARY KEY,
49                aggregate_id VARCHAR NOT NULL,
50                aggregate_type VARCHAR NOT NULL,
51                event_type VARCHAR NOT NULL,
52                event_version INTEGER NOT NULL,
53                aggregate_version BIGINT NOT NULL,
54                event_data JSONB NOT NULL,
55                event_data_type VARCHAR NOT NULL DEFAULT 'json',
56                metadata JSONB NOT NULL,
57                timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
58                UNIQUE(aggregate_id, aggregate_version)
59            );
60            
61            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_id ON {} (aggregate_id);
62            CREATE INDEX IF NOT EXISTS idx_{}_aggregate_type ON {} (aggregate_type);
63            CREATE INDEX IF NOT EXISTS idx_{}_timestamp ON {} (timestamp);
64            "#,
65            self.table_name, 
66            self.table_name, self.table_name,
67            self.table_name, self.table_name,
68            self.table_name, self.table_name
69        );
70
71        sqlx::query(&create_events_table)
72            .execute(&self.pool)
73            .await?;
74
75        Ok(())
76    }
77}
78
79#[async_trait]
80impl EventStoreBackend for PostgreSQLBackend {
81    async fn initialize(&mut self) -> Result<()> {
82        self.create_tables().await
83    }
84
85    async fn save_events(&self, events: Vec<Event>) -> Result<()> {
86        if events.is_empty() {
87            return Ok(());
88        }
89
90        let mut tx = self.pool.begin().await?;
91
92        for event in events {
93            let (event_data_json, event_data_type) = match &event.data {
94                EventData::Json(value) => (value.clone(), "json"),
95                EventData::Protobuf(bytes) => {
96                    // Store protobuf as base64 encoded JSON for PostgreSQL
97                    let base64_data = general_purpose::STANDARD.encode(bytes);
98                    (serde_json::json!({ "data": base64_data }), "protobuf")
99                }
100            };
101
102            let metadata_json = serde_json::to_value(&event.metadata)?;
103
104            let query = format!(
105                r#"
106                INSERT INTO {} (
107                    id, aggregate_id, aggregate_type, event_type, event_version,
108                    aggregate_version, event_data, event_data_type, metadata, timestamp
109                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
110                "#,
111                self.table_name
112            );
113
114            sqlx::query(&query)
115                .bind(event.id)
116                .bind(&event.aggregate_id)
117                .bind(&event.aggregate_type)
118                .bind(&event.event_type)
119                .bind(event.event_version)
120                .bind(event.aggregate_version)
121                .bind(&event_data_json)
122                .bind(event_data_type)
123                .bind(&metadata_json)
124                .bind(event.timestamp)
125                .execute(&mut *tx)
126                .await
127                .map_err(|e| match e {
128                    sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
129                        EventualiError::OptimisticConcurrency {
130                            expected: event.aggregate_version,
131                            actual: event.aggregate_version - 1,
132                        }
133                    }
134                    _ => EventualiError::Database(e),
135                })?;
136        }
137
138        tx.commit().await?;
139        Ok(())
140    }
141
142    async fn load_events(
143        &self,
144        aggregate_id: &AggregateId,
145        from_version: Option<AggregateVersion>,
146    ) -> Result<Vec<Event>> {
147        let query = match from_version {
148            Some(_version) => format!(
149                r#"
150                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
151                       aggregate_version, event_data, event_data_type, metadata, timestamp
152                FROM {} 
153                WHERE aggregate_id = $1 AND aggregate_version > $2
154                ORDER BY aggregate_version ASC
155                "#,
156                self.table_name
157            ),
158            None => format!(
159                r#"
160                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
161                       aggregate_version, event_data, event_data_type, metadata, timestamp
162                FROM {} 
163                WHERE aggregate_id = $1
164                ORDER BY aggregate_version ASC
165                "#,
166                self.table_name
167            ),
168        };
169
170        let rows = if let Some(version) = from_version {
171            sqlx::query(&query)
172                .bind(aggregate_id)
173                .bind(version)
174                .fetch_all(&self.pool)
175                .await?
176        } else {
177            sqlx::query(&query)
178                .bind(aggregate_id)
179                .fetch_all(&self.pool)
180                .await?
181        };
182
183        let mut events = Vec::new();
184        for row in rows {
185            let event = self.row_to_event(row)?;
186            events.push(event);
187        }
188
189        Ok(events)
190    }
191
192    async fn load_events_by_type(
193        &self,
194        aggregate_type: &str,
195        from_version: Option<AggregateVersion>,
196    ) -> Result<Vec<Event>> {
197        let query = match from_version {
198            Some(_version) => format!(
199                r#"
200                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
201                       aggregate_version, event_data, event_data_type, metadata, timestamp
202                FROM {} 
203                WHERE aggregate_type = $1 AND aggregate_version > $2
204                ORDER BY timestamp ASC
205                "#,
206                self.table_name
207            ),
208            None => format!(
209                r#"
210                SELECT id, aggregate_id, aggregate_type, event_type, event_version,
211                       aggregate_version, event_data, event_data_type, metadata, timestamp
212                FROM {} 
213                WHERE aggregate_type = $1
214                ORDER BY timestamp ASC
215                "#,
216                self.table_name
217            ),
218        };
219
220        let rows = if let Some(version) = from_version {
221            sqlx::query(&query)
222                .bind(aggregate_type)
223                .bind(version)
224                .fetch_all(&self.pool)
225                .await?
226        } else {
227            sqlx::query(&query)
228                .bind(aggregate_type)
229                .fetch_all(&self.pool)
230                .await?
231        };
232
233        let mut events = Vec::new();
234        for row in rows {
235            let event = self.row_to_event(row)?;
236            events.push(event);
237        }
238
239        Ok(events)
240    }
241
242    async fn get_aggregate_version(&self, aggregate_id: &AggregateId) -> Result<Option<AggregateVersion>> {
243        let query = format!(
244            "SELECT MAX(aggregate_version) FROM {} WHERE aggregate_id = $1",
245            self.table_name
246        );
247
248        let row = sqlx::query(&query)
249            .bind(aggregate_id)
250            .fetch_optional(&self.pool)
251            .await?;
252
253        if let Some(row) = row {
254            let version: Option<i64> = row.try_get(0)?;
255            Ok(version)
256        } else {
257            Ok(None)
258        }
259    }
260}
261
262impl PostgreSQLBackend {
263    fn row_to_event(&self, row: sqlx::postgres::PgRow) -> Result<Event> {
264        let id: Uuid = row.try_get("id")?;
265        let aggregate_id: String = row.try_get("aggregate_id")?;
266        let aggregate_type: String = row.try_get("aggregate_type")?;
267        let event_type: String = row.try_get("event_type")?;
268        let event_version: i32 = row.try_get("event_version")?;
269        let aggregate_version: i64 = row.try_get("aggregate_version")?;
270        let event_data_json: serde_json::Value = row.try_get("event_data")?;
271        let event_data_type: String = row.try_get("event_data_type")?;
272        let metadata_json: serde_json::Value = row.try_get("metadata")?;
273        let timestamp: DateTime<Utc> = row.try_get("timestamp")?;
274
275        let event_data = match event_data_type.as_str() {
276            "json" => EventData::Json(event_data_json),
277            "protobuf" => {
278                let base64_data = event_data_json
279                    .get("data")
280                    .and_then(|v| v.as_str())
281                    .ok_or_else(|| {
282                        EventualiError::InvalidEventData("Invalid protobuf data format".to_string())
283                    })?;
284                let bytes = general_purpose::STANDARD.decode(base64_data).map_err(|_| {
285                    EventualiError::InvalidEventData("Invalid base64 protobuf data".to_string())
286                })?;
287                EventData::Protobuf(bytes)
288            }
289            _ => {
290                return Err(EventualiError::InvalidEventData(format!(
291                    "Unknown event data type: {event_data_type}"
292                )))
293            }
294        };
295
296        let metadata: EventMetadata = serde_json::from_value(metadata_json)?;
297
298        Ok(Event {
299            id,
300            aggregate_id,
301            aggregate_type,
302            event_type,
303            event_version,
304            aggregate_version,
305            data: event_data,
306            metadata,
307            timestamp,
308        })
309    }
310}