1mod analyze;
2mod expand;
3mod from_glue_value;
4mod into_glue_expr;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::{quote as q, ToTokens};
9use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, Ident, Type};
10
11pub(crate) use gluesql_core::ast::DataType as SqlType;
12use SqlType::*;
13
14#[proc_macro_derive(Storage, attributes(pkey, unique))]
16pub fn table_derive(input: TokenStream) -> TokenStream {
17 let ast = parse_macro_input!(input as DeriveInput);
18 let struct_ident = ast.ident;
19 let table_name = struct_ident.to_string();
20
21 let fields = match ast.data {
23 Data::Struct(DataStruct {
24 fields: Fields::Named(it),
25 ..
26 }) => it,
27 _ => panic!("Expected a `struct` with named fields"),
28 };
29
30 let mut columns: Vec<Column> = fields.named.into_iter().map(analyze::from_field).collect();
32
33 match columns.iter().filter(|c| c.pkey).count() {
35 0 => {
36 columns[0].pkey = true;
37 columns[0].unique = true;
38 }
39 1 => {}
40 _ => panic!("Storage macro doesn't support more than one pkey at the moment"),
41 };
42
43 TokenStream::from(expand::impl_table(struct_ident, table_name, columns))
45}
46
47struct Column {
48 field_name: Ident,
49 field_name_str: String,
50 full_type: Type,
51 full_type_str: String,
52 inner_type: Type,
54 sql_type: SqlType,
56 pkey: bool,
58 optional: bool,
60 list: bool,
62 unique: bool,
64 serialized: bool,
66}
67
68impl Column {
69 fn value_transform(&self) -> ValueTransform {
70 match self.sql_type {
71 Uuid => ValueTransform::UuidU128,
72 _ if self.serialized => ValueTransform::SerDe,
73 _ => ValueTransform::None,
74 }
75 }
76
77 fn value_variant(&self) -> &str {
78 match self.sql_type {
79 Uuid => "Uuid",
80 Text => "Str",
81 Timestamp => "Timestamp",
82 Boolean => "Bool",
83 Uint128 => "U128",
84 Uint64 => "U64",
85 Uint32 => "U32",
86 Uint16 => "U16",
87 Uint8 => "U8",
88 Int128 => "I128",
89 Int => "I64",
90 Int32 => "I32",
91 Int16 => "I16",
92 Int8 => "I8",
93 Float32 => "F32",
94 Float => "F64",
95 _ if self.serialized => "Bytea",
96 _ => panic!("value variant for {:?} is not supported", self.sql_type),
97 }
98 }
99}
100
101enum ValueTransform {
102 UuidU128,
103 SerDe,
104 None,
105}
106
107fn ident(name: &str) -> Ident {
108 Ident::new(name, Span::call_site())
109}
110
111trait TypeProps {
112 fn impl_into_exprnode(&self) -> bool;
113 fn int_or_smaller(&self) -> bool;
114 fn integer(&self) -> bool;
115 fn numeric(&self) -> bool;
116 fn comparable(&self) -> bool;
117 fn quoted(&self) -> bool;
118}
119
120impl TypeProps for SqlType {
121 fn impl_into_exprnode(&self) -> bool {
122 matches!(self, Boolean | Int)
123 }
124 fn int_or_smaller(&self) -> bool {
125 matches!(self, Uint32 | Uint16 | Uint8 | Int | Int32 | Int16 | Int8)
126 }
127 fn integer(&self) -> bool {
128 self.int_or_smaller() || matches!(self, Uint128 | Uint64 | Int128)
129 }
130 fn numeric(&self) -> bool {
131 self.integer() || matches!(self, Float | Float32)
132 }
133 fn comparable(&self) -> bool {
134 self.numeric() || matches!(self, Timestamp | Date | Time)
135 }
136 fn quoted(&self) -> bool {
137 !self.impl_into_exprnode() && !self.numeric()
138 }
139}
140
141impl TypeProps for Column {
142 fn impl_into_exprnode(&self) -> bool {
143 self.sql_type.impl_into_exprnode()
144 }
145 fn int_or_smaller(&self) -> bool {
146 self.sql_type.int_or_smaller()
147 }
148 fn integer(&self) -> bool {
149 self.sql_type.integer()
150 }
151 fn numeric(&self) -> bool {
152 self.sql_type.numeric()
153 }
154 fn comparable(&self) -> bool {
155 self.sql_type.comparable()
156 }
157 fn quoted(&self) -> bool {
158 self.sql_type.quoted()
159 }
160}
161
162fn column_schema(col: &Column) -> proc_macro2::TokenStream {
163 let Column {
164 field_name_str,
165 full_type_str,
166 sql_type,
167 pkey,
168 unique,
169 list,
170 optional,
171 serialized,
172 ..
173 } = col;
174 let numeric = sql_type.numeric();
175 let comparable = sql_type.comparable();
176 let sql_type_str = sql_type.to_string();
177 let (first_char, rest) = sql_type_str.split_at(1);
178 let sql_type = ident(&(first_char.to_owned() + &rest.to_ascii_lowercase()));
179 q! {
180 FieldSchema {
181 name: #field_name_str,
182 rust_type: #full_type_str,
183 sql_type: prest::sql::DataType::#sql_type,
184 unique: #unique,
185 pkey: #pkey,
186 list: #list,
187 optional: #optional,
188 serialized: #serialized,
189 numeric: #numeric,
190 comparable: #comparable,
191 }
192 }
193}