refinery_macros/
lib.rs

1//! Contains Refinery macros that are used to import and embed migration files.
2#![recursion_limit = "128"]
3
4#[cfg(feature = "enums")]
5use heck::ToUpperCamelCase;
6use proc_macro::TokenStream;
7use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
8use quote::quote;
9use quote::ToTokens;
10use refinery_core::{find_migration_files, MigrationType};
11use std::path::PathBuf;
12use std::{env, fs};
13use syn::{parse_macro_input, Ident, LitStr};
14
15pub(crate) fn crate_root() -> PathBuf {
16    let crate_root = env::var("CARGO_MANIFEST_DIR")
17        .expect("CARGO_MANIFEST_DIR environment variable not present");
18    PathBuf::from(crate_root)
19}
20
21fn migration_fn_quoted<T: ToTokens>(_migrations: Vec<T>) -> TokenStream2 {
22    let result = quote! {
23        use refinery::{Migration, Runner, SchemaVersion};
24        pub fn runner() -> Runner {
25            let quoted_migrations: Vec<(&str, String)> = vec![#(#_migrations),*];
26            let mut migrations: Vec<Migration> = Vec::new();
27            for module in quoted_migrations.into_iter() {
28                migrations.push(Migration::unapplied(module.0, &module.1).unwrap());
29            }
30            Runner::new(&migrations)
31        }
32    };
33    result
34}
35
36#[cfg(feature = "enums")]
37fn migration_enum_quoted(migration_names: &[impl AsRef<str>]) -> TokenStream2 {
38    use refinery_core::SchemaVersion;
39
40    let mut variants = Vec::new();
41    let mut discriminants = Vec::new();
42
43    for m in migration_names {
44        let m = m.as_ref();
45        let (_, version, name) = refinery_core::parse_migration_name(m)
46            .unwrap_or_else(|e| panic!("Couldn't parse migration filename '{}': {:?}", m, e));
47        let version: SchemaVersion = version;
48        let variant = Ident::new(name.to_upper_camel_case().as_str(), Span2::call_site());
49        variants.push(quote! { #variant(Migration) = #version });
50        discriminants.push(quote! { #version => Self::#variant(migration) });
51    }
52    discriminants.push(quote! { v => panic!("Invalid migration version '{}'", v) });
53
54    #[cfg(feature = "int8-versions")]
55    let embedded = quote! {
56        #[repr(i64)]
57        #[derive(Debug)]
58        pub enum EmbeddedMigration {
59            #(#variants),*
60        }
61    };
62
63    #[cfg(not(feature = "int8-versions"))]
64    let embedded = quote! {
65        #[repr(i32)]
66        #[derive(Debug)]
67        pub enum EmbeddedMigration {
68            #(#variants),*
69        }
70    };
71
72    quote! {
73
74        #embedded
75
76        impl From<Migration> for EmbeddedMigration {
77            fn from(migration: Migration) -> Self {
78                match migration.version() as SchemaVersion {
79                    #(#discriminants),*
80                }
81            }
82        }
83    }
84}
85
86/// Interpret Rust or SQL migrations and inserts a function called runner that when called returns a [`Runner`] instance with the collected migration modules.
87///
88/// When called without arguments `embed_migrations` searches for migration files on a directory called `migrations` at the root level of your crate.
89/// if you want to specify another directory call `embed_migrations!` with it's location relative to the root level of your crate.
90///
91/// To be a valid migration module, it has to be named in the format `V{1}__{2}.{3} ` where `{1}` represents the migration version and `{2}` the name and `{3} is "rs" or "sql".
92/// For the name alphanumeric characters plus "_" are supported.
93/// The Rust migration file must have a function named `migration()` that returns a [`std::string::String`].
94/// The SQL migration file must have valid sql instructions for the database you want it to run on.
95///
96/// [`Runner`]: https://docs.rs/refinery/latest/refinery/struct.Runner.html
97#[proc_macro]
98pub fn embed_migrations(input: TokenStream) -> TokenStream {
99    let location = if input.is_empty() {
100        crate_root().join("migrations")
101    } else {
102        let location: LitStr = parse_macro_input!(input);
103        crate_root().join(location.value())
104    };
105
106    let migration_files =
107        find_migration_files(location, MigrationType::All).expect("error getting migration files");
108
109    let mut migrations_mods = Vec::new();
110    let mut _migrations = Vec::new();
111    let mut migration_filenames = Vec::new();
112
113    for migration in migration_files {
114        // safe to call unwrap as find_migration_filenames returns canonical paths
115        let filename = migration
116            .file_stem()
117            .and_then(|file| file.to_os_string().into_string().ok())
118            .unwrap();
119        let path = migration.display().to_string();
120        let extension = migration.extension().unwrap();
121        migration_filenames.push(filename.clone());
122
123        if extension == "sql" {
124            _migrations.push(quote! {(#filename, include_str!(#path).to_string())});
125        } else if extension == "rs" {
126            let rs_content = fs::read_to_string(&path)
127                .unwrap()
128                .parse::<TokenStream2>()
129                .unwrap();
130            let ident = Ident::new(&filename, Span2::call_site());
131            let mig_mod = quote! {pub mod #ident {
132                #rs_content
133                // also include the file as str so we trigger recompilation if it changes
134                const _RECOMPILE_IF_CHANGED: &str = include_str!(#path);
135            }};
136            _migrations.push(quote! {(#filename, #ident::migration())});
137            migrations_mods.push(mig_mod);
138        }
139    }
140
141    let fnq = migration_fn_quoted(_migrations);
142    #[cfg(feature = "enums")]
143    let enums = migration_enum_quoted(migration_filenames.as_slice());
144    #[cfg(not(feature = "enums"))]
145    let enums = quote!();
146
147    (quote! {
148        pub mod migrations {
149            #(#migrations_mods)*
150            #fnq
151            #enums
152        }
153    })
154    .into()
155}
156
157#[cfg(test)]
158mod tests {
159    use super::{migration_fn_quoted, quote};
160
161    #[cfg(all(feature = "enums", feature = "int8-versions"))]
162    #[test]
163    fn test_enum_fn_i8() {
164        let expected = concat! {
165            "# [repr (i64)] ",
166            "# [derive (Debug)] ",
167            "pub enum EmbeddedMigration { ",
168            "Foo (Migration) = 1i64 , ",
169            "BarBaz (Migration) = 3i64 ",
170            "} ",
171            "impl From < Migration > for EmbeddedMigration { ",
172            "fn from (migration : Migration) -> Self { ",
173            "match migration . version () as SchemaVersion { ",
174            "1i64 => Self :: Foo (migration) , ",
175            "3i64 => Self :: BarBaz (migration) , ",
176            "v => panic ! (\"Invalid migration version '{}'\" , v) ",
177            "} } }"
178        };
179        let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
180        assert_eq!(expected, enums);
181    }
182
183    #[cfg(all(feature = "enums", not(feature = "int8-versions")))]
184    #[test]
185    fn test_enum_fn() {
186        let expected = concat! {
187            "# [repr (i32)] ",
188            "# [derive (Debug)] ",
189            "pub enum EmbeddedMigration { ",
190            "Foo (Migration) = 1i32 , ",
191            "BarBaz (Migration) = 3i32 ",
192            "} ",
193            "impl From < Migration > for EmbeddedMigration { ",
194            "fn from (migration : Migration) -> Self { ",
195            "match migration . version () as SchemaVersion { ",
196            "1i32 => Self :: Foo (migration) , ",
197            "3i32 => Self :: BarBaz (migration) , ",
198            "v => panic ! (\"Invalid migration version '{}'\" , v) ",
199            "} } }"
200        };
201        let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
202        assert_eq!(expected, enums);
203    }
204
205    #[test]
206    fn test_quote_fn() {
207        let migs = vec![quote!("V1__first", "valid_sql_file")];
208        let expected = concat! {
209            "use refinery :: { Migration , Runner , SchemaVersion } ; ",
210            "pub fn runner () -> Runner { ",
211            "let quoted_migrations : Vec < (& str , String) > = vec ! [\"V1__first\" , \"valid_sql_file\"] ; ",
212            "let mut migrations : Vec < Migration > = Vec :: new () ; ",
213            "for module in quoted_migrations . into_iter () { ",
214            "migrations . push (Migration :: unapplied (module . 0 , & module . 1) . unwrap ()) ; ",
215            "} ",
216            "Runner :: new (& migrations) }"
217        };
218        assert_eq!(expected, migration_fn_quoted(migs).to_string());
219    }
220}