1use heck::ToPascalCase;
2use syn::Ident;
3use syn::Item::Macro;
4
5use crate::{code, Error, GenerationConfig, Result};
6
7pub const FILE_SIGNATURE: &str = "/* @generated and managed by dsync */";
9
10#[derive(Debug, Clone)]
14pub struct ParsedColumnMacro {
15 pub ty: String,
17 pub name: Ident,
19 pub column_name: String,
21 pub is_nullable: bool,
22 pub is_unsigned: bool,
23 pub is_array: bool,
24}
25
26#[derive(Debug, Clone)]
28pub struct ParsedTableMacro {
29 pub name: Ident,
31 pub struct_name: String,
33 pub columns: Vec<ParsedColumnMacro>,
35 pub primary_key_columns: Vec<Ident>,
37 pub foreign_keys: Vec<(
39 ForeignTableName,
40 JoinColumn, )>,
42 pub generated_code: String,
44}
45
46impl ParsedTableMacro {
47 pub fn primary_key_column_names(&self) -> Vec<String> {
48 self.primary_key_columns
49 .iter()
50 .map(|i| i.to_string())
51 .collect()
52 }
53}
54
55type ForeignTableName = Ident;
56type JoinColumn = String;
57
58#[derive(Debug, Clone)]
60pub struct ParsedJoinMacro {
61 pub table1: Ident,
63 pub table2: Ident,
65 pub table1_columns: String,
67}
68
69pub fn parse_and_generate_code(
71 schema_file_contents: &str,
72 config: &GenerationConfig,
73) -> Result<Vec<ParsedTableMacro>> {
74 let schema_file = syn::parse_file(schema_file_contents).unwrap();
75
76 let mut tables: Vec<ParsedTableMacro> = vec![];
77
78 for item in schema_file.items {
79 if let Macro(macro_item) = item {
80 let macro_identifier = macro_item
81 .mac
82 .path
83 .segments
84 .last()
85 .ok_or(Error::other("could not read identifier for macro"))?
86 .ident
87 .to_string();
88
89 match macro_identifier.as_str() {
90 "table" => {
91 let parsed_table = handle_table_macro(macro_item, config)?;
92
93 let table_options = config.table(parsed_table.name.to_string().as_str());
95 if !table_options.get_ignore() {
96 tables.push(parsed_table);
97 }
98 }
99 "joinable" => {
100 let parsed_join = handle_joinable_macro(macro_item)?;
101
102 for table in tables.iter_mut() {
103 if parsed_join
104 .table1
105 .to_string()
106 .eq(table.name.to_string().as_str())
107 {
108 table.foreign_keys.push((
109 parsed_join.table2.clone(),
110 parsed_join.table1_columns.clone(),
111 ));
112 break;
113 }
114 }
115 }
116 _ => {}
117 };
118 }
119 }
120
121 for table in tables.iter_mut() {
122 table.generated_code = code::generate_for_table(table, config);
123 }
124
125 Ok(tables)
126}
127
128fn handle_joinable_macro(macro_item: syn::ItemMacro) -> Result<ParsedJoinMacro> {
129 let mut table1_name: Option<Ident> = None;
132 let mut table2_name: Option<Ident> = None;
133 let mut table2_join_column: Option<String> = None;
134
135 for item in macro_item.mac.tokens.into_iter() {
136 match item {
137 proc_macro2::TokenTree::Ident(ident) => {
138 if table1_name.is_none() {
139 table1_name = Some(ident);
140 } else if table2_name.is_none() {
141 table2_name = Some(ident);
142 }
143 }
144 proc_macro2::TokenTree::Group(group) => {
145 if table1_name.is_none() || table2_name.is_none() {
146 return Err(Error::unsupported_schema_format(
147 "encountered join column group too early",
148 ));
149 } else {
150 table2_join_column = Some(group.stream().to_string());
151 }
152 }
153 _ => {}
154 }
155 }
156
157 Ok(ParsedJoinMacro {
158 table1: table1_name.ok_or(Error::unsupported_schema_format(
159 "could not determine first join table name",
160 ))?,
161 table2: table2_name.ok_or(Error::unsupported_schema_format(
162 "could not determine second join table name",
163 ))?,
164 table1_columns: table2_join_column.ok_or(Error::unsupported_schema_format(
165 "could not determine join column name",
166 ))?,
167 })
168}
169
170fn handle_table_macro(
172 macro_item: syn::ItemMacro,
173 config: &GenerationConfig,
174) -> Result<ParsedTableMacro> {
175 let mut table_name_ident: Option<Ident> = None;
176 let mut table_primary_key_idents: Vec<Ident> = vec![];
177 let mut table_columns: Vec<ParsedColumnMacro> = vec![];
178
179 let mut skip_until_semicolon = false;
180 let mut skip_square_brackets = false;
181
182 for item in macro_item.mac.tokens.into_iter() {
183 if skip_until_semicolon {
184 if let proc_macro2::TokenTree::Punct(punct) = item {
185 if punct.as_char() == ';' {
186 skip_until_semicolon = false;
187 }
188 }
189 continue;
190 }
191
192 match item {
193 proc_macro2::TokenTree::Punct(punct) => {
194 if punct.to_string().as_str() == "#" {
196 skip_square_brackets = true;
197 continue;
198 }
199 }
200 proc_macro2::TokenTree::Ident(ident) => {
201 if ident.to_string().eq("use") {
203 skip_until_semicolon = true;
204 continue;
205 }
206
207 table_name_ident = Some(ident);
208 }
209 proc_macro2::TokenTree::Group(group) => {
210 if skip_square_brackets {
211 if group.delimiter() == proc_macro2::Delimiter::Bracket {
212 skip_square_brackets = false;
213 }
214 continue;
215 }
216
217 if group.delimiter() == proc_macro2::Delimiter::Parenthesis {
218 for key_token in group.stream().into_iter() {
221 if let proc_macro2::TokenTree::Ident(ident) = key_token {
222 table_primary_key_idents.push(ident)
223 }
224 }
225 } else if group.delimiter() == proc_macro2::Delimiter::Brace {
226 let mut rust_column_name: Option<Ident> = None;
231 let mut actual_column_name: Option<String> = None;
233 let mut column_type: Option<Ident> = None;
234 let mut column_nullable: bool = false;
235 let mut column_unsigned: bool = false;
236 let mut column_array: bool = false;
237 let mut had_hashtag = false;
239
240 for column_tokens in group.stream().into_iter() {
241 let had_hashtag_last = had_hashtag;
243 had_hashtag = false;
244 match column_tokens {
245 proc_macro2::TokenTree::Group(group) => {
246 if had_hashtag_last {
247 if let Some((name, value)) = parse_diesel_attr_group(&group) {
250 if name == "sql_name" {
251 actual_column_name = Some(value);
252 }
253 }
254 }
255
256 continue;
257 }
258 proc_macro2::TokenTree::Ident(ident) => {
259 if rust_column_name.is_none() {
260 rust_column_name = Some(ident);
261 } else if ident.to_string().eq_ignore_ascii_case("Nullable") {
262 if column_array {
263 } else {
290 column_nullable = true;
291 }
292 } else if ident.to_string().eq_ignore_ascii_case("Unsigned") {
293 column_unsigned = true;
294 } else if ident.to_string().eq_ignore_ascii_case("Array") {
295 column_array = true;
296 } else {
297 column_type = Some(ident);
298 }
299 }
300 proc_macro2::TokenTree::Punct(punct) => {
301 let char = punct.as_char();
302
303 if char == '#' {
304 had_hashtag = true;
305 continue;
306 } else if char == '-' || char == '>' {
307 continue;
309 } else if char == ','
310 && rust_column_name.is_some()
311 && column_type.is_some()
312 {
313 let rust_column_name_checked = rust_column_name.ok_or(
316 Error::unsupported_schema_format(
317 "Invalid column name syntax",
318 ),
319 )?;
320 let column_name = actual_column_name
321 .unwrap_or(rust_column_name_checked.to_string());
322
323 table_columns.push(ParsedColumnMacro {
325 name: rust_column_name_checked,
326 ty: schema_type_to_rust_type(
327 column_type
328 .ok_or(Error::unsupported_schema_format(
329 "Invalid column type syntax",
330 ))?
331 .to_string(),
332 config,
333 )?,
334 is_nullable: column_nullable,
335 is_unsigned: column_unsigned,
336 is_array: column_array,
337 column_name,
338 });
339
340 rust_column_name = None;
342 actual_column_name = None;
343 column_type = None;
344 column_unsigned = false;
345 column_nullable = false;
346 column_array = false;
347 }
348 }
349 _ => {
350 return Err(Error::unsupported_schema_format(
351 "Invalid column definition token in diesel table macro",
352 ))
353 }
354 }
355 }
356
357 if rust_column_name.is_some()
358 || column_type.is_some()
359 || column_nullable
360 || column_unsigned
361 {
362 return Err(Error::unsupported_schema_format(
364 "It seems a column was partially defined",
365 ));
366 }
367 } else {
368 return Err(Error::unsupported_schema_format(
369 "Invalid delimiter in diesel table macro group",
370 ));
371 }
372 }
373 _ => {
374 return Err(Error::unsupported_schema_format(
375 "Invalid token tree item in diesel table macro",
376 ))
377 }
378 }
379 }
380
381 Ok(ParsedTableMacro {
382 name: table_name_ident
383 .clone()
384 .ok_or(Error::unsupported_schema_format(
385 "Could not extract table name from schema file",
386 ))?,
387 struct_name: table_name_ident.unwrap().to_string().to_pascal_case(),
388 columns: table_columns,
389 primary_key_columns: table_primary_key_idents,
390 foreign_keys: vec![],
391 generated_code: format!(
392 "{FILE_SIGNATURE}\n\nFATAL ERROR: nothing was generated; this shouldn't be possible."
393 ),
394 })
395}
396
397fn parse_diesel_attr_group(group: &proc_macro2::Group) -> Option<(Ident, String)> {
403 if group.delimiter() != proc_macro2::Delimiter::Bracket {
405 return None;
406 }
407
408 let mut token_stream = group.stream().into_iter();
409 let attr_name = match token_stream.next()? {
411 proc_macro2::TokenTree::Ident(ident) => ident,
412 _ => return None,
413 };
414
415 let punct = match token_stream.next()? {
417 proc_macro2::TokenTree::Punct(punct) => punct,
418 _ => return None,
419 };
420
421 if punct.as_char() != '=' {
422 return None;
423 }
424
425 let value = match token_stream.next()? {
427 proc_macro2::TokenTree::Literal(literal) => literal,
428 _ => return None,
429 };
430
431 let mut value = value.to_string();
432
433 if value.starts_with('"') && value.ends_with('"') {
435 value = String::from(&value[1..value.len() - 1]); }
437
438 Some((attr_name, value))
439}
440
441fn schema_type_to_rust_type(schema_type: String, config: &GenerationConfig) -> Result<String> {
449 Ok(match schema_type.to_lowercase().as_str() {
450 "unsigned" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), "inet" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), "cidr" => return Err(Error::unsupported_type("Unsigned types are not yet supported, please open an issue if you need this feature!")), "bool" => "bool",
456
457 "tinyint" => "i8",
459 "smallint" => "i16",
460 "smallserial" => "i16",
461 "int2" => "i16",
462 "int4" => "i32",
463 "int4range" => "(std::collections::Bound<i32>, std::collections::Bound<i32>)",
464 "integer" => "i32",
465 "serial" => "i32",
466 "bigint" => "i64",
467 "bigserial" => "i64",
468 "int8" => "i64",
469 "int8range" => "(std::collections::Bound<i64>, std::collections::Bound<i64>)",
470 "float" => "f32",
471 "float4" => "f32",
472 "double" => "f64",
473 "float8" => "f64",
474 "numeric" => "bigdecimal::BigDecimal",
475 "numrange" => "(std::collections::Bound<bigdecimal::BigDecimal>, std::collections::Bound<bigdecimal::BigDecimal>)",
476 "decimal" => "bigdecimal::BigDecimal",
477
478 "text" => "String",
480 "varchar" => "String",
481 "bpchar" => "String",
482 "char" => "String",
483 "tinytext" => "String",
484 "mediumtext" => "String",
485 "longtext" => "String",
486
487 "binary" => "Vec<u8>",
489 "bytea" => "Vec<u8>",
490 "tinyblob" => "Vec<u8>",
491 "blob" => "Vec<u8>",
492 "mediumblob" => "Vec<u8>",
493 "longblob" => "Vec<u8>",
494 "varbinary" => "Vec<u8>",
495 "bit" => "Vec<u8>",
496
497 "date" => "chrono::NaiveDate",
499 "daterange" => "(std::collections::Bound<chrono::NaiveDate>, std::collections::Bound<chrono::NaiveDate>)",
500 "datetime" => "chrono::NaiveDateTime",
501 "time" => "chrono::NaiveTime",
502 "timestamp" => "chrono::NaiveDateTime",
503 "tsrange" => "(std::collections::Bound<chrono::NaiveDateTime>, std::collections::Bound<chrono::NaiveDateTime>)",
504 "timestamptz" => "chrono::DateTime<chrono::Utc>",
505 "timestamptzsqlite" => "chrono::DateTime<chrono::Utc>",
506 "tstzrange" => "(std::collections::Bound<chrono::DateTime<chrono::Utc>>, std::collections::Bound<chrono::DateTime<chrono::Utc>>)",
507
508 "json" => "serde_json::Value",
510 "jsonb" => "serde_json::Value",
511
512 "uuid" => "uuid::Uuid",
514 "interval" => "PgInterval",
515 "oid" => "u32",
516 "money" => "PgMoney",
517 "macaddr" => "[u8; 6]",
518 _ => {
526 let schema_path = config.get_schema_path();
527 let _type = format!("{schema_path}sql_types::{schema_type}");
529 return Ok(_type);
530 }
531 }.to_string())
532}