trailbase_refinery_macros/
lib.rs

1//! Contains Refinery macros that are used to import and embed migration files.
2#![recursion_limit = "128"]
3
4use heck::ToUpperCamelCase;
5use proc_macro::TokenStream;
6use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
7use quote::quote;
8use quote::ToTokens;
9use trailbase_refinery_core::{find_migration_files, MigrationType};
10use std::path::PathBuf;
11use std::{env, fs};
12use syn::{parse_macro_input, Ident, LitStr};
13
14pub(crate) fn crate_root() -> PathBuf {
15    let crate_root = env::var("CARGO_MANIFEST_DIR")
16        .expect("CARGO_MANIFEST_DIR environment variable not present");
17    PathBuf::from(crate_root)
18}
19
20fn migration_fn_quoted<T: ToTokens>(_migrations: Vec<T>) -> TokenStream2 {
21    let result = quote! {
22        use trailbase_refinery_core::{Migration, Runner};
23        pub fn runner() -> Runner {
24            let quoted_migrations: Vec<(&str, String)> = vec![#(#_migrations),*];
25            let mut migrations: Vec<Migration> = Vec::new();
26            for module in quoted_migrations.into_iter() {
27                migrations.push(Migration::unapplied(module.0, &module.1).unwrap());
28            }
29            Runner::new(&migrations)
30        }
31    };
32    result
33}
34
35fn migration_enum_quoted(migration_names: &[impl AsRef<str>]) -> TokenStream2 {
36    if cfg!(feature = "enums") {
37        let mut variants = Vec::new();
38        let mut discriminants = Vec::new();
39
40        for m in migration_names {
41            let m = m.as_ref();
42            let (_, version, name) = trailbase_refinery_core::parse_migration_name(m)
43                .unwrap_or_else(|e| panic!("Couldn't parse migration filename '{}': {:?}", m, e));
44            let variant = Ident::new(name.to_upper_camel_case().as_str(), Span2::call_site());
45            variants.push(quote! { #variant(Migration) = #version });
46            discriminants.push(quote! { #version => Self::#variant(migration) });
47        }
48        discriminants.push(quote! { v => panic!("Invalid migration version '{}'", v) });
49
50        let result = quote! {
51            #[repr(i32)]
52            #[derive(Debug)]
53            pub enum EmbeddedMigration {
54                #(#variants),*
55            }
56
57            impl From<Migration> for EmbeddedMigration {
58                fn from(migration: Migration) -> Self {
59                    match migration.version() as i32 {
60                        #(#discriminants),*
61                    }
62                }
63            }
64        };
65        result
66    } else {
67        quote!()
68    }
69}
70
71/// Interpret Rust or SQL migrations and inserts a function called runner that when called returns a [`Runner`] instance with the collected migration modules.
72///
73/// When called without arguments `embed_migrations` searches for migration files on a directory called `migrations` at the root level of your crate.
74/// if you want to specify another directory call `embed_migrations!` with it's location relative to the root level of your crate.
75///
76/// To be a valid migration module, it has to be named in the format `V{1}__{2}.{3} ` where `{1}` represents the migration version and `{2}` the name and `{3} is "rs" or "sql".
77/// For the name alphanumeric characters plus "_" are supported.
78/// The Rust migration file must have a function named `migration()` that returns a [`std::string::String`].
79/// The SQL migration file must have valid sql instructions for the database you want it to run on.
80///
81/// [`Runner`]: https://docs.rs/refinery/latest/refinery/struct.Runner.html
82#[proc_macro]
83pub fn embed_migrations(input: TokenStream) -> TokenStream {
84    let location = if input.is_empty() {
85        crate_root().join("migrations")
86    } else {
87        let location: LitStr = parse_macro_input!(input);
88        crate_root().join(location.value())
89    };
90
91    let migration_files =
92        find_migration_files(location, MigrationType::All).expect("error getting migration files");
93
94    let mut migrations_mods = Vec::new();
95    let mut _migrations = Vec::new();
96    let mut migration_filenames = Vec::new();
97
98    for migration in migration_files {
99        // safe to call unwrap as find_migration_filenames returns canonical paths
100        let filename = migration
101            .file_stem()
102            .and_then(|file| file.to_os_string().into_string().ok())
103            .unwrap();
104        let path = migration.display().to_string();
105        let extension = migration.extension().unwrap();
106        migration_filenames.push(filename.clone());
107
108        if extension == "sql" {
109            _migrations.push(quote! {(#filename, include_str!(#path).to_string())});
110        } else if extension == "rs" {
111            let rs_content = fs::read_to_string(&path)
112                .unwrap()
113                .parse::<TokenStream2>()
114                .unwrap();
115            let ident = Ident::new(&filename, Span2::call_site());
116            let mig_mod = quote! {pub mod #ident {
117                #rs_content
118                // also include the file as str so we trigger recompilation if it changes
119                const _RECOMPILE_IF_CHANGED: &str = include_str!(#path);
120            }};
121            _migrations.push(quote! {(#filename, #ident::migration())});
122            migrations_mods.push(mig_mod);
123        }
124    }
125
126    let fnq = migration_fn_quoted(_migrations);
127    let enums = migration_enum_quoted(migration_filenames.as_slice());
128    (quote! {
129        pub mod migrations {
130            #(#migrations_mods)*
131            #fnq
132            #enums
133        }
134    })
135    .into()
136}
137
138#[cfg(test)]
139mod tests {
140    use super::{migration_fn_quoted, quote};
141
142    #[test]
143    #[cfg(feature = "enums")]
144    fn test_enum_fn() {
145        let expected = concat! {
146            "# [repr (i32)] # [derive (Debug)] ",
147            "pub enum EmbeddedMigration { ",
148            "Foo (Migration) = 1i32 , ",
149            "BarBaz (Migration) = 3i32 ",
150            "} ",
151            "impl From < Migration > for EmbeddedMigration { ",
152            "fn from (migration : Migration) -> Self { ",
153            "match migration . version () as i32 { ",
154            "1i32 => Self :: Foo (migration) , ",
155            "3i32 => Self :: BarBaz (migration) , ",
156            "v => panic ! (\"Invalid migration version '{}'\" , v) ",
157            "} } }"
158        };
159        let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
160        assert_eq!(expected, enums);
161    }
162
163    #[test]
164    fn test_quote_fn() {
165        let migs = vec![quote!("V1__first", "valid_sql_file")];
166        let expected = concat! {
167            "use trailbase_refinery_core :: { Migration , Runner } ; ",
168            "pub fn runner () -> Runner { ",
169            "let quoted_migrations : Vec < (& str , String) > = vec ! [\"V1__first\" , \"valid_sql_file\"] ; ",
170            "let mut migrations : Vec < Migration > = Vec :: new () ; ",
171            "for module in quoted_migrations . into_iter () { ",
172            "migrations . push (Migration :: unapplied (module . 0 , & module . 1) . unwrap ()) ; ",
173            "} ",
174            "Runner :: new (& migrations) }"
175        };
176        assert_eq!(expected, migration_fn_quoted(migs).to_string());
177    }
178}