1use crate::types::BackendId;
4use alembic_core::{TypeName, Uid};
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::BTreeMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use std::sync::Arc;
12
13#[derive(Debug, Default, Clone, Serialize, Deserialize)]
15pub struct StateData {
16 #[serde(default)]
17 pub mappings: BTreeMap<TypeName, BTreeMap<Uid, BackendId>>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PostgresTlsMode {
23 Disable,
24 Require,
25}
26
27#[async_trait::async_trait]
29pub trait StateBackend: Send + Sync + std::fmt::Debug {
30 async fn load(&self) -> Result<StateData>;
31 async fn save(&self, data: &StateData) -> Result<()>;
32}
33
34#[derive(Debug, Clone)]
36pub struct StateStore {
37 backend: Option<Arc<dyn StateBackend>>,
38 data: StateData,
39}
40
41impl StateStore {
42 pub fn new(backend: Option<Arc<dyn StateBackend>>, data: StateData) -> Self {
44 Self { backend, data }
45 }
46
47 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
49 let path = path.as_ref().to_path_buf();
50 let data = if path.exists() {
51 let raw = fs::read_to_string(&path)
52 .with_context(|| format!("read state: {}", path.display()))?;
53 serde_json::from_str::<StateData>(&raw)
54 .with_context(|| format!("parse state: {}", path.display()))?
55 } else {
56 StateData::default()
57 };
58 Ok(Self::new(Some(Arc::new(LocalBackend { path })), data))
59 }
60
61 pub async fn load_postgres(
63 url: impl Into<String>,
64 key: impl Into<String>,
65 tls_mode: PostgresTlsMode,
66 ) -> Result<Self> {
67 let backend: Arc<dyn StateBackend> = Arc::new(PostgresBackend {
68 url: url.into(),
69 key: key.into(),
70 tls_mode,
71 });
72 let data = backend.load().await?;
73 Ok(Self::new(Some(backend), data))
74 }
75
76 pub async fn load_async(&mut self) -> Result<()> {
78 if let Some(backend) = &self.backend {
79 self.data = backend.load().await?;
80 }
81 Ok(())
82 }
83
84 pub async fn save_async(&self) -> Result<()> {
86 if let Some(backend) = &self.backend {
87 backend.save(&self.data).await?;
88 }
89 Ok(())
90 }
91
92 pub fn backend_id(&self, type_name: TypeName, uid: Uid) -> Option<BackendId> {
94 self.data
95 .mappings
96 .get(&type_name)
97 .and_then(|map| map.get(&uid).cloned())
98 }
99
100 pub fn set_backend_id(&mut self, type_name: TypeName, uid: Uid, backend_id: BackendId) {
102 self.data
103 .mappings
104 .entry(type_name)
105 .or_default()
106 .insert(uid, backend_id);
107 }
108
109 pub fn remove_backend_id(&mut self, type_name: TypeName, uid: Uid) {
111 if let Some(type_map) = self.data.mappings.get_mut(&type_name) {
112 type_map.remove(&uid);
113 }
114 }
115
116 pub fn all_mappings(&self) -> &BTreeMap<TypeName, BTreeMap<Uid, BackendId>> {
118 &self.data.mappings
119 }
120}
121
122#[derive(Debug)]
123struct LocalBackend {
124 path: PathBuf,
125}
126
127#[async_trait::async_trait]
128impl StateBackend for LocalBackend {
129 async fn load(&self) -> Result<StateData> {
130 if self.path.exists() {
131 let raw = fs::read_to_string(&self.path)
132 .with_context(|| format!("read state: {}", self.path.display()))?;
133 let data = serde_json::from_str::<StateData>(&raw)
134 .with_context(|| format!("parse state: {}", self.path.display()))?;
135 Ok(data)
136 } else {
137 Ok(StateData::default())
138 }
139 }
140
141 async fn save(&self, data: &StateData) -> Result<()> {
142 if let Some(parent) = self.path.parent() {
143 fs::create_dir_all(parent)
144 .with_context(|| format!("create state dir: {}", parent.display()))?;
145 }
146 let raw = serde_json::to_string_pretty(data)?;
147 let tmp = self.path.with_extension("json.tmp");
148 fs::write(&tmp, &raw).with_context(|| format!("write state tmp: {}", tmp.display()))?;
149 fs::rename(&tmp, &self.path)
150 .with_context(|| format!("write state: {}", self.path.display()))?;
151 Ok(())
152 }
153}
154
155#[derive(Debug)]
156struct PostgresBackend {
157 url: String,
158 key: String,
159 tls_mode: PostgresTlsMode,
160}
161
162impl PostgresBackend {
163 async fn connect(&self) -> Result<tokio_postgres::Client> {
164 match self.tls_mode {
165 PostgresTlsMode::Disable => {
166 let (client, connection) =
167 tokio_postgres::connect(&self.url, tokio_postgres::NoTls)
168 .await
169 .with_context(|| "connect postgres state backend")?;
170 tokio::spawn(async move {
171 if let Err(err) = connection.await {
172 tracing::warn!("postgres state backend connection error: {err}");
173 }
174 });
175 Ok(client)
176 }
177 PostgresTlsMode::Require => {
178 let connector = native_tls::TlsConnector::builder()
179 .build()
180 .with_context(|| "build postgres TLS connector")?;
181 let connector = postgres_native_tls::MakeTlsConnector::new(connector);
182 let (client, connection) = tokio_postgres::connect(&self.url, connector)
183 .await
184 .with_context(|| "connect postgres state backend")?;
185 tokio::spawn(async move {
186 if let Err(err) = connection.await {
187 tracing::warn!("postgres state backend connection error: {err}");
188 }
189 });
190 Ok(client)
191 }
192 }
193 }
194
195 }
197
198#[async_trait::async_trait]
199impl StateBackend for PostgresBackend {
200 async fn load(&self) -> Result<StateData> {
201 let client = self.connect().await?;
202
203 let row = client
204 .query_opt(
205 "SELECT payload::text FROM alembic_state WHERE state_key = $1",
206 &[&self.key],
207 )
208 .await
209 .with_context(|| "load postgres state payload")?;
210
211 let Some(row) = row else {
212 return Ok(StateData::default());
213 };
214
215 let raw: String = row
216 .try_get(0)
217 .with_context(|| "decode postgres state payload")?;
218 serde_json::from_str::<StateData>(&raw).with_context(|| "parse postgres state payload")
219 }
220
221 async fn save(&self, data: &StateData) -> Result<()> {
222 let client = self.connect().await?;
223
224 let payload = serde_json::to_string(data)?;
225 client
226 .execute(
227 "INSERT INTO alembic_state (state_key, payload, updated_at)
228 VALUES ($1, CAST($2 AS TEXT)::jsonb, NOW())
229 ON CONFLICT (state_key)
230 DO UPDATE SET payload = EXCLUDED.payload, updated_at = NOW()",
231 &[&self.key, &payload],
232 )
233 .await
234 .with_context(|| "save postgres state payload")?;
235 Ok(())
236 }
237}