1use crate::{Cnn, CnnPool, Error};
4use async_stream::stream;
5use bb8_postgres::{bb8::Pool, PostgresConnectionManager};
6use bytes::Bytes;
7use eventsourced::event_log::EventLog;
8use futures::{Stream, StreamExt, TryStreamExt};
9use serde::{Deserialize, Serialize};
10use std::{
11 error::Error as StdError,
12 fmt::{self, Debug, Formatter},
13 marker::PhantomData,
14 num::{NonZeroU64, NonZeroUsize},
15 time::Duration,
16};
17use tokio::time::sleep;
18use tokio_postgres::{types::ToSql, NoTls};
19use tracing::{debug, instrument};
20
21#[derive(Clone)]
23pub struct PostgresEventLog<I> {
24 poll_interval: Duration,
25 cnn_pool: CnnPool<NoTls>,
26 _id: PhantomData<I>,
27}
28
29impl<I> PostgresEventLog<I>
30where
31 I: ToSql + Sync,
32{
33 #[allow(missing_docs)]
34 pub async fn new(config: Config) -> Result<Self, Error> {
35 debug!(?config, "creating PostgresEventLog");
36
37 let tls = NoTls;
39 let cnn_manager = PostgresConnectionManager::new_from_stringlike(config.cnn_config(), tls)
40 .map_err(|error| {
41 Error::Postgres("cannot create connection manager".to_string(), error)
42 })?;
43 let cnn_pool = Pool::builder()
44 .build(cnn_manager)
45 .await
46 .map_err(|error| Error::Postgres("cannot create connection pool".to_string(), error))?;
47
48 if config.setup {
50 cnn_pool
51 .get()
52 .await
53 .map_err(Error::GetConnection)?
54 .batch_execute(
55 &include_str!("create_event_log.sql").replace("events", &config.events_table),
56 )
57 .await
58 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?;
59 }
60
61 Ok(Self {
62 poll_interval: config.poll_interval,
63 cnn_pool,
64 _id: PhantomData,
65 })
66 }
67
68 async fn cnn(&self) -> Result<Cnn<NoTls>, Error> {
69 self.cnn_pool.get().await.map_err(Error::GetConnection)
70 }
71
72 async fn next_events_by_id<E, FromBytes, FromBytesError>(
73 &self,
74 id: &I,
75 seq_no: i64,
76 from_bytes: FromBytes,
77 ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Error>> + Send, Error>
78 where
79 E: Send,
80 FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Send,
81 FromBytesError: StdError + Send + Sync + 'static,
82 {
83 debug!(?id, ?seq_no, "querying events");
84 let params: [&(dyn ToSql + Sync); 2] = [&id, &seq_no];
85 let events = self
86 .cnn()
87 .await?
88 .query_raw(
89 "SELECT seq_no, event FROM events WHERE id = $1 AND seq_no >= $2",
90 params,
91 )
92 .await
93 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
94 .map_err(|error| Error::Postgres("cannot get next row".to_string(), error))
95 .map(move |row| {
96 row.and_then(|row| {
97 let seq_no = (row.get::<_, i64>(0) as u64)
98 .try_into()
99 .map_err(|_| Error::ZeroNonZeroU64)?;
100 let bytes = row.get::<_, &[u8]>(1);
101 let bytes = Bytes::copy_from_slice(bytes);
102 from_bytes(bytes)
103 .map_err(|source| Error::FromBytes(Box::new(source)))
104 .map(|event| (seq_no, event))
105 })
106 });
107
108 Ok(events)
109 }
110
111 async fn next_events_by_type<E, FromBytes, FromBytesError>(
112 &self,
113 type_name: &str,
114 seq_no: i64,
115 from_bytes: FromBytes,
116 ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Error>> + Send, Error>
117 where
118 E: Send,
119 FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Send,
120 FromBytesError: StdError + Send + Sync + 'static,
121 {
122 debug!(%type_name, seq_no, "querying events");
123
124 let params: [&(dyn ToSql + Sync); 2] = [&type_name, &seq_no];
125 let events = self
126 .cnn()
127 .await?
128 .query_raw(
129 "SELECT seq_no, event FROM events WHERE type = $1 AND seq_no >= $2",
130 params,
131 )
132 .await
133 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
134 .map_err(|error| Error::Postgres("cannot get next row".to_string(), error))
135 .map(move |row| {
136 row.and_then(|row| {
137 let seq_no = (row.get::<_, i64>(0) as u64)
138 .try_into()
139 .map_err(|_| Error::ZeroNonZeroU64)?;
140 let bytes = row.get::<_, &[u8]>(1);
141 let bytes = Bytes::copy_from_slice(bytes);
142 from_bytes(bytes)
143 .map_err(|source| Error::FromBytes(Box::new(source)))
144 .map(|event| (seq_no, event))
145 })
146 });
147
148 Ok(events)
149 }
150
151 async fn last_seq_no_by_type(&self, type_name: &str) -> Result<Option<NonZeroU64>, Error> {
152 self.cnn()
153 .await?
154 .query_one(
155 "SELECT MAX(seq_no) FROM events WHERE type = $1",
156 &[&type_name],
157 )
158 .await
159 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
160 .and_then(|row| {
161 row.try_get::<_, i64>(0)
163 .ok()
164 .map(|seq_no| {
165 (seq_no as u64)
166 .try_into()
167 .map_err(|_| Error::ZeroNonZeroU64)
168 })
169 .transpose()
170 })
171 }
172}
173
174impl<I> Debug for PostgresEventLog<I> {
175 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
176 f.debug_struct("PostgresEventLog").finish()
177 }
178}
179
180impl<I> EventLog for PostgresEventLog<I>
181where
182 I: Clone + ToSql + Send + Sync + 'static,
183{
184 type Id = I;
185
186 type Error = Error;
187
188 const MAX_SEQ_NO: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(i64::MAX as u64) };
191
192 #[instrument(skip(self, event, to_bytes))]
193 async fn persist<E, ToBytes, ToBytesError>(
194 &mut self,
195 type_name: &'static str,
196 id: &Self::Id,
197 last_seq_no: Option<NonZeroU64>,
198 event: &E,
199 to_bytes: &ToBytes,
200 ) -> Result<NonZeroU64, Self::Error>
201 where
202 ToBytes: Fn(&E) -> Result<Bytes, ToBytesError> + Sync,
203 ToBytesError: StdError + Send + Sync + 'static,
204 {
205 let seq_no = last_seq_no.map(|n| n.get() as i64).unwrap_or_default() + 1;
206
207 let bytes = to_bytes(event).map_err(|error| Error::ToBytes(Box::new(error)))?;
208
209 self.cnn()
210 .await?
211 .query_one(
212 "INSERT INTO events (seq_no, type, id, event) VALUES ($1, $2, $3, $4) RETURNING seq_no",
213 &[&seq_no, &type_name, &id, &bytes.as_ref()],
214 )
215 .await
216 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
217 .and_then(|row| {
218 (row.get::<_, i64>(0) as u64)
219 .try_into()
220 .map_err(|_| Error::ZeroNonZeroU64)
221 })
222 }
223
224 #[instrument(skip(self))]
225 async fn last_seq_no(
226 &self,
227 type_name: &'static str,
228 id: &Self::Id,
229 ) -> Result<Option<NonZeroU64>, Self::Error> {
230 self.cnn()
231 .await?
232 .query_one("SELECT MAX(seq_no) FROM events WHERE id = $1", &[&id])
233 .await
234 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
235 .and_then(|row| {
236 row.try_get::<_, i64>(0)
238 .ok()
239 .map(|seq_no| {
240 (seq_no as u64)
241 .try_into()
242 .map_err(|_| Error::ZeroNonZeroU64)
243 })
244 .transpose()
245 })
246 }
247
248 #[instrument(skip(self, from_bytes))]
249 async fn events_by_id<E, FromBytes, FromBytesError>(
250 &self,
251 type_name: &'static str,
252 id: &Self::Id,
253 seq_no: NonZeroU64,
254 from_bytes: FromBytes,
255 ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Self::Error>> + Send, Self::Error>
256 where
257 E: Send,
258 FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Copy + Send + Sync + 'static,
259 FromBytesError: StdError + Send + Sync + 'static,
260 {
261 let last_seq_no = self
262 .last_seq_no(type_name, id)
263 .await?
264 .map(|n| n.get() as i64)
265 .unwrap_or_default();
266
267 let mut current_seq_no = seq_no.get() as i64;
268 let events = stream! {
269 'outer: loop {
270 let events = self
271 .next_events_by_id(id, current_seq_no, from_bytes)
272 .await?;
273
274 for await event in events {
275 match event {
276 Ok(event @ (seq_no, _)) => {
277 current_seq_no += seq_no.get() as i64 + 1;
278 yield Ok(event);
279 }
280
281 Err(error) => {
282 yield Err(error);
283 break 'outer;
284 }
285 }
286 }
287
288 if current_seq_no >= last_seq_no {
290 sleep(self.poll_interval).await;
291 }
292 }
293 };
294
295 Ok(events)
296 }
297
298 #[instrument(skip(self, from_bytes))]
299 async fn events_by_type<E, FromBytes, FromBytesError>(
300 &self,
301 type_name: &'static str,
302 seq_no: NonZeroU64,
303 from_bytes: FromBytes,
304 ) -> Result<impl Stream<Item = Result<(NonZeroU64, E), Self::Error>> + Send, Self::Error>
305 where
306 E: Send,
307 FromBytes: Fn(Bytes) -> Result<E, FromBytesError> + Copy + Send + Sync + 'static,
308 FromBytesError: StdError + Send + Sync + 'static,
309 {
310 debug!(type_name, seq_no, "building events by type stream");
311
312 let last_seq_no = self
313 .last_seq_no_by_type(type_name)
314 .await?
315 .map(|n| n.get() as i64)
316 .unwrap_or_default();
317
318 let mut current_seq_no = seq_no.get() as i64;
319 let events = stream! {
320 'outer: loop {
321 let events = self
322 .next_events_by_type(type_name, current_seq_no, from_bytes)
323 .await?;
324
325 for await event in events {
326 match event {
327 Ok(event @ (seq_no, _)) => {
328 current_seq_no = seq_no.get() as i64 + 1;
329 yield Ok(event);
330 }
331
332 Err(error) => {
333 yield Err(error);
334 break 'outer;
335 }
336 }
337 }
338
339 if current_seq_no >= last_seq_no {
341 sleep(self.poll_interval).await;
342 }
343 }
344 };
345
346 Ok(events)
347 }
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
352#[serde(rename_all = "kebab-case")]
353pub struct Config {
354 pub host: String,
355
356 pub port: u16,
357
358 pub user: String,
359
360 pub password: String,
361
362 pub dbname: String,
363
364 pub sslmode: String,
365
366 #[serde(default = "events_table_default")]
367 pub events_table: String,
368
369 #[serde(default = "poll_interval_default", with = "humantime_serde")]
370 pub poll_interval: Duration,
371
372 #[serde(default = "id_broadcast_capacity_default")]
373 pub id_broadcast_capacity: NonZeroUsize,
374
375 #[serde(default)]
376 pub setup: bool,
377}
378
379impl Config {
380 fn cnn_config(&self) -> String {
381 format!(
382 "host={} port={} user={} password={} dbname={} sslmode={}",
383 self.host, self.port, self.user, self.password, self.dbname, self.sslmode
384 )
385 }
386}
387
388impl Default for Config {
389 fn default() -> Self {
391 Self {
392 host: "localhost".to_string(),
393 port: 5432,
394 user: "postgres".to_string(),
395 password: "".to_string(),
396 dbname: "postgres".to_string(),
397 sslmode: "prefer".to_string(),
398 events_table: events_table_default(),
399 poll_interval: poll_interval_default(),
400 id_broadcast_capacity: id_broadcast_capacity_default(),
401 setup: false,
402 }
403 }
404}
405
406fn events_table_default() -> String {
407 "events".to_string()
408}
409
410const fn poll_interval_default() -> Duration {
411 Duration::from_secs(2)
412}
413
414const fn id_broadcast_capacity_default() -> NonZeroUsize {
415 NonZeroUsize::MIN
416}
417
418#[cfg(test)]
419mod tests {
420 use crate::{PostgresEventLog, PostgresEventLogConfig};
421 use error_ext::BoxError;
422 use eventsourced::{binarize, event_log::EventLog};
423 use futures::{StreamExt, TryStreamExt};
424 use std::{future, num::NonZeroU64};
425 use testcontainers::clients::Cli;
426 use testcontainers_modules::postgres::Postgres;
427 use uuid::Uuid;
428
429 #[tokio::test]
430 async fn test_event_log() -> Result<(), BoxError> {
431 let client = Cli::default();
432 let container = client.run(Postgres::default().with_host_auth());
433 let port = container.get_host_port_ipv4(5432);
434
435 let config = PostgresEventLogConfig {
436 port,
437 setup: true,
438 ..Default::default()
439 };
440 let mut event_log = PostgresEventLog::<Uuid>::new(config).await?;
441
442 let id = Uuid::now_v7();
443
444 let last_seq_no = event_log.last_seq_no("counter", &id).await?;
447 assert_eq!(last_seq_no, None);
448
449 let last_seq_no = event_log
450 .persist("counter", &id, None, &1, &binarize::serde_json::to_bytes)
451 .await?;
452 assert!(last_seq_no.get() == 1);
453
454 event_log
455 .persist(
456 "counter",
457 &id,
458 Some(last_seq_no),
459 &2,
460 &binarize::serde_json::to_bytes,
461 )
462 .await?;
463
464 let result = event_log
465 .persist(
466 "counter",
467 &id,
468 Some(last_seq_no),
469 &3,
470 &binarize::serde_json::to_bytes,
471 )
472 .await;
473 assert!(result.is_err());
474
475 event_log
476 .persist(
477 "counter",
478 &id,
479 Some(last_seq_no.checked_add(1).expect("overflow")),
480 &3,
481 &binarize::serde_json::to_bytes,
482 )
483 .await?;
484
485 let last_seq_no = event_log.last_seq_no("counter", &id).await?;
486 assert_eq!(last_seq_no, Some(3.try_into()?));
487
488 let events = event_log
489 .events_by_id::<u32, _, _>(
490 "counter",
491 &id,
492 2.try_into()?,
493 binarize::serde_json::from_bytes,
494 )
495 .await?;
496 let sum = events
497 .take(2)
498 .try_fold(0u32, |acc, (_, n)| future::ready(Ok(acc + n)))
499 .await?;
500 assert_eq!(sum, 5);
501
502 let events = event_log
503 .events_by_type::<u32, _, _>(
504 "counter",
505 NonZeroU64::MIN,
506 binarize::serde_json::from_bytes,
507 )
508 .await?;
509
510 let last_seq_no = event_log
511 .clone()
512 .persist(
513 "counter",
514 &id,
515 last_seq_no,
516 &4,
517 &binarize::serde_json::to_bytes,
518 )
519 .await?;
520 event_log
521 .clone()
522 .persist(
523 "counter",
524 &id,
525 Some(last_seq_no),
526 &5,
527 &binarize::serde_json::to_bytes,
528 )
529 .await?;
530 let last_seq_no = event_log.last_seq_no("counter", &id).await?;
531 assert_eq!(last_seq_no, Some(5.try_into()?));
532
533 let sum = events
534 .take(5)
535 .try_fold(0u32, |acc, (_, n)| future::ready(Ok(acc + n)))
536 .await?;
537 assert_eq!(sum, 15);
538
539 Ok(())
540 }
541}