1use proc_macro::TokenStream;
2use std::env;
3use std::num::ParseIntError;
4use std::path::PathBuf;
5use std::str::FromStr;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{LitStr};
9use syn::__private::TokenStream2;
10
11use db_up_sql_changelog::ChangelogFile;
12
13#[derive(Debug, Clone)]
15struct MigrationInfo {
16 version: u32,
17 filename: String,
18 name: String,
19}
20
21#[proc_macro_attribute]
40pub fn migrations(args: TokenStream, input: TokenStream) -> TokenStream {
41 let input_clone = input.clone();
45 let input_struct = syn::parse_macro_input!(input_clone as syn::ItemStruct);
46 let path = if args.is_empty() {
49 map_to_crate_root(None)
50 } else {
51 let migrations_path = syn::parse_macro_input!(args as LitStr).value();
52 map_to_crate_root(Some(migrations_path.as_str()))
53 };
54 println!("migrations path: {:?}", path);
55
56 let migrations = get_migrations(&path)
57 .expect("Error while gathering migration file information.");
58 println!("migrations: {:?}", &migrations);
59
60 let migration_tokens: Vec<TokenStream2> = migrations.iter()
61 .map(|migration| {
62 let name = migration.name.as_str();
63 let version = migration.version;
64 let filename = migration.filename.as_str();
65 let file_path = path.clone().join(filename).display().to_string();
66 let content = std::fs::read_to_string(file_path.as_str())
67 .expect(format!("Could not read migration file: {}", file_path).as_str());
68
69 let _changelog = ChangelogFile::from_string(version.to_string().as_str(), content.as_str())
71 .expect(format!("Migration file is not a valid SQL changelog file: {}", file_path).as_str());
72
73 quote! {
74 (#version, #name.to_string(), #content)
75 }
76 })
77 .collect();
78
79 let struct_name = syn::Ident::new(input_struct.ident.to_string().as_str(), Span::call_site());
80 let result = quote! {
82 impl db_up::MigrationStore for #struct_name {
83 fn changelogs(&self) -> Vec<db_up::ChangelogFile> {
84 use db_up::ChangelogFile;
85
86 let mut result: Vec<ChangelogFile> = [#(#migration_tokens),*].iter()
87 .map(|migration| {
88 ChangelogFile::from_string(migration.0.to_string().as_str(), migration.2).unwrap()
89 })
90 .collect();
91 return result;
92 }
93 }
94 };
95 let input: TokenStream2 = input.into();
98 return quote! {
99 #input
100 #result
101 }.into();
102}
103
104fn map_to_crate_root(path: Option<&str>) -> PathBuf {
106 let root = env::var("CARGO_MANIFEST_DIR")
107 .map(|root| PathBuf::from(root))
108 .expect("Missing CARGO_MANIFEST_DIR environment variable. Cannot obtain crate root.");
109 let result = path.map(|path| root.join(PathBuf::from_str(path)
110 .expect("Could not parse filename.")))
111 .or(Some(root))
112 .unwrap();
113 return result;
114}
115
116fn get_migrations(path: &PathBuf) -> Result<Vec<MigrationInfo>, std::io::Error> {
118 let result: Vec<MigrationInfo> = std::fs::read_dir(path)?
119 .filter(|entry| entry.is_ok())
120 .map(|entry| entry.unwrap().file_name().to_str().map(|v| v.to_string()))
121 .filter(|filename| filename.is_some())
122 .map(|filename| filename.unwrap())
123 .filter(|filename| filename.starts_with("V") && filename.ends_with(".sql"))
124 .map(|filename| {
125 let index = filename.find("_");
126 let mut version = "";
127 let mut name = "";
128 if let Some(index) = index {
129 if index > 1 && index < filename.len() - "V.sql".len() {
130 if filename[1..index].chars().all(|ch| ch >= '0' && ch <= '9') {
131 version = &filename[1..index];
132 name = &filename[(index + 1)..(filename.len() - ".sql".len())];
133 }
134 }
135 }
136
137 return if version.is_empty() {
138 None
139 } else {
140 let result: Result<Option<u32>, ParseIntError> = version.parse::<u32>()
141 .map(|version| Some(version))
142 .or(Ok(None));
143
144 let result = result.unwrap()
145 .map(|version| {
146 MigrationInfo {
147 version,
148 filename: filename.to_string(),
149 name: name.to_string()
150 }
151 });
152 return result
153 };
154 })
155 .filter(|info| info.is_some())
156 .map(|info| info.unwrap())
157 .collect();
158 return Ok(result);
159}
160
161#[cfg(test)]
162mod test {
163 #[test]
164 pub fn test_get_migrations() {
165 let path = crate::map_to_crate_root(Some("examples/migrations"));
166 let result = crate::get_migrations(&path);
167 match result {
168 Ok(migrations) => {
169 assert_eq!(migrations.len(), 2, "Two migrations have been successfully loaded.");
170 }
171 Err(err) => {
172 assert!(false, "Migration loading failed: {}", err);
173 }
174 }
175 }
176}