dsync/
parser.rs

1use heck::ToPascalCase;
2use syn::Ident;
3use syn::Item::Macro;
4
5use crate::{code, Error, GenerationConfig, Result};
6
7/// dsync file signature for generated & managed files
8pub const FILE_SIGNATURE: &str = "/* @generated and managed by dsync */";
9
10// TODO: handle postgres array types
11// TODO: handle postgres tuple/record types
12
13#[derive(Debug, Clone)]
14pub struct ParsedColumnMacro {
15    /// Rust type to use (like `String`)
16    pub ty: String,
17    /// Rust ident for the field name
18    pub name: Ident,
19    /// Actual column name, as parsed from the attributes, or the same as "name"
20    pub column_name: String,
21    pub is_nullable: bool,
22    pub is_unsigned: bool,
23    pub is_array: bool,
24}
25
26/// Struct to hold all information needed from a parsed `diesel::table!` macro
27#[derive(Debug, Clone)]
28pub struct ParsedTableMacro {
29    /// Name of the table
30    pub name: Ident,
31    /// Rust struct name to use
32    pub struct_name: String,
33    /// All parsed columns
34    pub columns: Vec<ParsedColumnMacro>,
35    /// All Primary key column idents
36    pub primary_key_columns: Vec<Ident>,
37    /// All foreign key relations (foreign_table_name, local_join_column)
38    pub foreign_keys: Vec<(
39        ForeignTableName,
40        JoinColumn, /* this is the column from this table which maps to the foreign table's primary key*/
41    )>,
42    /// Final generated code
43    pub generated_code: String,
44}
45
46impl ParsedTableMacro {
47    pub fn primary_key_column_names(&self) -> Vec<String> {
48        self.primary_key_columns
49            .iter()
50            .map(|i| i.to_string())
51            .collect()
52    }
53}
54
55type ForeignTableName = Ident;
56type JoinColumn = String;
57
58/// Parsed representation of the `diesel::joinable!` macro
59#[derive(Debug, Clone)]
60pub struct ParsedJoinMacro {
61    /// Table ident with the foreign key
62    pub table1: Ident,
63    /// Table ident for the foreign key
64    pub table2: Ident,
65    /// Column ident of for the foreign key
66    pub table1_columns: String,
67}
68
69/// Try to parse a diesel schema file contents
70pub fn parse_and_generate_code(
71    schema_file_contents: &str,
72    config: &GenerationConfig,
73) -> Result<Vec<ParsedTableMacro>> {
74    let schema_file = syn::parse_file(schema_file_contents).unwrap();
75
76    let mut tables: Vec<ParsedTableMacro> = vec![];
77
78    for item in schema_file.items {
79        if let Macro(macro_item) = item {
80            let macro_identifier = macro_item
81                .mac
82                .path
83                .segments
84                .last()
85                .ok_or(Error::other("could not read identifier for macro"))?
86                .ident
87                .to_string();
88
89            match macro_identifier.as_str() {
90                "table" => {
91                    let parsed_table = handle_table_macro(macro_item, config)?;
92
93                    // make sure the table isn't ignored
94                    let table_options = config.table(parsed_table.name.to_string().as_str());
95                    if !table_options.get_ignore() {
96                        tables.push(parsed_table);
97                    }
98                }
99                "joinable" => {
100                    let parsed_join = handle_joinable_macro(macro_item)?;
101
102                    for table in tables.iter_mut() {
103                        if parsed_join
104                            .table1
105                            .to_string()
106                            .eq(table.name.to_string().as_str())
107                        {
108                            table.foreign_keys.push((
109                                parsed_join.table2.clone(),
110                                parsed_join.table1_columns.clone(),
111                            ));
112                            break;
113                        }
114                    }
115                }
116                _ => {}
117            };
118        }
119    }
120
121    for table in tables.iter_mut() {
122        table.generated_code = code::generate_for_table(table, config);
123    }
124
125    Ok(tables)
126}
127
128fn handle_joinable_macro(macro_item: syn::ItemMacro) -> Result<ParsedJoinMacro> {
129    // println!("joinable! macro: {:#?}", macro_item);
130
131    let mut table1_name: Option<Ident> = None;
132    let mut table2_name: Option<Ident> = None;
133    let mut table2_join_column: Option<String> = None;
134
135    for item in macro_item.mac.tokens.into_iter() {
136        match item {
137            proc_macro2::TokenTree::Ident(ident) => {
138                if table1_name.is_none() {
139                    table1_name = Some(ident);
140                } else if table2_name.is_none() {
141                    table2_name = Some(ident);
142                }
143            }
144            proc_macro2::TokenTree::Group(group) => {
145                if table1_name.is_none() || table2_name.is_none() {
146                    return Err(Error::unsupported_schema_format(
147                        "encountered join column group too early",
148                    ));
149                } else {
150                    table2_join_column = Some(group.stream().to_string());
151                }
152            }
153            _ => {}
154        }
155    }
156
157    Ok(ParsedJoinMacro {
158        table1: table1_name.ok_or(Error::unsupported_schema_format(
159            "could not determine first join table name",
160        ))?,
161        table2: table2_name.ok_or(Error::unsupported_schema_format(
162            "could not determine second join table name",
163        ))?,
164        table1_columns: table2_join_column.ok_or(Error::unsupported_schema_format(
165            "could not determine join column name",
166        ))?,
167    })
168}
169
170/// Try to parse a `diesel::table!` macro
171fn handle_table_macro(
172    macro_item: syn::ItemMacro,
173    config: &GenerationConfig,
174) -> Result<ParsedTableMacro> {
175    let mut table_name_ident: Option<Ident> = None;
176    let mut table_primary_key_idents: Vec<Ident> = vec![];
177    let mut table_columns: Vec<ParsedColumnMacro> = vec![];
178
179    let mut skip_until_semicolon = false;
180    let mut skip_square_brackets = false;
181
182    for item in macro_item.mac.tokens.into_iter() {
183        if skip_until_semicolon {
184            if let proc_macro2::TokenTree::Punct(punct) = item {
185                if punct.as_char() == ';' {
186                    skip_until_semicolon = false;
187                }
188            }
189            continue;
190        }
191
192        match item {
193            proc_macro2::TokenTree::Punct(punct) => {
194                // skip any "#[]"
195                if punct.to_string().as_str() == "#" {
196                    skip_square_brackets = true;
197                    continue;
198                }
199            }
200            proc_macro2::TokenTree::Ident(ident) => {
201                // skip any "use" statements
202                if ident.to_string().eq("use") {
203                    skip_until_semicolon = true;
204                    continue;
205                }
206
207                table_name_ident = Some(ident);
208            }
209            proc_macro2::TokenTree::Group(group) => {
210                if skip_square_brackets {
211                    if group.delimiter() == proc_macro2::Delimiter::Bracket {
212                        skip_square_brackets = false;
213                    }
214                    continue;
215                }
216
217                if group.delimiter() == proc_macro2::Delimiter::Parenthesis {
218                    // primary keys group
219                    // println!("GROUP-keys {:#?}", group);
220                    for key_token in group.stream().into_iter() {
221                        if let proc_macro2::TokenTree::Ident(ident) = key_token {
222                            table_primary_key_idents.push(ident)
223                        }
224                    }
225                } else if group.delimiter() == proc_macro2::Delimiter::Brace {
226                    // columns group
227                    // println!("GROUP-cols {:#?}", group);
228
229                    // rust name parsed from the macro (the "HERE" in "HERE -> TYPE")
230                    let mut rust_column_name: Option<Ident> = None;
231                    // actual column name, parsed from the attribute value, if any ("#[sql_name = "test"]")
232                    let mut actual_column_name: Option<String> = None;
233                    let mut column_type: Option<Ident> = None;
234                    let mut column_nullable: bool = false;
235                    let mut column_unsigned: bool = false;
236                    let mut column_array: bool = false;
237                    // track if the last loop was a "#" (start of a attribute)
238                    let mut had_hashtag = false;
239
240                    for column_tokens in group.stream().into_iter() {
241                        // reset "had_hastag" but still make it available for checking
242                        let had_hashtag_last = had_hashtag;
243                        had_hashtag = false;
244                        match column_tokens {
245                            proc_macro2::TokenTree::Group(group) => {
246                                if had_hashtag_last {
247                                    // parse some extra information from the bracket group
248                                    // like the actual column name
249                                    if let Some((name, value)) = parse_diesel_attr_group(&group) {
250                                        if name == "sql_name" {
251                                            actual_column_name = Some(value);
252                                        }
253                                    }
254                                }
255
256                                continue;
257                            }
258                            proc_macro2::TokenTree::Ident(ident) => {
259                                if rust_column_name.is_none() {
260                                    rust_column_name = Some(ident);
261                                } else if ident.to_string().eq_ignore_ascii_case("Nullable") {
262                                    if column_array {
263                                        /*
264                                           If we've already identified this column as an array,
265                                           then we know for sure that the field inside is marked as nullable
266                                           (but this isn't the same as `NOT NULL`, rather, it is an implementation detail of postgres arrays).
267
268                                           Therefore, we can safely ignore this case of "Nullable".
269
270                                           For example:
271
272                                           ```rs
273                                           #[sql_name = "phone_numbers"]
274                                           phone_numbers -> Array<Nullable<Text>>,
275                                           ```
276
277                                           becomes:
278
279                                           ```rs
280                                           phone_numbers: Vec<Option<String>>,
281                                           ```
282
283                                           instead of the incorrect (which would be generated if we didn't have this column_array check):
284
285                                           ```rs
286                                           phone_numbers: Option<Vec<Option<String>>>,
287                                           ```
288                                        */
289                                    } else {
290                                        column_nullable = true;
291                                    }
292                                } else if ident.to_string().eq_ignore_ascii_case("Unsigned") {
293                                    column_unsigned = true;
294                                } else if ident.to_string().eq_ignore_ascii_case("Array") {
295                                    column_array = true;
296                                } else {
297                                    column_type = Some(ident);
298                                }
299                            }
300                            proc_macro2::TokenTree::Punct(punct) => {
301                                let char = punct.as_char();
302
303                                if char == '#' {
304                                    had_hashtag = true;
305                                    continue;
306                                } else if char == '-' || char == '>' {
307                                    // nothing for arrow or any additional #[]
308                                    continue;
309                                } else if char == ','
310                                    && rust_column_name.is_some()
311                                    && column_type.is_some()
312                                {
313                                    // end of column def!
314
315                                    let rust_column_name_checked = rust_column_name.ok_or(
316                                        Error::unsupported_schema_format(
317                                            "Invalid column name syntax",
318                                        ),
319                                    )?;
320                                    let column_name = actual_column_name
321                                        .unwrap_or(rust_column_name_checked.to_string());
322
323                                    // add the column
324                                    table_columns.push(ParsedColumnMacro {
325                                        name: rust_column_name_checked,
326                                        ty: schema_type_to_rust_type(
327                                            column_type
328                                                .ok_or(Error::unsupported_schema_format(
329                                                    "Invalid column type syntax",
330                                                ))?
331                                                .to_string(),
332                                            config,
333                                        )?,
334                                        is_nullable: column_nullable,
335                                        is_unsigned: column_unsigned,
336                                        is_array: column_array,
337                                        column_name,
338                                    });
339
340                                    // reset the properties
341                                    rust_column_name = None;
342                                    actual_column_name = None;
343                                    column_type = None;
344                                    column_unsigned = false;
345                                    column_nullable = false;
346                                    column_array = false;
347                                }
348                            }
349                            _ => {
350                                return Err(Error::unsupported_schema_format(
351                                    "Invalid column definition token in diesel table macro",
352                                ))
353                            }
354                        }
355                    }
356
357                    if rust_column_name.is_some()
358                        || column_type.is_some()
359                        || column_nullable
360                        || column_unsigned
361                    {
362                        // looks like a column was in the middle of being parsed, let's panic!
363                        return Err(Error::unsupported_schema_format(
364                            "It seems a column was partially defined",
365                        ));
366                    }
367                } else {
368                    return Err(Error::unsupported_schema_format(
369                        "Invalid delimiter in diesel table macro group",
370                    ));
371                }
372            }
373            _ => {
374                return Err(Error::unsupported_schema_format(
375                    "Invalid token tree item in diesel table macro",
376                ))
377            }
378        }
379    }
380
381    Ok(ParsedTableMacro {
382        name: table_name_ident
383            .clone()
384            .ok_or(Error::unsupported_schema_format(
385                "Could not extract table name from schema file",
386            ))?,
387        struct_name: table_name_ident.unwrap().to_string().to_pascal_case(),
388        columns: table_columns,
389        primary_key_columns: table_primary_key_idents,
390        foreign_keys: vec![],
391        generated_code: format!(
392            "{FILE_SIGNATURE}\n\nFATAL ERROR: nothing was generated; this shouldn't be possible."
393        ),
394    })
395}
396
397/// Parse a diesel schema attribute group
398/// ```rs
399/// #[attr = value]
400/// ```
401/// into (attr, value)
402fn parse_diesel_attr_group(group: &proc_macro2::Group) -> Option<(Ident, String)> {
403    // diesel only uses square brackets, so ignore other types
404    if group.delimiter() != proc_macro2::Delimiter::Bracket {
405        return None;
406    }
407
408    let mut token_stream = group.stream().into_iter();
409    // parse the attribute name, if it is anything else, return None
410    let attr_name = match token_stream.next()? {
411        proc_macro2::TokenTree::Ident(ident) => ident,
412        _ => return None,
413    };
414
415    // diesel always uses "=" for assignments
416    let punct = match token_stream.next()? {
417        proc_macro2::TokenTree::Punct(punct) => punct,
418        _ => return None,
419    };
420
421    if punct.as_char() != '=' {
422        return None;
423    }
424
425    // diesel print-schema only uses literals currently, if anything else is used, it should be added here
426    let value = match token_stream.next()? {
427        proc_macro2::TokenTree::Literal(literal) => literal,
428        _ => return None,
429    };
430
431    let mut value = value.to_string();
432
433    // remove the starting and ending quotes
434    if value.starts_with('"') && value.ends_with('"') {
435        value = String::from(&value[1..value.len() - 1]); // safe char boundaries because '"' is only one byte long
436    }
437
438    Some((attr_name, value))
439}
440
441// A function to translate diesel schema types into rust types
442//
443// reference: https://github.com/diesel-rs/diesel/blob/master/diesel/src/sql_types/mod.rs
444// exact reference; https://github.com/diesel-rs/diesel/blob/292ac5c0ed6474f96734ba2e99b95b442064f69c/diesel/src/mysql/types/mod.rs
445//
446// The docs page for sql_types is comprehensive but it hides some alias types like Int4, Float8, etc.:
447// https://docs.rs/diesel/latest/diesel/sql_types/index.html
448fn schema_type_to_rust_type(schema_type: String, config: &GenerationConfig) -> Result<String> {
449    Ok(match schema_type.to_lowercase().as_str() {
450        "unsigned" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), // TODO: deal with this later
451        "inet" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), // TODO: deal with this later
452        "cidr" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), // TODO: deal with this later
453
454        // boolean
455        "bool" => "bool",
456
457        // numbers
458        "tinyint" => "i8",
459        "smallint" => "i16",
460        "smallserial" => "i16",
461        "int2" => "i16",
462        "int4" => "i32",
463        "int4range" => "(std::collections::Bound<i32>, std::collections::Bound<i32>)",
464        "integer" => "i32",
465        "serial" => "i32",
466        "bigint" => "i64",
467        "bigserial" => "i64",
468        "int8" => "i64",
469        "int8range" => "(std::collections::Bound<i64>, std::collections::Bound<i64>)",
470        "float" => "f32",
471        "float4" => "f32",
472        "double" => "f64",
473        "float8" => "f64",
474        "numeric" => "bigdecimal::BigDecimal",
475        "numrange" => "(std::collections::Bound<bigdecimal::BigDecimal>, std::collections::Bound<bigdecimal::BigDecimal>)",
476        "decimal" => "bigdecimal::BigDecimal",
477
478        // string
479        "text" => "String",
480        "varchar" => "String",
481        "bpchar" => "String",
482        "char" => "String",
483        "tinytext" => "String",
484        "mediumtext" => "String",
485        "longtext" => "String",
486
487        // bytes
488        "binary" => "Vec<u8>",
489        "bytea" => "Vec<u8>",
490        "tinyblob" => "Vec<u8>",
491        "blob" => "Vec<u8>",
492        "mediumblob" => "Vec<u8>",
493        "longblob" => "Vec<u8>",
494        "varbinary" => "Vec<u8>",
495        "bit" => "Vec<u8>",
496
497        // date & time
498        "date" => "chrono::NaiveDate",
499        "daterange" => "(std::collections::Bound<chrono::NaiveDate>, std::collections::Bound<chrono::NaiveDate>)",
500        "datetime" => "chrono::NaiveDateTime",
501        "time" => "chrono::NaiveTime",
502        "timestamp" => "chrono::NaiveDateTime",
503        "tsrange" => "(std::collections::Bound<chrono::NaiveDateTime>, std::collections::Bound<chrono::NaiveDateTime>)",
504        "timestamptz" => "chrono::DateTime<chrono::Utc>",
505        "timestamptzsqlite" => "chrono::DateTime<chrono::Utc>",
506        "tstzrange" => "(std::collections::Bound<chrono::DateTime<chrono::Utc>>, std::collections::Bound<chrono::DateTime<chrono::Utc>>)",
507
508        // json
509        "json" => "serde_json::Value",
510        "jsonb" => "serde_json::Value",
511
512        // misc
513        "uuid" => "uuid::Uuid",
514        "interval" => "PgInterval",
515        "oid" => "u32",
516        "money" => "PgMoney",
517        "macaddr" => "[u8; 6]",
518        // "inet" => "either ipnetwork::IpNetwork or ipnet::IpNet (TODO)",
519        // "cidr" => "either ipnetwork::IpNetwork or ipnet::IpNet (TODO)",
520
521        /*
522            // panic if no type is found (this means generation is broken for this particular schema)
523            _ => panic!("Unknown type found '{schema_type}', please report this!")
524         */
525        _ => {
526            let schema_path = config.get_schema_path();
527            // return the schema type if no type is found (this means generation is broken for this particular schema)
528            let _type = format!("{schema_path}sql_types::{schema_type}");
529            return Ok(_type);
530        }
531    }.to_string())
532}