1use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12#[derive(Debug, Default, Clone, Serialize, Deserialize)]
14pub struct StateData {
15 #[serde(default)]
16 pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum PostgresTlsMode {
22 Disable,
23 Require,
24}
25
26#[async_trait::async_trait]
28pub trait StateBackend: Send + Sync + std::fmt::Debug {
29 async fn load(&self) -> Result<StateData>;
30 async fn save(&self, data: &StateData) -> Result<()>;
31}
32
33#[derive(Debug, Clone)]
35pub struct StateStore {
36 backend: Option<Arc<dyn StateBackend>>,
37 data: StateData,
38}
39
40impl StateStore {
41 pub fn new(backend: Option<Arc<dyn StateBackend>>, data: StateData) -> Self {
43 Self { backend, data }
44 }
45
46 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
48 let path = path.as_ref().to_path_buf();
49 let backend: Option<Arc<dyn StateBackend>> =
52 Some(Arc::new(LocalBackend { path: path.clone() }) as Arc<dyn StateBackend>);
53 let data = if path.exists() {
54 let raw = fs::read_to_string(&path)
55 .with_context(|| format!("read state: {}", path.display()))?;
56 serde_json::from_str::<StateData>(&raw)
57 .with_context(|| format!("parse state: {}", path.display()))?
58 } else {
59 StateData::default()
60 };
61 Ok(Self::new(backend, data))
62 }
63
64 pub async fn load_postgres(
66 url: impl Into<String>,
67 key: impl Into<String>,
68 tls_mode: PostgresTlsMode,
69 ) -> Result<Self> {
70 let backend: Arc<dyn StateBackend> = Arc::new(PostgresBackend {
71 url: url.into(),
72 key: key.into(),
73 tls_mode,
74 });
75 let data = backend.load().await?;
76 Ok(Self::new(Some(backend), data))
77 }
78
79 pub async fn load_async(&mut self) -> Result<()> {
81 if let Some(backend) = &self.backend {
82 self.data = backend.load().await?;
83 }
84 Ok(())
85 }
86
87 pub async fn save_async(&self) -> Result<()> {
89 if let Some(backend) = &self.backend {
90 backend.save(&self.data).await?;
91 }
92 Ok(())
93 }
94
95 pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
97 self.data
98 .mappings
99 .get(&type_name)
100 .and_then(|map| map.get(&uid).cloned())
101 }
102
103 pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
105 self.data
106 .mappings
107 .entry(type_name)
108 .or_default()
109 .insert(uid, backend_id);
110 }
111
112 pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
114 if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
115 type_map.remove(&uid);
116 }
117 }
118
119 pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
121 &self.data.mappings
122 }
123}
124
125#[derive(Debug)]
126struct LocalBackend {
127 path: PathBuf,
128}
129
130#[async_trait::async_trait]
131impl StateBackend for LocalBackend {
132 async fn load(&self) -> Result<StateData> {
133 if self.path.exists() {
134 let raw = fs::read_to_string(&self.path)
135 .with_context(|| format!("read state: {}", self.path.display()))?;
136 let data = serde_json::from_str::<StateData>(&raw)
137 .with_context(|| format!("parse state: {}", self.path.display()))?;
138 Ok(data)
139 } else {
140 Ok(StateData::default())
141 }
142 }
143
144 async fn save(&self, data: &StateData) -> Result<()> {
145 if let Some(parent) = self.path.parent() {
146 fs::create_dir_all(parent)
147 .with_context(|| format!("create state dir: {}", parent.display()))?;
148 }
149 let raw = serde_json::to_string_pretty(data)?;
150 let tmp = self.path.with_extension("json.tmp");
151 fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
152 fs::rename(&tmp, &self.path)
153 .with_context(|| format!("write state: {}", self.path.display()))?;
154 Ok(())
155 }
156}
157
158#[derive(Debug)]
159struct PostgresBackend {
160 url: String,
161 key: String,
162 tls_mode: PostgresTlsMode,
163}
164
165#[async_trait::async_trait]
166impl StateBackend for PostgresBackend {
167 async fn load(&self) -> Result<StateData> {
168 let client = self.connect().await?;
169
170 let row = client
171 .query_opt(
172 "SELECT payload::text FROM alembic_state WHERE state_key = $1",
173 &[&self.key],
174 )
175 .await
176 .with_context(|| "load postgres state payload")?;
177
178 let Some(row) = row else {
179 return Ok(StateData::default());
180 };
181
182 let raw: String = row
183 .try_get(0)
184 .with_context(|| "decode postgres state payload")?;
185 serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
186 }
187
188 async fn save(&self, data: &StateData) -> Result<()> {
189 let client = self.connect().await?;
190
191 let payload = serde_json::to_string(data)?;
192 client
193 .execute(
194 "INSERT INTO alembic_state (state_key, payload, updated_at)
195 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
196 ON CONFLICT (state_key)
197 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW()",
198 &[&self.key, &payload],
199 )
200 .await
201 .with_context(|| "save postgres state payload")?;
202 Ok(())
203 }
204}
205
206impl PostgresBackend {
207 async fn connect(&self) -> Result<tokio_postgres::Client> {
208 match self.tls_mode {
209 PostgresTlsMode::Disable => {
210 let (client, connection) =
211 tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
212 .await
213 .with_context(|| "connect postgres state backend")?;
214 tokio::spawn(async move {
215 if let Err(err) = connection.await {
216 tracing::warn!("postgres state backend connection error: {err}");
217 }
218 });
219 Ok(client)
220 }
221 PostgresTlsMode::Require => {
222 let connector = native_tls::TlsConnector::builder()
223 .build()
224 .with_context(|| "build postgres TLS connector")?;
225 let connector = postgres_native_tls::MakeTlsConnector::new(connector);
226 let (client, connection) = tokio_postgres::connect(&self.url, connector)
227 .await
228 .with_context(|| "connect postgres state backend")?;
229 tokio::spawn(async move {
230 if let Err(err) = connection.await {
231 tracing::warn!("postgres state backend connection error: {err}");
232 }
233 });
234 Ok(client)
235 }
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use tempfile::TempDir;
244
245 fn t(s: &str) -> TypeName {
246 TypeName::new(s)
247 }
248
249 fn uid(n: u128) -> Uid {
250 Uid::from_u128(n)
251 }
252
253 #[test]
254 fn state_data_default_is_empty() {
255 let data = StateData::default();
256 assert!(data.mappings.is_empty());
257 }
258
259 #[test]
260 fn backend_id_returns_none_for_missing_type() {
261 let store = StateStore::new(None, StateData::default());
262 assert_eq!(store.backend_id(t("site"), uid(1)), None);
263 }
264
265 #[test]
266 fn backend_id_returns_none_for_missing_uid() {
267 let mut data = StateData::default();
268 data.mappings
269 .entry(t("site"))
270 .or_default()
271 .insert(uid(1), BackendId::Int(42));
272 let store = StateStore::new(None, data);
273 assert_eq!(store.backend_id(t("site"), uid(2)), None);
274 }
275
276 #[test]
277 fn backend_id_returns_value_for_existing_mapping() {
278 let mut data = StateData::default();
279 data.mappings
280 .entry(t("site"))
281 .or_default()
282 .insert(uid(1), BackendId::Int(42));
283 let store = StateStore::new(None, data);
284 assert_eq!(
285 store.backend_id(t("site"), uid(1)),
286 Some(BackendId::Int(42))
287 );
288 }
289
290 #[test]
291 fn set_backend_id_creates_mapping() {
292 let mut store = StateStore::new(None, StateData::default());
293 store.set_backend_id(t("site"), uid(1), BackendId::Int(42));
294 assert_eq!(
295 store.backend_id(t("site"), uid(1)),
296 Some(BackendId::Int(42))
297 );
298 }
299
300 #[test]
301 fn set_backend_id_overwrites_existing() {
302 let mut data = StateData::default();
303 data.mappings
304 .entry(t("site"))
305 .or_default()
306 .insert(uid(1), BackendId::Int(42));
307 let mut store = StateStore::new(None, data);
308 store.set_backend_id(t("site"), uid(1), BackendId::Int(99));
309 assert_eq!(
310 store.backend_id(t("site"), uid(1)),
311 Some(BackendId::Int(99))
312 );
313 }
314
315 #[test]
316 fn remove_backend_id_removes_mapping() {
317 let mut data = StateData::default();
318 data.mappings
319 .entry(t("site"))
320 .or_default()
321 .insert(uid(1), BackendId::Int(42));
322 let mut store = StateStore::new(None, data);
323 store.remove_backend_id(t("site"), uid(1));
324 assert_eq!(store.backend_id(t("site"), uid(1)), None);
325 }
326
327 #[test]
328 fn remove_backend_id_noop_for_missing() {
329 let mut store = StateStore::new(None, StateData::default());
330 store.remove_backend_id(t("site"), uid(1));
331 }
333
334 #[test]
335 fn all_mappings_returns_internal_reference() {
336 let store = StateStore::new(None, StateData::default());
337 assert!(store.all_mappings().is_empty());
338 }
339
340 #[tokio::test]
341 async fn local_backend_load_missing_returns_empty() {
342 let dir = TempDir::new().unwrap();
343 let backend = LocalBackend {
344 path: dir.path().join("nope.json"),
345 };
346 let data = backend.load().await.unwrap();
347 assert!(data.mappings.is_empty());
348 }
349
350 #[tokio::test]
351 async fn local_backend_save_load_round_trip() {
352 let dir = TempDir::new().unwrap();
353 let path = dir.path().join("sub").join("state.json");
354 let backend = LocalBackend { path: path.clone() };
355
356 let mut data = StateData::default();
357 data.mappings
358 .entry(t("site"))
359 .or_default()
360 .insert(uid(10), BackendId::String("site-001".into()));
361
362 backend.save(&data).await.unwrap();
363 assert!(path.exists());
364
365 let loaded = backend.load().await.unwrap();
366 assert_eq!(
367 loaded.mappings[&t("site")][&uid(10)],
368 BackendId::String("site-001".into())
369 );
370 }
371
372 #[tokio::test]
373 async fn store_save_without_backend_is_noop() {
374 let store = StateStore::new(None, StateData::default());
375 store.save_async().await.unwrap();
376 }
377
378 #[tokio::test]
379 async fn store_load_async_without_backend_is_noop() {
380 let mut store = StateStore::new(None, StateData::default());
381 store.set_backend_id(t("x"), uid(1), BackendId::Int(1));
382 store.load_async().await.unwrap();
383 assert_eq!(store.backend_id(t("x"), uid(1)), Some(BackendId::Int(1)));
384 }
385}