use crate::types::BackendId;
use alembic_core::{TypeName, Uid};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StateData {
#[serde(default)]
pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostgresTlsMode {
Disable,
Require,
}
#[async_trait::async_trait]
pub trait StateBackend: Send + Sync + std::fmt::Debug {
async fn load(&self) -> Result<StateData>;
async fn save(&self, data: &StateData) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct StateStore {
backend: Option<Arc<dyn StateBackend>>,
data: StateData,
}
impl StateStore {
pub fn new(backend: Option<Arc<dyn StateBackend>>, data: StateData) -> Self {
Self { backend, data }
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let backend: Option<Arc<dyn StateBackend>> =
Some(Arc::new(LocalBackend { path: path.clone() }) as Arc<dyn StateBackend>);
let data = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("read state: {}", path.display()))?;
serde_json::from_str::<StateData>(&raw)
.with_context(|| format!("parse state: {}", path.display()))?
} else {
StateData::default()
};
Ok(Self::new(backend, data))
}
pub async fn load_postgres(
url: impl Into<String>,
key: impl Into<String>,
tls_mode: PostgresTlsMode,
) -> Result<Self> {
let backend: Arc<dyn StateBackend> = Arc::new(PostgresBackend {
url: url.into(),
key: key.into(),
tls_mode,
});
let data = backend.load().await?;
Ok(Self::new(Some(backend), data))
}
pub async fn load_async(&mut self) -> Result<()> {
if let Some(backend) = &self.backend {
self.data = backend.load().await?;
}
Ok(())
}
pub async fn save_async(&self) -> Result<()> {
if let Some(backend) = &self.backend {
backend.save(&self.data).await?;
}
Ok(())
}
pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
self.data
.mappings
.get(&type_name)
.and_then(|map| map.get(&uid).cloned())
}
pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
self.data
.mappings
.entry(type_name)
.or_default()
.insert(uid, backend_id);
}
pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
type_map.remove(&uid);
}
}
pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
&self.data.mappings
}
}
#[derive(Debug)]
struct LocalBackend {
path: PathBuf,
}
#[async_trait::async_trait]
impl StateBackend for LocalBackend {
async fn load(&self) -> Result<StateData> {
if self.path.exists() {
let raw = fs::read_to_string(&self.path)
.with_context(|| format!("read state: {}", self.path.display()))?;
let data = serde_json::from_str::<StateData>(&raw)
.with_context(|| format!("parse state: {}", self.path.display()))?;
Ok(data)
} else {
Ok(StateData::default())
}
}
async fn save(&self, data: &StateData) -> Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("create state dir: {}", parent.display()))?;
}
let raw = serde_json::to_string_pretty(data)?;
let tmp = self.path.with_extension("json.tmp");
fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
fs::rename(&tmp, &self.path)
.with_context(|| format!("write state: {}", self.path.display()))?;
Ok(())
}
}
#[derive(Debug)]
struct PostgresBackend {
url: String,
key: String,
tls_mode: PostgresTlsMode,
}
#[async_trait::async_trait]
impl StateBackend for PostgresBackend {
async fn load(&self) -> Result<StateData> {
let client = self.connect().await?;
let row = client
.query_opt(
"SELECT payload::text FROM alembic_state WHERE state_key = $1",
&[&self.key],
)
.await
.with_context(|| "load postgres state payload")?;
let Some(row) = row else {
return Ok(StateData::default());
};
let raw: String = row
.try_get(0)
.with_context(|| "decode postgres state payload")?;
serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
}
async fn save(&self, data: &StateData) -> Result<()> {
let client = self.connect().await?;
let payload = serde_json::to_string(data)?;
client
.execute(
"INSERT INTO alembic_state (state_key, payload, updated_at)
VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
ON CONFLICT (state_key)
DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW()",
&[&self.key, &payload],
)
.await
.with_context(|| "save postgres state payload")?;
Ok(())
}
}
impl PostgresBackend {
async fn connect(&self) -> Result<tokio_postgres::Client> {
match self.tls_mode {
PostgresTlsMode::Disable => {
let (client, connection) =
tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
.await
.with_context(|| "connect postgres state backend")?;
tokio::spawn(async move {
if let Err(err) = connection.await {
tracing::warn!("postgres state backend connection error: {err}");
}
});
Ok(client)
}
PostgresTlsMode::Require => {
let connector = native_tls::TlsConnector::builder()
.build()
.with_context(|| "build postgres TLS connector")?;
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
let (client, connection) = tokio_postgres::connect(&self.url, connector)
.await
.with_context(|| "connect postgres state backend")?;
tokio::spawn(async move {
if let Err(err) = connection.await {
tracing::warn!("postgres state backend connection error: {err}");
}
});
Ok(client)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn t(s: &str) -> TypeName {
TypeName::new(s)
}
fn uid(n: u128) -> Uid {
Uid::from_u128(n)
}
#[test]
fn state_data_default_is_empty() {
let data = StateData::default();
assert!(data.mappings.is_empty());
}
#[test]
fn backend_id_returns_none_for_missing_type() {
let store = StateStore::new(None, StateData::default());
assert_eq!(store.backend_id(t("site"), uid(1)), None);
}
#[test]
fn backend_id_returns_none_for_missing_uid() {
let mut data = StateData::default();
data.mappings
.entry(t("site"))
.or_default()
.insert(uid(1), BackendId::Int(42));
let store = StateStore::new(None, data);
assert_eq!(store.backend_id(t("site"), uid(2)), None);
}
#[test]
fn backend_id_returns_value_for_existing_mapping() {
let mut data = StateData::default();
data.mappings
.entry(t("site"))
.or_default()
.insert(uid(1), BackendId::Int(42));
let store = StateStore::new(None, data);
assert_eq!(
store.backend_id(t("site"), uid(1)),
Some(BackendId::Int(42))
);
}
#[test]
fn set_backend_id_creates_mapping() {
let mut store = StateStore::new(None, StateData::default());
store.set_backend_id(t("site"), uid(1), BackendId::Int(42));
assert_eq!(
store.backend_id(t("site"), uid(1)),
Some(BackendId::Int(42))
);
}
#[test]
fn set_backend_id_overwrites_existing() {
let mut data = StateData::default();
data.mappings
.entry(t("site"))
.or_default()
.insert(uid(1), BackendId::Int(42));
let mut store = StateStore::new(None, data);
store.set_backend_id(t("site"), uid(1), BackendId::Int(99));
assert_eq!(
store.backend_id(t("site"), uid(1)),
Some(BackendId::Int(99))
);
}
#[test]
fn remove_backend_id_removes_mapping() {
let mut data = StateData::default();
data.mappings
.entry(t("site"))
.or_default()
.insert(uid(1), BackendId::Int(42));
let mut store = StateStore::new(None, data);
store.remove_backend_id(t("site"), uid(1));
assert_eq!(store.backend_id(t("site"), uid(1)), None);
}
#[test]
fn remove_backend_id_noop_for_missing() {
let mut store = StateStore::new(None, StateData::default());
store.remove_backend_id(t("site"), uid(1));
}
#[test]
fn all_mappings_returns_internal_reference() {
let store = StateStore::new(None, StateData::default());
assert!(store.all_mappings().is_empty());
}
#[tokio::test]
async fn local_backend_load_missing_returns_empty() {
let dir = TempDir::new().unwrap();
let backend = LocalBackend {
path: dir.path().join("nope.json"),
};
let data = backend.load().await.unwrap();
assert!(data.mappings.is_empty());
}
#[tokio::test]
async fn local_backend_save_load_round_trip() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("sub").join("state.json");
let backend = LocalBackend { path: path.clone() };
let mut data = StateData::default();
data.mappings
.entry(t("site"))
.or_default()
.insert(uid(10), BackendId::String("site-001".into()));
backend.save(&data).await.unwrap();
assert!(path.exists());
let loaded = backend.load().await.unwrap();
assert_eq!(
loaded.mappings[&t("site")][&uid(10)],
BackendId::String("site-001".into())
);
}
#[tokio::test]
async fn store_save_without_backend_is_noop() {
let store = StateStore::new(None, StateData::default());
store.save_async().await.unwrap();
}
#[tokio::test]
async fn store_load_async_without_backend_is_noop() {
let mut store = StateStore::new(None, StateData::default());
store.set_backend_id(t("x"), uid(1), BackendId::Int(1));
store.load_async().await.unwrap();
assert_eq!(store.backend_id(t("x"), uid(1)), Some(BackendId::Int(1)));
}
}