1use proc_macro::TokenStream;
2use std::env;
3use std::num::ParseIntError;
4use std::path::PathBuf;
5use std::str::FromStr;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{LitStr};
9use syn::__private::TokenStream2;
10
11use flyway_sql_changelog::ChangelogFile;
12
13#[derive(Debug, Clone)]
15struct MigrationInfo {
16 version: u64,
17 filename: String,
18 name: String,
19}
20
21#[proc_macro_attribute]
40pub fn migrations(args: TokenStream, input: TokenStream) -> TokenStream {
41 let input_clone = input.clone();
45 let input_struct = syn::parse_macro_input!(input_clone as syn::ItemStruct);
46 let path = if args.is_empty() {
49 map_to_crate_root(None)
50 } else {
51 let migrations_path = syn::parse_macro_input!(args as LitStr).value();
52 map_to_crate_root(Some(migrations_path.as_str()))
53 };
54
55 #[cfg(feature = "debug_mode")]
56 if cfg!(debug_assertions){
57 println!("migrations path: {:?}", path);
58 }
59
60
61 let migrations = get_migrations(&path)
62 .expect("Error while gathering migration file information.");
63 #[cfg(feature = "debug_mode")]
64 if cfg!(debug_assertions){
65 println!("migrations: {:?}", &migrations);
66 }
67
68 let migration_tokens: Vec<TokenStream2> = migrations.iter()
69 .map(|migration| {
70 let name = migration.name.as_str();
71 let version = migration.version;
72 let filename = migration.filename.as_str();
73 let file_path = path.clone().join(filename).display().to_string();
74 let content = std::fs::read_to_string(file_path.as_str())
75 .expect(format!("Could not read migration file: {}", file_path).as_str());
76
77 let _changelog = ChangelogFile::from_string(version, name,content.as_str())
79 .expect(format!("Migration file is not a valid SQL changelog file: {}", file_path).as_str());
80
81 quote! {
82 (#version, #name.to_string(), #content)
83 }
84 })
85 .collect();
86
87 let struct_name = syn::Ident::new(input_struct.ident.to_string().as_str(), Span::call_site());
88 let result = quote! {
90 impl flyway::MigrationStore for #struct_name {
91 fn changelogs(&self) -> Vec<flyway::ChangelogFile> {
92 use flyway::ChangelogFile;
93
94 let mut result: Vec<ChangelogFile> = [#(#migration_tokens),*].iter()
95 .map(|migration| {
96 ChangelogFile::from_string(migration.0,migration.1.to_string().as_str(), migration.2).unwrap()
97 })
98 .collect();
99 return result;
100 }
101 }
102 };
103 let input: TokenStream2 = input.into();
106 return quote! {
107 #input
108 #result
109 }.into();
110}
111
112fn map_to_crate_root(path: Option<&str>) -> PathBuf {
114 let root = env::var("CARGO_MANIFEST_DIR")
115 .map(|root| PathBuf::from(root))
116 .expect("Missing CARGO_MANIFEST_DIR environment variable. Cannot obtain crate root.");
117 let result = path.map(|path| root.join(PathBuf::from_str(path)
118 .expect("Could not parse filename.")))
119 .or(Some(root))
120 .unwrap();
121 return result;
122}
123
124fn get_migrations(path: &PathBuf) -> Result<Vec<MigrationInfo>, std::io::Error> {
126 let result: Vec<MigrationInfo> = std::fs::read_dir(path)?
127 .filter(|entry| entry.is_ok())
128 .map(|entry| entry.unwrap().file_name().to_str().map(|v| v.to_string()))
129 .filter(|filename| filename.is_some())
130 .map(|filename| filename.unwrap())
131 .filter(|filename| filename.starts_with("V") && filename.ends_with(".sql"))
132 .map(|filename| {
133 let index = filename.find("_");
134 let mut version = "";
135 let mut name = "";
136 if let Some(index) = index {
137 if index > 1 && index < filename.len() - "V.sql".len() {
138 if filename[1..index].chars().all(|ch| ch >= '0' && ch <= '9') {
139 version = &filename[1..index];
140 name = &filename[(index + 1)..(filename.len() - ".sql".len())];
141 }
142 }
143 }
144
145 return if version.is_empty() {
146 None
147 } else {
148 let result: Result<Option<u64>, ParseIntError> = version.parse::<u64>()
149 .map(|version| Some(version))
150 .or(Ok(None));
151
152 let result = result.unwrap()
153 .map(|version| {
154 MigrationInfo {
155 version,
156 filename: filename.to_string(),
157 name: name.to_string()
158 }
159 });
160 return result
161 };
162 })
163 .filter(|info| info.is_some())
164 .map(|info| info.unwrap())
165 .collect();
166 return Ok(result);
167}
168
169#[cfg(test)]
170mod test {
171 #[test]
172 pub fn test_get_migrations() {
173 let path = crate::map_to_crate_root(Some("examples/migrations"));
174 let result = crate::get_migrations(&path);
175 match result {
176 Ok(migrations) => {
177 assert_eq!(migrations.len(), 2, "Two migrations have been successfully loaded.");
178 }
179 Err(err) => {
180 assert!(false, "Migration loading failed: {}", err);
181 }
182 }
183 }
184}