1use std::{io::Write, path::Path};
26#[cfg(feature = "migrations")]
27use {
28 easy_sql_compilation_data::{TableData, TableDataVersion},
29 quote::quote,
30 std::collections::{HashMap, hash_map::Entry},
31 syn::LitInt,
32};
33
34use ::{
35 anyhow::{self, Context},
36 proc_macro2::LineColumn,
37 quote::ToTokens,
38 syn::{self, LitStr, punctuated::Punctuated},
39};
40
41#[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
42use ::{
43 convert_case::{Case, Casing},
44 syn::spanned::Spanned,
45};
46use easy_macros::{all_syntax_cases, always_context, context, get_attributes, has_attributes};
47use easy_sql_compilation_data::CompilationData;
48#[cfg(feature = "check_duplicate_table_names")]
49use {easy_sql_compilation_data::TableNameData, std::path::PathBuf};
50
51#[derive(Debug)]
52struct SearchData {
53 errors_found: bool,
55 found: bool,
56
57 created_unique_ids: Vec<(String, LineColumn)>,
59 compilation_data: CompilationData,
60 found_existing_tables_ids: Vec<String>,
61 tables_updated: bool,
63 #[cfg(feature = "check_duplicate_table_names")]
64 base_dir: PathBuf,
65 #[cfg(feature = "check_duplicate_table_names")]
66 current_file_relative: Option<String>,
67 unsorted_errors: Vec<anyhow::Error>,
70 file_matched_errors: Vec<(String, Vec<anyhow::Error>)>,
71}
72
73impl SearchData {
74 #[cfg(feature = "check_duplicate_table_names")]
75 fn new(compilation_data: CompilationData, base_dir: PathBuf) -> Self {
76 SearchData {
77 errors_found: false,
78 found: false,
79 created_unique_ids: Vec::new(),
80 compilation_data,
81 found_existing_tables_ids: Vec::new(),
82 tables_updated: false,
83 #[cfg(feature = "check_duplicate_table_names")]
84 base_dir,
85 #[cfg(feature = "check_duplicate_table_names")]
86 current_file_relative: None,
87 unsorted_errors: Vec::new(),
88 file_matched_errors: Vec::new(),
89 }
90 }
91
92 #[cfg(not(feature = "check_duplicate_table_names"))]
93 fn new(compilation_data: CompilationData) -> Self {
94 SearchData {
95 errors_found: false,
96 found: false,
97 created_unique_ids: Vec::new(),
98 compilation_data,
99 found_existing_tables_ids: Vec::new(),
100 tables_updated: false,
101 unsorted_errors: Vec::new(),
102 file_matched_errors: Vec::new(),
103 }
104 }
105}
106
107all_syntax_cases! {
108 setup=>{
109 generated_fn_prefix:"macro_search",
110 additional_input_type:&mut SearchData
111 }
112 default_cases=>{
113 fn struct_table_handle_wrapper(item: &mut syn::ItemStruct, context_info: &mut SearchData);
114 }
115 special_cases=>{
116 }
117}
118
119struct DeriveInsides {
120 list: Punctuated<syn::Path, syn::Token![,]>,
121}
122
123impl syn::parse::Parse for DeriveInsides {
124 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
125 let list = Punctuated::<syn::Path, syn::Token![,]>::parse_terminated(input)?;
126 Ok(DeriveInsides { list })
127 }
128}
129
130#[always_context]
132fn struct_table_handle(
133 item: &mut syn::ItemStruct,
134 context_info: &mut SearchData,
135) -> anyhow::Result<()> {
136 if let Some(attr) = get_attributes!(item, #[derive(__unknown__)])
137 .into_iter()
138 .next()
139 {
140 let parsed = match syn::parse2::<DeriveInsides>(attr) {
141 Ok(parsed) => parsed.list,
142 Err(_) => {
143 return Ok(());
145 }
146 };
147 let mut is_sql_table = false;
148 for path in parsed.iter() {
149 let path_str = path
150 .to_token_stream()
151 .to_string()
152 .replace(|c: char| c.is_whitespace(), "");
153 match path_str.as_str() {
154 "Table" | "easy_sql::Table" | "TableDebug" | "easy_sql::TableDebug" => {
155 is_sql_table = true;
156 }
157 _ => {}
158 }
159 }
160 if !is_sql_table {
161 return Ok(());
163 }
164 } else {
165 return Ok(());
167 }
168
169 let _no_version = has_attributes!(item, #[sql(no_version)]);
171
172 #[cfg(feature = "migrations")]
173 let has_version_attr = !get_attributes!(item, #[sql(version = __unknown__)]).is_empty();
174
175 #[cfg(feature = "migrations")]
177 let skip_migrations = _no_version || !has_version_attr;
178
179 #[cfg(feature = "migrations")]
180 let mut version_test: Option<LitInt> = None;
181
182 #[cfg(feature = "migrations")]
183 for attr_data in get_attributes!(item, #[sql(version_test = __unknown__)]) {
184 if version_test.is_some() {
185 anyhow::bail!("Only one version_test attribute is allowed");
186 }
187
188 let parsed: LitInt =
189 syn::parse2(attr_data.clone()).context("Expected version_test to be an integer")?;
190 version_test = Some(parsed);
191 }
192
193 let mut unique_id = None;
195 for attr_data in get_attributes!(item, #[sql(unique_id = __unknown__)]) {
196 if unique_id.is_some() {
197 anyhow::bail!(
199 "Multiple unique_id attributes found, struct: {}",
200 item.to_token_stream()
201 );
202 }
203 let lit_str: LitStr = syn::parse2(attr_data.clone())?;
204 unique_id = Some(lit_str.value());
205 }
206 #[cfg(feature = "migrations")]
207 if version_test.is_some() && unique_id.is_none() {
208 anyhow::bail!("#[sql(unique_id = ...)] is required when using #[sql(version_test = ...)]");
209 }
210 #[cfg(feature = "migrations")]
212 let newly_created =
213 if unique_id.is_none() && version_test.is_none() && !_no_version && has_version_attr {
214 let generated = context_info.compilation_data.generate_unique_id();
216 context_info
217 .created_unique_ids
218 .push((generated.clone(), item.struct_token.span().start()));
219
220 unique_id = Some(generated);
221 true
222 } else {
223 false
224 };
225
226 if let Some(unique_id) = unique_id.clone() {
227 context_info
228 .found_existing_tables_ids
229 .push(unique_id.clone());
230 }
231
232 match &item.fields {
233 syn::Fields::Named(_) => {}
234 _ => {
235 anyhow::bail!("non named fields, struct: {}", item.to_token_stream());
237 }
238 }
239 #[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
240 let mut table_name = item.ident.to_string().to_case(Case::Snake);
241 #[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
243 if let Some(attr_data) = get_attributes!(item, #[sql(table_name = __unknown__)]).first() {
244 let lit_str: LitStr = syn::parse2(attr_data.clone())?;
245 table_name = lit_str.value();
246 }
247
248 #[cfg(feature = "check_duplicate_table_names")]
249 {
250 #[cfg(feature = "migrations")]
251 let is_version_test = version_test.is_some();
252 #[cfg(not(feature = "migrations"))]
253 let is_version_test = false;
254
255 if !is_version_test {
256 let file_name = context_info
257 .current_file_relative
258 .clone()
259 .unwrap_or_else(|| "<unknown file>".to_string());
260
261 context_info
262 .compilation_data
263 .used_table_names
264 .entry(table_name.clone())
265 .or_insert_with(Vec::new)
266 .push(TableNameData {
267 filename: file_name,
268 struct_name: item.ident.to_string(),
269 });
270 }
271 }
272
273 #[cfg(feature = "migrations")]
274 {
275 if skip_migrations {
276 return Ok(());
277 }
278
279 let unique_id = unique_id.unwrap();
280
281 let mut version = None;
283 for attr_data in get_attributes!(item, #[sql(version = __unknown__)]) {
284 let lit_int: LitInt = syn::parse2(attr_data.clone())?;
285 version = Some(lit_int.base10_parse::<i64>()?);
286 }
287
288 if version_test.is_some() && version.is_some() {
289 anyhow::bail!(
290 "#[sql(version_test = ...)] replaces #[sql(version = ...)] and they cannot be used together"
291 );
292 }
293
294 let version = match (version, version_test.as_ref()) {
295 (Some(version), None) => Some(version),
296 (None, Some(version_test)) => Some(version_test.base10_parse::<i64>()?),
297 (None, None) => None,
298 (Some(_), Some(_)) => None,
299 };
300
301 #[no_context]
303 let version = version.context("Version attribute should exist")?;
304
305 let version_data = TableDataVersion::from_struct(item, table_name.clone())?;
307
308 let is_version_test = version_test.is_some();
309
310 if !newly_created && !is_version_test {
312 context_info
313 .compilation_data
314 .generate_migrations(&unique_id, &version_data, version, "e! {}, "e! {})
315 .with_context(|| {
316 format!("Compilation data: {:?}", context_info.compilation_data)
317 })?;
318 }
319
320 match context_info.compilation_data.tables.entry(unique_id) {
321 Entry::Occupied(occupied_entry) => {
322 let table_data = occupied_entry.into_mut();
323 if let Some(existing) = table_data.saved_versions.get(&version) {
324 if existing != &version_data {
325 anyhow::bail!(
326 "Version data mismatch for version {} in compilation data",
327 version
328 );
329 }
330 } else {
331 table_data.saved_versions.insert(version, version_data);
332 context_info.tables_updated = true;
333 }
334
335 if table_data.latest_version < version {
336 table_data.latest_version = version;
337 context_info.tables_updated = true;
338 }
339 }
340 Entry::Vacant(vacant_entry) => {
341 let mut saved_versions = HashMap::new();
342 saved_versions.insert(version, version_data);
343
344 let table_data = TableData {
345 latest_version: version,
346 saved_versions,
347 };
348
349 vacant_entry.insert(table_data);
350 }
351 }
352 }
353
354 Ok(())
355}
356
357fn struct_table_handle_wrapper(item: &mut syn::ItemStruct, context_info: &mut SearchData) {
358 match struct_table_handle(item, context_info) {
359 Ok(_) => {}
360 Err(err) => {
361 context_info.unsorted_errors.push(err);
362 }
363 }
364}
365
366#[always_context]
367fn handle_item(item: &mut syn::Item, updates: &mut SearchData) -> anyhow::Result<()> {
368 macro_search_item_handle(item, updates);
369 Ok(())
370}
371#[always_context]
374fn line_pos(haystack: &str, line: usize) -> anyhow::Result<usize> {
375 let mut regex_str = "^".to_string();
376 for _ in 0..line {
377 regex_str.push_str(r".*((\r\n)|\r|\n)");
378 }
379 let regex = regex::Regex::new(®ex_str)?;
380
381 let found = regex
382 .find_at(haystack, 0)
383 .with_context(context!("Finding line failed! | Regex: {:?}", regex))?;
384
385 Ok(found.end())
386}
387
388#[always_context]
389fn handle_file(file_path: impl AsRef<Path>, search_data: &mut SearchData) -> anyhow::Result<()> {
390 let file_path = file_path.as_ref();
391 match file_path.extension() {
393 Some(ext) if ext == "rs" => {}
394 _ => return Ok(()),
395 }
396
397 #[cfg(feature = "check_duplicate_table_names")]
398 {
399 let file_relative = file_path
400 .strip_prefix(&search_data.base_dir)
401 .unwrap_or(file_path)
402 .to_string_lossy()
403 .to_string();
404 search_data.current_file_relative = Some(file_relative);
405 }
406
407 let mut contents = std::fs::read_to_string(file_path)?;
409 let file = match syn::parse_file(&contents) {
411 Ok(file) => file,
412 Err(_) => {
413 search_data.errors_found = true;
415 return Ok(());
417 }
418 };
419
420 for mut item in file.items.into_iter() {
421 search_data.found = false;
422 handle_item(
423 #[context(tokens)]
424 &mut item,
425 search_data,
426 )?;
427 }
428
429 if !search_data.created_unique_ids.is_empty() {
431 search_data.tables_updated = true;
432
433 let mut updates = search_data.created_unique_ids.drain(..).collect::<Vec<_>>();
434 updates.sort_by(|a, b| b.1.line.cmp(&a.1.line));
436
437 for (unique_id, start_pos) in updates.into_iter() {
439 let line = start_pos.line;
441 let line_bytes_end = line_pos(&contents, line - 1)?;
443
444 contents.insert_str(
445 line_bytes_end,
446 &format!("#[sql(unique_id = \"{}\")]\r\n", unique_id),
447 );
448 }
449
450 let mut file = std::fs::File::create(file_path).unwrap();
451 file.write_all(contents.as_bytes()).unwrap();
452 }
453
454 if !search_data.unsorted_errors.is_empty() {
456 search_data.file_matched_errors.push((
457 file_path.display().to_string(),
458 search_data.unsorted_errors.drain(..).collect(),
459 ));
460 }
461
462 #[cfg(feature = "check_duplicate_table_names")]
463 {
464 search_data.current_file_relative = None;
465 }
466
467 Ok(())
468}
469
470#[always_context]
471fn handle_dir(
472 dir: impl AsRef<Path>,
473 ignore_list: &[regex::Regex],
474 base_path_len_bytes: usize,
475 search_data: &mut SearchData,
476) -> anyhow::Result<()> {
477 let files = std::fs::read_dir(dir.as_ref())?;
479 'entries: for entry in files {
481 #[no_context_inputs]
482 let entry = entry.context("Directory Entry")?;
483
484 let entry_path = entry.path();
486
487 for r in ignore_list.iter() {
489 let path_str = entry_path.display().to_string();
490
491 if r.is_match(&path_str) {
492 continue 'entries;
494 }
495 }
496
497 let file_type = entry.file_type()?;
498 if file_type.is_file() {
499 handle_file(&entry_path, search_data)?;
500 } else if file_type.is_dir() {
501 handle_dir(&entry_path, ignore_list, base_path_len_bytes, search_data)?;
503 }
504 }
505
506 Ok(())
507}
508
509#[always_context]
510fn build_result(ignore_list: &[regex::Regex], default_drivers: &[&str]) -> anyhow::Result<()> {
515 let current_dir = std::env::current_dir()?;
517 let base_path_len_bytes = current_dir.display().to_string().len();
518 let src_dir = current_dir.join("src");
520
521 let default_drivers_mapped = default_drivers
522 .iter()
523 .map(|s| s.to_string())
524 .collect::<Vec<_>>();
525
526 #[cfg(feature = "check_duplicate_table_names")]
527 let mut search_data = SearchData::new(
528 CompilationData::load(default_drivers_mapped.clone(), true)?,
529 current_dir.clone(),
530 );
531
532 #[cfg(not(feature = "check_duplicate_table_names"))]
533 let mut search_data =
534 SearchData::new(CompilationData::load(default_drivers_mapped.clone(), true)?);
535
536 #[cfg(feature = "check_duplicate_table_names")]
537 {
538 search_data.compilation_data.used_table_names.clear();
539 }
540
541 handle_dir(&src_dir, ignore_list, base_path_len_bytes, &mut search_data)?;
542
543 if !search_data.file_matched_errors.is_empty() {
545 let log_folder = current_dir.join("easy_sql_logs");
546 if !log_folder.exists() {
547 let result = std::fs::create_dir_all(&log_folder);
548 if let Err(e) = &result
549 && let std::io::ErrorKind::ReadOnlyFilesystem = e.kind()
550 {
551 return Ok(());
553 }
554 result.context("Creating easy_sql_logs folder failed")?;
555 }
556 let current_date = chrono::Utc::now();
557 let log_file = log_folder.join(format!("{}.txt", current_date.format("%Y-%m-%d")));
558
559 let errors = search_data
560 .file_matched_errors
561 .iter()
562 .map(|(file_path, errors)| {
563 let mut error_str =
564 format!("==========\r\nFile: {}\r\n==========\r\n\r\n", file_path);
565 for err in errors.iter() {
566 error_str.push_str(&format!("{:?}\r\n\r\n", err));
567 }
568 error_str
569 })
570 .collect::<Vec<_>>()
571 .join("\n");
572
573 let log_header = format!(
574 "==================\r\n[[[{} - Build Log]]]\r\n==================\r\n\r\n{}\r\n\r\n",
575 current_date.format("%H:%M:%S"),
576 errors
577 );
578 let mut log_file = std::fs::OpenOptions::new()
579 .create(true)
580 .append(true)
581 .open(&log_file)?;
582 log_file.write_all(log_header.as_bytes())?;
583 }
584
585 if !search_data.errors_found
588 && search_data.compilation_data.tables.len() != search_data.found_existing_tables_ids.len()
589 {
590 search_data.tables_updated = true;
591
592 search_data.compilation_data.tables.retain(|key, _| {
593 if search_data.found_existing_tables_ids.contains(key) {
594 return true;
595 }
596 false
598 });
599 }
600
601 search_data.compilation_data.save()?;
603
604 Ok(())
605}
606pub fn build(ignore_list: &[regex::Regex], default_drivers: &[&str]) {
639 if let Err(err) = build_result(ignore_list, default_drivers) {
640 panic!(
641 "Always Context Build Error: {}\r\n\r\nDebug Info:\r\n\r\n{:?}",
642 err, err
643 );
644 }
645}