1#![recursion_limit = "128"]
3
4#[cfg(feature = "enums")]
5use heck::ToUpperCamelCase;
6use proc_macro::TokenStream;
7use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
8use quote::quote;
9use quote::ToTokens;
10use refinery_core::{find_migration_files, MigrationType};
11use std::path::PathBuf;
12use std::{env, fs};
13use syn::{parse_macro_input, Ident, LitStr};
14
15pub(crate) fn crate_root() -> PathBuf {
16 let crate_root = env::var("CARGO_MANIFEST_DIR")
17 .expect("CARGO_MANIFEST_DIR environment variable not present");
18 PathBuf::from(crate_root)
19}
20
21fn migration_fn_quoted<T: ToTokens>(_migrations: Vec<T>) -> TokenStream2 {
22 let result = quote! {
23 use refinery::{Migration, Runner, SchemaVersion};
24 pub fn runner() -> Runner {
25 let quoted_migrations: Vec<(&str, String)> = vec![#(#_migrations),*];
26 let mut migrations: Vec<Migration> = Vec::new();
27 for module in quoted_migrations.into_iter() {
28 migrations.push(Migration::unapplied(module.0, &module.1).unwrap());
29 }
30 Runner::new(&migrations)
31 }
32 };
33 result
34}
35
36#[cfg(feature = "enums")]
37fn migration_enum_quoted(migration_names: &[impl AsRef<str>]) -> TokenStream2 {
38 use refinery_core::SchemaVersion;
39
40 let mut variants = Vec::new();
41 let mut discriminants = Vec::new();
42
43 for m in migration_names {
44 let m = m.as_ref();
45 let (_, version, name) = refinery_core::parse_migration_name(m)
46 .unwrap_or_else(|e| panic!("Couldn't parse migration filename '{}': {:?}", m, e));
47 let version: SchemaVersion = version;
48 let variant = Ident::new(name.to_upper_camel_case().as_str(), Span2::call_site());
49 variants.push(quote! { #variant(Migration) = #version });
50 discriminants.push(quote! { #version => Self::#variant(migration) });
51 }
52 discriminants.push(quote! { v => panic!("Invalid migration version '{}'", v) });
53
54 #[cfg(feature = "int8-versions")]
55 let embedded = quote! {
56 #[repr(i64)]
57 #[derive(Debug)]
58 pub enum EmbeddedMigration {
59 #(#variants),*
60 }
61 };
62
63 #[cfg(not(feature = "int8-versions"))]
64 let embedded = quote! {
65 #[repr(i32)]
66 #[derive(Debug)]
67 pub enum EmbeddedMigration {
68 #(#variants),*
69 }
70 };
71
72 quote! {
73
74 #embedded
75
76 impl From<Migration> for EmbeddedMigration {
77 fn from(migration: Migration) -> Self {
78 match migration.version() as SchemaVersion {
79 #(#discriminants),*
80 }
81 }
82 }
83 }
84}
85
86#[proc_macro]
98pub fn embed_migrations(input: TokenStream) -> TokenStream {
99 let location = if input.is_empty() {
100 crate_root().join("migrations")
101 } else {
102 let location: LitStr = parse_macro_input!(input);
103 crate_root().join(location.value())
104 };
105
106 let migration_files =
107 find_migration_files(location, MigrationType::All).expect("error getting migration files");
108
109 let mut migrations_mods = Vec::new();
110 let mut _migrations = Vec::new();
111 let mut migration_filenames = Vec::new();
112
113 for migration in migration_files {
114 let filename = migration
116 .file_stem()
117 .and_then(|file| file.to_os_string().into_string().ok())
118 .unwrap();
119 let path = migration.display().to_string();
120 let extension = migration.extension().unwrap();
121 migration_filenames.push(filename.clone());
122
123 if extension == "sql" {
124 _migrations.push(quote! {(#filename, include_str!(#path).to_string())});
125 } else if extension == "rs" {
126 let rs_content = fs::read_to_string(&path)
127 .unwrap()
128 .parse::<TokenStream2>()
129 .unwrap();
130 let ident = Ident::new(&filename, Span2::call_site());
131 let mig_mod = quote! {pub mod #ident {
132 #rs_content
133 const _RECOMPILE_IF_CHANGED: &str = include_str!(#path);
135 }};
136 _migrations.push(quote! {(#filename, #ident::migration())});
137 migrations_mods.push(mig_mod);
138 }
139 }
140
141 let fnq = migration_fn_quoted(_migrations);
142 #[cfg(feature = "enums")]
143 let enums = migration_enum_quoted(migration_filenames.as_slice());
144 #[cfg(not(feature = "enums"))]
145 let enums = quote!();
146
147 (quote! {
148 pub mod migrations {
149 #(#migrations_mods)*
150 #fnq
151 #enums
152 }
153 })
154 .into()
155}
156
157#[cfg(test)]
158mod tests {
159 use super::{migration_fn_quoted, quote};
160
161 #[cfg(all(feature = "enums", feature = "int8-versions"))]
162 #[test]
163 fn test_enum_fn_i8() {
164 let expected = concat! {
165 "# [repr (i64)] ",
166 "# [derive (Debug)] ",
167 "pub enum EmbeddedMigration { ",
168 "Foo (Migration) = 1i64 , ",
169 "BarBaz (Migration) = 3i64 ",
170 "} ",
171 "impl From < Migration > for EmbeddedMigration { ",
172 "fn from (migration : Migration) -> Self { ",
173 "match migration . version () as SchemaVersion { ",
174 "1i64 => Self :: Foo (migration) , ",
175 "3i64 => Self :: BarBaz (migration) , ",
176 "v => panic ! (\"Invalid migration version '{}'\" , v) ",
177 "} } }"
178 };
179 let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
180 assert_eq!(expected, enums);
181 }
182
183 #[cfg(all(feature = "enums", not(feature = "int8-versions")))]
184 #[test]
185 fn test_enum_fn() {
186 let expected = concat! {
187 "# [repr (i32)] ",
188 "# [derive (Debug)] ",
189 "pub enum EmbeddedMigration { ",
190 "Foo (Migration) = 1i32 , ",
191 "BarBaz (Migration) = 3i32 ",
192 "} ",
193 "impl From < Migration > for EmbeddedMigration { ",
194 "fn from (migration : Migration) -> Self { ",
195 "match migration . version () as SchemaVersion { ",
196 "1i32 => Self :: Foo (migration) , ",
197 "3i32 => Self :: BarBaz (migration) , ",
198 "v => panic ! (\"Invalid migration version '{}'\" , v) ",
199 "} } }"
200 };
201 let enums = super::migration_enum_quoted(&["V1__foo", "U3__barBAZ"]).to_string();
202 assert_eq!(expected, enums);
203 }
204
205 #[test]
206 fn test_quote_fn() {
207 let migs = vec![quote!("V1__first", "valid_sql_file")];
208 let expected = concat! {
209 "use refinery :: { Migration , Runner , SchemaVersion } ; ",
210 "pub fn runner () -> Runner { ",
211 "let quoted_migrations : Vec < (& str , String) > = vec ! [\"V1__first\" , \"valid_sql_file\"] ; ",
212 "let mut migrations : Vec < Migration > = Vec :: new () ; ",
213 "for module in quoted_migrations . into_iter () { ",
214 "migrations . push (Migration :: unapplied (module . 0 , & module . 1) . unwrap ()) ; ",
215 "} ",
216 "Runner :: new (& migrations) }"
217 };
218 assert_eq!(expected, migration_fn_quoted(migs).to_string());
219 }
220}