Skip to main content

refinery_core/
util.rs

1use crate::error::{Error, Kind};
2use crate::runner::Type;
3use crate::Migration;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10#[cfg(not(feature = "int8-versions"))]
11pub type SchemaVersion = i32;
12#[cfg(feature = "int8-versions")]
13pub type SchemaVersion = i64;
14
15const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
16
17/// Matches the stem of a migration file.
18fn file_stem_re() -> &'static Regex {
19    static RE: OnceLock<Regex> = OnceLock::new();
20    RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
21}
22
23/// Matches the stem + extension of a SQL migration file.
24fn file_re_sql() -> &'static Regex {
25    static RE: OnceLock<Regex> = OnceLock::new();
26    RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
27}
28
29/// Matches the stem + extension of any migration file.
30fn file_re_all() -> &'static Regex {
31    static RE: OnceLock<Regex> = OnceLock::new();
32    RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
33}
34
35/// enum containing the migration types used to search for migrations
36/// either just .sql files or both .sql and .rs
37pub enum MigrationType {
38    All,
39    Sql,
40}
41
42impl MigrationType {
43    fn file_match_re(&self) -> &'static Regex {
44        match self {
45            MigrationType::All => file_re_all(),
46            MigrationType::Sql => file_re_sql(),
47        }
48    }
49}
50
51/// Parse a migration filename stem into a prefix, version, and name.
52pub fn parse_migration_name(name: &str) -> Result<(Type, SchemaVersion, String), Error> {
53    let captures = file_stem_re()
54        .captures(name)
55        .filter(|caps| caps.len() == 4)
56        .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
57    let version: SchemaVersion = captures[2]
58        .parse()
59        .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
60
61    let name: String = (&captures[3]).into();
62    let prefix = match &captures[1] {
63        "V" => Type::Versioned,
64        "U" => Type::Unversioned,
65        _ => unreachable!(),
66    };
67
68    Ok((prefix, version, name))
69}
70
71/// find migrations on file system recursively across directories given a location and [MigrationType]
72pub fn find_migration_files(
73    location: impl AsRef<Path>,
74    migration_type: MigrationType,
75) -> Result<impl Iterator<Item = PathBuf>, Error> {
76    let location: &Path = location.as_ref();
77    let location = location.canonicalize().map_err(|err| {
78        Error::new(
79            Kind::InvalidMigrationPath(location.to_path_buf(), err),
80            None,
81        )
82    })?;
83
84    let re = migration_type.file_match_re();
85    let file_paths = WalkDir::new(location)
86        .into_iter()
87        .filter_map(Result::ok)
88        .map(DirEntry::into_path)
89        // filter by migration file regex
90        .filter(move |entry| {
91            if entry.is_dir() {
92                return false;
93            }
94
95            match entry.file_name().and_then(OsStr::to_str) {
96                Some(file_name) if re.is_match(file_name) => true,
97                Some(file_name) => {
98                    log::warn!(
99                        "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
100                        file_name
101                    );
102                    false
103                }
104                None => false,
105            }
106        });
107
108    Ok(file_paths)
109}
110
111/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to
112/// embedding. The resulting collection is ordered by version.
113pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
114    let migration_files = find_migration_files(location, MigrationType::Sql)?;
115
116    let mut migrations = vec![];
117
118    for path in migration_files {
119        let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
120            let path = path.to_owned();
121            let kind = match e.kind() {
122                std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
123                _ => Kind::InvalidMigrationFile(path, e),
124            };
125
126            Error::new(kind, None)
127        })?;
128
129        //safe to call unwrap as find_migration_filenames returns canonical paths
130        let filename = path
131            .file_stem()
132            .and_then(|file| file.to_os_string().into_string().ok())
133            .unwrap();
134
135        let migration = Migration::unapplied(&filename, &sql)?;
136        migrations.push(migration);
137    }
138
139    migrations.sort();
140    Ok(migrations)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::{find_migration_files, load_sql_migrations, MigrationType};
146    use std::fs;
147    use std::path::PathBuf;
148    use tempfile::TempDir;
149
150    #[test]
151    fn finds_mod_migrations() {
152        let tmp_dir = TempDir::new().unwrap();
153        let migrations_dir = tmp_dir.path().join("migrations");
154        fs::create_dir(&migrations_dir).unwrap();
155        let sql1 = migrations_dir.join("V1__first.rs");
156        fs::File::create(&sql1).unwrap();
157        let sql2 = migrations_dir.join("V2__second.rs");
158        fs::File::create(&sql2).unwrap();
159
160        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
161            .unwrap()
162            .collect();
163        mods.sort();
164        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
165        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
166    }
167
168    #[test]
169    fn ignores_mod_files_without_migration_regex_match() {
170        let tmp_dir = TempDir::new().unwrap();
171        let migrations_dir = tmp_dir.path().join("migrations");
172        fs::create_dir(&migrations_dir).unwrap();
173        let sql1 = migrations_dir.join("V1first.rs");
174        fs::File::create(sql1).unwrap();
175        let sql2 = migrations_dir.join("V2second.rs");
176        fs::File::create(sql2).unwrap();
177
178        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
179        assert!(mods.next().is_none());
180    }
181
182    #[test]
183    fn finds_sql_migrations() {
184        let tmp_dir = TempDir::new().unwrap();
185        let migrations_dir = tmp_dir.path().join("migrations");
186        fs::create_dir(&migrations_dir).unwrap();
187        let sql1 = migrations_dir.join("V1__first.sql");
188        fs::File::create(&sql1).unwrap();
189        let sql2 = migrations_dir.join("V2__second.sql");
190        fs::File::create(&sql2).unwrap();
191
192        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
193            .unwrap()
194            .collect();
195        mods.sort();
196        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
197        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
198    }
199
200    #[test]
201    fn finds_sql_migrations_in_nested_directories() {
202        let tmp_dir = TempDir::new().unwrap();
203        let migrations_dir = tmp_dir.path().join("migrations");
204        let nested_dir = migrations_dir.join("foo");
205        fs::create_dir(&migrations_dir).unwrap();
206        fs::create_dir(&nested_dir).unwrap();
207
208        let sql1 = nested_dir.join("V1__first.sql");
209        let sql2 = nested_dir.join("V2__second.sql");
210        fs::File::create(&sql1).unwrap();
211        fs::File::create(&sql2).unwrap();
212
213        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
214            .unwrap()
215            .collect();
216        mods.sort();
217
218        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
219        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
220    }
221
222    #[test]
223    fn finds_unversioned_migrations() {
224        let tmp_dir = TempDir::new().unwrap();
225        let migrations_dir = tmp_dir.path().join("migrations");
226        fs::create_dir(&migrations_dir).unwrap();
227        let sql1 = migrations_dir.join("U1__first.sql");
228        fs::File::create(&sql1).unwrap();
229        let sql2 = migrations_dir.join("U2__second.sql");
230        fs::File::create(&sql2).unwrap();
231
232        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
233            .unwrap()
234            .collect();
235        mods.sort();
236        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
237        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
238    }
239
240    #[test]
241    fn ignores_sql_files_without_migration_regex_match() {
242        let tmp_dir = TempDir::new().unwrap();
243        let migrations_dir = tmp_dir.path().join("migrations");
244        fs::create_dir(&migrations_dir).unwrap();
245        let sql1 = migrations_dir.join("V1first.sql");
246        fs::File::create(sql1).unwrap();
247        let sql2 = migrations_dir.join("V2second.sql");
248        fs::File::create(sql2).unwrap();
249
250        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
251        assert!(mods.next().is_none());
252    }
253
254    #[test]
255    fn loads_migrations_from_path() {
256        let tmp_dir = TempDir::new().unwrap();
257        let migrations_dir = tmp_dir.path().join("migrations");
258        fs::create_dir(&migrations_dir).unwrap();
259        let sql1 = migrations_dir.join("V1__first.sql");
260        fs::File::create(&sql1).unwrap();
261        let sql2 = migrations_dir.join("V2__second.sql");
262        fs::File::create(&sql2).unwrap();
263        let rs3 = migrations_dir.join("V3__third.rs");
264        fs::File::create(&rs3).unwrap();
265
266        let migrations = load_sql_migrations(migrations_dir).unwrap();
267        assert_eq!(migrations.len(), 2);
268        assert_eq!(&migrations[0].to_string(), "V1__first");
269        assert_eq!(&migrations[1].to_string(), "V2__second");
270    }
271}