flyway_codegen/
lib.rs

1use proc_macro::TokenStream;
2use std::env;
3use std::num::ParseIntError;
4use std::path::PathBuf;
5use std::str::FromStr;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{LitStr};
9use syn::__private::TokenStream2;
10
11use flyway_sql_changelog::ChangelogFile;
12
13/// Represents migration files loaded from a directory
14#[derive(Debug, Clone)]
15struct MigrationInfo {
16    version: u64,
17    filename: String,
18    name: String,
19}
20
21/// Attribute macro for automatically generating a `flyway::MigrationStore`
22///
23/// The macro takes one required literal string parameter representing the directory containing
24/// the migration files. Each file must be named like `V<version>_<name>.sql`, where `<version>`
25/// is a valid integer and `<name>` is some name describing what the migration does.
26///
27/// Example:
28/// ```ignore
29/// use flyway_codegen::migrations;
30///
31/// #[migrations("examples/migrations/")]
32/// struct Migrations {}
33///
34/// pub fn main() {
35///     let migration_store = Migrations {};
36///     println!("migrations: {:?}", migration_store.changelogs());
37/// }
38/// ```
39#[proc_macro_attribute]
40pub fn migrations(args: TokenStream, input: TokenStream) -> TokenStream {
41    // println!("metadata: {:?}", &args);
42    // println!("input:    {:?}", &input);
43
44    let input_clone = input.clone();
45    let input_struct = syn::parse_macro_input!(input_clone as syn::ItemStruct);
46    // println!("input struct: {:?}", &input_struct);
47
48    let path = if args.is_empty() {
49        map_to_crate_root(None)
50    } else {
51        let migrations_path = syn::parse_macro_input!(args as LitStr).value();
52        map_to_crate_root(Some(migrations_path.as_str()))
53    };
54
55    #[cfg(feature = "debug_mode")]
56    if cfg!(debug_assertions){
57        println!("migrations path: {:?}", path);
58    }
59
60
61    let migrations = get_migrations(&path)
62        .expect("Error while gathering migration file information.");
63    #[cfg(feature = "debug_mode")]
64    if cfg!(debug_assertions){
65        println!("migrations: {:?}", &migrations);
66    }
67
68    let migration_tokens: Vec<TokenStream2> = migrations.iter()
69        .map(|migration| {
70            let name = migration.name.as_str();
71            let version = migration.version;
72            let filename = migration.filename.as_str();
73            let file_path = path.clone().join(filename).display().to_string();
74            let content = std::fs::read_to_string(file_path.as_str())
75                .expect(format!("Could not read migration file: {}", file_path).as_str());
76
77            // just check if the changelog can be loaded correctly:
78            let _changelog = ChangelogFile::from_string(version, name,content.as_str())
79                .expect(format!("Migration file is not a valid SQL changelog file: {}", file_path).as_str());
80
81            quote! {
82                (#version, #name.to_string(), #content)
83            }
84        })
85        .collect();
86
87    let struct_name = syn::Ident::new(input_struct.ident.to_string().as_str(), Span::call_site());
88    // println!("struct_name: {}", &struct_name);
89    let result = quote! {
90        impl flyway::MigrationStore for #struct_name {
91            fn changelogs(&self) -> Vec<flyway::ChangelogFile> {
92                use flyway::ChangelogFile;
93
94                let mut result: Vec<ChangelogFile> = [#(#migration_tokens),*].iter()
95                .map(|migration| {
96                    ChangelogFile::from_string(migration.0,migration.1.to_string().as_str(), migration.2).unwrap()
97                })
98                .collect();
99                return result;
100            }
101        }
102    };
103    // println!("result: {}", result.to_string());
104
105    let input: TokenStream2 = input.into();
106    return quote! {
107        #input
108        #result
109    }.into();
110}
111
112/// Map a path to the root of the crate
113fn map_to_crate_root(path: Option<&str>) -> PathBuf {
114    let root = env::var("CARGO_MANIFEST_DIR")
115        .map(|root| PathBuf::from(root))
116        .expect("Missing CARGO_MANIFEST_DIR environment variable. Cannot obtain crate root.");
117    let result = path.map(|path| root.join(PathBuf::from_str(path)
118        .expect("Could not parse filename.")))
119        .or(Some(root))
120        .unwrap();
121    return result;
122}
123
124/// List migrations contained inside a directory
125fn get_migrations(path: &PathBuf) -> Result<Vec<MigrationInfo>, std::io::Error> {
126    let result: Vec<MigrationInfo> = std::fs::read_dir(path)?
127        .filter(|entry| entry.is_ok())
128        .map(|entry| entry.unwrap().file_name().to_str().map(|v| v.to_string()))
129        .filter(|filename| filename.is_some())
130        .map(|filename| filename.unwrap())
131        .filter(|filename| filename.starts_with("V") && filename.ends_with(".sql"))
132        .map(|filename| {
133            let index = filename.find("_");
134            let mut version = "";
135            let mut name = "";
136            if let Some(index) = index {
137                if index > 1 && index < filename.len() - "V.sql".len() {
138                    if filename[1..index].chars().all(|ch| ch >= '0' && ch <= '9') {
139                        version = &filename[1..index];
140                        name = &filename[(index + 1)..(filename.len() - ".sql".len())];
141                    }
142                }
143            }
144
145            return if version.is_empty() {
146                None
147            } else {
148                let result: Result<Option<u64>, ParseIntError> = version.parse::<u64>()
149                    .map(|version| Some(version))
150                    .or(Ok(None));
151
152                let result = result.unwrap()
153                    .map(|version| {
154                        MigrationInfo {
155                            version,
156                            filename: filename.to_string(),
157                            name: name.to_string()
158                        }
159                    });
160                return result
161            };
162        })
163        .filter(|info| info.is_some())
164        .map(|info| info.unwrap())
165        .collect();
166    return Ok(result);
167}
168
169#[cfg(test)]
170mod test {
171    #[test]
172    pub fn test_get_migrations() {
173        let path = crate::map_to_crate_root(Some("examples/migrations"));
174        let result = crate::get_migrations(&path);
175        match result {
176            Ok(migrations) => {
177                assert_eq!(migrations.len(), 2, "Two migrations have been successfully loaded.");
178            }
179            Err(err) => {
180                assert!(false, "Migration loading failed: {}", err);
181            }
182        }
183    }
184}