1use crate::{Cnn, CnnPool, Error};
4use bb8_postgres::{bb8::Pool, PostgresConnectionManager};
5use bytes::Bytes;
6use eventsourced::snapshot_store::{Snapshot, SnapshotStore};
7use serde::{Deserialize, Serialize};
8use std::{
9 error::Error as StdError,
10 fmt::{self, Debug, Formatter},
11 marker::PhantomData,
12 num::NonZeroU64,
13};
14use tokio_postgres::{types::ToSql, NoTls};
15use tracing::debug;
16
17#[derive(Clone)]
19pub struct PostgresSnapshotStore<I> {
20 cnn_pool: CnnPool<NoTls>,
21 _id: PhantomData<I>,
22}
23
24impl<I> PostgresSnapshotStore<I> {
25 #[allow(missing_docs)]
26 pub async fn new(config: Config) -> Result<Self, Error> {
27 debug!(?config, "creating PostgresSnapshotStore");
28
29 let tls = NoTls;
31 let cnn_manager = PostgresConnectionManager::new_from_stringlike(config.cnn_config(), tls)
32 .map_err(|error| {
33 Error::Postgres("cannot create connection manager".to_string(), error)
34 })?;
35 let cnn_pool = Pool::builder()
36 .build(cnn_manager)
37 .await
38 .map_err(|error| Error::Postgres("cannot create connection pool".to_string(), error))?;
39
40 if config.setup {
42 cnn_pool
43 .get()
44 .await
45 .map_err(Error::GetConnection)?
46 .execute(
47 &include_str!("create_snapshot_store.sql")
48 .replace("snapshots", &config.snapshots_table),
49 &[],
50 )
51 .await
52 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?;
53 }
54
55 Ok(Self {
56 cnn_pool,
57 _id: PhantomData,
58 })
59 }
60
61 async fn cnn(&self) -> Result<Cnn<NoTls>, Error> {
62 self.cnn_pool.get().await.map_err(Error::GetConnection)
63 }
64}
65
66impl<I> Debug for PostgresSnapshotStore<I> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 f.debug_struct("PostgresSnapshotStore").finish()
69 }
70}
71
72impl<I> SnapshotStore for PostgresSnapshotStore<I>
73where
74 I: Debug + Clone + ToSql + Send + Sync + 'static,
75{
76 type Id = I;
77
78 type Error = Error;
79
80 async fn save<S, ToBytes, ToBytesError>(
81 &mut self,
82 id: &Self::Id,
83 seq_no: NonZeroU64,
84 state: &S,
85 to_bytes: &ToBytes,
86 ) -> Result<(), Self::Error>
87 where
88 S: Send,
89 ToBytes: Fn(&S) -> Result<Bytes, ToBytesError> + Sync,
90 ToBytesError: StdError + Send + Sync + 'static,
91 {
92 debug!(?id, %seq_no, "saving snapshot");
93
94 let bytes = to_bytes(state).map_err(|source| Error::ToBytes(Box::new(source)))?;
95 self.cnn()
96 .await?
97 .execute(
98 "INSERT INTO snapshots VALUES ($1, $2, $3)",
99 &[&id, &(seq_no.get() as i64), &bytes.as_ref()],
100 )
101 .await
102 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))
103 .map(|_| ())
104 }
105
106 async fn load<S, FromBytes, FromBytesError>(
107 &self,
108 id: &Self::Id,
109 from_bytes: FromBytes,
110 ) -> Result<Option<Snapshot<S>>, Self::Error>
111 where
112 FromBytes: Fn(Bytes) -> Result<S, FromBytesError> + Send,
113 FromBytesError: StdError + Send + Sync + 'static,
114 {
115 debug!(?id, "loading snapshot");
116
117 self.cnn()
118 .await?
119 .query_opt(
120 "SELECT seq_no, state FROM snapshots
121 WHERE id = $1
122 AND seq_no = (select max(seq_no) from snapshots where id = $1)",
123 &[&id],
124 )
125 .await
126 .map_err(|error| Error::Postgres("cannot execute query".to_string(), error))?
127 .map(move |row| {
128 let seq_no = (row.get::<_, i64>(0) as u64)
129 .try_into()
130 .map_err(|_| Error::ZeroNonZeroU64)?;
131 let bytes = row.get::<_, &[u8]>(1);
132 let bytes = Bytes::copy_from_slice(bytes);
133 from_bytes(bytes)
134 .map_err(|source| Error::FromBytes(Box::new(source)))
135 .map(|state| Snapshot::new(seq_no, state))
136 })
137 .transpose()
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143#[serde(rename_all = "kebab-case")]
144pub struct Config {
145 pub host: String,
146
147 pub port: u16,
148
149 pub user: String,
150
151 pub password: String,
152
153 pub dbname: String,
154
155 pub sslmode: String,
156
157 #[serde(default = "snapshots_table_default")]
158 pub snapshots_table: String,
159
160 #[serde(default)]
161 pub setup: bool,
162}
163
164impl Config {
165 fn cnn_config(&self) -> String {
166 format!(
167 "host={} port={} user={} password={} dbname={} sslmode={}",
168 self.host, self.port, self.user, self.password, self.dbname, self.sslmode
169 )
170 }
171}
172
173impl Default for Config {
174 fn default() -> Self {
176 Self {
177 host: "localhost".to_string(),
178 port: 5432,
179 user: "postgres".to_string(),
180 password: "".to_string(),
181 dbname: "postgres".to_string(),
182 sslmode: "prefer".to_string(),
183 snapshots_table: snapshots_table_default(),
184 setup: false,
185 }
186 }
187}
188
189fn snapshots_table_default() -> String {
190 "snapshots".to_string()
191}
192
193#[cfg(test)]
194mod tests {
195 use crate::{PostgresSnapshotStore, PostgresSnapshotStoreConfig};
196 use error_ext::BoxError;
197 use eventsourced::{binarize, snapshot_store::SnapshotStore};
198 use testcontainers::clients::Cli;
199 use testcontainers_modules::postgres::Postgres;
200 use uuid::Uuid;
201
202 #[tokio::test]
203 async fn test_snapshot_store() -> Result<(), BoxError> {
204 let client = Cli::default();
205 let container = client.run(Postgres::default().with_host_auth());
206 let port = container.get_host_port_ipv4(5432);
207
208 let config = PostgresSnapshotStoreConfig {
209 port,
210 setup: true,
211 ..Default::default()
212 };
213 let mut snapshot_store = PostgresSnapshotStore::<Uuid>::new(config).await?;
214
215 let id = Uuid::now_v7();
216
217 let snapshot = snapshot_store
218 .load::<i32, _, _>(&id, &binarize::serde_json::from_bytes)
219 .await?;
220 assert!(snapshot.is_none());
221
222 let seq_no = 42.try_into().unwrap();
223 let state = 666;
224
225 snapshot_store
226 .save(&id, seq_no, &state, &binarize::serde_json::to_bytes)
227 .await?;
228
229 let snapshot = snapshot_store
230 .load::<i32, _, _>(&id, &binarize::serde_json::from_bytes)
231 .await?;
232
233 assert!(snapshot.is_some());
234 let snapshot = snapshot.unwrap();
235 assert_eq!(snapshot.seq_no, seq_no);
236 assert_eq!(snapshot.state, state);
237
238 Ok(())
239 }
240}