Skip to main content

alembic_engine/
state.rs

1//! local uid -> backend id state store.
2
3use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use std::sync::Arc;
12
13/// on-disk state schema.
14#[derive(Debug, Default, Clone, Serialize, Deserialize)]
15pub struct StateData {
16    #[serde(default)]
17    pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
18}
19
20/// TLS configuration for postgres state backend connections.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PostgresTlsMode {
23    Disable,
24    Require,
25}
26
27/// trait for pluggable state backends.
28#[async_trait::async_trait]
29pub trait StateBackend: Send + Sync + std::fmt::Debug {
30    async fn load(&self) -> Result<StateData>;
31    async fn save(&self, data: &StateData) -> Result<()>;
32}
33
34/// state store wrapper with load/save helpers.
35#[derive(Debug, Clone)]
36pub struct StateStore {
37    backend: Option<Arc<dyn StateBackend>>,
38    data: StateData,
39}
40
41impl StateStore {
42    /// create a new state store with an optional backend.
43    pub fn new(backend: Option<Arc<dyn StateBackend>>, data: StateData) -> Self {
44        Self { backend, data }
45    }
46
47    /// load state from a file path.
48    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
49        let path = path.as_ref().to_path_buf();
50        let data = if path.exists() {
51            let raw = fs::read_to_string(&path)
52                .with_context(|| format!("read state: {}", path.display()))?;
53            serde_json::from_str::<StateData>(&raw)
54                .with_context(|| format!("parse state: {}", path.display()))?
55        } else {
56            StateData::default()
57        };
58        Ok(Self::new(Some(Arc::new(LocalBackend { path })), data))
59    }
60
61    /// load state from a postgres backend.
62    pub async fn load_postgres(
63        url: impl Into<String>,
64        key: impl Into<String>,
65        tls_mode: PostgresTlsMode,
66    ) -> Result<Self> {
67        let backend: Arc<dyn StateBackend> = Arc::new(PostgresBackend {
68            url: url.into(),
69            key: key.into(),
70            tls_mode,
71        });
72        let data = backend.load().await?;
73        Ok(Self::new(Some(backend), data))
74    }
75
76    /// load state from the configured backend.
77    pub async fn load_async(&mut self) -> Result<()> {
78        if let Some(backend) = &self.backend {
79            self.data = backend.load().await?;
80        }
81        Ok(())
82    }
83
84    /// persist state to the configured backend.
85    pub async fn save_async(&self) -> Result<()> {
86        if let Some(backend) = &self.backend {
87            backend.save(&self.data).await?;
88        }
89        Ok(())
90    }
91
92    /// lookup a backend id by type + uid.
93    pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
94        self.data
95            .mappings
96            .get(&type_name)
97            .and_then(|map| map.get(&uid).cloned())
98    }
99
100    /// set a backend id mapping.
101    pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
102        self.data
103            .mappings
104            .entry(type_name)
105            .or_default()
106            .insert(uid, backend_id);
107    }
108
109    /// remove a backend id mapping.
110    pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
111        if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
112            type_map.remove(&uid);
113        }
114    }
115
116    /// return all mappings for external use.
117    pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
118        &self.data.mappings
119    }
120}
121
122#[derive(Debug)]
123struct LocalBackend {
124    path: PathBuf,
125}
126
127#[async_trait::async_trait]
128impl StateBackend for LocalBackend {
129    async fn load(&self) -> Result<StateData> {
130        if self.path.exists() {
131            let raw = fs::read_to_string(&self.path)
132                .with_context(|| format!("read state: {}", self.path.display()))?;
133            let data = serde_json::from_str::<StateData>(&raw)
134                .with_context(|| format!("parse state: {}", self.path.display()))?;
135            Ok(data)
136        } else {
137            Ok(StateData::default())
138        }
139    }
140
141    async fn save(&self, data: &StateData) -> Result<()> {
142        if let Some(parent) = self.path.parent() {
143            fs::create_dir_all(parent)
144                .with_context(|| format!("create state dir: {}", parent.display()))?;
145        }
146        let raw = serde_json::to_string_pretty(data)?;
147        let tmp = self.path.with_extension("json.tmp");
148        fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
149        fs::rename(&tmp, &self.path)
150            .with_context(|| format!("write state: {}", self.path.display()))?;
151        Ok(())
152    }
153}
154
155#[derive(Debug)]
156struct PostgresBackend {
157    url: String,
158    key: String,
159    tls_mode: PostgresTlsMode,
160}
161
162impl PostgresBackend {
163    async fn connect(&self) -> Result<tokio_postgres::Client> {
164        match self.tls_mode {
165            PostgresTlsMode::Disable => {
166                let (client, connection) =
167                    tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
168                        .await
169                        .with_context(|| "connect postgres state backend")?;
170                tokio::spawn(async move {
171                    if let Err(err) = connection.await {
172                        tracing::warn!("postgres state backend connection error: {err}");
173                    }
174                });
175                Ok(client)
176            }
177            PostgresTlsMode::Require => {
178                let connector = native_tls::TlsConnector::builder()
179                    .build()
180                    .with_context(|| "build postgres TLS connector")?;
181                let connector = postgres_native_tls::MakeTlsConnector::new(connector);
182                let (client, connection) = tokio_postgres::connect(&self.url, connector)
183                    .await
184                    .with_context(|| "connect postgres state backend")?;
185                tokio::spawn(async move {
186                    if let Err(err) = connection.await {
187                        tracing::warn!("postgres state backend connection error: {err}");
188                    }
189                });
190                Ok(client)
191            }
192        }
193    }
194
195    // The postgres table is expected to be pre-provisioned.
196}
197
198#[async_trait::async_trait]
199impl StateBackend for PostgresBackend {
200    async fn load(&self) -> Result<StateData> {
201        let client = self.connect().await?;
202
203        let row = client
204            .query_opt(
205                "SELECT payload::text FROM alembic_state WHERE state_key = $1",
206                &[&self.key],
207            )
208            .await
209            .with_context(|| "load postgres state payload")?;
210
211        let Some(row) = row else {
212            return Ok(StateData::default());
213        };
214
215        let raw: String = row
216            .try_get(0)
217            .with_context(|| "decode postgres state payload")?;
218        serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
219    }
220
221    async fn save(&self, data: &StateData) -> Result<()> {
222        let client = self.connect().await?;
223
224        let payload = serde_json::to_string(data)?;
225        client
226            .execute(
227                "INSERT INTO alembic_state (state_key, payload, updated_at)
228                 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
229                 ON CONFLICT (state_key)
230                 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW()",
231                &[&self.key, &payload],
232            )
233            .await
234            .with_context(|| "save postgres state payload")?;
235        Ok(())
236    }
237}