1use crate::error::{Error, Kind};
2use crate::runner::Type;
3use crate::Migration;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10#[cfg(not(feature = "int8-versions"))]
11pub type SchemaVersion = i32;
12#[cfg(feature = "int8-versions")]
13pub type SchemaVersion = i64;
14
15const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
16
17fn file_stem_re() -> &'static Regex {
19 static RE: OnceLock<Regex> = OnceLock::new();
20 RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
21}
22
23fn file_re_sql() -> &'static Regex {
25 static RE: OnceLock<Regex> = OnceLock::new();
26 RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
27}
28
29fn file_re_all() -> &'static Regex {
31 static RE: OnceLock<Regex> = OnceLock::new();
32 RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
33}
34
35pub enum MigrationType {
38 All,
39 Sql,
40}
41
42impl MigrationType {
43 fn file_match_re(&self) -> &'static Regex {
44 match self {
45 MigrationType::All => file_re_all(),
46 MigrationType::Sql => file_re_sql(),
47 }
48 }
49}
50
51pub fn parse_migration_name(name: &str) -> Result<(Type, SchemaVersion, String), Error> {
53 let captures = file_stem_re()
54 .captures(name)
55 .filter(|caps| caps.len() == 4)
56 .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
57 let version: SchemaVersion = captures[2]
58 .parse()
59 .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
60
61 let name: String = (&captures[3]).into();
62 let prefix = match &captures[1] {
63 "V" => Type::Versioned,
64 "U" => Type::Unversioned,
65 _ => unreachable!(),
66 };
67
68 Ok((prefix, version, name))
69}
70
71pub fn find_migration_files(
73 location: impl AsRef<Path>,
74 migration_type: MigrationType,
75) -> Result<impl Iterator<Item = PathBuf>, Error> {
76 let location: &Path = location.as_ref();
77 let location = location.canonicalize().map_err(|err| {
78 Error::new(
79 Kind::InvalidMigrationPath(location.to_path_buf(), err),
80 None,
81 )
82 })?;
83
84 let re = migration_type.file_match_re();
85 let file_paths = WalkDir::new(location)
86 .into_iter()
87 .filter_map(Result::ok)
88 .map(DirEntry::into_path)
89 .filter(move |entry| {
91 if entry.is_dir() {
92 return false;
93 }
94
95 match entry.file_name().and_then(OsStr::to_str) {
96 Some(file_name) if re.is_match(file_name) => true,
97 Some(file_name) => {
98 log::warn!(
99 "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
100 file_name
101 );
102 false
103 }
104 None => false,
105 }
106 });
107
108 Ok(file_paths)
109}
110
111pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
114 let migration_files = find_migration_files(location, MigrationType::Sql)?;
115
116 let mut migrations = vec![];
117
118 for path in migration_files {
119 let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
120 let path = path.to_owned();
121 let kind = match e.kind() {
122 std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
123 _ => Kind::InvalidMigrationFile(path, e),
124 };
125
126 Error::new(kind, None)
127 })?;
128
129 let filename = path
131 .file_stem()
132 .and_then(|file| file.to_os_string().into_string().ok())
133 .unwrap();
134
135 let migration = Migration::unapplied(&filename, &sql)?;
136 migrations.push(migration);
137 }
138
139 migrations.sort();
140 Ok(migrations)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::{find_migration_files, load_sql_migrations, MigrationType};
146 use std::fs;
147 use std::path::PathBuf;
148 use tempfile::TempDir;
149
150 #[test]
151 fn finds_mod_migrations() {
152 let tmp_dir = TempDir::new().unwrap();
153 let migrations_dir = tmp_dir.path().join("migrations");
154 fs::create_dir(&migrations_dir).unwrap();
155 let sql1 = migrations_dir.join("V1__first.rs");
156 fs::File::create(&sql1).unwrap();
157 let sql2 = migrations_dir.join("V2__second.rs");
158 fs::File::create(&sql2).unwrap();
159
160 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
161 .unwrap()
162 .collect();
163 mods.sort();
164 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
165 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
166 }
167
168 #[test]
169 fn ignores_mod_files_without_migration_regex_match() {
170 let tmp_dir = TempDir::new().unwrap();
171 let migrations_dir = tmp_dir.path().join("migrations");
172 fs::create_dir(&migrations_dir).unwrap();
173 let sql1 = migrations_dir.join("V1first.rs");
174 fs::File::create(sql1).unwrap();
175 let sql2 = migrations_dir.join("V2second.rs");
176 fs::File::create(sql2).unwrap();
177
178 let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
179 assert!(mods.next().is_none());
180 }
181
182 #[test]
183 fn finds_sql_migrations() {
184 let tmp_dir = TempDir::new().unwrap();
185 let migrations_dir = tmp_dir.path().join("migrations");
186 fs::create_dir(&migrations_dir).unwrap();
187 let sql1 = migrations_dir.join("V1__first.sql");
188 fs::File::create(&sql1).unwrap();
189 let sql2 = migrations_dir.join("V2__second.sql");
190 fs::File::create(&sql2).unwrap();
191
192 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
193 .unwrap()
194 .collect();
195 mods.sort();
196 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
197 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
198 }
199
200 #[test]
201 fn finds_sql_migrations_in_nested_directories() {
202 let tmp_dir = TempDir::new().unwrap();
203 let migrations_dir = tmp_dir.path().join("migrations");
204 let nested_dir = migrations_dir.join("foo");
205 fs::create_dir(&migrations_dir).unwrap();
206 fs::create_dir(&nested_dir).unwrap();
207
208 let sql1 = nested_dir.join("V1__first.sql");
209 let sql2 = nested_dir.join("V2__second.sql");
210 fs::File::create(&sql1).unwrap();
211 fs::File::create(&sql2).unwrap();
212
213 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
214 .unwrap()
215 .collect();
216 mods.sort();
217
218 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
219 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
220 }
221
222 #[test]
223 fn finds_unversioned_migrations() {
224 let tmp_dir = TempDir::new().unwrap();
225 let migrations_dir = tmp_dir.path().join("migrations");
226 fs::create_dir(&migrations_dir).unwrap();
227 let sql1 = migrations_dir.join("U1__first.sql");
228 fs::File::create(&sql1).unwrap();
229 let sql2 = migrations_dir.join("U2__second.sql");
230 fs::File::create(&sql2).unwrap();
231
232 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
233 .unwrap()
234 .collect();
235 mods.sort();
236 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
237 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
238 }
239
240 #[test]
241 fn ignores_sql_files_without_migration_regex_match() {
242 let tmp_dir = TempDir::new().unwrap();
243 let migrations_dir = tmp_dir.path().join("migrations");
244 fs::create_dir(&migrations_dir).unwrap();
245 let sql1 = migrations_dir.join("V1first.sql");
246 fs::File::create(sql1).unwrap();
247 let sql2 = migrations_dir.join("V2second.sql");
248 fs::File::create(sql2).unwrap();
249
250 let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
251 assert!(mods.next().is_none());
252 }
253
254 #[test]
255 fn loads_migrations_from_path() {
256 let tmp_dir = TempDir::new().unwrap();
257 let migrations_dir = tmp_dir.path().join("migrations");
258 fs::create_dir(&migrations_dir).unwrap();
259 let sql1 = migrations_dir.join("V1__first.sql");
260 fs::File::create(&sql1).unwrap();
261 let sql2 = migrations_dir.join("V2__second.sql");
262 fs::File::create(&sql2).unwrap();
263 let rs3 = migrations_dir.join("V3__third.rs");
264 fs::File::create(&rs3).unwrap();
265
266 let migrations = load_sql_migrations(migrations_dir).unwrap();
267 assert_eq!(migrations.len(), 2);
268 assert_eq!(&migrations[0].to_string(), "V1__first");
269 assert_eq!(&migrations[1].to_string(), "V2__second");
270 }
271}