Skip to main content

alembic_engine/
state.rs

1//! local uid -> backend id state store.
2
3use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12/// on-disk state schema.
13#[derive(Debug, Default, Clone, Serialize, Deserialize)]
14pub struct StateData {
15    #[serde(default)]
16    pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
17}
18
19/// TLS configuration for postgres state backend connections.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum PostgresTlsMode {
22    Disable,
23    Require,
24}
25
26/// trait for pluggable state backends.
27#[async_trait::async_trait]
28pub trait StateBackend: Send + Sync + std::fmt::Debug {
29    async fn load(&self) -> Result<StateData>;
30    async fn save(&self, data: &StateData) -> Result<()>;
31}
32
33/// state store wrapper with load/save helpers.
34#[derive(Debug, Clone)]
35pub struct StateStore {
36    backend: Option<Arc<dyn StateBackend>>,
37    data: StateData,
38}
39
40impl StateStore {
41    /// create a new state store with an optional backend.
42    pub fn new(backend: Option<Arc<dyn StateBackend>>, data: StateData) -> Self {
43        Self { backend, data }
44    }
45
46    /// load state from a file path.
47    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
48        let path = path.as_ref().to_path_buf();
49        // Always create a backend so we can save to the same path later.
50        // The backend's load() method handles missing files gracefully.
51        let backend: Option<Arc<dyn StateBackend>> =
52            Some(Arc::new(LocalBackend { path: path.clone() }) as Arc<dyn StateBackend>);
53        let data = if path.exists() {
54            let raw = fs::read_to_string(&path)
55                .with_context(|| format!("read state: {}", path.display()))?;
56            serde_json::from_str::<StateData>(&raw)
57                .with_context(|| format!("parse state: {}", path.display()))?
58        } else {
59            StateData::default()
60        };
61        Ok(Self::new(backend, data))
62    }
63
64    /// load state from a postgres backend.
65    pub async fn load_postgres(
66        url: impl Into<String>,
67        key: impl Into<String>,
68        tls_mode: PostgresTlsMode,
69    ) -> Result<Self> {
70        let backend: Arc<dyn StateBackend> = Arc::new(PostgresBackend {
71            url: url.into(),
72            key: key.into(),
73            tls_mode,
74        });
75        let data = backend.load().await?;
76        Ok(Self::new(Some(backend), data))
77    }
78
79    /// load state from the configured backend.
80    pub async fn load_async(&mut self) -> Result<()> {
81        if let Some(backend) = &self.backend {
82            self.data = backend.load().await?;
83        }
84        Ok(())
85    }
86
87    /// persist state to the configured backend.
88    pub async fn save_async(&self) -> Result<()> {
89        if let Some(backend) = &self.backend {
90            backend.save(&self.data).await?;
91        }
92        Ok(())
93    }
94
95    /// lookup a backend id by type + uid.
96    pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
97        self.data
98            .mappings
99            .get(&type_name)
100            .and_then(|map| map.get(&uid).cloned())
101    }
102
103    /// set a backend id mapping.
104    pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
105        self.data
106            .mappings
107            .entry(type_name)
108            .or_default()
109            .insert(uid, backend_id);
110    }
111
112    /// remove a backend id mapping.
113    pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
114        if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
115            type_map.remove(&uid);
116        }
117    }
118
119    /// return all mappings for external use.
120    pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
121        &self.data.mappings
122    }
123}
124
125#[derive(Debug)]
126struct LocalBackend {
127    path: PathBuf,
128}
129
130#[async_trait::async_trait]
131impl StateBackend for LocalBackend {
132    async fn load(&self) -> Result<StateData> {
133        if self.path.exists() {
134            let raw = fs::read_to_string(&self.path)
135                .with_context(|| format!("read state: {}", self.path.display()))?;
136            let data = serde_json::from_str::<StateData>(&raw)
137                .with_context(|| format!("parse state: {}", self.path.display()))?;
138            Ok(data)
139        } else {
140            Ok(StateData::default())
141        }
142    }
143
144    async fn save(&self, data: &StateData) -> Result<()> {
145        if let Some(parent) = self.path.parent() {
146            fs::create_dir_all(parent)
147                .with_context(|| format!("create state dir: {}", parent.display()))?;
148        }
149        let raw = serde_json::to_string_pretty(data)?;
150        let tmp = self.path.with_extension("json.tmp");
151        fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
152        fs::rename(&tmp, &self.path)
153            .with_context(|| format!("write state: {}", self.path.display()))?;
154        Ok(())
155    }
156}
157
158#[derive(Debug)]
159struct PostgresBackend {
160    url: String,
161    key: String,
162    tls_mode: PostgresTlsMode,
163}
164
165#[async_trait::async_trait]
166impl StateBackend for PostgresBackend {
167    async fn load(&self) -> Result<StateData> {
168        let client = self.connect().await?;
169
170        let row = client
171            .query_opt(
172                "SELECT payload::text FROM alembic_state WHERE state_key = $1",
173                &[&self.key],
174            )
175            .await
176            .with_context(|| "load postgres state payload")?;
177
178        let Some(row) = row else {
179            return Ok(StateData::default());
180        };
181
182        let raw: String = row
183            .try_get(0)
184            .with_context(|| "decode postgres state payload")?;
185        serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
186    }
187
188    async fn save(&self, data: &StateData) -> Result<()> {
189        let client = self.connect().await?;
190
191        let payload = serde_json::to_string(data)?;
192        client
193            .execute(
194                "INSERT INTO alembic_state (state_key, payload, updated_at)
195                 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
196                 ON CONFLICT (state_key)
197                 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW()",
198                &[&self.key, &payload],
199            )
200            .await
201            .with_context(|| "save postgres state payload")?;
202        Ok(())
203    }
204}
205
206impl PostgresBackend {
207    async fn connect(&self) -> Result<tokio_postgres::Client> {
208        match self.tls_mode {
209            PostgresTlsMode::Disable => {
210                let (client, connection) =
211                    tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
212                        .await
213                        .with_context(|| "connect postgres state backend")?;
214                tokio::spawn(async move {
215                    if let Err(err) = connection.await {
216                        tracing::warn!("postgres state backend connection error: {err}");
217                    }
218                });
219                Ok(client)
220            }
221            PostgresTlsMode::Require => {
222                let connector = native_tls::TlsConnector::builder()
223                    .build()
224                    .with_context(|| "build postgres TLS connector")?;
225                let connector = postgres_native_tls::MakeTlsConnector::new(connector);
226                let (client, connection) = tokio_postgres::connect(&self.url, connector)
227                    .await
228                    .with_context(|| "connect postgres state backend")?;
229                tokio::spawn(async move {
230                    if let Err(err) = connection.await {
231                        tracing::warn!("postgres state backend connection error: {err}");
232                    }
233                });
234                Ok(client)
235            }
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use tempfile::TempDir;
244
245    fn t(s: &str) -> TypeName {
246        TypeName::new(s)
247    }
248
249    fn uid(n: u128) -> Uid {
250        Uid::from_u128(n)
251    }
252
253    #[test]
254    fn state_data_default_is_empty() {
255        let data = StateData::default();
256        assert!(data.mappings.is_empty());
257    }
258
259    #[test]
260    fn backend_id_returns_none_for_missing_type() {
261        let store = StateStore::new(None, StateData::default());
262        assert_eq!(store.backend_id(t("site"), uid(1)), None);
263    }
264
265    #[test]
266    fn backend_id_returns_none_for_missing_uid() {
267        let mut data = StateData::default();
268        data.mappings
269            .entry(t("site"))
270            .or_default()
271            .insert(uid(1), BackendId::Int(42));
272        let store = StateStore::new(None, data);
273        assert_eq!(store.backend_id(t("site"), uid(2)), None);
274    }
275
276    #[test]
277    fn backend_id_returns_value_for_existing_mapping() {
278        let mut data = StateData::default();
279        data.mappings
280            .entry(t("site"))
281            .or_default()
282            .insert(uid(1), BackendId::Int(42));
283        let store = StateStore::new(None, data);
284        assert_eq!(
285            store.backend_id(t("site"), uid(1)),
286            Some(BackendId::Int(42))
287        );
288    }
289
290    #[test]
291    fn set_backend_id_creates_mapping() {
292        let mut store = StateStore::new(None, StateData::default());
293        store.set_backend_id(t("site"), uid(1), BackendId::Int(42));
294        assert_eq!(
295            store.backend_id(t("site"), uid(1)),
296            Some(BackendId::Int(42))
297        );
298    }
299
300    #[test]
301    fn set_backend_id_overwrites_existing() {
302        let mut data = StateData::default();
303        data.mappings
304            .entry(t("site"))
305            .or_default()
306            .insert(uid(1), BackendId::Int(42));
307        let mut store = StateStore::new(None, data);
308        store.set_backend_id(t("site"), uid(1), BackendId::Int(99));
309        assert_eq!(
310            store.backend_id(t("site"), uid(1)),
311            Some(BackendId::Int(99))
312        );
313    }
314
315    #[test]
316    fn remove_backend_id_removes_mapping() {
317        let mut data = StateData::default();
318        data.mappings
319            .entry(t("site"))
320            .or_default()
321            .insert(uid(1), BackendId::Int(42));
322        let mut store = StateStore::new(None, data);
323        store.remove_backend_id(t("site"), uid(1));
324        assert_eq!(store.backend_id(t("site"), uid(1)), None);
325    }
326
327    #[test]
328    fn remove_backend_id_noop_for_missing() {
329        let mut store = StateStore::new(None, StateData::default());
330        store.remove_backend_id(t("site"), uid(1));
331        // Should not panic
332    }
333
334    #[test]
335    fn all_mappings_returns_internal_reference() {
336        let store = StateStore::new(None, StateData::default());
337        assert!(store.all_mappings().is_empty());
338    }
339
340    #[tokio::test]
341    async fn local_backend_load_missing_returns_empty() {
342        let dir = TempDir::new().unwrap();
343        let backend = LocalBackend {
344            path: dir.path().join("nope.json"),
345        };
346        let data = backend.load().await.unwrap();
347        assert!(data.mappings.is_empty());
348    }
349
350    #[tokio::test]
351    async fn local_backend_save_load_round_trip() {
352        let dir = TempDir::new().unwrap();
353        let path = dir.path().join("sub").join("state.json");
354        let backend = LocalBackend { path: path.clone() };
355
356        let mut data = StateData::default();
357        data.mappings
358            .entry(t("site"))
359            .or_default()
360            .insert(uid(10), BackendId::String("site-001".into()));
361
362        backend.save(&data).await.unwrap();
363        assert!(path.exists());
364
365        let loaded = backend.load().await.unwrap();
366        assert_eq!(
367            loaded.mappings[&t("site")][&uid(10)],
368            BackendId::String("site-001".into())
369        );
370    }
371
372    #[tokio::test]
373    async fn store_save_without_backend_is_noop() {
374        let store = StateStore::new(None, StateData::default());
375        store.save_async().await.unwrap();
376    }
377
378    #[tokio::test]
379    async fn store_load_async_without_backend_is_noop() {
380        let mut store = StateStore::new(None, StateData::default());
381        store.set_backend_id(t("x"), uid(1), BackendId::Int(1));
382        store.load_async().await.unwrap();
383        assert_eq!(store.backend_id(t("x"), uid(1)), Some(BackendId::Int(1)));
384    }
385}