Skip to main content

easy_sql_compilation_data/
lib.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    path::PathBuf,
4    str::FromStr,
5};
6
7use anyhow::{self, Context};
8use quote::ToTokens;
9#[cfg(feature = "migrations")]
10use {
11    easy_macros::TokensBuilder,
12    proc_macro2::{Span, TokenStream},
13    quote::quote,
14};
15
16use easy_macros::{
17    always_context, get_attributes, has_attributes, token_stream_to_consistent_string,
18};
19use serde::{Deserialize, Serialize, Serializer};
20
21#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
22pub struct TableField {
23    pub name: String,
24    #[serde(default)]
25    pub ty_to_bytes: bool,
26    pub field_type: String,
27    ///Tokens converted to_string()
28    pub default: Option<String>,
29    pub is_unique: bool,
30}
31#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
32pub struct TableDataVersion {
33    pub table_name: String,
34    pub fields: Vec<TableField>,
35    pub primary_keys: Vec<String>,
36    pub auto_increment: bool,
37    ///key - table (struct) name
38    ///value - current field name
39    #[serde(serialize_with = "ordered_map")]
40    pub foreign_keys: HashMap<String, Vec<String>>,
41}
42
43fn ordered_map<S, K: Ord + Serialize, V: Serialize>(
44    value: &HashMap<K, V>,
45    serializer: S,
46) -> Result<S::Ok, S::Error>
47where
48    S: Serializer,
49{
50    let ordered: BTreeMap<_, _> = value.iter().collect();
51    ordered.serialize(serializer)
52}
53
54impl TableDataVersion {
55    pub fn from_struct(item: &syn::ItemStruct, table_name: String) -> anyhow::Result<Self> {
56        let fields = match &item.fields {
57            syn::Fields::Named(fields_named) => &fields_named.named,
58            _ => {
59                anyhow::bail!(
60                    "non named field type should be handled before `generate_table_data_from_struct` is called"
61                )
62            }
63        };
64
65        let mut fields_converted = Vec::new();
66        let mut primary_keys = Vec::new();
67        let mut foreign_keys = HashMap::new();
68
69        let mut auto_increment = false;
70
71        for field in fields.iter() {
72            let name = field.ident.as_ref().unwrap().to_string();
73
74            //Auto Increment Check
75            if has_attributes!(field, #[sql(auto_increment)]) {
76                if auto_increment {
77                    anyhow::bail!("Auto increment is only supported for single primary key");
78                }
79                auto_increment = true;
80            }
81
82            let ty_to_bytes = has_attributes!(field, #[sql(bytes)]);
83
84            let default = get_attributes!(field, #[sql(default = __unknown__)])
85                .into_iter()
86                .next()
87                .map(token_stream_to_consistent_string);
88
89            for foreign_key in get_attributes!(field, #[sql(foreign_key = __unknown__)])
90                .into_iter()
91                .map(token_stream_to_consistent_string)
92            {
93                let fields: &mut Vec<String> = foreign_keys
94                    .entry(foreign_key)
95                    .or_insert(Default::default());
96                fields.push(name.clone());
97            }
98
99            if has_attributes!(field, #[sql(primary_key)]) {
100                primary_keys.push(name.clone());
101            }
102
103            let is_unique = has_attributes!(field, #[sql(unique)]);
104
105            fields_converted.push(TableField {
106                name,
107                field_type: token_stream_to_consistent_string(field.ty.to_token_stream()),
108                default,
109                is_unique,
110                ty_to_bytes,
111            });
112        }
113
114        Ok(TableDataVersion {
115            table_name,
116            fields: fields_converted,
117            foreign_keys,
118            primary_keys,
119            auto_increment,
120        })
121    }
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125pub struct TableData {
126    #[serde(serialize_with = "ordered_map")]
127    pub saved_versions: HashMap<i64, TableDataVersion>,
128    pub latest_version: i64,
129}
130#[cfg(feature = "check_duplicate_table_names")]
131#[derive(Debug, Serialize, Deserialize)]
132pub struct TableNameData {
133    pub filename: String,
134    pub struct_name: String,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138pub struct CompilationData {
139    ///Key - table id, generated by build macro, put in #[sql(unique_id = "...")] attribute on struct
140    #[serde(serialize_with = "ordered_map")]
141    pub tables: HashMap<String, TableData>,
142    #[cfg(feature = "check_duplicate_table_names")]
143    #[serde(serialize_with = "ordered_map")]
144    #[serde(default)]
145    ///Key - table name
146    pub used_table_names: HashMap<String, Vec<TableNameData>>,
147    #[serde(default)]
148    pub default_drivers: Vec<String>,
149}
150#[always_context]
151impl CompilationData {
152    pub fn data_location() -> anyhow::Result<PathBuf> {
153        let manifest_dir_str = std::env::var("CARGO_MANIFEST_DIR")?;
154        let current_dir = PathBuf::from_str(&manifest_dir_str)?;
155
156        Ok(current_dir.join("easy_sql.ron"))
157    }
158
159    #[cfg(feature = "build")]
160    pub fn load(
161        default_drivers: Vec<String>,
162        default_drivers_update: bool,
163    ) -> anyhow::Result<CompilationData> {
164        let data_path = Self::data_location()?;
165
166        let data: CompilationData = {
167            if !data_path.exists() {
168                CompilationData {
169                    tables: HashMap::new(),
170                    #[cfg(feature = "check_duplicate_table_names")]
171                    used_table_names: HashMap::new(),
172                    default_drivers,
173                }
174            } else {
175                let data = std::fs::read_to_string(&data_path)
176                    .context("Failed to read easy_sql.ron file")?;
177
178                let mut data: CompilationData =
179                    ron::de::from_str(&data).context("Failed to parse easy_sql.ron file")?;
180
181                if default_drivers_update && data.default_drivers != default_drivers {
182                    data.default_drivers = default_drivers;
183                    data.save()?;
184                }
185
186                data
187            }
188        };
189
190        Ok(data)
191    }
192
193    pub fn load_in_macro() -> anyhow::Result<CompilationData> {
194        let data_path = Self::data_location()?;
195
196        let data: CompilationData = {
197            {
198                if !data_path.exists() {
199                    return Ok(CompilationData {
200                        tables: HashMap::new(),
201                        #[cfg(feature = "check_duplicate_table_names")]
202                        used_table_names: HashMap::new(),
203                        default_drivers: Vec::new(),
204                    });
205                }
206
207                let data = std::fs::read_to_string(&data_path)
208                    .context("Failed to read easy_sql.ron file")?;
209
210                ron::de::from_str(&data).context("Failed to parse easy_sql.ron file")?
211            }
212        };
213
214        Ok(data)
215    }
216
217    #[cfg(feature = "build")]
218    pub fn save(&self) -> anyhow::Result<()> {
219        let data_path = Self::data_location()?;
220
221        let data =
222            ron::ser::to_string_pretty(self, ron::ser::PrettyConfig::new().struct_names(true))?;
223
224        let result = std::fs::write(&data_path, &data);
225
226        if let Err(e) = &result
227            && let std::io::ErrorKind::ReadOnlyFilesystem = e.kind()
228        {
229            return Ok(());
230        }
231
232        result.context("Failed to write easy_sql.ron file")?;
233
234        Ok(())
235    }
236
237    pub fn generate_unique_id(&self) -> String {
238        let mut generated = uuid::Uuid::new_v4().to_string();
239        let mut exists = true;
240
241        while exists {
242            exists = false;
243            for unique_id in self.tables.keys() {
244                if unique_id == &generated {
245                    exists = true;
246                    generated = uuid::Uuid::new_v4().to_string();
247                    break;
248                }
249            }
250        }
251        generated
252    }
253
254    pub fn is_duplicate_table_name(
255        &self,
256        current_unique_id: &str,
257        table_name: &str,
258    ) -> anyhow::Result<bool> {
259        if table_name == "easy_sql_tables" {
260            return Ok(true);
261        }
262        for (unique_id, table_data) in self.tables.iter() {
263            if unique_id == current_unique_id {
264                continue;
265            }
266            let latest_version_data =
267                match table_data.saved_versions.get(&table_data.latest_version) {
268                    Some(o) => o,
269                    None => anyhow::bail!(
270                        "Table data not found for latest version: {} | unique id: {:?}",
271                        table_data.latest_version,
272                        unique_id
273                    ),
274                };
275
276            if latest_version_data.table_name == table_name {
277                return Ok(true);
278            }
279        }
280
281        Ok(false)
282    }
283
284    #[cfg(feature = "migrations")]
285    pub fn generate_migrations(
286        &self,
287        current_unique_id: &str,
288        latest_version: &TableDataVersion,
289        latest_version_number: i64,
290        sql_crate: &TokenStream,
291        item_name: &TokenStream,
292    ) -> anyhow::Result<TokenStream> {
293        let macro_support = quote! { #sql_crate::macro_support };
294
295        let table_data = self
296            .tables
297            .get(current_unique_id)
298            .context("Table not found in Sql Compilation Data (easy_sql.ron)")?;
299
300        let mut result = TokensBuilder::default();
301
302        for (version_number, version_data) in table_data.saved_versions.iter() {
303            let mut changes_needed = Vec::new();
304            let mut rename_table = None;
305
306            if version_number == &latest_version_number {
307                continue;
308            }
309            //Primary Key Check (Must be equal)
310            if version_data.primary_keys != latest_version.primary_keys {
311                anyhow::bail!(
312                    "Primary key change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
313                    latest_version.primary_keys,
314                    version_number,
315                    version_data.primary_keys
316                );
317            }
318            //Foreign Key Check (Must be equal)
319            if version_data.foreign_keys != latest_version.foreign_keys {
320                anyhow::bail!(
321                    "Foreign key change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
322                    latest_version.foreign_keys,
323                    version_number,
324                    version_data.foreign_keys
325                );
326            }
327            //Auto increment check (Must be equal)
328            if version_data.auto_increment != latest_version.auto_increment {
329                anyhow::bail!(
330                    "Auto increment change is not supported (yet) -> Latest Version: {:?} ||| Version {}: {:?}",
331                    latest_version.auto_increment,
332                    version_number,
333                    version_data.auto_increment
334                );
335            }
336
337            // Table name change support
338            if version_data.table_name != latest_version.table_name {
339                let new_name = latest_version.table_name.as_str();
340
341                rename_table = Some(quote! {
342                    #sql_crate::driver::AlterTableSingle::RenameTable{
343                        new_table_name: #new_name,
344                    }
345                });
346            }
347            // Check for old column change
348            for (old_field, new_field) in
349                version_data.fields.iter().zip(latest_version.fields.iter())
350            {
351                //We can only rename old columns
352                if old_field.name != new_field.name {
353                    let old_name = old_field.name.as_str();
354                    let new_name = new_field.name.as_str();
355
356                    changes_needed.push(quote! {
357                        #sql_crate::driver::AlterTableSingle::RenameColumn{
358                            old_column_name: #old_name,
359                            new_column_name: #new_name,
360                        }
361                    });
362                }
363                //Everything else on old column is not supported
364                if old_field.field_type != new_field.field_type {
365                    anyhow::bail!(
366                        "Field type change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
367                        latest_version.fields,
368                        version_number,
369                        version_data.fields
370                    );
371                }
372                if old_field.is_unique != new_field.is_unique {
373                    anyhow::bail!(
374                        "Field unique change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
375                        latest_version.fields,
376                        version_number,
377                        version_data.fields
378                    );
379                }
380                if old_field.default != new_field.default {
381                    anyhow::bail!(
382                        "Field default value change is not supported (yet) (only rename) -> Latest Version: {:?} ||| Version {}: {:?}",
383                        latest_version.fields,
384                        version_number,
385                        version_data.fields
386                    );
387                }
388            }
389
390            //New Columns Check
391            for new_field in latest_version.fields.iter().skip(version_data.fields.len()) {
392                //New columns need default value
393                if new_field.default.is_none() && !new_field.field_type.starts_with("Option<") {
394                    anyhow::bail!(
395                        "New (not null) column without default value is not supported -> Latest Version: {:?} ||| Version {}: {:?}",
396                        latest_version.fields,
397                        version_number,
398                        version_data.fields
399                    );
400                }
401
402                let field_name = new_field.name.as_str();
403                let field_ident = syn::Ident::new(field_name, Span::call_site());
404                let data_type: syn::Type = syn::parse_str(new_field.field_type.as_str())?;
405                let is_not_null = !new_field.field_type.starts_with("Option<");
406                let is_unique = new_field.is_unique;
407
408                let default_value = if let Some(default_value) = new_field.default.as_deref() {
409                    let default_expr: syn::Expr = syn::parse_str(default_value)?;
410
411                    //For compatibility sake
412                    let default_value = default_expr;
413
414                    quote! {
415                        {
416                            //Check if default value has valid type for the current column
417                            let _= ||{
418                                let mut table_instance = #macro_support::never_any::<#item_name>();
419                                table_instance.#field_ident = #default_value;
420                            };
421
422                            Some(#sql_crate::ToDefault::to_default(#default_value))
423                        }
424                    }
425                } else {
426                    quote! {
427                        None
428                    }
429                };
430
431                //Create new field
432                changes_needed.push(quote! {
433                    #sql_crate::driver::AlterTableSingle::AddColumn{
434                        column: #sql_crate::driver::TableField {
435                            name: #field_name,
436                            data_type: {
437                                #macro_support::TypeInfo::name(
438                                    &<#data_type as #macro_support::Type<#macro_support::InternalDriver<_EasySqlMigrationDriver>>>::type_info(),
439                                )
440                                .to_owned()
441                            },
442                            is_unique: #is_unique,
443                            is_not_null: #is_not_null,
444                            default: #default_value,
445                            is_auto_increment: false,
446                        }
447                    }
448                });
449            }
450
451            if let Some(rename_table) = rename_table {
452                changes_needed.push(rename_table);
453            }
454
455            //Generate Migration (if needed)
456            if !changes_needed.is_empty() {
457                let version_number = *version_number;
458                let table_name = version_data.table_name.as_str();
459
460                result.add(quote! {
461                    if current_version_number == #version_number{
462                        #sql_crate::EasyExecutor::query_setup(conn, #sql_crate::driver::AlterTable{
463                            table_name: #table_name,
464                            alters: vec![#(#changes_needed),*],
465                        }).await?;
466                        #sql_crate::EasySqlTables_update_version!(_EasySqlMigrationDriver, *conn, #current_unique_id, #latest_version_number);
467                        return Ok(());
468                    }
469                });
470            }
471        }
472
473        Ok(result.finalize())
474    }
475}