trailbase_refinery_macros/
lib.rs1#![recursion_limit = "128"]
3
4use heck::ToUpperCamelCase;
5use proc_macro::TokenStream;
6use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
7use quote::quote;
8use quote::ToTokens;
9use trailbase_refinery_core::{find_migration_files, MigrationType};
10use std::path::PathBuf;
11use std::{env, fs};
12use syn::{parse_macro_input, Ident, LitStr};
13
14pub(crate) fn crate_root() -> PathBuf {
15 let crate_root = env::var("CARGO_MANIFEST_DIR")
16 .expect("CARGO_MANIFEST_DIR environment variable not present");
17 PathBuf::from(crate_root)
18}
19
20fn migration_fn_quoted<T: ToTokens>(_migrations: Vec<T>) -> TokenStream2 {
21 let result = quote! {
22 use trailbase_refinery_core::{Migration, Runner};
23 pub fn runner() -> Runner {
24 let quoted_migrations: Vec<(&str, String)> = vec![#(#_migrations),*];
25 let mut migrations: Vec<Migration> = Vec::new();
26 for module in quoted_migrations.into_iter() {
27 migrations.push(Migration::unapplied(module.0, &module.1).unwrap());
28 }
29 Runner::new(&migrations)
30 }
31 };
32 result
33}
34
35fn migration_enum_quoted(migration_names: &[impl AsRef<str>]) -> TokenStream2 {
36 if cfg!(feature = "enums") {
37 let mut variants = Vec::new();
38 let mut discriminants = Vec::new();
39
40 for m in migration_names {
41 let m = m.as_ref();
42 let (_, version, name) = trailbase_refinery_core::parse_migration_name(m)
43 .unwrap_or_else(|e| panic!("Couldn't parse migration filename '{}': {:?}", m, e));
44 let variant = Ident::new(name.to_upper_camel_case().as_str(), Span2::call_site());
45 variants.push(quote! { #variant(Migration) = #version });
46 discriminants.push(quote! { #version => Self::#variant(migration) });
47 }
48 discriminants.push(quote! { v => panic!("Invalid migration version '{}'", v) });
49
50 let result = quote! {
51 #[repr(i32)]
52 #[derive(Debug)]
53 pub enum EmbeddedMigration {
54 #(#variants),*
55 }
56
57 impl From<Migration> for EmbeddedMigration {
58 fn from(migration: Migration) -> Self {
59 match migration.version() as i32 {
60 #(#discriminants),*
61 }
62 }
63 }
64 };
65 result
66 } else {
67 quote!()
68 }
69}
70
71#[proc_macro]
83pub fn embed_migrations(input: TokenStream) -> TokenStream {
84 let location = if input.is_empty() {
85 crate_root().join("migrations")
86 } else {
87 let location: LitStr = parse_macro_input!(input);
88 crate_root().join(location.value())
89 };
90
91 let migration_files =
92 find_migration_files(location, MigrationType::All).expect("error getting migration files");
93
94 let mut migrations_mods = Vec::new();
95 let mut _migrations = Vec::new();
96 let mut migration_filenames = Vec::new();
97
98 for migration in migration_files {
99 let filename = migration
101 .file_stem()
102 .and_then(|file| file.to_os_string().into_string().ok())
103 .unwrap();
104 let path = migration.display().to_string();
105 let extension = migration.extension().unwrap();
106 migration_filenames.push(filename.clone());
107
108 if extension == "sql" {
109 _migrations.push(quote! {(#filename, include_str!(#path).to_string())});
110 } else if extension == "rs" {
111 let rs_content = fs::read_to_string(&path)
112 .unwrap()
113 .parse::<TokenStream2>()
114 .unwrap();
115 let ident = Ident::new(&filename, Span2::call_site());
116 let mig_mod = quote! {pub mod #ident {
117 #rs_content
118 const _RECOMPILE_IF_CHANGED: &str = include_str!(#path);
120 }};
121 _migrations.push(quote! {(#filename, #ident::migration())});
122 migrations_mods.push(mig_mod);
123 }
124 }
125
126 let fnq = migration_fn_quoted(_migrations);
127 let enums = migration_enum_quoted(migration_filenames.as_slice());
128 (quote! {
129 pub mod migrations {
130 #(#migrations_mods)*
131 #fnq
132 #enums
133 }
134 })
135 .into()
136}
137
138#[cfg(test)]
139mod tests {
140 use super::{migration_fn_quoted, quote};
141
142 #[test]
143 #[cfg(feature = "enums")]
144 fn test_enum_fn() {
145 let expected = concat! {
146 "# [repr (i32)] # [derive (Debug)] ",
147 "pub enum EmbeddedMigration { ",
148 "Foo (Migration) = 1i32 , ",
149 "BarBaz (Migration) = 3i32 ",
150 "} ",
151 "impl From < Migration > for EmbeddedMigration { ",
152 "fn from (migration : Migration) -> Self { ",
153 "match migration . version () as i32 { ",
154 "1i32 => Self :: Foo (migration) , ",
155 "3i32 => Self :: BarBaz (migration) , ",
156 "v => panic ! (\"Invalid migration version '{}'\" , v) ",
157 "} } }"
158 };
159 let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
160 assert_eq!(expected, enums);
161 }
162
163 #[test]
164 fn test_quote_fn() {
165 let migs = vec![quote!("V1__first", "valid_sql_file")];
166 let expected = concat! {
167 "use trailbase_refinery_core :: { Migration , Runner } ; ",
168 "pub fn runner () -> Runner { ",
169 "let quoted_migrations : Vec < (& str , String) > = vec ! [\"V1__first\" , \"valid_sql_file\"] ; ",
170 "let mut migrations : Vec < Migration > = Vec :: new () ; ",
171 "for module in quoted_migrations . into_iter () { ",
172 "migrations . push (Migration :: unapplied (module . 0 , & module . 1) . unwrap ()) ; ",
173 "} ",
174 "Runner :: new (& migrations) }"
175 };
176 assert_eq!(expected, migration_fn_quoted(migs).to_string());
177 }
178}