1use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, RwLock};
10
11use async_trait::async_trait;
12use eventcore::{
13 Checkpoint, EventId, EventMetadata, EventProcessor, EventStore, EventStoreError, EventToWrite,
14 EventVersion, ExpectedVersion, ReadOptions, StoredEvent, StreamData, StreamEvents, StreamId,
15 Subscription, SubscriptionError, SubscriptionName, SubscriptionOptions, SubscriptionPosition,
16 SubscriptionResult, Timestamp,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use sqlx::{postgres::PgRow, Row};
21use tracing::{debug, error, instrument};
22use uuid::Uuid;
23
24use crate::{PostgresError, PostgresEventStore};
25
26type EventStoreResult<T> = Result<T, EventStoreError>;
27
28#[derive(Debug)]
30#[allow(dead_code)] struct EventRow {
32 event_id: Uuid,
33 stream_id: String,
34 event_version: i64,
35 event_type: String,
36 event_data: Value,
37 metadata: Option<Value>,
38 causation_id: Option<Uuid>,
39 correlation_id: Option<String>,
40 user_id: Option<String>,
41 created_at: chrono::DateTime<chrono::Utc>,
42}
43
44impl TryFrom<PgRow> for EventRow {
45 type Error = sqlx::Error;
46
47 fn try_from(row: PgRow) -> Result<Self, Self::Error> {
48 Ok(Self {
49 event_id: row.try_get("event_id")?,
50 stream_id: row.try_get("stream_id")?,
51 event_version: row.try_get("event_version")?,
52 event_type: row.try_get("event_type")?,
53 event_data: row.try_get("event_data")?,
54 metadata: row.try_get("metadata")?,
55 causation_id: row.try_get("causation_id")?,
56 correlation_id: row.try_get("correlation_id")?,
57 user_id: row.try_get("user_id")?,
58 created_at: row.try_get("created_at")?,
59 })
60 }
61}
62
63impl EventRow {
64 #[allow(clippy::wrong_self_convention)] fn to_stored_event<E>(self) -> EventStoreResult<StoredEvent<E>>
67 where
68 E: for<'de> Deserialize<'de> + PartialEq + Eq,
69 {
70 let event_id = EventId::try_new(self.event_id)
71 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
72
73 let stream_id = StreamId::try_new(self.stream_id)
74 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
75
76 let event_version = if self.event_version >= 0 {
77 let version_u64 = u64::try_from(self.event_version)?;
78 EventVersion::try_new(version_u64)
79 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?
80 } else {
81 return Err(EventStoreError::SerializationFailed(
82 "Negative event version in database".to_string(),
83 ));
84 };
85
86 let timestamp = Timestamp::new(self.created_at);
87
88 let metadata = if let Some(metadata_json) = self.metadata {
89 let event_metadata: EventMetadata = serde_json::from_value(metadata_json)?;
90 Some(event_metadata)
91 } else {
92 None
93 };
94
95 let payload: E = serde_json::from_value(self.event_data)?;
97
98 Ok(StoredEvent::new(
99 event_id,
100 stream_id,
101 event_version,
102 timestamp,
103 payload,
104 metadata,
105 ))
106 }
107}
108
109#[async_trait]
110impl<E> EventStore for PostgresEventStore<E>
111where
112 E: Serialize
113 + for<'de> Deserialize<'de>
114 + Send
115 + Sync
116 + std::fmt::Debug
117 + Clone
118 + PartialEq
119 + Eq
120 + 'static,
121{
122 type Event = E;
123
124 #[instrument(skip(self), fields(streams = stream_ids.len()))]
125 async fn read_streams(
126 &self,
127 stream_ids: &[StreamId],
128 options: &ReadOptions,
129 ) -> EventStoreResult<StreamData<Self::Event>> {
130 if stream_ids.is_empty() {
131 return Ok(StreamData::new(Vec::new(), HashMap::new()));
132 }
133
134 debug!(
135 "Reading {} streams with options: {:?}",
136 stream_ids.len(),
137 options
138 );
139
140 let mut query = String::from(
142 "SELECT event_id, stream_id, event_version, event_type, event_data, metadata,
143 causation_id, correlation_id, user_id, created_at
144 FROM events
145 WHERE stream_id = ANY($1)",
146 );
147
148 let mut param_count = 2;
149
150 if let Some(_from_version) = options.from_version {
152 use std::fmt::Write;
153 write!(&mut query, " AND event_version >= ${param_count}").expect("Write to string");
154 param_count += 1;
155 }
156
157 if let Some(_to_version) = options.to_version {
158 use std::fmt::Write;
159 write!(&mut query, " AND event_version <= ${param_count}").expect("Write to string");
160 param_count += 1;
161 }
162
163 query.push_str(" ORDER BY event_id");
165
166 let effective_limit = options.max_events.unwrap_or(self.config.read_batch_size);
168 use std::fmt::Write;
169 write!(&mut query, " LIMIT ${param_count}").expect("Write to string");
170
171 let stream_id_strings: Vec<String> =
172 stream_ids.iter().map(|s| s.as_ref().to_string()).collect();
173
174 let mut sqlx_query = sqlx::query(&query).bind(&stream_id_strings);
176
177 if let Some(from_version) = options.from_version {
178 let version_value: u64 = from_version.into();
179 let version_i64 = i64::try_from(version_value).map_err(|_| {
180 EventStoreError::SerializationFailed("Version too large".to_string())
181 })?;
182 sqlx_query = sqlx_query.bind(version_i64);
183 }
184
185 if let Some(to_version) = options.to_version {
186 let version_value: u64 = to_version.into();
187 let version_i64 = i64::try_from(version_value).map_err(|_| {
188 EventStoreError::SerializationFailed("Version too large".to_string())
189 })?;
190 sqlx_query = sqlx_query.bind(version_i64);
191 }
192
193 let limit_i64 = i64::try_from(effective_limit)
195 .map_err(|_| EventStoreError::SerializationFailed("Limit too large".to_string()))?;
196 sqlx_query = sqlx_query.bind(limit_i64);
197
198 let rows = sqlx_query
199 .fetch_all(self.pool.as_ref())
200 .await
201 .map_err(PostgresError::Connection)?;
202
203 debug!("Retrieved {} events from database", rows.len());
204
205 let mut events = Vec::new();
207 let mut stream_versions = HashMap::new();
208
209 for stream_id in stream_ids {
211 stream_versions.insert(stream_id.clone(), EventVersion::initial());
212 }
213
214 for row in rows {
215 let event_row = EventRow::try_from(row)
216 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
217 let stored_event = event_row.to_stored_event::<E>()?;
218
219 let initial_version = EventVersion::initial();
221 let current_max = stream_versions
222 .get(&stored_event.stream_id)
223 .unwrap_or(&initial_version);
224 if stored_event.event_version > *current_max {
225 stream_versions.insert(stored_event.stream_id.clone(), stored_event.event_version);
226 }
227
228 events.push(stored_event);
229 }
230
231 Ok(StreamData::new(events, stream_versions))
232 }
233
234 #[instrument(skip(self), fields(streams = stream_events.len()))]
235 async fn write_events_multi(
236 &self,
237 stream_events: Vec<StreamEvents<Self::Event>>,
238 ) -> EventStoreResult<HashMap<StreamId, EventVersion>> {
239 if stream_events.is_empty() {
240 return Ok(HashMap::new());
241 }
242
243 debug!("Writing events to {} streams", stream_events.len());
244
245 let mut result_versions = HashMap::new();
248
249 for stream in stream_events {
250 let stream_id = stream.stream_id.clone();
251 let new_version = self.write_stream_events_direct(stream).await?;
252 result_versions.insert(stream_id, new_version);
253 }
254
255 debug!(
256 "Successfully wrote events to {} streams",
257 result_versions.len()
258 );
259 Ok(result_versions)
260 }
261
262 #[instrument(skip(self))]
263 async fn stream_exists(&self, stream_id: &StreamId) -> EventStoreResult<bool> {
264 let exists =
265 sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM event_streams WHERE stream_id = $1)")
266 .bind(stream_id.as_ref())
267 .fetch_one(self.pool.as_ref())
268 .await
269 .map_err(PostgresError::Connection)?;
270
271 Ok(exists)
272 }
273
274 async fn get_stream_version(
275 &self,
276 stream_id: &StreamId,
277 ) -> EventStoreResult<Option<EventVersion>> {
278 let max_version: Option<i64> =
279 sqlx::query_scalar("SELECT MAX(event_version) FROM events WHERE stream_id = $1")
280 .bind(stream_id.as_ref())
281 .fetch_optional(self.pool.as_ref())
282 .await
283 .map_err(PostgresError::Connection)?
284 .flatten(); match max_version {
287 Some(v) if v >= 0 => {
288 let v_u64 = u64::try_from(v).map_err(|_| {
289 EventStoreError::SerializationFailed("Invalid version".to_string())
290 })?;
291 Ok(Some(EventVersion::try_new(v_u64).map_err(|e| {
292 EventStoreError::SerializationFailed(e.to_string())
293 })?))
294 }
295 _ => Ok(None), }
297 }
298
299 #[instrument(skip(self))]
300 async fn subscribe(
301 &self,
302 options: SubscriptionOptions,
303 ) -> EventStoreResult<Box<dyn Subscription<Event = Self::Event>>> {
304 let subscription = PostgresSubscription::new(self.clone(), options);
305 Ok(Box::new(subscription))
306 }
307}
308
309impl<E> PostgresEventStore<E>
311where
312 E: Serialize
313 + for<'de> Deserialize<'de>
314 + Send
315 + Sync
316 + std::fmt::Debug
317 + Clone
318 + PartialEq
319 + Eq
320 + 'static,
321{
322 pub async fn read_streams_paginated_impl(
325 &self,
326 stream_ids: &[StreamId],
327 options: &ReadOptions,
328 continuation_token: Option<EventId>,
329 ) -> EventStoreResult<(Vec<StoredEvent<E>>, Option<EventId>)> {
330 if stream_ids.is_empty() {
331 return Ok((Vec::new(), None));
332 }
333
334 debug!(
335 "Reading {} streams with pagination, continuation: {:?}",
336 stream_ids.len(),
337 continuation_token
338 );
339
340 let mut query = String::from(
342 "SELECT event_id, stream_id, event_version, event_type, event_data, metadata,
343 causation_id, correlation_id, user_id, created_at
344 FROM events
345 WHERE stream_id = ANY($1)",
346 );
347
348 let mut param_count = 2;
349
350 if continuation_token.is_some() {
352 use std::fmt::Write;
353 write!(&mut query, " AND event_id > ${param_count}").expect("Write to string");
354 param_count += 1;
355 }
356
357 if let Some(_from_version) = options.from_version {
359 use std::fmt::Write;
360 write!(&mut query, " AND event_version >= ${param_count}").expect("Write to string");
361 param_count += 1;
362 }
363
364 if let Some(_to_version) = options.to_version {
365 use std::fmt::Write;
366 write!(&mut query, " AND event_version <= ${param_count}").expect("Write to string");
367 param_count += 1;
368 }
369
370 query.push_str(" ORDER BY event_id");
372
373 let effective_limit = options.max_events.unwrap_or(self.config.read_batch_size);
375 {
376 use std::fmt::Write;
377 write!(&mut query, " LIMIT ${param_count}").expect("Write to string");
378 }
379
380 let stream_id_strings: Vec<String> =
381 stream_ids.iter().map(|s| s.as_ref().to_string()).collect();
382
383 let mut sqlx_query = sqlx::query(&query).bind(&stream_id_strings);
385
386 if let Some(token) = &continuation_token {
387 sqlx_query = sqlx_query.bind(token.as_ref());
388 }
389
390 if let Some(from_version) = options.from_version {
391 let version_value: u64 = from_version.into();
392 let version_i64 = i64::try_from(version_value).map_err(|_| {
393 EventStoreError::SerializationFailed("Version too large".to_string())
394 })?;
395 sqlx_query = sqlx_query.bind(version_i64);
396 }
397
398 if let Some(to_version) = options.to_version {
399 let version_value: u64 = to_version.into();
400 let version_i64 = i64::try_from(version_value).map_err(|_| {
401 EventStoreError::SerializationFailed("Version too large".to_string())
402 })?;
403 sqlx_query = sqlx_query.bind(version_i64);
404 }
405
406 let limit_i64 = i64::try_from(effective_limit)
407 .map_err(|_| EventStoreError::SerializationFailed("Limit too large".to_string()))?;
408 sqlx_query = sqlx_query.bind(limit_i64);
409
410 let rows = sqlx_query
411 .fetch_all(self.pool.as_ref())
412 .await
413 .map_err(PostgresError::Connection)?;
414
415 debug!("Retrieved {} events from database", rows.len());
416
417 let mut events = Vec::new();
419 let mut last_event_id = None;
420
421 for row in rows {
422 let event_row = EventRow::try_from(row)
423 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
424 let stored_event = event_row.to_stored_event::<E>()?;
425 last_event_id = Some(stored_event.event_id);
426 events.push(stored_event);
427 }
428
429 let continuation = if events.len() == effective_limit {
431 last_event_id
432 } else {
433 None
434 };
435
436 Ok((events, continuation))
437 }
438}
439
440pub struct PostgresSubscription<E>
442where
443 E: Serialize
444 + for<'de> Deserialize<'de>
445 + Send
446 + Sync
447 + std::fmt::Debug
448 + Clone
449 + PartialEq
450 + Eq
451 + 'static,
452{
453 event_store: PostgresEventStore<E>,
454 options: SubscriptionOptions,
455 current_position: Arc<RwLock<Option<SubscriptionPosition>>>,
456 is_running: Arc<AtomicBool>,
457 is_paused: Arc<AtomicBool>,
458 stop_signal: Arc<AtomicBool>,
459}
460
461impl<E> PostgresSubscription<E>
462where
463 E: Serialize
464 + for<'de> Deserialize<'de>
465 + Send
466 + Sync
467 + std::fmt::Debug
468 + Clone
469 + PartialEq
470 + Eq
471 + 'static,
472{
473 pub fn new(event_store: PostgresEventStore<E>, options: SubscriptionOptions) -> Self {
475 Self {
476 event_store,
477 options,
478 current_position: Arc::new(RwLock::new(None)),
479 is_running: Arc::new(AtomicBool::new(false)),
480 is_paused: Arc::new(AtomicBool::new(false)),
481 stop_signal: Arc::new(AtomicBool::new(false)),
482 }
483 }
484
485 async fn process_events(
487 &self,
488 name: SubscriptionName,
489 mut processor: Box<dyn EventProcessor<Event = E>>,
490 ) -> SubscriptionResult<()>
491 where
492 E: PartialEq + Eq,
493 {
494 let starting_position = self.load_checkpoint(&name).await?;
496
497 loop {
498 if self.stop_signal.load(Ordering::Acquire) {
500 break;
501 }
502
503 if self.is_paused.load(Ordering::Acquire) {
505 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
506 continue;
507 }
508
509 let events = self
511 .get_events_for_processing(starting_position.as_ref())
512 .await?;
513
514 let mut current_pos = starting_position.clone();
515 let mut has_new_events = false;
516
517 for event in events {
518 if let Some(ref pos) = current_pos {
520 if event.event_id <= pos.last_event_id {
521 continue;
522 }
523 }
524
525 processor.process_event(event.clone()).await?;
527 has_new_events = true;
528
529 let new_checkpoint = Checkpoint::new(event.event_id, event.event_version.into());
531
532 current_pos = Some(if let Some(mut pos) = current_pos {
533 pos.last_event_id = event.event_id;
534 pos.update_checkpoint(event.stream_id.clone(), new_checkpoint);
535 pos
536 } else {
537 let mut pos = SubscriptionPosition::new(event.event_id);
538 pos.update_checkpoint(event.stream_id.clone(), new_checkpoint);
539 pos
540 });
541
542 {
544 let mut guard = self.current_position.write().map_err(|_| {
545 SubscriptionError::CheckpointSaveFailed(
546 "Failed to acquire position lock".to_string(),
547 )
548 })?;
549 (*guard).clone_from(¤t_pos);
550 }
551
552 if let Some(ref pos) = current_pos {
554 self.save_checkpoint_to_db(&name, pos.clone()).await?;
555 }
556 }
557
558 if !has_new_events && matches!(self.options, SubscriptionOptions::LiveOnly) {
560 processor.on_live().await?;
561 }
562
563 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
565 }
566
567 Ok(())
568 }
569
570 async fn get_events_for_processing(
572 &self,
573 starting_position: Option<&SubscriptionPosition>,
574 ) -> SubscriptionResult<Vec<StoredEvent<E>>> {
575 let (streams, from_position) = match &self.options {
576 SubscriptionOptions::CatchUpFromBeginning => (vec![], None),
577 SubscriptionOptions::CatchUpFromPosition(pos) => (vec![], Some(pos.last_event_id)),
578 SubscriptionOptions::LiveOnly => {
579 (vec![], starting_position.as_ref().map(|p| p.last_event_id))
581 }
582 SubscriptionOptions::SpecificStreamsFromBeginning(_mode) => {
583 (vec![], None)
585 }
586 SubscriptionOptions::SpecificStreamsFromPosition(_mode, pos) => {
587 (vec![], Some(pos.last_event_id))
588 }
589 SubscriptionOptions::AllStreams { from_position } => (vec![], *from_position),
590 SubscriptionOptions::SpecificStreams {
591 streams,
592 from_position,
593 } => (streams.clone(), *from_position),
594 };
595
596 if streams.is_empty() {
598 self.read_all_events_since(
599 from_position.or_else(|| starting_position.map(|p| p.last_event_id)),
600 )
601 .await
602 } else {
603 self.read_streams_events_since(
604 &streams,
605 from_position.or_else(|| starting_position.map(|p| p.last_event_id)),
606 )
607 .await
608 }
609 }
610
611 async fn read_all_events_since(
613 &self,
614 from_event_id: Option<EventId>,
615 ) -> SubscriptionResult<Vec<StoredEvent<E>>> {
616 let all_streams = self.get_all_stream_ids().await?;
619 if all_streams.is_empty() {
620 return Ok(vec![]);
621 }
622
623 let read_options = ReadOptions {
624 from_version: None,
625 to_version: None,
626 max_events: Some(self.event_store.config.read_batch_size),
627 };
628
629 let stream_data = self
630 .event_store
631 .read_streams(&all_streams, &read_options)
632 .await
633 .map_err(SubscriptionError::EventStore)?;
634
635 let filtered_events = if let Some(from_id) = from_event_id {
637 stream_data
638 .events
639 .into_iter()
640 .filter(|e| e.event_id > from_id)
641 .collect()
642 } else {
643 stream_data.events
644 };
645
646 Ok(filtered_events)
647 }
648
649 async fn get_all_stream_ids(&self) -> SubscriptionResult<Vec<StreamId>> {
651 let query_str = format!(
652 "SELECT DISTINCT stream_id FROM events LIMIT {}",
653 self.event_store.config.read_batch_size
654 );
655 let rows = sqlx::query(&query_str)
656 .fetch_all(self.event_store.pool.as_ref())
657 .await
658 .map_err(|e| {
659 SubscriptionError::EventStore(EventStoreError::Internal(format!(
660 "Failed to fetch stream IDs from database for subscription processing (query: '{query_str}'): {e}"
661 )))
662 })?;
663
664 let mut stream_ids = Vec::new();
665 for row in rows {
666 let stream_id_str: String = row.get("stream_id");
667 if let Ok(stream_id) = StreamId::try_new(stream_id_str) {
668 stream_ids.push(stream_id);
669 }
670 }
671
672 Ok(stream_ids)
673 }
674
675 async fn read_streams_events_since(
677 &self,
678 stream_ids: &[StreamId],
679 from_event_id: Option<EventId>,
680 ) -> SubscriptionResult<Vec<StoredEvent<E>>> {
681 let read_options = ReadOptions {
682 from_version: None,
683 to_version: None,
684 max_events: Some(self.event_store.config.read_batch_size),
685 };
686
687 let stream_data = self
688 .event_store
689 .read_streams(stream_ids, &read_options)
690 .await
691 .map_err(SubscriptionError::EventStore)?;
692
693 let filtered_events = if let Some(from_id) = from_event_id {
695 stream_data
696 .events
697 .into_iter()
698 .filter(|e| e.event_id > from_id)
699 .collect()
700 } else {
701 stream_data.events
702 };
703
704 Ok(filtered_events)
705 }
706
707 async fn save_checkpoint_to_db(
709 &self,
710 name: &SubscriptionName,
711 position: SubscriptionPosition,
712 ) -> SubscriptionResult<()> {
713 let position_json = serde_json::to_string(&position).map_err(|e| {
714 SubscriptionError::CheckpointSaveFailed(format!(
715 "Failed to serialize checkpoint position for subscription '{}': {e}",
716 name.as_ref()
717 ))
718 })?;
719
720 sqlx::query(
721 "INSERT INTO subscription_checkpoints (subscription_name, position_data, updated_at)
722 VALUES ($1, $2, NOW())
723 ON CONFLICT (subscription_name)
724 DO UPDATE SET position_data = $2, updated_at = NOW()",
725 )
726 .bind(name.as_ref())
727 .bind(position_json)
728 .execute(self.event_store.pool.as_ref())
729 .await
730 .map_err(|e| {
731 SubscriptionError::CheckpointSaveFailed(format!(
732 "Failed to save checkpoint for subscription '{}' to database: {e}",
733 name.as_ref()
734 ))
735 })?;
736
737 Ok(())
738 }
739
740 async fn load_checkpoint_from_db(
742 &self,
743 name: &SubscriptionName,
744 ) -> SubscriptionResult<Option<SubscriptionPosition>> {
745 let row = sqlx::query(
746 "SELECT position_data FROM subscription_checkpoints WHERE subscription_name = $1",
747 )
748 .bind(name.as_ref())
749 .fetch_optional(self.event_store.pool.as_ref())
750 .await
751 .map_err(|e| {
752 SubscriptionError::CheckpointLoadFailed(format!(
753 "Failed to load checkpoint for subscription '{}' from database: {e}",
754 name.as_ref()
755 ))
756 })?;
757
758 if let Some(row) = row {
759 let position_json: String = row.get("position_data");
760 let position = serde_json::from_str(&position_json).map_err(|e| {
761 SubscriptionError::CheckpointLoadFailed(format!(
762 "Failed to deserialize checkpoint position for subscription '{}': {e}",
763 name.as_ref()
764 ))
765 })?;
766 Ok(Some(position))
767 } else {
768 Ok(None)
769 }
770 }
771}
772
773#[async_trait]
774impl<E> Subscription for PostgresSubscription<E>
775where
776 E: Serialize
777 + for<'de> Deserialize<'de>
778 + Send
779 + Sync
780 + std::fmt::Debug
781 + Clone
782 + PartialEq
783 + Eq
784 + 'static,
785{
786 type Event = E;
787
788 async fn start(
789 &mut self,
790 name: SubscriptionName,
791 options: SubscriptionOptions,
792 processor: Box<dyn EventProcessor<Event = Self::Event>>,
793 ) -> SubscriptionResult<()>
794 where
795 Self::Event: PartialEq + Eq,
796 {
797 self.options = options;
799
800 self.is_running.store(true, Ordering::Release);
802 self.stop_signal.store(false, Ordering::Release);
803 self.is_paused.store(false, Ordering::Release);
804
805 let subscription = self.clone(); let name_copy = name;
808
809 tokio::spawn(async move {
810 if let Err(e) = subscription.process_events(name_copy, processor).await {
811 error!("PostgreSQL subscription processing failed: {}", e);
812 }
813 });
814
815 Ok(())
816 }
817
818 async fn stop(&mut self) -> SubscriptionResult<()> {
819 self.stop_signal.store(true, Ordering::Release);
820 self.is_running.store(false, Ordering::Release);
821 Ok(())
822 }
823
824 async fn pause(&mut self) -> SubscriptionResult<()> {
825 self.is_paused.store(true, Ordering::Release);
826 Ok(())
827 }
828
829 async fn resume(&mut self) -> SubscriptionResult<()> {
830 self.is_paused.store(false, Ordering::Release);
831 Ok(())
832 }
833
834 async fn get_position(&self) -> SubscriptionResult<Option<SubscriptionPosition>> {
835 let guard = self.current_position.read().map_err(|_| {
836 SubscriptionError::CheckpointLoadFailed("Failed to acquire position lock".to_string())
837 })?;
838 Ok(guard.clone())
839 }
840
841 async fn save_checkpoint(&mut self, position: SubscriptionPosition) -> SubscriptionResult<()> {
842 {
844 let mut guard = self.current_position.write().map_err(|_| {
845 SubscriptionError::CheckpointSaveFailed(
846 "Failed to acquire position lock".to_string(),
847 )
848 })?;
849 *guard = Some(position);
850 }
851 Ok(())
852 }
853
854 async fn load_checkpoint(
855 &self,
856 name: &SubscriptionName,
857 ) -> SubscriptionResult<Option<SubscriptionPosition>> {
858 self.load_checkpoint_from_db(name).await
859 }
860}
861
862impl<E> Clone for PostgresSubscription<E>
864where
865 E: Serialize
866 + for<'de> Deserialize<'de>
867 + Send
868 + Sync
869 + std::fmt::Debug
870 + Clone
871 + PartialEq
872 + Eq
873 + 'static,
874{
875 fn clone(&self) -> Self {
876 Self {
877 event_store: self.event_store.clone(),
878 options: self.options.clone(),
879 current_position: Arc::clone(&self.current_position),
880 is_running: Arc::clone(&self.is_running),
881 is_paused: Arc::clone(&self.is_paused),
882 stop_signal: Arc::clone(&self.stop_signal),
883 }
884 }
885}
886
887impl<E> PostgresEventStore<E>
888where
889 E: Serialize
890 + for<'de> Deserialize<'de>
891 + Send
892 + Sync
893 + std::fmt::Debug
894 + Clone
895 + PartialEq
896 + Eq
897 + 'static,
898{
899 #[allow(dead_code)]
902 async fn write_stream_events_direct(
903 &self,
904 stream_events: StreamEvents<E>,
905 ) -> EventStoreResult<EventVersion>
906 where
907 E: serde::Serialize + Sync,
908 {
909 let StreamEvents {
910 stream_id,
911 expected_version,
912 events,
913 } = stream_events;
914
915 if events.is_empty() {
916 let current_version = self
918 .verify_stream_version_direct(&stream_id, expected_version)
919 .await?;
920 return Ok(current_version);
921 }
922
923 let starting_version = match expected_version {
925 ExpectedVersion::New => EventVersion::initial(),
926 ExpectedVersion::Exact(v) => v,
927 ExpectedVersion::Any => {
928 let current: Option<i64> = sqlx::query_scalar(
930 "SELECT MAX(event_version) FROM events WHERE stream_id = $1",
931 )
932 .bind(stream_id.as_ref())
933 .fetch_optional(self.pool.as_ref())
934 .await
935 .map_err(|e| EventStoreError::ConnectionFailed(e.to_string()))?
936 .flatten();
937
938 if let Some(v) = current {
939 Self::convert_version(v)?
940 } else {
941 EventVersion::initial()
942 }
943 }
944 };
945
946 let starting_value: u64 = starting_version.into();
948 let new_version = EventVersion::try_new(starting_value + events.len() as u64)
949 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
950
951 self.insert_events_batch_direct(&stream_id, starting_version, &events)
953 .await
954 .map_err(|e| {
955 if let EventStoreError::ConnectionFailed(msg) = &e {
957 if msg.contains("Version conflict") || msg.contains("expected new stream") {
958 return EventStoreError::VersionConflict {
959 stream: stream_id.clone(),
960 expected: starting_version,
961 current: EventVersion::try_new(starting_value + 1)
962 .unwrap_or(starting_version),
963 };
964 }
965 }
966 e
967 })?;
968
969 Ok(new_version)
970 }
971
972 #[allow(clippy::cognitive_complexity)]
974 fn verify_version_matches(
975 stream_id: &StreamId,
976 current_version: Option<i64>,
977 expected_version: ExpectedVersion,
978 ) -> EventStoreResult<EventVersion> {
979 debug!(
980 stream_id = %stream_id.as_ref(),
981 current = ?current_version,
982 expected = ?expected_version,
983 "Verifying version match"
984 );
985
986 let result = match (current_version, expected_version) {
987 (None, ExpectedVersion::New) => {
988 Ok(EventVersion::initial())
990 }
991 (None, ExpectedVersion::Exact(expected)) => Err(EventStoreError::VersionConflict {
992 stream: stream_id.clone(),
993 expected,
994 current: EventVersion::initial(),
995 }),
996 (None, ExpectedVersion::Any) => {
997 Ok(EventVersion::initial())
999 }
1000 (Some(actual), ExpectedVersion::New) => {
1001 let actual_version = Self::convert_version(actual)?;
1002 Err(EventStoreError::VersionConflict {
1003 stream: stream_id.clone(),
1004 expected: EventVersion::initial(),
1005 current: actual_version,
1006 })
1007 }
1008 (Some(actual), ExpectedVersion::Exact(expected)) => {
1009 let actual_version = Self::convert_version(actual)?;
1010 if actual_version == expected {
1011 Ok(actual_version)
1012 } else {
1013 Err(EventStoreError::VersionConflict {
1014 stream: stream_id.clone(),
1015 expected,
1016 current: actual_version,
1017 })
1018 }
1019 }
1020 (Some(actual), ExpectedVersion::Any) => {
1021 let actual_version = Self::convert_version(actual)?;
1022 Ok(actual_version)
1023 }
1024 };
1025
1026 debug!(result = ?result, "Version verification complete");
1027 result
1028 }
1029
1030 fn convert_version(version: i64) -> EventStoreResult<EventVersion> {
1032 if version >= 0 {
1033 let version_u64 = u64::try_from(version)
1034 .map_err(|_| EventStoreError::SerializationFailed("Invalid version".to_string()))?;
1035 EventVersion::try_new(version_u64)
1036 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))
1037 } else {
1038 Err(EventStoreError::SerializationFailed(
1039 "Negative version in database".to_string(),
1040 ))
1041 }
1042 }
1043
1044 async fn verify_stream_version_direct(
1046 &self,
1047 stream_id: &StreamId,
1048 expected_version: ExpectedVersion,
1049 ) -> EventStoreResult<EventVersion> {
1050 let current_version: Option<i64> =
1051 sqlx::query_scalar("SELECT MAX(event_version) FROM events WHERE stream_id = $1")
1052 .bind(stream_id.as_ref())
1053 .fetch_optional(self.pool.as_ref())
1054 .await
1055 .map_err(|e| EventStoreError::ConnectionFailed(e.to_string()))?
1056 .flatten();
1057
1058 Self::verify_version_matches(stream_id, current_version, expected_version)
1059 }
1060
1061 #[allow(clippy::too_many_lines)]
1063 async fn insert_events_batch_direct(
1064 &self,
1065 stream_id: &StreamId,
1066 starting_version: EventVersion,
1067 events: &[EventToWrite<E>],
1068 ) -> EventStoreResult<()>
1069 where
1070 E: serde::Serialize + Sync,
1071 {
1072 const MAX_EVENTS_PER_BATCH: usize = 1000;
1073
1074 if events.is_empty() {
1075 return Ok(());
1076 }
1077
1078 for (batch_idx, batch) in events.chunks(MAX_EVENTS_PER_BATCH).enumerate() {
1081 let mut query = String::from(
1082 "INSERT INTO events
1083 (stream_id, event_version, event_type, event_data, metadata, causation_id, correlation_id, user_id)
1084 VALUES "
1085 );
1086
1087 let mut values = Vec::new();
1088 let starting_version_u64: u64 = starting_version.into();
1089 let batch_starting_version =
1090 starting_version_u64 + (batch_idx * MAX_EVENTS_PER_BATCH) as u64;
1091
1092 let mut stream_ids = Vec::with_capacity(batch.len());
1094 let mut versions = Vec::with_capacity(batch.len());
1095 let mut event_types = Vec::with_capacity(batch.len());
1096 let mut event_data_values = Vec::with_capacity(batch.len());
1097 let mut metadata_values = Vec::with_capacity(batch.len());
1098 let mut causation_ids = Vec::with_capacity(batch.len());
1099 let mut correlation_ids = Vec::with_capacity(batch.len());
1100 let mut user_ids = Vec::with_capacity(batch.len());
1101
1102 for (i, event) in batch.iter().enumerate() {
1103 let event_version = EventVersion::try_new(batch_starting_version + i as u64 + 1)
1104 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?;
1105
1106 let metadata_json = if let Some(metadata) = &event.metadata {
1108 Some(
1109 serde_json::to_value(metadata)
1110 .map_err(|e| EventStoreError::SerializationFailed(e.to_string()))?,
1111 )
1112 } else {
1113 None
1114 };
1115
1116 let (causation_id, correlation_id, user_id) =
1118 event
1119 .metadata
1120 .as_ref()
1121 .map_or((None, None, None), |metadata| {
1122 (
1123 metadata.causation_id.as_ref().map(|id| **id),
1124 Some(metadata.correlation_id.to_string()),
1125 metadata
1126 .user_id
1127 .as_ref()
1128 .map(|uid| uid.as_ref().to_string()),
1129 )
1130 });
1131
1132 let event_data = serde_json::to_value(&event.payload).map_err(|e| {
1134 EventStoreError::SerializationFailed(format!(
1135 "Failed to serialize event data: {e}"
1136 ))
1137 })?;
1138
1139 let param_offset = i * 8;
1141 values.push(format!(
1142 "(${}, ${}, ${}, ${}, ${}, ${}, ${}, ${})",
1143 param_offset + 1,
1144 param_offset + 2,
1145 param_offset + 3,
1146 param_offset + 4,
1147 param_offset + 5,
1148 param_offset + 6,
1149 param_offset + 7,
1150 param_offset + 8
1151 ));
1152
1153 stream_ids.push(stream_id.as_ref().to_string());
1155 versions.push({
1156 let version_value: u64 = event_version.into();
1157 i64::try_from(version_value).map_err(|_| {
1158 EventStoreError::SerializationFailed(
1159 "Version too large for database".to_string(),
1160 )
1161 })?
1162 });
1163 event_types.push("generic".to_string());
1164 event_data_values.push(event_data);
1165 metadata_values.push(metadata_json);
1166 causation_ids.push(causation_id);
1167 correlation_ids.push(correlation_id);
1168 user_ids.push(user_id);
1169 }
1170
1171 query.push_str(&values.join(", "));
1173
1174 let mut sqlx_query = sqlx::query(&query);
1176
1177 for i in 0..batch.len() {
1179 sqlx_query = sqlx_query
1180 .bind(&stream_ids[i])
1181 .bind(versions[i])
1182 .bind(&event_types[i])
1183 .bind(&event_data_values[i])
1184 .bind(&metadata_values[i])
1185 .bind(causation_ids[i])
1186 .bind(&correlation_ids[i])
1187 .bind(&user_ids[i]);
1188 }
1189
1190 sqlx_query.execute(self.pool.as_ref()).await.map_err(|e| {
1191 e.as_database_error().map_or_else(
1192 || EventStoreError::ConnectionFailed(e.to_string()),
1193 |db_err| {
1194 db_err.code().map_or_else(
1195 || EventStoreError::ConnectionFailed(e.to_string()),
1196 |code| {
1197 if code == "40001" {
1198 EventStoreError::VersionConflict {
1200 stream: stream_id.clone(),
1201 expected: starting_version,
1202 current: EventVersion::initial(),
1203 }
1204 } else if db_err.is_unique_violation() {
1205 EventStoreError::VersionConflict {
1207 stream: stream_id.clone(),
1208 expected: starting_version,
1209 current: EventVersion::try_new(
1210 u64::from(starting_version) + 1,
1211 )
1212 .unwrap_or(starting_version),
1213 }
1214 } else {
1215 EventStoreError::ConnectionFailed(e.to_string())
1216 }
1217 },
1218 )
1219 },
1220 )
1221 })?;
1222 }
1223
1224 Ok(())
1225 }
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231 use eventcore::{EventToWrite, ExpectedVersion, StreamEvents};
1232 use serde::{Deserialize, Serialize};
1233
1234 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1236 enum TestEvent {
1237 Created { name: String },
1238 Updated { value: i32 },
1239 }
1240
1241 #[test]
1245 fn test_event_row_conversion() {
1246 let event_id = Uuid::nil(); let stream_id = "test-stream".to_string();
1251 let event_version = 0i64;
1252 let event_type = "TestEvent".to_string();
1253 let event_data = serde_json::json!({"test": true});
1254 let metadata = None;
1255 let causation_id = None;
1256 let correlation_id = None;
1257 let user_id = None;
1258 let created_at = chrono::Utc::now();
1259
1260 let event_row = EventRow {
1261 event_id,
1262 stream_id,
1263 event_version,
1264 event_type,
1265 event_data,
1266 metadata,
1267 causation_id,
1268 correlation_id,
1269 user_id,
1270 created_at,
1271 };
1272
1273 assert!(!format!("{event_row:?}").is_empty());
1275 }
1276
1277 #[test]
1278 fn test_expected_version_logic() {
1279 let new_version = ExpectedVersion::New;
1281 let exact_version = ExpectedVersion::Exact(EventVersion::try_new(5).unwrap());
1282 let any_version = ExpectedVersion::Any;
1283
1284 assert_eq!(new_version, ExpectedVersion::New);
1285 assert_eq!(
1286 exact_version,
1287 ExpectedVersion::Exact(EventVersion::try_new(5).unwrap())
1288 );
1289 assert_eq!(any_version, ExpectedVersion::Any);
1290 }
1291
1292 #[test]
1293 fn test_metadata_serialization() {
1294 use eventcore::{CorrelationId, UserId};
1295
1296 let metadata = EventMetadata::new()
1297 .with_correlation_id(CorrelationId::new())
1298 .with_user_id(Some(UserId::try_new("test-user").unwrap()));
1299
1300 let json_value = serde_json::to_value(&metadata).unwrap();
1301 let deserialized: EventMetadata = serde_json::from_value(json_value).unwrap();
1302
1303 assert_eq!(metadata, deserialized);
1304 }
1305
1306 #[test]
1307 fn test_stream_events_construction() {
1308 let stream_id = StreamId::try_new("test-stream").unwrap();
1309 let event_id = EventId::new();
1310 let payload = TestEvent::Created {
1311 name: "test".to_string(),
1312 };
1313
1314 let event = EventToWrite::new(event_id, payload);
1315 let stream_events = StreamEvents::new(stream_id.clone(), ExpectedVersion::New, vec![event]);
1316
1317 assert_eq!(stream_events.stream_id, stream_id);
1318 assert_eq!(stream_events.expected_version, ExpectedVersion::New);
1319 assert_eq!(stream_events.events.len(), 1);
1320 }
1321}