1use std::time::Duration;
2
3use eventcore_types::{
4 Event, EventFilter, EventPage, EventReader, EventStore, EventStoreError, EventStreamReader,
5 EventStreamSlice, Operation, StreamId, StreamPosition, StreamWriteEntry, StreamWrites,
6};
7use nutype::nutype;
8use serde_json::{Value, json};
9use sqlx::types::Json;
10use sqlx::{Pool, Postgres, Row, postgres::PgPoolOptions, query};
11use thiserror::Error;
12use tracing::{error, info, instrument, warn};
13use uuid::Uuid;
14
15#[derive(Debug, Error)]
16pub enum PostgresEventStoreError {
17 #[error("failed to create postgres connection pool")]
18 ConnectionFailed(#[source] sqlx::Error),
19}
20
21#[nutype(derive(Debug, Clone, Copy, PartialEq, Eq, Display, AsRef, Into))]
40pub struct MaxConnections(std::num::NonZeroU32);
41
42#[derive(Debug, Clone)]
44pub struct PostgresConfig {
45 pub max_connections: MaxConnections,
47 pub acquire_timeout: Duration,
49 pub idle_timeout: Duration,
51}
52
53impl Default for PostgresConfig {
54 fn default() -> Self {
55 const DEFAULT_MAX_CONNECTIONS: std::num::NonZeroU32 = match std::num::NonZeroU32::new(10) {
56 Some(v) => v,
57 None => unreachable!(),
58 };
59
60 Self {
61 max_connections: MaxConnections::new(DEFAULT_MAX_CONNECTIONS),
62 acquire_timeout: Duration::from_secs(30),
63 idle_timeout: Duration::from_secs(600), }
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct PostgresEventStore {
70 pool: Pool<Postgres>,
71}
72
73impl PostgresEventStore {
74 pub async fn new<S: Into<String>>(
76 connection_string: S,
77 ) -> Result<Self, PostgresEventStoreError> {
78 Self::with_config(connection_string, PostgresConfig::default()).await
79 }
80
81 pub async fn with_config<S: Into<String>>(
83 connection_string: S,
84 config: PostgresConfig,
85 ) -> Result<Self, PostgresEventStoreError> {
86 let connection_string = connection_string.into();
87 let max_connections: std::num::NonZeroU32 = config.max_connections.into();
88 let pool = PgPoolOptions::new()
89 .max_connections(max_connections.get())
90 .acquire_timeout(config.acquire_timeout)
91 .idle_timeout(config.idle_timeout)
92 .connect(&connection_string)
93 .await
94 .map_err(PostgresEventStoreError::ConnectionFailed)?;
95 Ok(Self { pool })
96 }
97
98 pub fn from_pool(pool: Pool<Postgres>) -> Self {
103 Self { pool }
104 }
105
106 #[cfg_attr(test, mutants::skip)] pub async fn ping(&self) {
108 query("SELECT 1")
109 .execute(&self.pool)
110 .await
111 .expect("postgres ping failed");
112 }
113
114 #[cfg_attr(test, mutants::skip)] pub async fn migrate(&self) {
116 sqlx::migrate!("./migrations")
117 .run(&self.pool)
118 .await
119 .expect("postgres migration failed");
120 }
121}
122
123impl EventStore for PostgresEventStore {
124 #[instrument(name = "postgres.read_stream", skip(self))]
125 async fn read_stream<E: Event>(
126 &self,
127 stream_id: StreamId,
128 ) -> Result<EventStreamReader<E>, EventStoreError> {
129 info!(
130 stream = %stream_id,
131 "[postgres.read_stream] reading events from postgres"
132 );
133
134 let rows = query(
135 "SELECT event_data FROM eventcore_events WHERE stream_id = $1 ORDER BY stream_version ASC",
136 )
137 .bind(stream_id.as_ref())
138 .fetch_all(&self.pool)
139 .await
140 .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
141
142 let mut events = Vec::with_capacity(rows.len());
143 for row in rows {
144 let payload: Value = row
145 .try_get("event_data")
146 .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
147 let event = serde_json::from_value(payload).map_err(|error| {
148 EventStoreError::DeserializationFailed {
149 stream_id: stream_id.clone(),
150 detail: error.to_string(),
151 }
152 })?;
153 events.push(event);
154 }
155
156 Ok(EventStreamReader::new(events))
157 }
158
159 #[instrument(name = "postgres.append_events", skip(self, writes))]
160 async fn append_events(
161 &self,
162 writes: StreamWrites,
163 ) -> Result<EventStreamSlice, EventStoreError> {
164 let expected_versions = writes.expected_versions().clone();
165 let entries = writes.into_entries();
166
167 if entries.is_empty() {
168 return Ok(EventStreamSlice);
169 }
170
171 info!(
172 stream_count = expected_versions.len(),
173 event_count = entries.len(),
174 "[postgres.append_events] appending events to postgres"
175 );
176
177 let expected_versions_json: Value = expected_versions
179 .iter()
180 .map(|(stream_id, version)| {
181 (stream_id.as_ref().to_string(), json!(version.into_inner()))
182 })
183 .collect();
184
185 let mut tx = self
186 .pool
187 .begin()
188 .await
189 .map_err(|error| map_sqlx_error(error, Operation::BeginTransaction))?;
190
191 query("SELECT set_config('eventcore.expected_versions', $1, true)")
193 .bind(expected_versions_json.to_string())
194 .execute(&mut *tx)
195 .await
196 .map_err(|error| map_sqlx_error(error, Operation::SetExpectedVersions))?;
197
198 for entry in entries {
200 let StreamWriteEntry {
201 stream_id,
202 event_type,
203 event_data,
204 ..
205 } = entry;
206
207 let event_id = Uuid::now_v7();
208 query(
209 "INSERT INTO eventcore_events (event_id, stream_id, event_type, event_data, metadata)
210 VALUES ($1, $2, $3, $4, $5)",
211 )
212 .bind(event_id)
213 .bind(stream_id.as_ref())
214 .bind(event_type)
215 .bind(Json(event_data))
216 .bind(Json(json!({})))
217 .execute(&mut *tx)
218 .await
219 .map_err(|error| map_sqlx_error(error, Operation::AppendEvents))?;
220 }
221
222 tx.commit()
223 .await
224 .map_err(|error| map_sqlx_error(error, Operation::CommitTransaction))?;
225
226 Ok(EventStreamSlice)
227 }
228}
229
230impl EventReader for PostgresEventStore {
231 type Error = EventStoreError;
232
233 async fn read_events<E: Event>(
234 &self,
235 filter: EventFilter,
236 page: EventPage,
237 ) -> Result<Vec<(E, StreamPosition)>, Self::Error> {
238 let after_event_id: Option<Uuid> = page.after_position().map(|p| p.into_inner());
241 let limit: i64 = page.limit().into_inner() as i64;
242
243 let rows = if let Some(prefix) = filter.stream_prefix() {
244 let prefix_str = prefix.as_ref();
245
246 if let Some(after_id) = after_event_id {
247 let query_str = r#"
248 SELECT event_id, event_data, stream_id
249 FROM eventcore_events
250 WHERE event_id > $1
251 AND stream_id LIKE $2 || '%'
252 ORDER BY event_id
253 LIMIT $3
254 "#;
255 query(query_str)
256 .bind(after_id)
257 .bind(prefix_str)
258 .bind(limit)
259 .fetch_all(&self.pool)
260 .await
261 } else {
262 let query_str = r#"
263 SELECT event_id, event_data, stream_id
264 FROM eventcore_events
265 WHERE stream_id LIKE $1 || '%'
266 ORDER BY event_id
267 LIMIT $2
268 "#;
269 query(query_str)
270 .bind(prefix_str)
271 .bind(limit)
272 .fetch_all(&self.pool)
273 .await
274 }
275 } else if let Some(after_id) = after_event_id {
276 let query_str = r#"
277 SELECT event_id, event_data, stream_id
278 FROM eventcore_events
279 WHERE event_id > $1
280 ORDER BY event_id
281 LIMIT $2
282 "#;
283 query(query_str)
284 .bind(after_id)
285 .bind(limit)
286 .fetch_all(&self.pool)
287 .await
288 } else {
289 let query_str = r#"
290 SELECT event_id, event_data, stream_id
291 FROM eventcore_events
292 ORDER BY event_id
293 LIMIT $1
294 "#;
295 query(query_str).bind(limit).fetch_all(&self.pool).await
296 }
297 .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
298
299 let events: Vec<(E, StreamPosition)> = rows
300 .into_iter()
301 .filter_map(|row| {
302 let event_data: Json<Value> = row.get("event_data");
303 let event_id: Uuid = row.get("event_id");
304 serde_json::from_value::<E>(event_data.0)
305 .ok()
306 .map(|e| (e, StreamPosition::new(event_id)))
307 })
308 .collect();
309
310 Ok(events)
311 }
312}
313
314fn map_sqlx_error(error: sqlx::Error, operation: Operation) -> EventStoreError {
315 if let sqlx::Error::Database(db_error) = &error {
316 let code = db_error.code();
317 let code_str = code.as_deref();
318 if code_str == Some("P0001") || code_str == Some("23505") {
321 warn!(
322 error = %db_error,
323 "[postgres.version_conflict] optimistic concurrency check failed"
324 );
325 return EventStoreError::VersionConflict;
326 }
327 }
328
329 error!(
330 error = %error,
331 operation = %operation,
332 "[postgres.database_error] database operation failed"
333 );
334 EventStoreError::StoreFailure { operation }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use sqlx::{Executor, postgres::PgPoolOptions};
341 use std::env;
342 use std::sync::OnceLock;
343 use testcontainers::{Container, ImageExt, ReuseDirective, runners::SyncRunner};
344 use testcontainers_modules::postgres::Postgres as PgContainer;
345 #[allow(unused_imports)]
346 use tokio::test;
347 use uuid::Uuid;
348
349 const CONTAINER_NAME: &str = "eventcore-test-postgres";
351
352 static SHARED_CONTAINER: OnceLock<SharedPostgres> = OnceLock::new();
355
356 struct SharedPostgres {
357 connection_string: String,
358 #[allow(dead_code)]
359 container: Container<PgContainer>,
360 }
361
362 fn postgres_version() -> String {
364 env::var("POSTGRES_VERSION").unwrap_or_else(|_| "17".to_string())
365 }
366
367 fn start_container_with_retry() -> Container<PgContainer> {
373 let version = postgres_version();
374 let max_retries = 10;
375 let retry_delay = std::time::Duration::from_millis(500);
376
377 for attempt in 0..max_retries {
378 match PgContainer::default()
379 .with_tag(&version)
380 .with_container_name(CONTAINER_NAME)
381 .with_reuse(ReuseDirective::Always)
382 .start()
383 {
384 Ok(container) => return container,
385 Err(e) => {
386 let error_str = e.to_string();
387 if error_str.contains("already in use") && attempt < max_retries - 1 {
388 std::thread::sleep(retry_delay);
390 continue;
391 }
392 panic!("should start postgres container: {}", e);
393 }
394 }
395 }
396 panic!(
397 "failed to start postgres container after {} retries",
398 max_retries
399 );
400 }
401
402 fn get_shared_postgres() -> &'static SharedPostgres {
403 SHARED_CONTAINER.get_or_init(|| {
404 std::thread::spawn(|| {
406 let container = start_container_with_retry();
407
408 let host_port = container
409 .get_host_port_ipv4(5432)
410 .expect("should get postgres port");
411
412 let connection_string = format!(
413 "postgres://postgres:postgres@127.0.0.1:{}/postgres",
414 host_port
415 );
416
417 let rt = tokio::runtime::Runtime::new()
420 .expect("should create tokio runtime for migrations");
421 rt.block_on(async {
422 let max_conn_retries = 30;
423 let conn_retry_delay = std::time::Duration::from_millis(500);
424 let mut pool = None;
425
426 for attempt in 0..max_conn_retries {
427 match PgPoolOptions::new()
428 .max_connections(1)
429 .connect(&connection_string)
430 .await
431 {
432 Ok(p) => {
433 pool = Some(p);
434 break;
435 }
436 Err(e) => {
437 if attempt < max_conn_retries - 1 {
438 tokio::time::sleep(conn_retry_delay).await;
439 continue;
440 }
441 panic!(
442 "should connect to test database after {} retries: {}",
443 max_conn_retries, e
444 );
445 }
446 }
447 }
448
449 let pool = pool.expect("pool should be set");
450 sqlx::migrate!("./migrations")
451 .run(&pool)
452 .await
453 .expect("migrations should succeed");
454 });
455
456 SharedPostgres {
457 connection_string,
458 container,
459 }
460 })
461 .join()
462 .expect("container setup thread should complete")
463 })
464 }
465
466 async fn get_test_pool() -> Pool<Postgres> {
467 let shared = get_shared_postgres();
468 PgPoolOptions::new()
469 .max_connections(1)
470 .connect(&shared.connection_string)
471 .await
472 .expect("should connect to shared postgres container")
473 }
474
475 fn unique_stream_id(prefix: &str) -> String {
476 format!("{}-{}", prefix, Uuid::now_v7())
477 }
478
479 #[tokio::test]
480 async fn trigger_assigns_sequential_versions() {
481 let pool = get_test_pool().await;
482 let stream_id = unique_stream_id("trigger-test");
483
484 let config_query = format!(
486 "SELECT set_config('eventcore.expected_versions', '{{\"{}\":0}}', true)",
487 stream_id
488 );
489 sqlx::query(&config_query)
490 .execute(&pool)
491 .await
492 .expect("should set expected versions");
493
494 let result = sqlx::query(
496 "INSERT INTO eventcore_events (event_id, stream_id, event_type, event_data, metadata)
497 VALUES ($1, $2, $3, $4, $5) RETURNING stream_version",
498 )
499 .bind(Uuid::now_v7())
500 .bind(&stream_id)
501 .bind("TestEvent")
502 .bind(serde_json::json!({"n": 1}))
503 .bind(serde_json::json!({}))
504 .fetch_one(&pool)
505 .await;
506
507 match &result {
508 Ok(row) => {
509 let version: i64 = row.get("stream_version");
510 assert_eq!(version, 1, "first event should have version 1");
511 }
512 Err(e) => panic!("insert failed: {}", e),
513 }
514 }
515
516 #[tokio::test]
517 async fn map_sqlx_error_translates_unique_constraint_violations() {
518 let pool = get_test_pool().await;
520 let table_name = format!("map_sqlx_error_test_{}", Uuid::now_v7().simple());
521 let create_statement = format!("CREATE TABLE {table_name} (event_id UUID PRIMARY KEY)");
522 pool.execute(create_statement.as_str())
523 .await
524 .expect("should create temporary table for unique constraint test");
525
526 let insert_statement = format!("INSERT INTO {table_name} (event_id) VALUES ($1)");
527 let event_id = Uuid::now_v7();
528 sqlx::query(insert_statement.as_str())
529 .bind(event_id)
530 .execute(&pool)
531 .await
532 .expect("initial insert should succeed");
533
534 let duplicate_error = sqlx::query(insert_statement.as_str())
535 .bind(event_id)
536 .execute(&pool)
537 .await
538 .expect_err("duplicate insert should trigger unique constraint");
539
540 let drop_statement = format!("DROP TABLE IF EXISTS {table_name}");
541 pool.execute(drop_statement.as_str())
542 .await
543 .expect("should drop temporary table after unique constraint test");
544
545 let mapped_error = map_sqlx_error(duplicate_error, Operation::AppendEvents);
547
548 assert!(
550 matches!(mapped_error, EventStoreError::VersionConflict),
551 "unique constraint violations should map to version conflict"
552 );
553 }
554}