rustio_admin/
migrations.rs1use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9use crate::orm::Db;
10
11pub struct MigrationFile {
12 pub version: i64,
13 pub name: String,
14 pub path: PathBuf,
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct ApplyOptions {
19 pub verbose: bool,
20}
21
22pub async fn apply(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<String>> {
23 apply_with(db, dir, ApplyOptions::default()).await
24}
25
26pub async fn apply_with(db: &Db, dir: impl AsRef<Path>, opts: ApplyOptions) -> Result<Vec<String>> {
27 ensure_tracking_table(db).await?;
28
29 let files = discover(dir.as_ref())?;
30 let already = applied_versions(db).await?;
31 let mut newly = Vec::new();
32
33 for file in files {
34 if already.contains(&file.version) {
35 continue;
36 }
37 if opts.verbose {
38 log::info!("applying migration {:04}_{}", file.version, file.name);
39 }
40
41 let sql = fs::read_to_string(&file.path)?;
42 let statements = split_statements(&sql);
43
44 let mut tx = db
45 .pool()
46 .begin()
47 .await
48 .map_err(|e| Error::Internal(format!("begin tx: {e}")))?;
49
50 for stmt in &statements {
51 let trimmed = stmt.trim();
52 if trimmed.is_empty() {
53 continue;
54 }
55 sqlx::query(trimmed)
56 .execute(&mut *tx)
57 .await
58 .map_err(|e| Error::Internal(format!("migration {} failed: {e}", file.name)))?;
59 }
60
61 sqlx::query(
62 "INSERT INTO rustio_migrations (version, name, applied_at)
63 VALUES ($1, $2, NOW())",
64 )
65 .bind(file.version)
66 .bind(&file.name)
67 .execute(&mut *tx)
68 .await
69 .map_err(|e| Error::Internal(format!("tracking insert: {e}")))?;
70
71 tx.commit()
72 .await
73 .map_err(|e| Error::Internal(format!("commit: {e}")))?;
74
75 newly.push(file.name.clone());
76 }
77
78 Ok(newly)
79}
80
81pub async fn applied_versions(db: &Db) -> Result<Vec<i64>> {
82 ensure_tracking_table(db).await?;
83 let rows =
84 sqlx::query_scalar::<_, i64>("SELECT version FROM rustio_migrations ORDER BY version ASC")
85 .fetch_all(db.pool())
86 .await?;
87 Ok(rows)
88}
89
90pub async fn status(db: &Db, dir: impl AsRef<Path>) -> Result<Vec<(String, bool)>> {
91 let applied = applied_versions(db).await?;
92 let files = discover(dir.as_ref())?;
93 Ok(files
94 .into_iter()
95 .map(|f| {
96 (
97 format!("{:04}_{}", f.version, f.name),
98 applied.contains(&f.version),
99 )
100 })
101 .collect())
102}
103
104pub fn generate(dir: impl AsRef<Path>, name: &str) -> Result<PathBuf> {
105 let dir = dir.as_ref();
106 fs::create_dir_all(dir)?;
107 let existing = discover(dir).unwrap_or_default();
108 let next = existing.iter().map(|m| m.version).max().unwrap_or(0) + 1;
109 let filename = format!("{:04}_{}.sql", next, slugify(name));
110 let path = dir.join(filename);
111 fs::write(&path, format!("-- {}\n\n", name))?;
112 Ok(path)
113}
114
115fn discover(dir: &Path) -> Result<Vec<MigrationFile>> {
116 if !dir.exists() {
117 return Ok(Vec::new());
118 }
119 let mut out = Vec::new();
120 for entry in fs::read_dir(dir)? {
121 let entry = entry?;
122 let path = entry.path();
123 if path.extension().and_then(|s| s.to_str()) != Some("sql") {
124 continue;
125 }
126 let stem = match path.file_stem().and_then(|s| s.to_str()) {
127 Some(s) => s,
128 None => continue,
129 };
130 let (ver_part, name_part) = match stem.split_once('_') {
131 Some(p) => p,
132 None => continue,
133 };
134 let version: i64 = match ver_part.parse() {
135 Ok(n) => n,
136 Err(_) => continue,
137 };
138 out.push(MigrationFile {
139 version,
140 name: name_part.to_string(),
141 path,
142 });
143 }
144 out.sort_by_key(|m| m.version);
145 Ok(out)
146}
147
148async fn ensure_tracking_table(db: &Db) -> Result<()> {
149 sqlx::query(
150 "CREATE TABLE IF NOT EXISTS rustio_migrations (
151 version BIGINT PRIMARY KEY,
152 name TEXT NOT NULL,
153 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
154 )",
155 )
156 .execute(db.pool())
157 .await?;
158 Ok(())
159}
160
161fn split_statements(sql: &str) -> Vec<String> {
165 let mut out = Vec::new();
166 let mut current = String::new();
167 let mut chars = sql.chars().peekable();
168 let mut in_string = false;
169 let mut in_dollar = false;
170 let mut dollar_tag = String::new();
171 let mut in_line_comment = false;
172 let mut in_block_comment = false;
173
174 while let Some(c) = chars.next() {
175 if in_line_comment {
176 current.push(c);
177 if c == '\n' {
178 in_line_comment = false;
179 }
180 continue;
181 }
182 if in_block_comment {
183 current.push(c);
184 if c == '*' && chars.peek() == Some(&'/') {
185 current.push(chars.next().unwrap());
186 in_block_comment = false;
187 }
188 continue;
189 }
190 if in_dollar {
191 current.push(c);
192 if c == '$' {
193 let rest: String = chars.clone().take(dollar_tag.len()).collect();
194 if rest == dollar_tag {
195 for _ in 0..dollar_tag.len() {
196 current.push(chars.next().unwrap());
197 }
198 in_dollar = false;
199 dollar_tag.clear();
200 }
201 }
202 continue;
203 }
204 if in_string {
205 current.push(c);
206 if c == '\'' {
207 if chars.peek() == Some(&'\'') {
208 current.push(chars.next().unwrap());
209 } else {
210 in_string = false;
211 }
212 }
213 continue;
214 }
215
216 match c {
217 '\'' => {
218 in_string = true;
219 current.push(c);
220 }
221 '-' if chars.peek() == Some(&'-') => {
222 in_line_comment = true;
223 current.push(c);
224 }
225 '/' if chars.peek() == Some(&'*') => {
226 in_block_comment = true;
227 current.push(c);
228 }
229 '$' => {
230 let mut tag = String::from("$");
231 let mut clone = chars.clone();
232 while let Some(&nc) = clone.peek() {
233 if nc == '$' {
234 tag.push('$');
235 break;
236 }
237 if nc.is_alphanumeric() || nc == '_' {
238 tag.push(nc);
239 clone.next();
240 } else {
241 break;
242 }
243 }
244 if tag.ends_with('$') && tag.len() >= 2 {
245 for _ in 1..tag.len() {
246 current.push(chars.next().unwrap());
247 }
248 current.insert(current.len() - tag.len() + 1, '$');
249 current.push('$');
250 dollar_tag = tag;
251 in_dollar = true;
252 } else {
253 current.push(c);
254 }
255 }
256 ';' => {
257 out.push(std::mem::take(&mut current));
258 }
259 other => current.push(other),
260 }
261 }
262
263 if !current.trim().is_empty() {
264 out.push(current);
265 }
266 out
267}
268
269fn slugify(name: &str) -> String {
270 name.chars()
271 .map(|c| {
272 if c.is_alphanumeric() {
273 c.to_ascii_lowercase()
274 } else {
275 '_'
276 }
277 })
278 .collect()
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn split_ignores_semicolon_in_string() {
287 let sql = "INSERT INTO t VALUES ('a;b'); SELECT 1;";
288 let parts = split_statements(sql);
289 assert_eq!(parts.len(), 2);
290 }
291
292 #[test]
293 fn split_ignores_line_comments() {
294 let sql = "SELECT 1; -- comment with ;\nSELECT 2;";
295 let parts = split_statements(sql);
296 assert_eq!(parts.len(), 2);
297 }
298
299 #[test]
300 fn slugify_lowercases_and_replaces() {
301 assert_eq!(slugify("Add Users Table!"), "add_users_table_");
302 }
303}