Skip to main content

alembic_engine/
state.rs

1//! local uid -> backend id state store.
2
3use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{anyhow, Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12use tokio_postgres::Client;
13
14/// on-disk state schema.
15#[derive(Debug, Default, Clone, Serialize, Deserialize)]
16pub struct StateData {
17    #[serde(default)]
18    pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
19}
20
21/// TLS configuration for postgres state backend connections.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum PostgresTlsMode {
24    Disable,
25    Require,
26}
27
28/// trait for pluggable state backends.
29#[async_trait::async_trait]
30pub trait StateBackend: Send + Sync + std::fmt::Debug {
31    async fn load(&mut self) -> Result<StateData>;
32    async fn save(&mut self, data: &StateData) -> Result<()>;
33}
34
35/// state store wrapper with load/save helpers.
36#[derive(Debug, Clone)]
37pub struct StateStore {
38    backend: Option<Arc<Mutex<dyn StateBackend>>>,
39    data: StateData,
40}
41
42impl StateStore {
43    /// create a new state store with an optional backend.
44    pub fn new(backend: Option<Arc<Mutex<dyn StateBackend>>>, data: StateData) -> Self {
45        Self { backend, data }
46    }
47
48    /// load state from a file path.
49    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
50        let path = path.as_ref().to_path_buf();
51        // always create a backend so we can save to the same path later.
52        // the backend's load() method handles missing files gracefully.
53        let backend: Option<Arc<Mutex<dyn StateBackend>>> =
54            Some(Arc::new(Mutex::new(LocalBackend { path: path.clone() }))
55                as Arc<Mutex<dyn StateBackend>>);
56        let data = if path.exists() {
57            let raw = fs::read_to_string(&path)
58                .with_context(|| format!("read state: {}", path.display()))?;
59            serde_json::from_str::<StateData>(&raw)
60                .with_context(|| format!("parse state: {}", path.display()))?
61        } else {
62            StateData::default()
63        };
64        Ok(Self::new(backend, data))
65    }
66
67    /// load state from a postgres backend.
68    pub async fn load_postgres(
69        url: impl Into<String>,
70        key: impl Into<String>,
71        tls_mode: PostgresTlsMode,
72    ) -> Result<Self> {
73        let mut postgres_backend = PostgresBackend {
74            url: url.into(),
75            key: key.into(),
76            tls_mode,
77            loaded_version: None,
78            table_ensured: false,
79        };
80        let data = postgres_backend.load().await?;
81        let backend: Arc<Mutex<dyn StateBackend>> = Arc::new(Mutex::new(postgres_backend));
82        Ok(Self::new(Some(backend), data))
83    }
84
85    /// load state from the configured backend.
86    pub async fn load_async(&mut self) -> Result<()> {
87        if let Some(backend) = &self.backend {
88            self.data = backend.lock().await.load().await?;
89        }
90        Ok(())
91    }
92
93    /// persist state to the configured backend.
94    pub async fn save_async(&self) -> Result<()> {
95        if let Some(backend) = &self.backend {
96            backend.lock().await.save(&self.data).await?;
97        }
98        Ok(())
99    }
100
101    /// lookup a backend id by type + uid.
102    pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
103        self.data
104            .mappings
105            .get(&type_name)
106            .and_then(|map| map.get(&uid).cloned())
107    }
108
109    /// set a backend id mapping.
110    pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
111        self.data
112            .mappings
113            .entry(type_name)
114            .or_default()
115            .insert(uid, backend_id);
116    }
117
118    /// remove a backend id mapping.
119    pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
120        if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
121            type_map.remove(&uid);
122        }
123    }
124
125    /// return all mappings for external use.
126    pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
127        &self.data.mappings
128    }
129}
130
131#[derive(Debug)]
132struct LocalBackend {
133    path: PathBuf,
134}
135
136#[async_trait::async_trait]
137impl StateBackend for LocalBackend {
138    async fn load(&mut self) -> Result<StateData> {
139        if self.path.exists() {
140            let raw = fs::read_to_string(&self.path)
141                .with_context(|| format!("read state: {}", self.path.display()))?;
142            let data = serde_json::from_str::<StateData>(&raw)
143                .with_context(|| format!("parse state: {}", self.path.display()))?;
144            Ok(data)
145        } else {
146            Ok(StateData::default())
147        }
148    }
149
150    async fn save(&mut self, data: &StateData) -> Result<()> {
151        if let Some(parent) = self.path.parent() {
152            fs::create_dir_all(parent)
153                .with_context(|| format!("create state dir: {}", parent.display()))?;
154        }
155        let raw = serde_json::to_string_pretty(data)?;
156        let tmp = self.path.with_extension("json.tmp");
157        fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
158        fs::rename(&tmp, &self.path)
159            .with_context(|| format!("write state: {}", self.path.display()))?;
160        Ok(())
161    }
162}
163
164#[derive(Debug)]
165struct PostgresBackend {
166    url: String,
167    key: String,
168    tls_mode: PostgresTlsMode,
169    loaded_version: Option<i32>,
170    table_ensured: bool,
171}
172
173#[async_trait::async_trait]
174impl StateBackend for PostgresBackend {
175    async fn load(&mut self) -> Result<StateData> {
176        let client = self.connect().await?;
177
178        let row = client
179            .query_opt(
180                "SELECT payload::text, loaded_version FROM alembic_state WHERE state_key = $1",
181                &[&self.key],
182            )
183            .await
184            .with_context(|| "load postgres state payload")?;
185
186        let Some(row) = row else {
187            self.loaded_version = Some(0);
188            return Ok(StateData::default());
189        };
190
191        self.loaded_version = Some(row.try_get::<_, i32>("loaded_version")?);
192
193        let raw: String = row
194            .try_get(0)
195            .with_context(|| "decode postgres state payload")?;
196        serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
197    }
198
199    async fn save(&mut self, data: &StateData) -> Result<()> {
200        let client = self.connect().await?;
201        let Some(loaded_version) = self.loaded_version else {
202            return Err(anyhow!("must load state before saving"));
203        };
204
205        let payload = serde_json::to_string(data)?;
206        let rows_modified = client
207            .execute(
208                "INSERT INTO alembic_state (state_key, payload, updated_at)
209                 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
210                 ON CONFLICT (state_key)
211                 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW(), loaded_version = $3+1
212                 WHERE alembic_state.loaded_version = $3",
213                &[&self.key, &payload, &loaded_version],
214            )
215            .await
216            .with_context(|| "save postgres state payload")?;
217
218        if rows_modified == 0 {
219            return Err(anyhow!(
220                "failed to save postgres state for key '{}': optimistic lock failed (expected loaded_version={})",
221                self.key,
222                loaded_version
223            ));
224        }
225
226        self.loaded_version = Some(loaded_version + 1);
227
228        Ok(())
229    }
230}
231
232impl PostgresBackend {
233    async fn ensure_table(&mut self, client: &Client) -> Result<()> {
234        if self.table_ensured {
235            return Ok(());
236        }
237
238        client
239            .execute(
240                "CREATE TABLE IF NOT EXISTS alembic_state (\
241                state_key TEXT PRIMARY KEY, \
242                payload JSONB NOT NULL, \
243                updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
244                loaded_version INTEGER NOT NULL DEFAULT 1\
245                )",
246                &[],
247            )
248            .await
249            .with_context(|| "ensure postgres alembic_state table exists")?;
250
251        // if the table already existed, CREATE TABLE IF NOT EXISTS won't add new columns.
252        client
253            .execute(
254                "ALTER TABLE alembic_state ADD COLUMN IF NOT EXISTS loaded_version INTEGER NOT NULL DEFAULT 1",
255                &[],
256            )
257            .await
258            .with_context(|| "ensure postgres alembic_state.loaded_version column exists")?;
259
260        self.table_ensured = true;
261        Ok(())
262    }
263
264    async fn connect(&mut self) -> Result<tokio_postgres::Client> {
265        match self.tls_mode {
266            PostgresTlsMode::Disable => {
267                let (client, connection) =
268                    tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
269                        .await
270                        .with_context(|| "connect postgres state backend")?;
271                tokio::spawn(async move {
272                    if let Err(err) = connection.await {
273                        tracing::warn!("postgres state backend connection error: {err}");
274                    }
275                });
276                self.ensure_table(&client).await?;
277                Ok(client)
278            }
279            PostgresTlsMode::Require => {
280                let connector = native_tls::TlsConnector::builder()
281                    .build()
282                    .with_context(|| "build postgres TLS connector")?;
283                let connector = postgres_native_tls::MakeTlsConnector::new(connector);
284                let (client, connection) = tokio_postgres::connect(&self.url, connector)
285                    .await
286                    .with_context(|| "connect postgres state backend")?;
287                tokio::spawn(async move {
288                    if let Err(err) = connection.await {
289                        tracing::warn!("postgres state backend connection error: {err}");
290                    }
291                });
292                self.ensure_table(&client).await?;
293                Ok(client)
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use tempfile::TempDir;
303
304    fn t(s: &str) -> TypeName {
305        TypeName::new(s)
306    }
307
308    fn uid(n: u128) -> Uid {
309        Uid::from_u128(n)
310    }
311
312    #[test]
313    fn state_data_default_is_empty() {
314        let data = StateData::default();
315        assert!(data.mappings.is_empty());
316    }
317
318    #[test]
319    fn backend_id_returns_none_for_missing_type() {
320        let store = StateStore::new(None, StateData::default());
321        assert_eq!(store.backend_id(t("site"), uid(1)), None);
322    }
323
324    #[test]
325    fn backend_id_returns_none_for_missing_uid() {
326        let mut data = StateData::default();
327        data.mappings
328            .entry(t("site"))
329            .or_default()
330            .insert(uid(1), BackendId::Int(42));
331        let store = StateStore::new(None, data);
332        assert_eq!(store.backend_id(t("site"), uid(2)), None);
333    }
334
335    #[test]
336    fn backend_id_returns_value_for_existing_mapping() {
337        let mut data = StateData::default();
338        data.mappings
339            .entry(t("site"))
340            .or_default()
341            .insert(uid(1), BackendId::Int(42));
342        let store = StateStore::new(None, data);
343        assert_eq!(
344            store.backend_id(t("site"), uid(1)),
345            Some(BackendId::Int(42))
346        );
347    }
348
349    #[test]
350    fn set_backend_id_creates_mapping() {
351        let mut store = StateStore::new(None, StateData::default());
352        store.set_backend_id(t("site"), uid(1), BackendId::Int(42));
353        assert_eq!(
354            store.backend_id(t("site"), uid(1)),
355            Some(BackendId::Int(42))
356        );
357    }
358
359    #[test]
360    fn set_backend_id_overwrites_existing() {
361        let mut data = StateData::default();
362        data.mappings
363            .entry(t("site"))
364            .or_default()
365            .insert(uid(1), BackendId::Int(42));
366        let mut store = StateStore::new(None, data);
367        store.set_backend_id(t("site"), uid(1), BackendId::Int(99));
368        assert_eq!(
369            store.backend_id(t("site"), uid(1)),
370            Some(BackendId::Int(99))
371        );
372    }
373
374    #[test]
375    fn remove_backend_id_removes_mapping() {
376        let mut data = StateData::default();
377        data.mappings
378            .entry(t("site"))
379            .or_default()
380            .insert(uid(1), BackendId::Int(42));
381        let mut store = StateStore::new(None, data);
382        store.remove_backend_id(t("site"), uid(1));
383        assert_eq!(store.backend_id(t("site"), uid(1)), None);
384    }
385
386    #[test]
387    fn remove_backend_id_noop_for_missing() {
388        let mut store = StateStore::new(None, StateData::default());
389        store.remove_backend_id(t("site"), uid(1));
390        // should not panic
391    }
392
393    #[test]
394    fn all_mappings_returns_internal_reference() {
395        let store = StateStore::new(None, StateData::default());
396        assert!(store.all_mappings().is_empty());
397    }
398
399    #[tokio::test]
400    async fn local_backend_load_missing_returns_empty() {
401        let dir = TempDir::new().unwrap();
402        let mut backend = LocalBackend {
403            path: dir.path().join("nope.json"),
404        };
405        let data = backend.load().await.unwrap();
406        assert!(data.mappings.is_empty());
407    }
408
409    #[tokio::test]
410    async fn local_backend_save_load_round_trip() {
411        let dir = TempDir::new().unwrap();
412        let path = dir.path().join("sub").join("state.json");
413        let mut backend = LocalBackend { path: path.clone() };
414
415        let mut data = StateData::default();
416        data.mappings
417            .entry(t("site"))
418            .or_default()
419            .insert(uid(10), BackendId::String("site-001".into()));
420
421        backend.save(&data).await.unwrap();
422        assert!(path.exists());
423
424        let loaded = backend.load().await.unwrap();
425        assert_eq!(
426            loaded.mappings[&t("site")][&uid(10)],
427            BackendId::String("site-001".into())
428        );
429    }
430
431    #[tokio::test]
432    async fn store_save_without_backend_is_noop() {
433        let store = StateStore::new(None, StateData::default());
434        store.save_async().await.unwrap();
435    }
436
437    #[tokio::test]
438    async fn store_load_async_without_backend_is_noop() {
439        let mut store = StateStore::new(None, StateData::default());
440        store.set_backend_id(t("x"), uid(1), BackendId::Int(1));
441        store.load_async().await.unwrap();
442        assert_eq!(store.backend_id(t("x"), uid(1)), Some(BackendId::Int(1)));
443    }
444}