eventcore_postgres/
lib.rs

1use std::time::Duration;
2
3use eventcore_types::{
4    Event, EventFilter, EventPage, EventReader, EventStore, EventStoreError, EventStreamReader,
5    EventStreamSlice, Operation, StreamId, StreamPosition, StreamWriteEntry, StreamWrites,
6};
7use serde_json::{Value, json};
8use sqlx::types::Json;
9use sqlx::{Pool, Postgres, Row, postgres::PgPoolOptions, query};
10use thiserror::Error;
11use tracing::{error, info, instrument, warn};
12use uuid::Uuid;
13
14#[derive(Debug, Error)]
15pub enum PostgresEventStoreError {
16    #[error("failed to create postgres connection pool")]
17    ConnectionFailed(#[source] sqlx::Error),
18}
19
20/// Configuration for PostgresEventStore connection pool.
21#[derive(Debug, Clone)]
22pub struct PostgresConfig {
23    /// Maximum number of connections in the pool (default: 10)
24    pub max_connections: u32,
25    /// Timeout for acquiring a connection from the pool (default: 30 seconds)
26    pub acquire_timeout: Duration,
27    /// Idle timeout for connections in the pool (default: 10 minutes)
28    pub idle_timeout: Duration,
29}
30
31impl Default for PostgresConfig {
32    fn default() -> Self {
33        Self {
34            max_connections: 10,
35            acquire_timeout: Duration::from_secs(30),
36            idle_timeout: Duration::from_secs(600), // 10 minutes
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct PostgresEventStore {
43    pool: Pool<Postgres>,
44}
45
46impl PostgresEventStore {
47    /// Create a new PostgresEventStore with default configuration.
48    pub async fn new<S: Into<String>>(
49        connection_string: S,
50    ) -> Result<Self, PostgresEventStoreError> {
51        Self::with_config(connection_string, PostgresConfig::default()).await
52    }
53
54    /// Create a new PostgresEventStore with custom configuration.
55    pub async fn with_config<S: Into<String>>(
56        connection_string: S,
57        config: PostgresConfig,
58    ) -> Result<Self, PostgresEventStoreError> {
59        let connection_string = connection_string.into();
60        let pool = PgPoolOptions::new()
61            .max_connections(config.max_connections)
62            .acquire_timeout(config.acquire_timeout)
63            .idle_timeout(config.idle_timeout)
64            .connect(&connection_string)
65            .await
66            .map_err(PostgresEventStoreError::ConnectionFailed)?;
67        Ok(Self { pool })
68    }
69
70    /// Create a PostgresEventStore from an existing connection pool.
71    ///
72    /// Use this when you need full control over pool configuration or want to
73    /// share a pool across multiple components.
74    pub fn from_pool(pool: Pool<Postgres>) -> Self {
75        Self { pool }
76    }
77
78    #[cfg_attr(test, mutants::skip)] // infallible: panics on failure
79    pub async fn ping(&self) {
80        query("SELECT 1")
81            .execute(&self.pool)
82            .await
83            .expect("postgres ping failed");
84    }
85
86    #[cfg_attr(test, mutants::skip)] // infallible: panics on failure
87    pub async fn migrate(&self) {
88        sqlx::migrate!("./migrations")
89            .run(&self.pool)
90            .await
91            .expect("postgres migration failed");
92    }
93}
94
95impl EventStore for PostgresEventStore {
96    #[instrument(name = "postgres.read_stream", skip(self))]
97    async fn read_stream<E: Event>(
98        &self,
99        stream_id: StreamId,
100    ) -> Result<EventStreamReader<E>, EventStoreError> {
101        info!(
102            stream = %stream_id,
103            "[postgres.read_stream] reading events from postgres"
104        );
105
106        let rows = query(
107            "SELECT event_data FROM eventcore_events WHERE stream_id = $1 ORDER BY stream_version ASC",
108        )
109        .bind(stream_id.as_ref())
110        .fetch_all(&self.pool)
111        .await
112        .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
113
114        let mut events = Vec::with_capacity(rows.len());
115        for row in rows {
116            let payload: Value = row
117                .try_get("event_data")
118                .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
119            let event = serde_json::from_value(payload).map_err(|error| {
120                EventStoreError::DeserializationFailed {
121                    stream_id: stream_id.clone(),
122                    detail: error.to_string(),
123                }
124            })?;
125            events.push(event);
126        }
127
128        Ok(EventStreamReader::new(events))
129    }
130
131    #[instrument(name = "postgres.append_events", skip(self, writes))]
132    async fn append_events(
133        &self,
134        writes: StreamWrites,
135    ) -> Result<EventStreamSlice, EventStoreError> {
136        let expected_versions = writes.expected_versions().clone();
137        let entries = writes.into_entries();
138
139        if entries.is_empty() {
140            return Ok(EventStreamSlice);
141        }
142
143        info!(
144            stream_count = expected_versions.len(),
145            event_count = entries.len(),
146            "[postgres.append_events] appending events to postgres"
147        );
148
149        // Build expected versions JSON for the trigger
150        let expected_versions_json: Value = expected_versions
151            .iter()
152            .map(|(stream_id, version)| {
153                (stream_id.as_ref().to_string(), json!(version.into_inner()))
154            })
155            .collect();
156
157        let mut tx = self
158            .pool
159            .begin()
160            .await
161            .map_err(|error| map_sqlx_error(error, Operation::BeginTransaction))?;
162
163        // Set expected versions in session config for trigger validation
164        query("SELECT set_config('eventcore.expected_versions', $1, true)")
165            .bind(expected_versions_json.to_string())
166            .execute(&mut *tx)
167            .await
168            .map_err(|error| map_sqlx_error(error, Operation::SetExpectedVersions))?;
169
170        // Insert all events - trigger handles version assignment and validation
171        for entry in entries {
172            let StreamWriteEntry {
173                stream_id,
174                event_type,
175                event_data,
176                ..
177            } = entry;
178
179            let event_id = Uuid::now_v7();
180            query(
181                "INSERT INTO eventcore_events (event_id, stream_id, event_type, event_data, metadata)
182                 VALUES ($1, $2, $3, $4, $5)",
183            )
184            .bind(event_id)
185            .bind(stream_id.as_ref())
186            .bind(event_type)
187            .bind(Json(event_data))
188            .bind(Json(json!({})))
189            .execute(&mut *tx)
190            .await
191            .map_err(|error| map_sqlx_error(error, Operation::AppendEvents))?;
192        }
193
194        tx.commit()
195            .await
196            .map_err(|error| map_sqlx_error(error, Operation::CommitTransaction))?;
197
198        Ok(EventStreamSlice)
199    }
200}
201
202impl EventReader for PostgresEventStore {
203    type Error = EventStoreError;
204
205    async fn read_events<E: Event>(
206        &self,
207        filter: EventFilter,
208        page: EventPage,
209    ) -> Result<Vec<(E, StreamPosition)>, Self::Error> {
210        // Query events ordered by event_id (UUID7, monotonically increasing).
211        // Use event_id directly as the global position - no need for ROW_NUMBER.
212        let after_event_id: Option<Uuid> = page.after_position().map(|p| p.into_inner());
213        let limit: i64 = page.limit().into_inner() as i64;
214
215        let rows = if let Some(prefix) = filter.stream_prefix() {
216            let prefix_str = prefix.as_ref();
217
218            if let Some(after_id) = after_event_id {
219                let query_str = r#"
220                    SELECT event_id, event_data, stream_id
221                    FROM eventcore_events
222                    WHERE event_id > $1
223                      AND stream_id LIKE $2 || '%'
224                    ORDER BY event_id
225                    LIMIT $3
226                "#;
227                query(query_str)
228                    .bind(after_id)
229                    .bind(prefix_str)
230                    .bind(limit)
231                    .fetch_all(&self.pool)
232                    .await
233            } else {
234                let query_str = r#"
235                    SELECT event_id, event_data, stream_id
236                    FROM eventcore_events
237                    WHERE stream_id LIKE $1 || '%'
238                    ORDER BY event_id
239                    LIMIT $2
240                "#;
241                query(query_str)
242                    .bind(prefix_str)
243                    .bind(limit)
244                    .fetch_all(&self.pool)
245                    .await
246            }
247        } else if let Some(after_id) = after_event_id {
248            let query_str = r#"
249                SELECT event_id, event_data, stream_id
250                FROM eventcore_events
251                WHERE event_id > $1
252                ORDER BY event_id
253                LIMIT $2
254            "#;
255            query(query_str)
256                .bind(after_id)
257                .bind(limit)
258                .fetch_all(&self.pool)
259                .await
260        } else {
261            let query_str = r#"
262                SELECT event_id, event_data, stream_id
263                FROM eventcore_events
264                ORDER BY event_id
265                LIMIT $1
266            "#;
267            query(query_str).bind(limit).fetch_all(&self.pool).await
268        }
269        .map_err(|error| map_sqlx_error(error, Operation::ReadStream))?;
270
271        let events: Vec<(E, StreamPosition)> = rows
272            .into_iter()
273            .filter_map(|row| {
274                let event_data: Json<Value> = row.get("event_data");
275                let event_id: Uuid = row.get("event_id");
276                serde_json::from_value::<E>(event_data.0)
277                    .ok()
278                    .map(|e| (e, StreamPosition::new(event_id)))
279            })
280            .collect();
281
282        Ok(events)
283    }
284}
285
286fn map_sqlx_error(error: sqlx::Error, operation: Operation) -> EventStoreError {
287    if let sqlx::Error::Database(db_error) = &error {
288        let code = db_error.code();
289        let code_str = code.as_deref();
290        // P0001: Custom error from trigger (version_conflict)
291        // 23505: Unique constraint violation (fallback for version conflict)
292        if code_str == Some("P0001") || code_str == Some("23505") {
293            warn!(
294                error = %db_error,
295                "[postgres.version_conflict] optimistic concurrency check failed"
296            );
297            return EventStoreError::VersionConflict;
298        }
299    }
300
301    error!(
302        error = %error,
303        operation = %operation,
304        "[postgres.database_error] database operation failed"
305    );
306    EventStoreError::StoreFailure { operation }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use sqlx::{Executor, postgres::PgPoolOptions};
313    use std::env;
314    use std::sync::OnceLock;
315    use testcontainers::{Container, ImageExt, ReuseDirective, runners::SyncRunner};
316    use testcontainers_modules::postgres::Postgres as PgContainer;
317    #[allow(unused_imports)]
318    use tokio::test;
319    use uuid::Uuid;
320
321    /// Container name for the shared reusable Postgres instance.
322    const CONTAINER_NAME: &str = "eventcore-test-postgres";
323
324    /// Shared container and connection string for all unit tests.
325    /// The container persists between test runs for faster iteration.
326    static SHARED_CONTAINER: OnceLock<SharedPostgres> = OnceLock::new();
327
328    struct SharedPostgres {
329        connection_string: String,
330        #[allow(dead_code)]
331        container: Container<PgContainer>,
332    }
333
334    /// Get the Postgres version to use for tests.
335    fn postgres_version() -> String {
336        env::var("POSTGRES_VERSION").unwrap_or_else(|_| "17".to_string())
337    }
338
339    /// Start a reusable container with retry logic for cross-process races.
340    ///
341    /// When nextest runs test binaries in parallel, multiple processes may try to
342    /// create the same named container simultaneously. This retries on "name already
343    /// in use" errors, allowing the other process to finish creation.
344    fn start_container_with_retry() -> Container<PgContainer> {
345        let version = postgres_version();
346        let max_retries = 10;
347        let retry_delay = std::time::Duration::from_millis(500);
348
349        for attempt in 0..max_retries {
350            match PgContainer::default()
351                .with_tag(&version)
352                .with_container_name(CONTAINER_NAME)
353                .with_reuse(ReuseDirective::Always)
354                .start()
355            {
356                Ok(container) => return container,
357                Err(e) => {
358                    let error_str = e.to_string();
359                    if error_str.contains("already in use") && attempt < max_retries - 1 {
360                        // Another process is creating the container, wait and retry
361                        std::thread::sleep(retry_delay);
362                        continue;
363                    }
364                    panic!("should start postgres container: {}", e);
365                }
366            }
367        }
368        panic!(
369            "failed to start postgres container after {} retries",
370            max_retries
371        );
372    }
373
374    fn get_shared_postgres() -> &'static SharedPostgres {
375        SHARED_CONTAINER.get_or_init(|| {
376            // Run container setup in a separate thread to avoid tokio runtime conflicts
377            std::thread::spawn(|| {
378                let container = start_container_with_retry();
379
380                let host_port = container
381                    .get_host_port_ipv4(5432)
382                    .expect("should get postgres port");
383
384                let connection_string = format!(
385                    "postgres://postgres:postgres@127.0.0.1:{}/postgres",
386                    host_port
387                );
388
389                // Run migrations using a temporary runtime
390                // Retry connection in case postgres is still starting up
391                let rt = tokio::runtime::Runtime::new()
392                    .expect("should create tokio runtime for migrations");
393                rt.block_on(async {
394                    let max_conn_retries = 30;
395                    let conn_retry_delay = std::time::Duration::from_millis(500);
396                    let mut pool = None;
397
398                    for attempt in 0..max_conn_retries {
399                        match PgPoolOptions::new()
400                            .max_connections(1)
401                            .connect(&connection_string)
402                            .await
403                        {
404                            Ok(p) => {
405                                pool = Some(p);
406                                break;
407                            }
408                            Err(e) => {
409                                if attempt < max_conn_retries - 1 {
410                                    tokio::time::sleep(conn_retry_delay).await;
411                                    continue;
412                                }
413                                panic!(
414                                    "should connect to test database after {} retries: {}",
415                                    max_conn_retries, e
416                                );
417                            }
418                        }
419                    }
420
421                    let pool = pool.expect("pool should be set");
422                    sqlx::migrate!("./migrations")
423                        .run(&pool)
424                        .await
425                        .expect("migrations should succeed");
426                });
427
428                SharedPostgres {
429                    connection_string,
430                    container,
431                }
432            })
433            .join()
434            .expect("container setup thread should complete")
435        })
436    }
437
438    async fn get_test_pool() -> Pool<Postgres> {
439        let shared = get_shared_postgres();
440        PgPoolOptions::new()
441            .max_connections(1)
442            .connect(&shared.connection_string)
443            .await
444            .expect("should connect to shared postgres container")
445    }
446
447    fn unique_stream_id(prefix: &str) -> String {
448        format!("{}-{}", prefix, Uuid::now_v7())
449    }
450
451    #[tokio::test]
452    async fn trigger_assigns_sequential_versions() {
453        let pool = get_test_pool().await;
454        let stream_id = unique_stream_id("trigger-test");
455
456        // Set expected version via session config
457        let config_query = format!(
458            "SELECT set_config('eventcore.expected_versions', '{{\"{}\":0}}', true)",
459            stream_id
460        );
461        sqlx::query(&config_query)
462            .execute(&pool)
463            .await
464            .expect("should set expected versions");
465
466        // Insert first event
467        let result = sqlx::query(
468            "INSERT INTO eventcore_events (event_id, stream_id, event_type, event_data, metadata)
469             VALUES ($1, $2, $3, $4, $5) RETURNING stream_version",
470        )
471        .bind(Uuid::now_v7())
472        .bind(&stream_id)
473        .bind("TestEvent")
474        .bind(serde_json::json!({"n": 1}))
475        .bind(serde_json::json!({}))
476        .fetch_one(&pool)
477        .await;
478
479        match &result {
480            Ok(row) => {
481                let version: i64 = row.get("stream_version");
482                assert_eq!(version, 1, "first event should have version 1");
483            }
484            Err(e) => panic!("insert failed: {}", e),
485        }
486    }
487
488    #[tokio::test]
489    async fn map_sqlx_error_translates_unique_constraint_violations() {
490        // Given: Developer has a table with a unique constraint to trigger duplicates
491        let pool = get_test_pool().await;
492        let table_name = format!("map_sqlx_error_test_{}", Uuid::now_v7().simple());
493        let create_statement = format!("CREATE TABLE {table_name} (event_id UUID PRIMARY KEY)");
494        pool.execute(create_statement.as_str())
495            .await
496            .expect("should create temporary table for unique constraint test");
497
498        let insert_statement = format!("INSERT INTO {table_name} (event_id) VALUES ($1)");
499        let event_id = Uuid::now_v7();
500        sqlx::query(insert_statement.as_str())
501            .bind(event_id)
502            .execute(&pool)
503            .await
504            .expect("initial insert should succeed");
505
506        let duplicate_error = sqlx::query(insert_statement.as_str())
507            .bind(event_id)
508            .execute(&pool)
509            .await
510            .expect_err("duplicate insert should trigger unique constraint");
511
512        let drop_statement = format!("DROP TABLE IF EXISTS {table_name}");
513        pool.execute(drop_statement.as_str())
514            .await
515            .expect("should drop temporary table after unique constraint test");
516
517        // When: Developer maps the sqlx duplicate error
518        let mapped_error = map_sqlx_error(duplicate_error, Operation::AppendEvents);
519
520        // Then: Developer sees version conflict error for 23505 violations
521        assert!(
522            matches!(mapped_error, EventStoreError::VersionConflict),
523            "unique constraint violations should map to version conflict"
524        );
525    }
526}