refinery_core/
util.rs

1use crate::error::{Error, Kind};
2use crate::runner::Type;
3use crate::Migration;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
11
12/// Matches the stem of a migration file.
13fn file_stem_re() -> &'static Regex {
14    static RE: OnceLock<Regex> = OnceLock::new();
15    RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
16}
17
18/// Matches the stem + extension of a SQL migration file.
19fn file_re_sql() -> &'static Regex {
20    static RE: OnceLock<Regex> = OnceLock::new();
21    RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
22}
23
24/// Matches the stem + extension of any migration file.
25fn file_re_all() -> &'static Regex {
26    static RE: OnceLock<Regex> = OnceLock::new();
27    RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
28}
29
30/// enum containing the migration types used to search for migrations
31/// either just .sql files or both .sql and .rs
32pub enum MigrationType {
33    All,
34    Sql,
35}
36
37impl MigrationType {
38    fn file_match_re(&self) -> &'static Regex {
39        match self {
40            MigrationType::All => file_re_all(),
41            MigrationType::Sql => file_re_sql(),
42        }
43    }
44}
45
46/// Parse a migration filename stem into a prefix, version, and name.
47pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
48    let captures = file_stem_re()
49        .captures(name)
50        .filter(|caps| caps.len() == 4)
51        .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
52    let version: i32 = captures[2]
53        .parse()
54        .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
55
56    let name: String = (&captures[3]).into();
57    let prefix = match &captures[1] {
58        "V" => Type::Versioned,
59        "U" => Type::Unversioned,
60        _ => unreachable!(),
61    };
62
63    Ok((prefix, version, name))
64}
65
66/// find migrations on file system recursively across directories given a location and [MigrationType]
67pub fn find_migration_files(
68    location: impl AsRef<Path>,
69    migration_type: MigrationType,
70) -> Result<impl Iterator<Item = PathBuf>, Error> {
71    let location: &Path = location.as_ref();
72    let location = location.canonicalize().map_err(|err| {
73        Error::new(
74            Kind::InvalidMigrationPath(location.to_path_buf(), err),
75            None,
76        )
77    })?;
78
79    let re = migration_type.file_match_re();
80    let file_paths = WalkDir::new(location)
81        .into_iter()
82        .filter_map(Result::ok)
83        .map(DirEntry::into_path)
84        // filter by migration file regex
85        .filter(
86            move |entry| match entry.file_name().and_then(OsStr::to_str) {
87                Some(file_name) if re.is_match(file_name) => true,
88                Some(file_name) => {
89                    log::warn!(
90                        "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
91                        file_name
92                    );
93                    false
94                }
95                None => false,
96            },
97        );
98
99    Ok(file_paths)
100}
101
102/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to
103/// embedding. The resulting collection is ordered by version.
104pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
105    let migration_files = find_migration_files(location, MigrationType::Sql)?;
106
107    let mut migrations = vec![];
108
109    for path in migration_files {
110        let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
111            let path = path.to_owned();
112            let kind = match e.kind() {
113                std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
114                _ => Kind::InvalidMigrationFile(path, e),
115            };
116
117            Error::new(kind, None)
118        })?;
119
120        //safe to call unwrap as find_migration_filenames returns canonical paths
121        let filename = path
122            .file_stem()
123            .and_then(|file| file.to_os_string().into_string().ok())
124            .unwrap();
125
126        let migration = Migration::unapplied(&filename, &sql)?;
127        migrations.push(migration);
128    }
129
130    migrations.sort();
131    Ok(migrations)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::{find_migration_files, load_sql_migrations, MigrationType};
137    use std::fs;
138    use std::path::PathBuf;
139    use tempfile::TempDir;
140
141    #[test]
142    fn finds_mod_migrations() {
143        let tmp_dir = TempDir::new().unwrap();
144        let migrations_dir = tmp_dir.path().join("migrations");
145        fs::create_dir(&migrations_dir).unwrap();
146        let sql1 = migrations_dir.join("V1__first.rs");
147        fs::File::create(&sql1).unwrap();
148        let sql2 = migrations_dir.join("V2__second.rs");
149        fs::File::create(&sql2).unwrap();
150
151        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
152            .unwrap()
153            .collect();
154        mods.sort();
155        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
156        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
157    }
158
159    #[test]
160    fn ignores_mod_files_without_migration_regex_match() {
161        let tmp_dir = TempDir::new().unwrap();
162        let migrations_dir = tmp_dir.path().join("migrations");
163        fs::create_dir(&migrations_dir).unwrap();
164        let sql1 = migrations_dir.join("V1first.rs");
165        fs::File::create(sql1).unwrap();
166        let sql2 = migrations_dir.join("V2second.rs");
167        fs::File::create(sql2).unwrap();
168
169        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
170        assert!(mods.next().is_none());
171    }
172
173    #[test]
174    fn finds_sql_migrations() {
175        let tmp_dir = TempDir::new().unwrap();
176        let migrations_dir = tmp_dir.path().join("migrations");
177        fs::create_dir(&migrations_dir).unwrap();
178        let sql1 = migrations_dir.join("V1__first.sql");
179        fs::File::create(&sql1).unwrap();
180        let sql2 = migrations_dir.join("V2__second.sql");
181        fs::File::create(&sql2).unwrap();
182
183        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
184            .unwrap()
185            .collect();
186        mods.sort();
187        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
188        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
189    }
190
191    #[test]
192    fn finds_unversioned_migrations() {
193        let tmp_dir = TempDir::new().unwrap();
194        let migrations_dir = tmp_dir.path().join("migrations");
195        fs::create_dir(&migrations_dir).unwrap();
196        let sql1 = migrations_dir.join("U1__first.sql");
197        fs::File::create(&sql1).unwrap();
198        let sql2 = migrations_dir.join("U2__second.sql");
199        fs::File::create(&sql2).unwrap();
200
201        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
202            .unwrap()
203            .collect();
204        mods.sort();
205        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
206        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
207    }
208
209    #[test]
210    fn ignores_sql_files_without_migration_regex_match() {
211        let tmp_dir = TempDir::new().unwrap();
212        let migrations_dir = tmp_dir.path().join("migrations");
213        fs::create_dir(&migrations_dir).unwrap();
214        let sql1 = migrations_dir.join("V1first.sql");
215        fs::File::create(sql1).unwrap();
216        let sql2 = migrations_dir.join("V2second.sql");
217        fs::File::create(sql2).unwrap();
218
219        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
220        assert!(mods.next().is_none());
221    }
222
223    #[test]
224    fn loads_migrations_from_path() {
225        let tmp_dir = TempDir::new().unwrap();
226        let migrations_dir = tmp_dir.path().join("migrations");
227        fs::create_dir(&migrations_dir).unwrap();
228        let sql1 = migrations_dir.join("V1__first.sql");
229        fs::File::create(&sql1).unwrap();
230        let sql2 = migrations_dir.join("V2__second.sql");
231        fs::File::create(&sql2).unwrap();
232        let rs3 = migrations_dir.join("V3__third.rs");
233        fs::File::create(&rs3).unwrap();
234
235        let migrations = load_sql_migrations(migrations_dir).unwrap();
236        assert_eq!(migrations.len(), 2);
237        assert_eq!(&migrations[0].to_string(), "V1__first");
238        assert_eq!(&migrations[1].to_string(), "V2__second");
239    }
240}