1use crate::error::{Error, Kind};
2use crate::runner::Type;
3use crate::Migration;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
11
12fn file_stem_re() -> &'static Regex {
14 static RE: OnceLock<Regex> = OnceLock::new();
15 RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
16}
17
18fn file_re_sql() -> &'static Regex {
20 static RE: OnceLock<Regex> = OnceLock::new();
21 RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
22}
23
24fn file_re_all() -> &'static Regex {
26 static RE: OnceLock<Regex> = OnceLock::new();
27 RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
28}
29
30pub enum MigrationType {
33 All,
34 Sql,
35}
36
37impl MigrationType {
38 fn file_match_re(&self) -> &'static Regex {
39 match self {
40 MigrationType::All => file_re_all(),
41 MigrationType::Sql => file_re_sql(),
42 }
43 }
44}
45
46pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
48 let captures = file_stem_re()
49 .captures(name)
50 .filter(|caps| caps.len() == 4)
51 .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
52 let version: i32 = captures[2]
53 .parse()
54 .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
55
56 let name: String = (&captures[3]).into();
57 let prefix = match &captures[1] {
58 "V" => Type::Versioned,
59 "U" => Type::Unversioned,
60 _ => unreachable!(),
61 };
62
63 Ok((prefix, version, name))
64}
65
66pub fn find_migration_files(
68 location: impl AsRef<Path>,
69 migration_type: MigrationType,
70) -> Result<impl Iterator<Item = PathBuf>, Error> {
71 let location: &Path = location.as_ref();
72 let location = location.canonicalize().map_err(|err| {
73 Error::new(
74 Kind::InvalidMigrationPath(location.to_path_buf(), err),
75 None,
76 )
77 })?;
78
79 let re = migration_type.file_match_re();
80 let file_paths = WalkDir::new(location)
81 .into_iter()
82 .filter_map(Result::ok)
83 .map(DirEntry::into_path)
84 .filter(
86 move |entry| match entry.file_name().and_then(OsStr::to_str) {
87 Some(file_name) if re.is_match(file_name) => true,
88 Some(file_name) => {
89 log::warn!(
90 "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
91 file_name
92 );
93 false
94 }
95 None => false,
96 },
97 );
98
99 Ok(file_paths)
100}
101
102pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
105 let migration_files = find_migration_files(location, MigrationType::Sql)?;
106
107 let mut migrations = vec![];
108
109 for path in migration_files {
110 let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
111 let path = path.to_owned();
112 let kind = match e.kind() {
113 std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
114 _ => Kind::InvalidMigrationFile(path, e),
115 };
116
117 Error::new(kind, None)
118 })?;
119
120 let filename = path
122 .file_stem()
123 .and_then(|file| file.to_os_string().into_string().ok())
124 .unwrap();
125
126 let migration = Migration::unapplied(&filename, &sql)?;
127 migrations.push(migration);
128 }
129
130 migrations.sort();
131 Ok(migrations)
132}
133
134#[cfg(test)]
135mod tests {
136 use super::{find_migration_files, load_sql_migrations, MigrationType};
137 use std::fs;
138 use std::path::PathBuf;
139 use tempfile::TempDir;
140
141 #[test]
142 fn finds_mod_migrations() {
143 let tmp_dir = TempDir::new().unwrap();
144 let migrations_dir = tmp_dir.path().join("migrations");
145 fs::create_dir(&migrations_dir).unwrap();
146 let sql1 = migrations_dir.join("V1__first.rs");
147 fs::File::create(&sql1).unwrap();
148 let sql2 = migrations_dir.join("V2__second.rs");
149 fs::File::create(&sql2).unwrap();
150
151 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
152 .unwrap()
153 .collect();
154 mods.sort();
155 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
156 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
157 }
158
159 #[test]
160 fn ignores_mod_files_without_migration_regex_match() {
161 let tmp_dir = TempDir::new().unwrap();
162 let migrations_dir = tmp_dir.path().join("migrations");
163 fs::create_dir(&migrations_dir).unwrap();
164 let sql1 = migrations_dir.join("V1first.rs");
165 fs::File::create(sql1).unwrap();
166 let sql2 = migrations_dir.join("V2second.rs");
167 fs::File::create(sql2).unwrap();
168
169 let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
170 assert!(mods.next().is_none());
171 }
172
173 #[test]
174 fn finds_sql_migrations() {
175 let tmp_dir = TempDir::new().unwrap();
176 let migrations_dir = tmp_dir.path().join("migrations");
177 fs::create_dir(&migrations_dir).unwrap();
178 let sql1 = migrations_dir.join("V1__first.sql");
179 fs::File::create(&sql1).unwrap();
180 let sql2 = migrations_dir.join("V2__second.sql");
181 fs::File::create(&sql2).unwrap();
182
183 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
184 .unwrap()
185 .collect();
186 mods.sort();
187 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
188 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
189 }
190
191 #[test]
192 fn finds_unversioned_migrations() {
193 let tmp_dir = TempDir::new().unwrap();
194 let migrations_dir = tmp_dir.path().join("migrations");
195 fs::create_dir(&migrations_dir).unwrap();
196 let sql1 = migrations_dir.join("U1__first.sql");
197 fs::File::create(&sql1).unwrap();
198 let sql2 = migrations_dir.join("U2__second.sql");
199 fs::File::create(&sql2).unwrap();
200
201 let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
202 .unwrap()
203 .collect();
204 mods.sort();
205 assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
206 assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
207 }
208
209 #[test]
210 fn ignores_sql_files_without_migration_regex_match() {
211 let tmp_dir = TempDir::new().unwrap();
212 let migrations_dir = tmp_dir.path().join("migrations");
213 fs::create_dir(&migrations_dir).unwrap();
214 let sql1 = migrations_dir.join("V1first.sql");
215 fs::File::create(sql1).unwrap();
216 let sql2 = migrations_dir.join("V2second.sql");
217 fs::File::create(sql2).unwrap();
218
219 let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
220 assert!(mods.next().is_none());
221 }
222
223 #[test]
224 fn loads_migrations_from_path() {
225 let tmp_dir = TempDir::new().unwrap();
226 let migrations_dir = tmp_dir.path().join("migrations");
227 fs::create_dir(&migrations_dir).unwrap();
228 let sql1 = migrations_dir.join("V1__first.sql");
229 fs::File::create(&sql1).unwrap();
230 let sql2 = migrations_dir.join("V2__second.sql");
231 fs::File::create(&sql2).unwrap();
232 let rs3 = migrations_dir.join("V3__third.rs");
233 fs::File::create(&rs3).unwrap();
234
235 let migrations = load_sql_migrations(migrations_dir).unwrap();
236 assert_eq!(migrations.len(), 2);
237 assert_eq!(&migrations[0].to_string(), "V1__first");
238 assert_eq!(&migrations[1].to_string(), "V2__second");
239 }
240}