1use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{anyhow, Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12use tokio_postgres::Client;
13
14#[derive(Debug, Default, Clone, Serialize, Deserialize)]
16pub struct StateData {
17 #[serde(default)]
18 pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum PostgresTlsMode {
24 Disable,
25 Require,
26}
27
28#[async_trait::async_trait]
30pub trait StateBackend: Send + Sync + std::fmt::Debug {
31 async fn load(&mut self) -> Result<StateData>;
32 async fn save(&mut self, data: &StateData) -> Result<()>;
33}
34
35#[derive(Debug, Clone)]
37pub struct StateStore {
38 backend: Option<Arc<Mutex<dyn StateBackend>>>,
39 data: StateData,
40}
41
42impl StateStore {
43 pub fn new(backend: Option<Arc<Mutex<dyn StateBackend>>>, data: StateData) -> Self {
45 Self { backend, data }
46 }
47
48 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
50 let path = path.as_ref().to_path_buf();
51 let backend: Option<Arc<Mutex<dyn StateBackend>>> =
54 Some(Arc::new(Mutex::new(LocalBackend { path: path.clone() }))
55 as Arc<Mutex<dyn StateBackend>>);
56 let data = if path.exists() {
57 let raw = fs::read_to_string(&path)
58 .with_context(|| format!("read state: {}", path.display()))?;
59 serde_json::from_str::<StateData>(&raw)
60 .with_context(|| format!("parse state: {}", path.display()))?
61 } else {
62 StateData::default()
63 };
64 Ok(Self::new(backend, data))
65 }
66
67 pub async fn load_postgres(
69 url: impl Into<String>,
70 key: impl Into<String>,
71 tls_mode: PostgresTlsMode,
72 ) -> Result<Self> {
73 let mut postgres_backend = PostgresBackend {
74 url: url.into(),
75 key: key.into(),
76 tls_mode,
77 loaded_version: None,
78 table_ensured: false,
79 };
80 let data = postgres_backend.load().await?;
81 let backend: Arc<Mutex<dyn StateBackend>> = Arc::new(Mutex::new(postgres_backend));
82 Ok(Self::new(Some(backend), data))
83 }
84
85 pub async fn load_async(&mut self) -> Result<()> {
87 if let Some(backend) = &self.backend {
88 self.data = backend.lock().await.load().await?;
89 }
90 Ok(())
91 }
92
93 pub async fn save_async(&self) -> Result<()> {
95 if let Some(backend) = &self.backend {
96 backend.lock().await.save(&self.data).await?;
97 }
98 Ok(())
99 }
100
101 pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
103 self.data
104 .mappings
105 .get(&type_name)
106 .and_then(|map| map.get(&uid).cloned())
107 }
108
109 pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
111 self.data
112 .mappings
113 .entry(type_name)
114 .or_default()
115 .insert(uid, backend_id);
116 }
117
118 pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
120 if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
121 type_map.remove(&uid);
122 }
123 }
124
125 pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
127 &self.data.mappings
128 }
129}
130
131#[derive(Debug)]
132struct LocalBackend {
133 path: PathBuf,
134}
135
136#[async_trait::async_trait]
137impl StateBackend for LocalBackend {
138 async fn load(&mut self) -> Result<StateData> {
139 if self.path.exists() {
140 let raw = fs::read_to_string(&self.path)
141 .with_context(|| format!("read state: {}", self.path.display()))?;
142 let data = serde_json::from_str::<StateData>(&raw)
143 .with_context(|| format!("parse state: {}", self.path.display()))?;
144 Ok(data)
145 } else {
146 Ok(StateData::default())
147 }
148 }
149
150 async fn save(&mut self, data: &StateData) -> Result<()> {
151 if let Some(parent) = self.path.parent() {
152 fs::create_dir_all(parent)
153 .with_context(|| format!("create state dir: {}", parent.display()))?;
154 }
155 let raw = serde_json::to_string_pretty(data)?;
156 let tmp = self.path.with_extension("json.tmp");
157 fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
158 fs::rename(&tmp, &self.path)
159 .with_context(|| format!("write state: {}", self.path.display()))?;
160 Ok(())
161 }
162}
163
164#[derive(Debug)]
165struct PostgresBackend {
166 url: String,
167 key: String,
168 tls_mode: PostgresTlsMode,
169 loaded_version: Option<i32>,
170 table_ensured: bool,
171}
172
173#[async_trait::async_trait]
174impl StateBackend for PostgresBackend {
175 async fn load(&mut self) -> Result<StateData> {
176 let client = self.connect().await?;
177
178 let row = client
179 .query_opt(
180 "SELECT payload::text, loaded_version FROM alembic_state WHERE state_key = $1",
181 &[&self.key],
182 )
183 .await
184 .with_context(|| "load postgres state payload")?;
185
186 let Some(row) = row else {
187 self.loaded_version = Some(0);
188 return Ok(StateData::default());
189 };
190
191 self.loaded_version = Some(row.try_get::<_, i32>("loaded_version")?);
192
193 let raw: String = row
194 .try_get(0)
195 .with_context(|| "decode postgres state payload")?;
196 serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
197 }
198
199 async fn save(&mut self, data: &StateData) -> Result<()> {
200 let client = self.connect().await?;
201 let Some(loaded_version) = self.loaded_version else {
202 return Err(anyhow!("must load state before saving"));
203 };
204
205 let payload = serde_json::to_string(data)?;
206 let rows_modified = client
207 .execute(
208 "INSERT INTO alembic_state (state_key, payload, updated_at)
209 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
210 ON CONFLICT (state_key)
211 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW(), loaded_version = $3+1
212 WHERE alembic_state.loaded_version = $3",
213 &[&self.key, &payload, &loaded_version],
214 )
215 .await
216 .with_context(|| "save postgres state payload")?;
217
218 if rows_modified == 0 {
219 return Err(anyhow!(
220 "failed to save postgres state for key '{}': optimistic lock failed (expected loaded_version={})",
221 self.key,
222 loaded_version
223 ));
224 }
225
226 self.loaded_version = Some(loaded_version + 1);
227
228 Ok(())
229 }
230}
231
232impl PostgresBackend {
233 async fn ensure_table(&mut self, client: &Client) -> Result<()> {
234 if self.table_ensured {
235 return Ok(());
236 }
237
238 client
239 .execute(
240 "CREATE TABLE IF NOT EXISTS alembic_state (\
241 state_key TEXT PRIMARY KEY, \
242 payload JSONB NOT NULL, \
243 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
244 loaded_version INTEGER NOT NULL DEFAULT 1\
245 )",
246 &[],
247 )
248 .await
249 .with_context(|| "ensure postgres alembic_state table exists")?;
250
251 client
253 .execute(
254 "ALTER TABLE alembic_state ADD COLUMN IF NOT EXISTS loaded_version INTEGER NOT NULL DEFAULT 1",
255 &[],
256 )
257 .await
258 .with_context(|| "ensure postgres alembic_state.loaded_version column exists")?;
259
260 self.table_ensured = true;
261 Ok(())
262 }
263
264 async fn connect(&mut self) -> Result<tokio_postgres::Client> {
265 match self.tls_mode {
266 PostgresTlsMode::Disable => {
267 let (client, connection) =
268 tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
269 .await
270 .with_context(|| "connect postgres state backend")?;
271 tokio::spawn(async move {
272 if let Err(err) = connection.await {
273 tracing::warn!("postgres state backend connection error: {err}");
274 }
275 });
276 self.ensure_table(&client).await?;
277 Ok(client)
278 }
279 PostgresTlsMode::Require => {
280 let connector = native_tls::TlsConnector::builder()
281 .build()
282 .with_context(|| "build postgres TLS connector")?;
283 let connector = postgres_native_tls::MakeTlsConnector::new(connector);
284 let (client, connection) = tokio_postgres::connect(&self.url, connector)
285 .await
286 .with_context(|| "connect postgres state backend")?;
287 tokio::spawn(async move {
288 if let Err(err) = connection.await {
289 tracing::warn!("postgres state backend connection error: {err}");
290 }
291 });
292 self.ensure_table(&client).await?;
293 Ok(client)
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use tempfile::TempDir;
303
304 fn t(s: &str) -> TypeName {
305 TypeName::new(s)
306 }
307
308 fn uid(n: u128) -> Uid {
309 Uid::from_u128(n)
310 }
311
312 #[test]
313 fn state_data_default_is_empty() {
314 let data = StateData::default();
315 assert!(data.mappings.is_empty());
316 }
317
318 #[test]
319 fn backend_id_returns_none_for_missing_type() {
320 let store = StateStore::new(None, StateData::default());
321 assert_eq!(store.backend_id(t("site"), uid(1)), None);
322 }
323
324 #[test]
325 fn backend_id_returns_none_for_missing_uid() {
326 let mut data = StateData::default();
327 data.mappings
328 .entry(t("site"))
329 .or_default()
330 .insert(uid(1), BackendId::Int(42));
331 let store = StateStore::new(None, data);
332 assert_eq!(store.backend_id(t("site"), uid(2)), None);
333 }
334
335 #[test]
336 fn backend_id_returns_value_for_existing_mapping() {
337 let mut data = StateData::default();
338 data.mappings
339 .entry(t("site"))
340 .or_default()
341 .insert(uid(1), BackendId::Int(42));
342 let store = StateStore::new(None, data);
343 assert_eq!(
344 store.backend_id(t("site"), uid(1)),
345 Some(BackendId::Int(42))
346 );
347 }
348
349 #[test]
350 fn set_backend_id_creates_mapping() {
351 let mut store = StateStore::new(None, StateData::default());
352 store.set_backend_id(t("site"), uid(1), BackendId::Int(42));
353 assert_eq!(
354 store.backend_id(t("site"), uid(1)),
355 Some(BackendId::Int(42))
356 );
357 }
358
359 #[test]
360 fn set_backend_id_overwrites_existing() {
361 let mut data = StateData::default();
362 data.mappings
363 .entry(t("site"))
364 .or_default()
365 .insert(uid(1), BackendId::Int(42));
366 let mut store = StateStore::new(None, data);
367 store.set_backend_id(t("site"), uid(1), BackendId::Int(99));
368 assert_eq!(
369 store.backend_id(t("site"), uid(1)),
370 Some(BackendId::Int(99))
371 );
372 }
373
374 #[test]
375 fn remove_backend_id_removes_mapping() {
376 let mut data = StateData::default();
377 data.mappings
378 .entry(t("site"))
379 .or_default()
380 .insert(uid(1), BackendId::Int(42));
381 let mut store = StateStore::new(None, data);
382 store.remove_backend_id(t("site"), uid(1));
383 assert_eq!(store.backend_id(t("site"), uid(1)), None);
384 }
385
386 #[test]
387 fn remove_backend_id_noop_for_missing() {
388 let mut store = StateStore::new(None, StateData::default());
389 store.remove_backend_id(t("site"), uid(1));
390 }
392
393 #[test]
394 fn all_mappings_returns_internal_reference() {
395 let store = StateStore::new(None, StateData::default());
396 assert!(store.all_mappings().is_empty());
397 }
398
399 #[tokio::test]
400 async fn local_backend_load_missing_returns_empty() {
401 let dir = TempDir::new().unwrap();
402 let mut backend = LocalBackend {
403 path: dir.path().join("nope.json"),
404 };
405 let data = backend.load().await.unwrap();
406 assert!(data.mappings.is_empty());
407 }
408
409 #[tokio::test]
410 async fn local_backend_save_load_round_trip() {
411 let dir = TempDir::new().unwrap();
412 let path = dir.path().join("sub").join("state.json");
413 let mut backend = LocalBackend { path: path.clone() };
414
415 let mut data = StateData::default();
416 data.mappings
417 .entry(t("site"))
418 .or_default()
419 .insert(uid(10), BackendId::String("site-001".into()));
420
421 backend.save(&data).await.unwrap();
422 assert!(path.exists());
423
424 let loaded = backend.load().await.unwrap();
425 assert_eq!(
426 loaded.mappings[&t("site")][&uid(10)],
427 BackendId::String("site-001".into())
428 );
429 }
430
431 #[tokio::test]
432 async fn store_save_without_backend_is_noop() {
433 let store = StateStore::new(None, StateData::default());
434 store.save_async().await.unwrap();
435 }
436
437 #[tokio::test]
438 async fn store_load_async_without_backend_is_noop() {
439 let mut store = StateStore::new(None, StateData::default());
440 store.set_backend_id(t("x"), uid(1), BackendId::Int(1));
441 store.load_async().await.unwrap();
442 assert_eq!(store.backend_id(t("x"), uid(1)), Some(BackendId::Int(1)));
443 }
444}