Skip to main content

easy_sql_compilation_data/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    path::PathBuf,
4    str::FromStr,
5};
6
7use anyhow::{self, Context};
8use quote::ToTokens;
9#[cfg(feature = "migrations")]
10use {
11    easy_macros::TokensBuilder,
12    proc_macro2::{Span, TokenStream},
13    quote::quote,
14};
15
16use easy_macros::{
17    always_context, get_attributes, has_attributes, token_stream_to_consistent_string,
18};
19use serde::{Deserialize, Serialize, Serializer};
20
21#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
22pub struct TableField {
23    pub name: String,
24    #[serde(default)]
25    pub ty_to_bytes: bool,
26    pub field_type: String,
27    ///Tokens converted to_string()
28    pub default: Option<String>,
29    pub is_unique: bool,
30}
31#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
32pub struct TableDataVersion {
33    pub table_name: String,
34    pub fields: Vec<TableField>,
35    pub primary_keys: Vec<String>,
36    pub auto_increment: bool,
37    ///key - table (struct) name
38    ///value - current field name
39    #[serde(serialize_with = "ordered_map")]
40    pub foreign_keys: HashMap<String, Vec<String>>,
41}
42
43fn ordered_map<S, K: Ord + Serialize, V: Serialize>(
44    value: &HashMap<K, V>,
45    serializer: S,
46) -> Result<S::Ok, S::Error>
47where
48    S: Serializer,
49{
50    let ordered: BTreeMap<_, _> = value.iter().collect();
51    ordered.serialize(serializer)
52}
53
54impl TableDataVersion {
55    pub fn from_struct(item: &syn::ItemStruct, table_name: String) -> anyhow::Result<Self> {
56        let fields = match &item.fields {
57            syn::Fields::Named(fields_named) => &fields_named.named,
58            _ => {
59                anyhow::bail!(
60                    "non named field type should be handled before `generate_table_data_from_struct` is called"
61                )
62            }
63        };
64
65        let mut fields_converted = Vec::new();
66        let mut primary_keys = Vec::new();
67        let mut foreign_keys = HashMap::new();
68
69        let mut auto_increment = false;
70
71        for field in fields.iter() {
72            let name = field.ident.as_ref().unwrap().to_string();
73
74            //Auto Increment Check
75            if has_attributes!(field, #[sql(auto_increment)]) {
76                if auto_increment {
77                    anyhow::bail!("Auto increment is only supported for single primary key");
78                }
79                auto_increment = true;
80            }
81
82            let ty_to_bytes = has_attributes!(field, #[sql(bytes)]);
83
84            let default = get_attributes!(field, #[sql(default = __unknown__)])
85                .into_iter()
86                .next()
87                .map(token_stream_to_consistent_string);
88
89            for foreign_key in get_attributes!(field, #[sql(foreign_key = __unknown__)])
90                .into_iter()
91                .map(token_stream_to_consistent_string)
92            {
93                let fields: &mut Vec<String> = foreign_keys
94                    .entry(foreign_key)
95                    .or_insert(Default::default());
96                fields.push(name.clone());
97            }
98
99            if has_attributes!(field, #[sql(primary_key)]) {
100                primary_keys.push(name.clone());
101            }
102
103            let is_unique = has_attributes!(field, #[sql(unique)]);
104
105            fields_converted.push(TableField {
106                name,
107                field_type: token_stream_to_consistent_string(field.ty.to_token_stream()),
108                default,
109                is_unique,
110                ty_to_bytes,
111            });
112        }
113
114        Ok(TableDataVersion {
115            table_name,
116            fields: fields_converted,
117            foreign_keys,
118            primary_keys,
119            auto_increment,
120        })
121    }
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125pub struct TableData {
126    #[serde(serialize_with = "ordered_map")]
127    pub saved_versions: HashMap<i64, TableDataVersion>,
128    pub latest_version: i64,
129}
130#[cfg(feature = "check_duplicate_table_names")]
131#[derive(Debug, Serialize, Deserialize)]
132pub struct TableNameData {
133    pub filename: String,
134    pub struct_name: String,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138pub struct CompilationData {
139    ///Key - table id, generated by build macro, put in #[sql(unique_id = "...")] attribute on struct
140    #[serde(serialize_with = "ordered_map")]
141    pub tables: HashMap<String, TableData>,
142    #[cfg(feature = "check_duplicate_table_names")]
143    #[serde(serialize_with = "ordered_map")]
144    #[serde(default)]
145    ///Key - table name
146    pub used_table_names: HashMap<String, Vec<TableNameData>>,
147    #[serde(default)]
148    pub default_drivers: Vec<String>,
149}
150#[always_context]
151impl CompilationData {
152    pub fn data_location() -> anyhow::Result<PathBuf> {
153        let manifest_dir_str = std::env::var("CARGO_MANIFEST_DIR")?;
154        let current_dir = PathBuf::from_str(&manifest_dir_str)?;
155
156        Ok(current_dir.join("easy_sql.ron"))
157    }
158
159    #[cfg(feature = "build")]
160    pub fn load(
161        default_drivers: Vec<String>,
162        default_drivers_update: bool,
163    ) -> anyhow::Result<CompilationData> {
164        let data_path = Self::data_location()?;
165
166        let data: CompilationData = {
167            if !data_path.exists() {
168                CompilationData {
169                    tables: HashMap::new(),
170                    #[cfg(feature = "check_duplicate_table_names")]
171                    used_table_names: HashMap::new(),
172                    default_drivers,
173                }
174            } else {
175                let data = std::fs::read_to_string(&data_path)
176                    .context("Failed to read easy_sql.ron file")?;
177
178                let mut data: CompilationData =
179                    ron::de::from_str(&data).context("Failed to parse easy_sql.ron file")?;
180
181                if default_drivers_update && data.default_drivers != default_drivers {
182                    data.default_drivers = default_drivers;
183                    data.save()?;
184                }
185
186                data
187            }
188        };
189
190        Ok(data)
191    }
192
193    pub fn load_in_macro() -> anyhow::Result<CompilationData> {
194        let data_path = Self::data_location()?;
195
196        let data: CompilationData = {
197            {
198                if !data_path.exists() {
199                    return Ok(CompilationData {
200                        tables: HashMap::new(),
201                        #[cfg(feature = "check_duplicate_table_names")]
202                        used_table_names: HashMap::new(),
203                        default_drivers: Vec::new(),
204                    });
205                }
206
207                let data = std::fs::read_to_string(&data_path)
208                    .context("Failed to read easy_sql.ron file")?;
209
210                ron::de::from_str(&data).context("Failed to parse easy_sql.ron file")?
211            }
212        };
213
214        Ok(data)
215    }
216
217    #[cfg(feature = "build")]
218    pub fn save(&self) -> anyhow::Result<()> {
219        let data_path = Self::data_location()?;
220
221        let data =
222            ron::ser::to_string_pretty(self, ron::ser::PrettyConfig::new().struct_names(true))?;
223
224        std::fs::write(&data_path, &data).context("Failed to write easy_sql.ron file")?;
225
226        Ok(())
227    }
228
229    pub fn generate_unique_id(&self) -> String {
230        let mut generated = uuid::Uuid::new_v4().to_string();
231        let mut exists = true;
232
233        while exists {
234            exists = false;
235            for unique_id in self.tables.keys() {
236                if unique_id == &generated {
237                    exists = true;
238                    generated = uuid::Uuid::new_v4().to_string();
239                    break;
240                }
241            }
242        }
243        generated
244    }
245
246    pub fn is_duplicate_table_name(
247        &self,
248        current_unique_id: &str,
249        table_name: &str,
250    ) -> anyhow::Result<bool> {
251        if table_name == "easy_sql_tables" {
252            return Ok(true);
253        }
254        for (unique_id, table_data) in self.tables.iter() {
255            if unique_id == current_unique_id {
256                continue;
257            }
258            let latest_version_data =
259                match table_data.saved_versions.get(&table_data.latest_version) {
260                    Some(o) => o,
261                    None => anyhow::bail!(
262                        "Table data not found for latest version: {} | unique id: {:?}",
263                        table_data.latest_version,
264                        unique_id
265                    ),
266                };
267
268            if latest_version_data.table_name == table_name {
269                return Ok(true);
270            }
271        }
272
273        Ok(false)
274    }
275
276    #[cfg(feature = "migrations")]
277    pub fn generate_migrations(
278        &self,
279        current_unique_id: &str,
280        latest_version: &TableDataVersion,
281        latest_version_number: i64,
282        sql_crate: &TokenStream,
283        item_name: &TokenStream,
284    ) -> anyhow::Result<TokenStream> {
285        let macro_support = quote! { #sql_crate::macro_support };
286
287        let table_data = self
288            .tables
289            .get(current_unique_id)
290            .context("Table not found in Sql Compilation Data (easy_sql.ron)")?;
291
292        let mut result = TokensBuilder::default();
293
294        for (version_number, version_data) in table_data.saved_versions.iter() {
295            let mut changes_needed = Vec::new();
296            let mut rename_table = None;
297
298            if version_number == &latest_version_number {
299                continue;
300            }
301            //Primary Key Check (Must be equal)
302            if version_data.primary_keys != latest_version.primary_keys {
303                anyhow::bail!(
304                    "Primary key change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
305                    latest_version.primary_keys,
306                    version_number,
307                    version_data.primary_keys
308                );
309            }
310            //Foreign Key Check (Must be equal)
311            if version_data.foreign_keys != latest_version.foreign_keys {
312                anyhow::bail!(
313                    "Foreign key change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
314                    latest_version.foreign_keys,
315                    version_number,
316                    version_data.foreign_keys
317                );
318            }
319            //Auto increment check (Must be equal)
320            if version_data.auto_increment != latest_version.auto_increment {
321                anyhow::bail!(
322                    "Auto increment change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
323                    latest_version.auto_increment,
324                    version_number,
325                    version_data.auto_increment
326                );
327            }
328
329            // Table name change support
330            if version_data.table_name != latest_version.table_name {
331                let new_name = latest_version.table_name.as_str();
332
333                rename_table = Some(quote! {
334                    #sql_crate::driver::AlterTableSingle::RenameTable{
335                        new_table_name: #new_name,
336                    }
337                });
338            }
339            // Check for old column change
340            for (old_field, new_field) in
341                version_data.fields.iter().zip(latest_version.fields.iter())
342            {
343                //We can only rename old columns
344                if old_field.name != new_field.name {
345                    let old_name = old_field.name.as_str();
346                    let new_name = new_field.name.as_str();
347
348                    changes_needed.push(quote! {
349                        #sql_crate::driver::AlterTableSingle::RenameColumn{
350                            old_column_name: #old_name,
351                            new_column_name: #new_name,
352                        }
353                    });
354                }
355                //Everything else on old column is not supported
356                if old_field.field_type != new_field.field_type {
357                    anyhow::bail!(
358                        "Field type change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
359                        latest_version.fields,
360                        version_number,
361                        version_data.fields
362                    );
363                }
364                if old_field.is_unique != new_field.is_unique {
365                    anyhow::bail!(
366                        "Field unique change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
367                        latest_version.fields,
368                        version_number,
369                        version_data.fields
370                    );
371                }
372                if old_field.default != new_field.default {
373                    anyhow::bail!(
374                        "Field default value change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
375                        latest_version.fields,
376                        version_number,
377                        version_data.fields
378                    );
379                }
380            }
381
382            //New Columns Check
383            for new_field in latest_version.fields.iter().skip(version_data.fields.len()) {
384                //New columns need default value
385                if new_field.default.is_none() && !new_field.field_type.starts_with("Option<") {
386                    anyhow::bail!(
387                        "New (not null) column without default value is not supported -> Latest Version: {:?} ||| Version {}: {:?}",
388                        latest_version.fields,
389                        version_number,
390                        version_data.fields
391                    );
392                }
393
394                let field_name = new_field.name.as_str();
395                let field_ident = syn::Ident::new(field_name, Span::call_site());
396                let data_type: syn::Type = syn::parse_str(new_field.field_type.as_str())?;
397                let is_not_null = !new_field.field_type.starts_with("Option<");
398                let is_unique = new_field.is_unique;
399
400                let default_value = if let Some(default_value) = new_field.default.as_deref() {
401                    let default_expr: syn::Expr = syn::parse_str(default_value)?;
402
403                    //For compatibility sake
404                    let default_value = default_expr;
405
406                    quote! {
407                        {
408                            //Check if default value has valid type for the current column
409                            let _= ||{
410                                let mut table_instance = #macro_support::never_any::<#item_name>();
411                                table_instance.#field_ident = #default_value;
412                            };
413
414                            Some(#sql_crate::ToDefault::to_default(#default_value))
415                        }
416                    }
417                } else {
418                    quote! {
419                        None
420                    }
421                };
422
423                //Create new field
424                changes_needed.push(quote! {
425                    #sql_crate::driver::AlterTableSingle::AddColumn{
426                        column: #sql_crate::driver::TableField {
427                            name: #field_name,
428                            data_type: {
429                                #macro_support::TypeInfo::name(
430                                    &<#data_type as #macro_support::Type<#macro_support::InternalDriver<_EasySqlMigrationDriver>>>::type_info(),
431                                )
432                                .to_owned()
433                            },
434                            is_unique: #is_unique,
435                            is_not_null: #is_not_null,
436                            default: #default_value,
437                            is_auto_increment: false,
438                        }
439                    }
440                });
441            }
442
443            if let Some(rename_table) = rename_table {
444                changes_needed.push(rename_table);
445            }
446
447            //Generate Migration (if needed)
448            if !changes_needed.is_empty() {
449                let version_number = *version_number;
450                let table_name = version_data.table_name.as_str();
451
452                result.add(quote! {
453                    if current_version_number == #version_number{
454                        #sql_crate::EasyExecutor::query_setup(conn, #sql_crate::driver::AlterTable{
455                            table_name: #table_name,
456                            alters: vec![#(#changes_needed),*],
457                        }).await?;
458                        #sql_crate::EasySqlTables_update_version!(_EasySqlMigrationDriver, *conn, #current_unique_id, #latest_version_number);
459                        return Ok(());
460                    }
461                });
462            }
463        }
464
465        Ok(result.finalize())
466    }
467}