eventsourced_postgres/
snapshot_store.rs

1//! A [SnapshotStore] implementation based on [PostgreSQL](https://www.postgresql.org/).
2
3use crate::{Cnn, CnnPool, Error};
4use bb8_postgres::{bb8::Pool, PostgresConnectionManager};
5use bytes::Bytes;
6use eventsourced::snapshot_store::{Snapshot, SnapshotStore};
7use serde::{Deserialize, Serialize};
8use std::{
9    error::Error as StdError,
10    fmt::{self, Debug, Formatter},
11    marker::PhantomData,
12    num::NonZeroU64,
13};
14use tokio_postgres::{types::ToSql, NoTls};
15use tracing::debug;
16
17/// A [SnapshotStore] implementation based on [PostgreSQL](https://www.postgresql.org/).
18#[derive(Clone)]
19pub struct PostgresSnapshotStore<I> {
20    cnn_pool: CnnPool<NoTls>,
21    _id: PhantomData<I>,
22}
23
24impl<I> PostgresSnapshotStore<I> {
25    #[allow(missing_docs)]
26    pub async fn new(config: Config) -> Result<Self, Error> {
27        debug!(?config, "creating PostgresSnapshotStore");
28
29        // Create connection pool.
30        let tls = NoTls;
31        let cnn_manager = PostgresConnectionManager::new_from_stringlike(config.cnn_config(), tls)
32            .map_err(|error| {
33                Error::Postgres("cannot create connection manager".to_string(), error)
34            })?;
35        let cnn_pool = Pool::builder()
36            .build(cnn_manager)
37            .await
38            .map_err(|error| Error::Postgres("cannot create connection pool".to_string(), error))?;
39
40        // Setup tables.
41        if config.setup {
42            cnn_pool
43                .get()
44                .await
45                .map_err(Error::GetConnection)?
46                .execute(
47                    &include_str!("create_snapshot_store.sql")
48                        .replace("snapshots", &config.snapshots_table),
49                    &[],
50                )
51                .await
52                .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?;
53        }
54
55        Ok(Self {
56            cnn_pool,
57            _id: PhantomData,
58        })
59    }
60
61    async fn cnn(&self) -> Result<Cnn<NoTls>, Error> {
62        self.cnn_pool.get().await.map_err(Error::GetConnection)
63    }
64}
65
66impl<I> Debug for PostgresSnapshotStore<I> {
67    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68        f.debug_struct("PostgresSnapshotStore").finish()
69    }
70}
71
72impl<I> SnapshotStore for PostgresSnapshotStore<I>
73where
74    I: Debug + Clone + ToSql + Send + Sync + 'static,
75{
76    type Id = I;
77
78    type Error = Error;
79
80    async fn save<S, ToBytes, ToBytesError>(
81        &mut self,
82        id: &Self::Id,
83        seq_no: NonZeroU64,
84        state: &S,
85        to_bytes: &ToBytes,
86    ) -> Result<(), Self::Error>
87    where
88        S: Send,
89        ToBytes: Fn(&S) -> Result<Bytes, ToBytesError> + Sync,
90        ToBytesError: StdError + Send + Sync + 'static,
91    {
92        debug!(?id, %seq_no, "saving snapshot");
93
94        let bytes = to_bytes(state).map_err(|source| Error::ToBytes(Box::new(source)))?;
95        self.cnn()
96            .await?
97            .execute(
98                "INSERT INTO snapshots VALUES ($1, $2, $3)",
99                &[&id, &(seq_no.get() as i64), &bytes.as_ref()],
100            )
101            .await
102            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
103            .map(|_| ())
104    }
105
106    async fn load<S, FromBytes, FromBytesError>(
107        &self,
108        id: &Self::Id,
109        from_bytes: FromBytes,
110    ) -> Result<Option<Snapshot<S>>, Self::Error>
111    where
112        FromBytes: Fn(Bytes) -> Result<S, FromBytesError> + Send,
113        FromBytesError: StdError + Send + Sync + 'static,
114    {
115        debug!(?id, "loading snapshot");
116
117        self.cnn()
118            .await?
119            .query_opt(
120                "SELECT seq_no, state FROM snapshots
121                 WHERE id = $1
122                 AND seq_no = (select max(seq_no) from snapshots where id = $1)",
123                &[&id],
124            )
125            .await
126            .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
127            .map(move |row| {
128                let seq_no = (row.get::<_, i64>(0) as u64)
129                    .try_into()
130                    .map_err(|_| Error::ZeroNonZeroU64)?;
131                let bytes = row.get::<_, &[u8]>(1);
132                let bytes = Bytes::copy_from_slice(bytes);
133                from_bytes(bytes)
134                    .map_err(|source| Error::FromBytes(Box::new(source)))
135                    .map(|state| Snapshot::new(seq_no, state))
136            })
137            .transpose()
138    }
139}
140
141/// Configuration for the [PostgresSnapshotStore].
142#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "kebab-case")]
144pub struct Config {
145    pub host: String,
146
147    pub port: u16,
148
149    pub user: String,
150
151    pub password: String,
152
153    pub dbname: String,
154
155    pub sslmode: String,
156
157    #[serde(default = "snapshots_table_default")]
158    pub snapshots_table: String,
159
160    #[serde(default)]
161    pub setup: bool,
162}
163
164impl Config {
165    fn cnn_config(&self) -> String {
166        format!(
167            "host={} port={} user={} password={} dbname={} sslmode={}",
168            self.host, self.port, self.user, self.password, self.dbname, self.sslmode
169        )
170    }
171}
172
173impl Default for Config {
174    /// Default values suitable for local testing only.
175    fn default() -> Self {
176        Self {
177            host: "localhost".to_string(),
178            port: 5432,
179            user: "postgres".to_string(),
180            password: "".to_string(),
181            dbname: "postgres".to_string(),
182            sslmode: "prefer".to_string(),
183            snapshots_table: snapshots_table_default(),
184            setup: false,
185        }
186    }
187}
188
189fn snapshots_table_default() -> String {
190    "snapshots".to_string()
191}
192
193#[cfg(test)]
194mod tests {
195    use crate::{PostgresSnapshotStore, PostgresSnapshotStoreConfig};
196    use error_ext::BoxError;
197    use eventsourced::{binarize, snapshot_store::SnapshotStore};
198    use testcontainers::clients::Cli;
199    use testcontainers_modules::postgres::Postgres;
200    use uuid::Uuid;
201
202    #[tokio::test]
203    async fn test_snapshot_store() -> Result<(), BoxError> {
204        let client = Cli::default();
205        let container = client.run(Postgres::default().with_host_auth());
206        let port = container.get_host_port_ipv4(5432);
207
208        let config = PostgresSnapshotStoreConfig {
209            port,
210            setup: true,
211            ..Default::default()
212        };
213        let mut snapshot_store = PostgresSnapshotStore::<Uuid>::new(config).await?;
214
215        let id = Uuid::now_v7();
216
217        let snapshot = snapshot_store
218            .load::<i32, _, _>(&id, &binarize::serde_json::from_bytes)
219            .await?;
220        assert!(snapshot.is_none());
221
222        let seq_no = 42.try_into().unwrap();
223        let state = 666;
224
225        snapshot_store
226            .save(&id, seq_no, &state, &binarize::serde_json::to_bytes)
227            .await?;
228
229        let snapshot = snapshot_store
230            .load::<i32, _, _>(&id, &binarize::serde_json::from_bytes)
231            .await?;
232
233        assert!(snapshot.is_some());
234        let snapshot = snapshot.unwrap();
235        assert_eq!(snapshot.seq_no, seq_no);
236        assert_eq!(snapshot.state, state);
237
238        Ok(())
239    }
240}