Skip to main content

easy_sql_build/
lib.rs

1//! Build-time helper for easy-sql compilation data.
2//!
3//! Scans your crate's `src/` directory for [`easy_sql::Table`](https://docs.rs/easy-sql/latest/easy_sql/derive.Table.html)
4//! definitions and keeps `easy_sql.ron` in sync with your schema metadata.
5//!
6//! ## What it does
7//! - Generates missing `#[sql(unique_id = "...")]` attributes for tables.
8//! - Updates migration metadata in `easy_sql.ron` (feature `migrations`).
9//! - Writes build errors to `easy_sql_logs/YYYY-MM-DD.txt` when parsing fails.
10//!
11//! ## Example (build script)
12//! ```rust,ignore
13//! use regex::Regex;
14//!
15//! fn main() {
16//!     sql_build::build(
17//!         &[Regex::new(r"example_all\.rs").unwrap()],
18//!         &["crate::Sqlite", "crate::Postgres"],
19//!     );
20//! }
21//! ```
22//!
23//! See [`build`] for details on default drivers and ignore patterns.
24
25use std::{io::Write, path::Path};
26#[cfg(feature = "migrations")]
27use {
28    easy_sql_compilation_data::{TableData, TableDataVersion},
29    quote::quote,
30    std::collections::{HashMap, hash_map::Entry},
31    syn::LitInt,
32};
33
34use ::{
35    anyhow::{self, Context},
36    proc_macro2::LineColumn,
37    quote::ToTokens,
38    syn::{self, LitStr, punctuated::Punctuated},
39};
40
41#[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
42use ::{
43    convert_case::{Case, Casing},
44    syn::spanned::Spanned,
45};
46use easy_macros::{all_syntax_cases, always_context, context, get_attributes, has_attributes};
47use easy_sql_compilation_data::CompilationData;
48#[cfg(feature = "check_duplicate_table_names")]
49use {easy_sql_compilation_data::TableNameData, std::path::PathBuf};
50
51#[derive(Debug)]
52struct SearchData {
53    ///When parsing rust files
54    errors_found: bool,
55    found: bool,
56
57    //Table handling
58    created_unique_ids: Vec<(String, LineColumn)>,
59    compilation_data: CompilationData,
60    found_existing_tables_ids: Vec<String>,
61    //Use also created_unique_ids to check if tables were updated
62    tables_updated: bool,
63    #[cfg(feature = "check_duplicate_table_names")]
64    base_dir: PathBuf,
65    #[cfg(feature = "check_duplicate_table_names")]
66    current_file_relative: Option<String>,
67    ///Will be added to logs
68    /// Also add to logs when there were no errors
69    unsorted_errors: Vec<anyhow::Error>,
70    file_matched_errors: Vec<(String, Vec<anyhow::Error>)>,
71}
72
73impl SearchData {
74    #[cfg(feature = "check_duplicate_table_names")]
75    fn new(compilation_data: CompilationData, base_dir: PathBuf) -> Self {
76        SearchData {
77            errors_found: false,
78            found: false,
79            created_unique_ids: Vec::new(),
80            compilation_data,
81            found_existing_tables_ids: Vec::new(),
82            tables_updated: false,
83            #[cfg(feature = "check_duplicate_table_names")]
84            base_dir,
85            #[cfg(feature = "check_duplicate_table_names")]
86            current_file_relative: None,
87            unsorted_errors: Vec::new(),
88            file_matched_errors: Vec::new(),
89        }
90    }
91
92    #[cfg(not(feature = "check_duplicate_table_names"))]
93    fn new(compilation_data: CompilationData) -> Self {
94        SearchData {
95            errors_found: false,
96            found: false,
97            created_unique_ids: Vec::new(),
98            compilation_data,
99            found_existing_tables_ids: Vec::new(),
100            tables_updated: false,
101            unsorted_errors: Vec::new(),
102            file_matched_errors: Vec::new(),
103        }
104    }
105}
106
107all_syntax_cases! {
108    setup=>{
109        generated_fn_prefix:"macro_search",
110        additional_input_type:&mut SearchData
111    }
112    default_cases=>{
113        fn struct_table_handle_wrapper(item: &mut syn::ItemStruct, context_info: &mut SearchData);
114    }
115    special_cases=>{
116    }
117}
118
119struct DeriveInsides {
120    list: Punctuated<syn::Path, syn::Token![,]>,
121}
122
123impl syn::parse::Parse for DeriveInsides {
124    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
125        let list = Punctuated::<syn::Path, syn::Token![,]>::parse_terminated(input)?;
126        Ok(DeriveInsides { list })
127    }
128}
129
130///Table handling
131#[always_context]
132fn struct_table_handle(
133    item: &mut syn::ItemStruct,
134    context_info: &mut SearchData,
135) -> anyhow::Result<()> {
136    if let Some(attr) = get_attributes!(item, #[derive(__unknown__)])
137        .into_iter()
138        .next()
139    {
140        let parsed = match syn::parse2::<DeriveInsides>(attr) {
141            Ok(parsed) => parsed.list,
142            Err(_) => {
143                //Ignore invalid attributes, error should be shown by derive macro
144                return Ok(());
145            }
146        };
147        let mut is_sql_table = false;
148        for path in parsed.iter() {
149            let path_str = path
150                .to_token_stream()
151                .to_string()
152                .replace(|c: char| c.is_whitespace(), "");
153            match path_str.as_str() {
154                "Table" | "easy_sql::Table" | "TableDebug" | "easy_sql::TableDebug" => {
155                    is_sql_table = true;
156                }
157                _ => {}
158            }
159        }
160        if !is_sql_table {
161            //No Sql Table derive
162            return Ok(());
163        }
164    } else {
165        //No Sql Table derive
166        return Ok(());
167    }
168
169    // Check for no_version attribute
170    let _no_version = has_attributes!(item, #[sql(no_version)]);
171
172    #[cfg(feature = "migrations")]
173    let has_version_attr = !get_attributes!(item, #[sql(version = __unknown__)]).is_empty();
174
175    // Skip migrations if no_version is set
176    #[cfg(feature = "migrations")]
177    let skip_migrations = _no_version || !has_version_attr;
178
179    #[cfg(feature = "migrations")]
180    let mut version_test: Option<LitInt> = None;
181
182    #[cfg(feature = "migrations")]
183    for attr_data in get_attributes!(item, #[sql(version_test = __unknown__)]) {
184        if version_test.is_some() {
185            anyhow::bail!("Only one version_test attribute is allowed");
186        }
187
188        let parsed: LitInt =
189            syn::parse2(attr_data.clone()).context("Expected version_test to be an integer")?;
190        version_test = Some(parsed);
191    }
192
193    //Check is unique_id present
194    let mut unique_id = None;
195    for attr_data in get_attributes!(item, #[sql(unique_id = __unknown__)]) {
196        if unique_id.is_some() {
197            //Ignore multiple unique_id attributes, error should be shown by derive macro
198            anyhow::bail!(
199                "Multiple unique_id attributes found, struct: {}",
200                item.to_token_stream()
201            );
202        }
203        let lit_str: LitStr = syn::parse2(attr_data.clone())?;
204        unique_id = Some(lit_str.value());
205    }
206    #[cfg(feature = "migrations")]
207    if version_test.is_some() && unique_id.is_none() {
208        anyhow::bail!("#[sql(unique_id = ...)] is required when using #[sql(version_test = ...)]");
209    }
210    //Unique Id
211    #[cfg(feature = "migrations")]
212    let newly_created =
213        if unique_id.is_none() && version_test.is_none() && !_no_version && has_version_attr {
214            //Create unique_id
215            let generated = context_info.compilation_data.generate_unique_id();
216            context_info
217                .created_unique_ids
218                .push((generated.clone(), item.struct_token.span().start()));
219
220            unique_id = Some(generated);
221            true
222        } else {
223            false
224        };
225
226    if let Some(unique_id) = unique_id.clone() {
227        context_info
228            .found_existing_tables_ids
229            .push(unique_id.clone());
230    }
231
232    match &item.fields {
233        syn::Fields::Named(_) => {}
234        _ => {
235            //Ignore unnamed and unit structs, error should be shown by derive macro, leave debug info
236            anyhow::bail!("non named fields, struct: {}", item.to_token_stream());
237        }
238    }
239    #[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
240    let mut table_name = item.ident.to_string().to_case(Case::Snake);
241    //Check if table_name was set manually
242    #[cfg(any(feature = "check_duplicate_table_names", feature = "migrations"))]
243    if let Some(attr_data) = get_attributes!(item, #[sql(table_name = __unknown__)]).first() {
244        let lit_str: LitStr = syn::parse2(attr_data.clone())?;
245        table_name = lit_str.value();
246    }
247
248    #[cfg(feature = "check_duplicate_table_names")]
249    {
250        #[cfg(feature = "migrations")]
251        let is_version_test = version_test.is_some();
252        #[cfg(not(feature = "migrations"))]
253        let is_version_test = false;
254
255        if !is_version_test {
256            let file_name = context_info
257                .current_file_relative
258                .clone()
259                .unwrap_or_else(|| "<unknown file>".to_string());
260
261            context_info
262                .compilation_data
263                .used_table_names
264                .entry(table_name.clone())
265                .or_insert_with(Vec::new)
266                .push(TableNameData {
267                    filename: file_name,
268                    struct_name: item.ident.to_string(),
269                });
270        }
271    }
272
273    #[cfg(feature = "migrations")]
274    {
275        if skip_migrations {
276            return Ok(());
277        }
278
279        let unique_id = unique_id.unwrap();
280
281        //Check if table version has changed
282        let mut version = None;
283        for attr_data in get_attributes!(item, #[sql(version = __unknown__)]) {
284            let lit_int: LitInt = syn::parse2(attr_data.clone())?;
285            version = Some(lit_int.base10_parse::<i64>()?);
286        }
287
288        if version_test.is_some() && version.is_some() {
289            anyhow::bail!(
290                "#[sql(version_test = ...)] replaces #[sql(version = ...)] and they cannot be used together"
291            );
292        }
293
294        let version = match (version, version_test.as_ref()) {
295            (Some(version), None) => Some(version),
296            (None, Some(version_test)) => Some(version_test.base10_parse::<i64>()?),
297            (None, None) => None,
298            (Some(_), Some(_)) => None,
299        };
300
301        //Version attribute should exist. if it doesn't error by derive macro should be shown
302        #[no_context]
303        let version = version.context("Version attribute should exist")?;
304
305        //Generate table version data
306        let version_data = TableDataVersion::from_struct(item, table_name.clone())?;
307
308        let is_version_test = version_test.is_some();
309
310        //Migration check if data exists before
311        if !newly_created && !is_version_test {
312            context_info
313                .compilation_data
314                .generate_migrations(&unique_id, &version_data, version, &quote! {}, &quote! {})
315                .with_context(|| {
316                    format!("Compilation data: {:?}", context_info.compilation_data)
317                })?;
318        }
319
320        match context_info.compilation_data.tables.entry(unique_id) {
321            Entry::Occupied(occupied_entry) => {
322                let table_data = occupied_entry.into_mut();
323                if let Some(existing) = table_data.saved_versions.get(&version) {
324                    if existing != &version_data {
325                        anyhow::bail!(
326                            "Version data mismatch for version {} in compilation data",
327                            version
328                        );
329                    }
330                } else {
331                    table_data.saved_versions.insert(version, version_data);
332                    context_info.tables_updated = true;
333                }
334
335                if table_data.latest_version < version {
336                    table_data.latest_version = version;
337                    context_info.tables_updated = true;
338                }
339            }
340            Entry::Vacant(vacant_entry) => {
341                let mut saved_versions = HashMap::new();
342                saved_versions.insert(version, version_data);
343
344                let table_data = TableData {
345                    latest_version: version,
346                    saved_versions,
347                };
348
349                vacant_entry.insert(table_data);
350            }
351        }
352    }
353
354    Ok(())
355}
356
357fn struct_table_handle_wrapper(item: &mut syn::ItemStruct, context_info: &mut SearchData) {
358    match struct_table_handle(item, context_info) {
359        Ok(_) => {}
360        Err(err) => {
361            context_info.unsorted_errors.push(err);
362        }
363    }
364}
365
366#[always_context]
367fn handle_item(item: &mut syn::Item, updates: &mut SearchData) -> anyhow::Result<()> {
368    macro_search_item_handle(item, updates);
369    Ok(())
370}
371/// # Inputs
372/// `line` - 0 indexed
373#[always_context]
374fn line_pos(haystack: &str, line: usize) -> anyhow::Result<usize> {
375    let mut regex_str = "^".to_string();
376    for _ in 0..line {
377        regex_str.push_str(r".*((\r\n)|\r|\n)");
378    }
379    let regex = regex::Regex::new(&regex_str)?;
380
381    let found = regex
382        .find_at(haystack, 0)
383        .with_context(context!("Finding line failed! | Regex: {:?}", regex))?;
384
385    Ok(found.end())
386}
387
388#[always_context]
389fn handle_file(file_path: impl AsRef<Path>, search_data: &mut SearchData) -> anyhow::Result<()> {
390    let file_path = file_path.as_ref();
391    // Check if the file is a rust file
392    match file_path.extension() {
393        Some(ext) if ext == "rs" => {}
394        _ => return Ok(()),
395    }
396
397    #[cfg(feature = "check_duplicate_table_names")]
398    {
399        let file_relative = file_path
400            .strip_prefix(&search_data.base_dir)
401            .unwrap_or(file_path)
402            .to_string_lossy()
403            .to_string();
404        search_data.current_file_relative = Some(file_relative);
405    }
406
407    // Read the file
408    let mut contents = std::fs::read_to_string(file_path)?;
409    //Operate on syn::File
410    let file = match syn::parse_file(&contents) {
411        Ok(file) => file,
412        Err(_) => {
413            //Don't delete tables if at least one file has errors
414            search_data.errors_found = true;
415            //Ignore files with errors
416            return Ok(());
417        }
418    };
419
420    for mut item in file.items.into_iter() {
421        search_data.found = false;
422        handle_item(
423            #[context(tokens)]
424            &mut item,
425            search_data,
426        )?;
427    }
428
429    //Create unique ids in the file (if needed)
430    if !search_data.created_unique_ids.is_empty() {
431        search_data.tables_updated = true;
432
433        let mut updates = search_data.created_unique_ids.drain(..).collect::<Vec<_>>();
434        //Sort our lines (reverse order)
435        updates.sort_by(|a, b| b.1.line.cmp(&a.1.line));
436
437        //Uses span position info to add #[sql(unique_id = ...)] to every item on the list
438        for (unique_id, start_pos) in updates.into_iter() {
439            //1 indexed
440            let line = start_pos.line;
441            //Find position based on line
442            let line_bytes_end = line_pos(&contents, line - 1)?;
443
444            contents.insert_str(
445                line_bytes_end,
446                &format!("#[sql(unique_id = \"{}\")]\r\n", unique_id),
447            );
448        }
449
450        let mut file = std::fs::File::create(file_path).unwrap();
451        file.write_all(contents.as_bytes()).unwrap();
452    }
453
454    //File match errors
455    if !search_data.unsorted_errors.is_empty() {
456        search_data.file_matched_errors.push((
457            file_path.display().to_string(),
458            search_data.unsorted_errors.drain(..).collect(),
459        ));
460    }
461
462    #[cfg(feature = "check_duplicate_table_names")]
463    {
464        search_data.current_file_relative = None;
465    }
466
467    Ok(())
468}
469
470#[always_context]
471fn handle_dir(
472    dir: impl AsRef<Path>,
473    ignore_list: &[regex::Regex],
474    base_path_len_bytes: usize,
475    search_data: &mut SearchData,
476) -> anyhow::Result<()> {
477    // Get all files in the src directory
478    let files = std::fs::read_dir(dir.as_ref())?;
479    // Iterate over all files
480    'entries: for entry in files {
481        #[no_context_inputs]
482        let entry = entry.context("Directory Entry")?;
483
484        // Get the file path
485        let entry_path = entry.path();
486
487        //Ignore list check
488        for r in ignore_list.iter() {
489            let path_str = entry_path.display().to_string();
490
491            if r.is_match(&path_str) {
492                // Ignore this entry
493                continue 'entries;
494            }
495        }
496
497        let file_type = entry.file_type()?;
498        if file_type.is_file() {
499            handle_file(&entry_path, search_data)?;
500        } else if file_type.is_dir() {
501            // If the file is a directory, call this function recursively
502            handle_dir(&entry_path, ignore_list, base_path_len_bytes, search_data)?;
503        }
504    }
505
506    Ok(())
507}
508
509#[always_context]
510/// # Arguments
511///
512/// `ignore_list` - A list of regex patterns to ignore. The patterns are used on the file path. Path is ignored if match found.
513///
514fn build_result(ignore_list: &[regex::Regex], default_drivers: &[&str]) -> anyhow::Result<()> {
515    // Get the current directory
516    let current_dir = std::env::current_dir()?;
517    let base_path_len_bytes = current_dir.display().to_string().len();
518    // Get the src directory
519    let src_dir = current_dir.join("src");
520
521    let default_drivers_mapped = default_drivers
522        .iter()
523        .map(|s| s.to_string())
524        .collect::<Vec<_>>();
525
526    #[cfg(feature = "check_duplicate_table_names")]
527    let mut search_data = SearchData::new(
528        CompilationData::load(default_drivers_mapped.clone(), true)?,
529        current_dir.clone(),
530    );
531
532    #[cfg(not(feature = "check_duplicate_table_names"))]
533    let mut search_data =
534        SearchData::new(CompilationData::load(default_drivers_mapped.clone(), true)?);
535
536    #[cfg(feature = "check_duplicate_table_names")]
537    {
538        search_data.compilation_data.used_table_names.clear();
539    }
540
541    handle_dir(&src_dir, ignore_list, base_path_len_bytes, &mut search_data)?;
542
543    //Write into log file (if needed)
544    if !search_data.file_matched_errors.is_empty() {
545        let log_folder = current_dir.join("easy_sql_logs");
546        if !log_folder.exists() {
547            let result = std::fs::create_dir_all(&log_folder);
548            if let Err(e) = &result
549                && let std::io::ErrorKind::ReadOnlyFilesystem = e.kind()
550            {
551                //If we can't create log folder, just skip logging
552                return Ok(());
553            }
554            result.context("Creating easy_sql_logs folder failed")?;
555        }
556        let current_date = chrono::Utc::now();
557        let log_file = log_folder.join(format!("{}.txt", current_date.format("%Y-%m-%d")));
558
559        let errors = search_data
560            .file_matched_errors
561            .iter()
562            .map(|(file_path, errors)| {
563                let mut error_str =
564                    format!("==========\r\nFile: {}\r\n==========\r\n\r\n", file_path);
565                for err in errors.iter() {
566                    error_str.push_str(&format!("{:?}\r\n\r\n", err));
567                }
568                error_str
569            })
570            .collect::<Vec<_>>()
571            .join("\n");
572
573        let log_header = format!(
574            "==================\r\n[[[{} - Build Log]]]\r\n==================\r\n\r\n{}\r\n\r\n",
575            current_date.format("%H:%M:%S"),
576            errors
577        );
578        let mut log_file = std::fs::OpenOptions::new()
579            .create(true)
580            .append(true)
581            .open(&log_file)?;
582        log_file.write_all(log_header.as_bytes())?;
583    }
584
585    //Remove deleted tables
586    // If no errors when parsing rust files were found
587    if !search_data.errors_found
588        && search_data.compilation_data.tables.len() != search_data.found_existing_tables_ids.len()
589    {
590        search_data.tables_updated = true;
591
592        search_data.compilation_data.tables.retain(|key, _| {
593            if search_data.found_existing_tables_ids.contains(key) {
594                return true;
595            }
596            //Table was deleted
597            false
598        });
599    }
600
601    //Update compilation data (if needed)
602    search_data.compilation_data.save()?;
603
604    Ok(())
605}
606/// Runs the easy-sql build step and panics on failure.
607///
608/// Scans the `src/` directory for table definitions, updates `easy_sql.ron`, and
609/// generates missing `#[sql(unique_id = "...")]` attributes for new tables.
610///
611/// When parsing fails, detailed error logs are written to
612/// `easy_sql_logs/YYYY-MM-DD.txt`.
613///
614/// The `default_drivers` list is used by query macros when no `<Driver>` is
615/// specified at the call site. This mirrors the values you would normally pass
616/// from the build script.
617///
618/// ## Arguments
619///
620/// `ignore_list` - Regex patterns used to skip files or directories by path.
621///
622/// `default_drivers` - Driver paths (for example, `"crate::Sqlite"`).
623///
624/// ## Panics
625/// Panics with detailed diagnostics if parsing fails or compilation data is invalid.
626///
627/// ## Example
628/// ```rust,ignore
629/// use regex::Regex;
630///
631/// fn main() {
632///     sql_build::build(
633///         &[Regex::new(r"example_all\.rs").unwrap()],
634///         &["crate::Sqlite"],
635///     );
636/// }
637/// ```
638pub fn build(ignore_list: &[regex::Regex], default_drivers: &[&str]) {
639    if let Err(err) = build_result(ignore_list, default_drivers) {
640        panic!(
641            "Always Context Build Error: {}\r\n\r\nDebug Info:\r\n\r\n{:?}",
642            err, err
643        );
644    }
645}