db_up_codegen/
lib.rs

1use proc_macro::TokenStream;
2use std::env;
3use std::num::ParseIntError;
4use std::path::PathBuf;
5use std::str::FromStr;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{LitStr};
9use syn::__private::TokenStream2;
10
11use db_up_sql_changelog::ChangelogFile;
12
13/// Represents migration files loaded from a directory
14#[derive(Debug, Clone)]
15struct MigrationInfo {
16    version: u32,
17    filename: String,
18    name: String,
19}
20
21/// Attribute macro for automatically generating a `db_up::MigrationStore`
22///
23/// The macro takes one required literal string parameter representing the directory containing
24/// the migration files. Each file must be named like `V<version>_<name>.sql`, where `<version>`
25/// is a valid integer and `<name>` is some name describing what the migration does.
26///
27/// Example:
28/// ```ignore
29/// use db_up_codegen::migrations;
30///
31/// #[migrations("examples/migrations/")]
32/// struct Migrations {}
33///
34/// pub fn main() {
35///     let migration_store = Migrations {};
36///     println!("migrations: {:?}", migration_store.changelogs());
37/// }
38/// ```
39#[proc_macro_attribute]
40pub fn migrations(args: TokenStream, input: TokenStream) -> TokenStream {
41    // println!("metadata: {:?}", &args);
42    // println!("input:    {:?}", &input);
43
44    let input_clone = input.clone();
45    let input_struct = syn::parse_macro_input!(input_clone as syn::ItemStruct);
46    // println!("input struct: {:?}", &input_struct);
47
48    let path = if args.is_empty() {
49        map_to_crate_root(None)
50    } else {
51        let migrations_path = syn::parse_macro_input!(args as LitStr).value();
52        map_to_crate_root(Some(migrations_path.as_str()))
53    };
54    println!("migrations path: {:?}", path);
55
56    let migrations = get_migrations(&path)
57        .expect("Error while gathering migration file information.");
58    println!("migrations: {:?}", &migrations);
59
60    let migration_tokens: Vec<TokenStream2> = migrations.iter()
61        .map(|migration| {
62            let name = migration.name.as_str();
63            let version = migration.version;
64            let filename = migration.filename.as_str();
65            let file_path = path.clone().join(filename).display().to_string();
66            let content = std::fs::read_to_string(file_path.as_str())
67                .expect(format!("Could not read migration file: {}", file_path).as_str());
68
69            // just check if the changelog can be loaded correctly:
70            let _changelog = ChangelogFile::from_string(version.to_string().as_str(), content.as_str())
71                .expect(format!("Migration file is not a valid SQL changelog file: {}", file_path).as_str());
72
73            quote! {
74                (#version, #name.to_string(), #content)
75            }
76        })
77        .collect();
78
79    let struct_name = syn::Ident::new(input_struct.ident.to_string().as_str(), Span::call_site());
80    // println!("struct_name: {}", &struct_name);
81    let result = quote! {
82        impl db_up::MigrationStore for #struct_name {
83            fn changelogs(&self) -> Vec<db_up::ChangelogFile> {
84                use db_up::ChangelogFile;
85
86                let mut result: Vec<ChangelogFile> = [#(#migration_tokens),*].iter()
87                .map(|migration| {
88                    ChangelogFile::from_string(migration.0.to_string().as_str(), migration.2).unwrap()
89                })
90                .collect();
91                return result;
92            }
93        }
94    };
95    // println!("result: {}", result.to_string());
96
97    let input: TokenStream2 = input.into();
98    return quote! {
99        #input
100        #result
101    }.into();
102}
103
104/// Map a path to the root of the crate
105fn map_to_crate_root(path: Option<&str>) -> PathBuf {
106    let root = env::var("CARGO_MANIFEST_DIR")
107        .map(|root| PathBuf::from(root))
108        .expect("Missing CARGO_MANIFEST_DIR environment variable. Cannot obtain crate root.");
109    let result = path.map(|path| root.join(PathBuf::from_str(path)
110        .expect("Could not parse filename.")))
111        .or(Some(root))
112        .unwrap();
113    return result;
114}
115
116/// List migrations contained inside a directory
117fn get_migrations(path: &PathBuf) -> Result<Vec<MigrationInfo>, std::io::Error> {
118    let result: Vec<MigrationInfo> = std::fs::read_dir(path)?
119        .filter(|entry| entry.is_ok())
120        .map(|entry| entry.unwrap().file_name().to_str().map(|v| v.to_string()))
121        .filter(|filename| filename.is_some())
122        .map(|filename| filename.unwrap())
123        .filter(|filename| filename.starts_with("V") && filename.ends_with(".sql"))
124        .map(|filename| {
125            let index = filename.find("_");
126            let mut version = "";
127            let mut name = "";
128            if let Some(index) = index {
129                if index > 1 && index < filename.len() - "V.sql".len() {
130                    if filename[1..index].chars().all(|ch| ch >= '0' && ch <= '9') {
131                        version = &filename[1..index];
132                        name = &filename[(index + 1)..(filename.len() - ".sql".len())];
133                    }
134                }
135            }
136
137            return if version.is_empty() {
138                None
139            } else {
140                let result: Result<Option<u32>, ParseIntError> = version.parse::<u32>()
141                    .map(|version| Some(version))
142                    .or(Ok(None));
143
144                let result = result.unwrap()
145                    .map(|version| {
146                        MigrationInfo {
147                            version,
148                            filename: filename.to_string(),
149                            name: name.to_string()
150                        }
151                    });
152                return result
153            };
154        })
155        .filter(|info| info.is_some())
156        .map(|info| info.unwrap())
157        .collect();
158    return Ok(result);
159}
160
161#[cfg(test)]
162mod test {
163    #[test]
164    pub fn test_get_migrations() {
165        let path = crate::map_to_crate_root(Some("examples/migrations"));
166        let result = crate::get_migrations(&path);
167        match result {
168            Ok(migrations) => {
169                assert_eq!(migrations.len(), 2, "Two migrations have been successfully loaded.");
170            }
171            Err(err) => {
172                assert!(false, "Migration loading failed: {}", err);
173            }
174        }
175    }
176}