1use crate::{
2 store::{traits::EventStoreBackend, EventStoreConfig},
3 Event, EventData, EventMetadata, AggregateId, AggregateVersion, Result, EventualiError,
4};
5use async_trait::async_trait;
6use base64::{Engine as _, engine::general_purpose};
7use chrono::{DateTime, Utc};
8use serde_json;
9use sqlx::{postgres::PgPool, Row};
10use uuid::Uuid;
11
12pub struct PostgreSQLBackend {
13 pool: PgPool,
14 table_name: String,
15}
16
17impl PostgreSQLBackend {
18 pub async fn new(config: &EventStoreConfig) -> Result<Self> {
19 match config {
20 EventStoreConfig::PostgreSQL {
21 connection_string,
22 max_connections,
23 table_name,
24 } => {
25 let pool = sqlx::postgres::PgPoolOptions::new()
26 .max_connections(max_connections.unwrap_or(10))
27 .connect(connection_string)
28 .await?;
29
30 let table_name = table_name
31 .as_deref()
32 .unwrap_or("events")
33 .to_string();
34
35 let backend = Self { pool, table_name };
36 Ok(backend)
37 }
38 _ => Err(EventualiError::Configuration(
39 "Invalid configuration for PostgreSQL backend".to_string(),
40 )),
41 }
42 }
43
44 async fn create_tables(&self) -> Result<()> {
45 let create_events_table = format!(
46 r#"
47 CREATE TABLE IF NOT EXISTS {} (
48 id UUID PRIMARY KEY,
49 aggregate_id VARCHAR NOT NULL,
50 aggregate_type VARCHAR NOT NULL,
51 event_type VARCHAR NOT NULL,
52 event_version INTEGER NOT NULL,
53 aggregate_version BIGINT NOT NULL,
54 event_data JSONB NOT NULL,
55 event_data_type VARCHAR NOT NULL DEFAULT 'json',
56 metadata JSONB NOT NULL,
57 timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
58 UNIQUE(aggregate_id, aggregate_version)
59 );
60
61 CREATE INDEX IF NOT EXISTS idx_{}_aggregate_id ON {} (aggregate_id);
62 CREATE INDEX IF NOT EXISTS idx_{}_aggregate_type ON {} (aggregate_type);
63 CREATE INDEX IF NOT EXISTS idx_{}_timestamp ON {} (timestamp);
64 "#,
65 self.table_name,
66 self.table_name, self.table_name,
67 self.table_name, self.table_name,
68 self.table_name, self.table_name
69 );
70
71 sqlx::query(&create_events_table)
72 .execute(&self.pool)
73 .await?;
74
75 Ok(())
76 }
77}
78
79#[async_trait]
80impl EventStoreBackend for PostgreSQLBackend {
81 async fn initialize(&mut self) -> Result<()> {
82 self.create_tables().await
83 }
84
85 async fn save_events(&self, events: Vec<Event>) -> Result<()> {
86 if events.is_empty() {
87 return Ok(());
88 }
89
90 let mut tx = self.pool.begin().await?;
91
92 for event in events {
93 let (event_data_json, event_data_type) = match &event.data {
94 EventData::Json(value) => (value.clone(), "json"),
95 EventData::Protobuf(bytes) => {
96 let base64_data = general_purpose::STANDARD.encode(bytes);
98 (serde_json::json!({ "data": base64_data }), "protobuf")
99 }
100 };
101
102 let metadata_json = serde_json::to_value(&event.metadata)?;
103
104 let query = format!(
105 r#"
106 INSERT INTO {} (
107 id, aggregate_id, aggregate_type, event_type, event_version,
108 aggregate_version, event_data, event_data_type, metadata, timestamp
109 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
110 "#,
111 self.table_name
112 );
113
114 sqlx::query(&query)
115 .bind(event.id)
116 .bind(&event.aggregate_id)
117 .bind(&event.aggregate_type)
118 .bind(&event.event_type)
119 .bind(event.event_version)
120 .bind(event.aggregate_version)
121 .bind(&event_data_json)
122 .bind(event_data_type)
123 .bind(&metadata_json)
124 .bind(event.timestamp)
125 .execute(&mut *tx)
126 .await
127 .map_err(|e| match e {
128 sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
129 EventualiError::OptimisticConcurrency {
130 expected: event.aggregate_version,
131 actual: event.aggregate_version - 1,
132 }
133 }
134 _ => EventualiError::Database(e),
135 })?;
136 }
137
138 tx.commit().await?;
139 Ok(())
140 }
141
142 async fn load_events(
143 &self,
144 aggregate_id: &AggregateId,
145 from_version: Option<AggregateVersion>,
146 ) -> Result<Vec<Event>> {
147 let query = match from_version {
148 Some(_version) => format!(
149 r#"
150 SELECT id, aggregate_id, aggregate_type, event_type, event_version,
151 aggregate_version, event_data, event_data_type, metadata, timestamp
152 FROM {}
153 WHERE aggregate_id = $1 AND aggregate_version > $2
154 ORDER BY aggregate_version ASC
155 "#,
156 self.table_name
157 ),
158 None => format!(
159 r#"
160 SELECT id, aggregate_id, aggregate_type, event_type, event_version,
161 aggregate_version, event_data, event_data_type, metadata, timestamp
162 FROM {}
163 WHERE aggregate_id = $1
164 ORDER BY aggregate_version ASC
165 "#,
166 self.table_name
167 ),
168 };
169
170 let rows = if let Some(version) = from_version {
171 sqlx::query(&query)
172 .bind(aggregate_id)
173 .bind(version)
174 .fetch_all(&self.pool)
175 .await?
176 } else {
177 sqlx::query(&query)
178 .bind(aggregate_id)
179 .fetch_all(&self.pool)
180 .await?
181 };
182
183 let mut events = Vec::new();
184 for row in rows {
185 let event = self.row_to_event(row)?;
186 events.push(event);
187 }
188
189 Ok(events)
190 }
191
192 async fn load_events_by_type(
193 &self,
194 aggregate_type: &str,
195 from_version: Option<AggregateVersion>,
196 ) -> Result<Vec<Event>> {
197 let query = match from_version {
198 Some(_version) => format!(
199 r#"
200 SELECT id, aggregate_id, aggregate_type, event_type, event_version,
201 aggregate_version, event_data, event_data_type, metadata, timestamp
202 FROM {}
203 WHERE aggregate_type = $1 AND aggregate_version > $2
204 ORDER BY timestamp ASC
205 "#,
206 self.table_name
207 ),
208 None => format!(
209 r#"
210 SELECT id, aggregate_id, aggregate_type, event_type, event_version,
211 aggregate_version, event_data, event_data_type, metadata, timestamp
212 FROM {}
213 WHERE aggregate_type = $1
214 ORDER BY timestamp ASC
215 "#,
216 self.table_name
217 ),
218 };
219
220 let rows = if let Some(version) = from_version {
221 sqlx::query(&query)
222 .bind(aggregate_type)
223 .bind(version)
224 .fetch_all(&self.pool)
225 .await?
226 } else {
227 sqlx::query(&query)
228 .bind(aggregate_type)
229 .fetch_all(&self.pool)
230 .await?
231 };
232
233 let mut events = Vec::new();
234 for row in rows {
235 let event = self.row_to_event(row)?;
236 events.push(event);
237 }
238
239 Ok(events)
240 }
241
242 async fn get_aggregate_version(&self, aggregate_id: &AggregateId) -> Result<Option<AggregateVersion>> {
243 let query = format!(
244 "SELECT MAX(aggregate_version) FROM {} WHERE aggregate_id = $1",
245 self.table_name
246 );
247
248 let row = sqlx::query(&query)
249 .bind(aggregate_id)
250 .fetch_optional(&self.pool)
251 .await?;
252
253 if let Some(row) = row {
254 let version: Option<i64> = row.try_get(0)?;
255 Ok(version)
256 } else {
257 Ok(None)
258 }
259 }
260}
261
262impl PostgreSQLBackend {
263 fn row_to_event(&self, row: sqlx::postgres::PgRow) -> Result<Event> {
264 let id: Uuid = row.try_get("id")?;
265 let aggregate_id: String = row.try_get("aggregate_id")?;
266 let aggregate_type: String = row.try_get("aggregate_type")?;
267 let event_type: String = row.try_get("event_type")?;
268 let event_version: i32 = row.try_get("event_version")?;
269 let aggregate_version: i64 = row.try_get("aggregate_version")?;
270 let event_data_json: serde_json::Value = row.try_get("event_data")?;
271 let event_data_type: String = row.try_get("event_data_type")?;
272 let metadata_json: serde_json::Value = row.try_get("metadata")?;
273 let timestamp: DateTime<Utc> = row.try_get("timestamp")?;
274
275 let event_data = match event_data_type.as_str() {
276 "json" => EventData::Json(event_data_json),
277 "protobuf" => {
278 let base64_data = event_data_json
279 .get("data")
280 .and_then(|v| v.as_str())
281 .ok_or_else(|| {
282 EventualiError::InvalidEventData("Invalid protobuf data format".to_string())
283 })?;
284 let bytes = general_purpose::STANDARD.decode(base64_data).map_err(|_| {
285 EventualiError::InvalidEventData("Invalid base64 protobuf data".to_string())
286 })?;
287 EventData::Protobuf(bytes)
288 }
289 _ => {
290 return Err(EventualiError::InvalidEventData(format!(
291 "Unknown event data type: {event_data_type}"
292 )))
293 }
294 };
295
296 let metadata: EventMetadata = serde_json::from_value(metadata_json)?;
297
298 Ok(Event {
299 id,
300 aggregate_id,
301 aggregate_type,
302 event_type,
303 event_version,
304 aggregate_version,
305 data: event_data,
306 metadata,
307 timestamp,
308 })
309 }
310}