1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7 parse::Parse, parse_macro_input, spanned::Spanned, Data, DeriveInput, Expr, ExprAssign,
8 ExprLit, ExprPath, Field, Ident, Lit, LitStr, PathSegment, Result, Type, TypePath,
9};
10
11#[proc_macro_derive(Table, attributes(rizzle))]
12pub fn table(s: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(s as DeriveInput);
14 match table_macro(input) {
15 Ok(s) => s.to_token_stream().into(),
16 Err(e) => e.to_compile_error().into(),
17 }
18}
19
20enum Rel {
21 One(LitStr),
22 Many(LitStr),
23}
24
25#[derive(Default)]
26struct RizzleAttr {
27 table_name: Option<LitStr>,
28 primary_key: bool,
29 not_null: bool,
30 default_value: Option<LitStr>,
31 columns: Option<LitStr>,
32 references: Option<LitStr>,
33 from: Option<LitStr>,
34 to: Option<LitStr>,
35 rel: Option<Rel>,
36}
37
38impl Parse for RizzleAttr {
39 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
40 let mut rizzle_attr = RizzleAttr::default();
41 let args_parsed =
42 syn::punctuated::Punctuated::<Expr, syn::Token![,]>::parse_terminated(input)?;
43 for expr in args_parsed.iter() {
44 match expr {
45 Expr::Assign(ExprAssign { left, right, .. }) => match (&**left, &**right) {
46 (Expr::Path(ExprPath { path, .. }), Expr::Lit(ExprLit { lit, .. })) => {
47 if let (Some(PathSegment { ident, .. }), Lit::Str(lit_str)) =
48 (path.segments.last(), lit)
49 {
50 match ident.to_string().as_ref() {
51 "table" => {
52 rizzle_attr.table_name = Some(lit_str.clone());
53 }
54 "r#default" => {
55 rizzle_attr.default_value = Some(lit_str.clone());
56 }
57 "columns" => {
58 rizzle_attr.columns = Some(lit_str.clone());
59 }
60 "references" => {
61 rizzle_attr.references = Some(lit_str.clone());
62 }
63 "many" => {
64 rizzle_attr.rel = Some(Rel::Many(lit_str.clone()));
65 }
66 "from" => {
67 rizzle_attr.from = Some(lit_str.clone());
68 }
69 "to" => {
70 rizzle_attr.to = Some(lit_str.clone());
71 }
72 "one" => {
73 rizzle_attr.rel = Some(Rel::One(lit_str.clone()));
74 }
75 _ => unimplemented!(),
76 }
77 }
78 }
79 _ => unimplemented!(),
80 },
81 Expr::Path(path) => match path.path.segments.len() {
82 1 => match path
83 .path
84 .segments
85 .first()
86 .unwrap()
87 .ident
88 .to_string()
89 .as_ref()
90 {
91 "not_null" => rizzle_attr.not_null = true,
92 "primary_key" => rizzle_attr.primary_key = true,
93 _ => {}
94 },
95 _ => {}
96 },
97 _ => {}
98 }
99 }
100
101 Ok(rizzle_attr)
102 }
103}
104
105struct RizzleField {
106 ident_name: String,
107 ident: Ident,
108 field: Field,
109 attrs: Vec<RizzleAttr>,
110 type_string: String,
111}
112
113fn table_macro(input: DeriveInput) -> Result<TokenStream2> {
114 let table_str = input
115 .attrs
116 .iter()
117 .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
118 .last()
119 .expect("define #![rizzle(table = \"your table name here\")] on struct")
120 .table_name
121 .unwrap();
122 let struct_name = input.ident;
123 let table_name = table_str.value();
124 let rizzle_fields = match input.data {
125 syn::Data::Struct(ref data) => data
126 .fields
127 .iter()
128 .map(|field| {
129 let ident = field
130 .ident
131 .as_ref()
132 .expect("Struct fields should have names");
133 RizzleField {
134 ident: ident.clone(),
135 ident_name: ident.to_string(),
136 field: field.clone(),
137 attrs: field
138 .attrs
139 .iter()
140 .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
141 .collect::<Vec<_>>(),
142 type_string: string_type_from_field(&field),
143 }
144 })
145 .collect::<Vec<_>>(),
146 _ => unimplemented!(),
147 };
148 let columns = columns(&table_name, &rizzle_fields);
149 let attrs = struct_attrs(&table_name, &rizzle_fields);
150 let indexes = indexes(&table_name, &rizzle_fields);
151 let references = references(&table_name, &rizzle_fields);
152
153 Ok(quote! {
154 impl Table for #struct_name {
155 fn new() -> Self {
156 Self { #(#attrs,)* }
157 }
158
159 fn name(&self) -> String {
160 String::from(#table_str)
161 }
162
163 fn columns(&self) -> Vec<Column> {
164 vec![#(#columns,)*]
165 }
166
167 fn indexes(&self) -> Vec<Index> {
168 vec![#(#indexes,)*]
169 }
170
171 fn references(&self) -> Vec<Reference> {
172 vec![#(#references,)*]
173 }
174
175 fn create_sql(&self) -> String {
176 let columns_sql = self.columns()
177 .iter()
178 .map(|c| c.definition_sql())
179 .collect::<Vec<_>>()
180 .join(", ");
181 format!("create table {} ({})", self.name(), columns_sql)
182 }
183 }
184 })
185}
186
187fn references(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
188 fields
189 .iter()
190 .filter(|field| match field.type_string.as_str() {
191 "Real" | "Integer" | "Text" | "Blob" | "Many" => true,
192 _ => false,
193 })
194 .filter(|field| match field.attrs.last() {
195 Some(attr) => attr.references.is_some(),
196 None => false,
197 })
198 .map(|field| {
199 let RizzleAttr { references, .. } = field.attrs.last().unwrap();
200 let many = field.type_string == "Many";
201 quote! {
202 Reference {
203 table: #table_name.to_owned(),
204 clause: #references.to_owned(),
205 many: #many,
206 ..Default::default()
207 }
208 }
209 })
210 .collect()
211}
212
213fn string_type_from_field(field: &Field) -> String {
214 match &field.ty {
215 syn::Type::Path(TypePath { path, .. }) => match path.segments.last() {
216 Some(PathSegment { ident, .. }) => ident.to_string(),
217 None => unimplemented!(),
218 },
219 _ => unimplemented!(),
220 }
221}
222
223fn struct_attrs(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
224 fields
225 .iter()
226 .map(|f| {
227 let ident = &f
228 .field
229 .ident
230 .as_ref()
231 .expect("Struct fields should have names");
232 let value = format!("{}.{}", table_name, ident.to_string());
233 quote! {
234 #ident: #value
235 }
236 })
237 .collect::<Vec<_>>()
238}
239
240fn data_type(string_type: &String) -> TokenStream2 {
241 match string_type.as_str() {
242 "Real" => quote! { sqlite::DataType::Real },
243 "Integer" => quote! { sqlite::DataType::Integer },
244 "Text" => quote! { sqlite::DataType::Text },
245 _ => quote! { sqlite::DataType::Blob },
246 }
247}
248
249fn columns(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
250 fields
251 .iter()
252 .filter(|field| match field.type_string.as_ref() {
253 "Real" | "Integer" | "Text" | "Blob" => true,
254 _ => false,
255 })
256 .map(|field| {
257 let ty = data_type(&field.type_string);
258 let ident = &field.ident_name;
259 if let Some(RizzleAttr {
260 primary_key,
261 not_null,
262 default_value,
263 references,
264 ..
265 }) = field.attrs.last()
266 {
267 let default_value = match default_value {
268 Some(default) => quote! { Some(#default.to_owned()) },
269 None => quote! { None },
270 };
271 let references = match references {
272 Some(references) => quote! { Some(#references.to_owned()) },
273 None => quote! { None },
274 };
275 quote! {
276 Column {
277 table_name: #table_name.to_string(),
278 name: #ident.to_string(),
279 data_type: #ty,
280 primary_key: #primary_key,
281 not_null: #not_null,
282 default_value: #default_value,
283 references: #references,
284 ..Default::default()
285 }
286 }
287 } else {
288 quote! {
289 Column {
290 table_name: #table_name.to_string(),
291 name: #ident.to_string(),
292 data_type: #ty,
293 ..Default::default()
294 }
295 }
296 }
297 })
298 .collect::<Vec<_>>()
299}
300
301fn indexes(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
302 let column_names = fields
303 .iter()
304 .filter(|f| match f.type_string.as_ref() {
305 "Text" | "Integer" | "Real" | "Blob" => true,
306 _ => false,
307 })
308 .map(|f| f.ident_name.clone())
309 .collect::<HashSet<_>>();
310 fields
311 .iter()
312 .filter(|f| match f.type_string.as_ref() {
313 "Index" | "UniqueIndex" => true,
314 _ => false,
315 })
316 .filter(|f| match f.attrs.last() {
317 Some(attr) => attr.columns.is_some(),
318 None => false,
319 })
320 .map(|f| {
321 let name = &f.ident_name;
322 let attr = f.attrs.last().unwrap();
323 let RizzleAttr { columns, .. } = attr;
324 let attr_column_names = match columns {
325 Some(lit_str) => lit_str.value(),
326 None => String::default(),
327 };
328 let names = attr_column_names
329 .split(",")
330 .map(|x| x.to_owned())
331 .collect::<HashSet<_>>();
332 let diff = &names.difference(&column_names).collect::<Vec<_>>();
333 let column_names_list = column_names
334 .iter()
335 .map(|x| format!("- {}", x))
336 .collect::<Vec<_>>()
337 .join("\n");
338 let compiler_error = format!(
339 "index {:?} on {:?} in table {:?} which only declares \n{}",
340 name, diff, table_name, column_names_list
341 );
342 if diff.len() != 0 {
343 quote_spanned! {
344 columns.span() => compile_error!(#compiler_error)
345 }
346 } else {
347 let column_names = match columns {
348 Some(lit_str) => quote! { #lit_str.to_string() },
349 None => quote! { "".to_string() },
350 };
351 let ty = match f.type_string.as_ref() {
352 "Index" => quote! { sqlite::IndexType::Plain },
353 "UniqueIndex" => quote! { sqlite::IndexType::Unique },
354 _ => unimplemented!(),
355 };
356 quote! {
357 Index {
358 table_name: #table_name.to_string(),
359 name: #name.to_string(),
360 index_type: #ty,
361 column_names: #column_names
362 }
363 }
364 }
365 })
366 .collect::<Vec<_>>()
367}
368
369#[proc_macro_derive(Select, attributes(rizzle))]
370pub fn select(s: TokenStream) -> TokenStream {
371 let input = parse_macro_input!(s as DeriveInput);
372 match select_macro(input) {
373 Ok(s) => s.to_token_stream().into(),
374 Err(e) => e.to_compile_error().into(),
375 }
376}
377
378fn select_macro(input: DeriveInput) -> Result<TokenStream2> {
379 let struct_name = input.ident;
380 let rizzle_fields = rizzle_fields(&input.data);
381 let column_names = rizzle_fields
382 .iter()
383 .filter(|RizzleField { attrs, .. }| attrs.is_empty())
384 .map(|RizzleField { ident, .. }| ident.to_string())
385 .collect::<Vec<_>>()
386 .join(", ");
387 let fk_column_streams = rizzle_fields
388 .iter()
389 .filter(|RizzleField { attrs, .. }| !attrs.is_empty())
390 .map(|RizzleField { attrs, field, .. }| {
391 let ty = last_segment_from_type(&field.ty).expect("");
392 let attr = attrs.last().unwrap();
393 let ty_name = &ty.to_string();
394 match &attr.rel {
395 Some(Rel::One(table_name)) => {
396 quote! {
397 #ty::column_names_sql().split(",").map(|col| format!("{}.{} as '{}_{}'", #table_name, col.trim(), #ty_name, col.trim())).collect::<Vec<_>>().join(", ")
398 }
399 },
400 Some(Rel::Many(_)) => todo!(),
401 None => todo!(),
402 }
403 })
404 .collect::<Vec<_>>();
405 let sets = &rizzle_fields
406 .iter()
407 .map(
408 |RizzleField {
409 ident,
410 attrs,
411 field,
412 ..
413 }| {
414 if let (Some(ty_ident), Some(attr)) =
415 (last_segment_from_type(&field.ty), attrs.last())
416 {
417 match &attr.rel {
418 Some(Rel::One(_)) => quote! {
419 #ident: #ty_ident::from_row(row)?
420 },
421 Some(Rel::Many(_)) => todo!(),
422 _ => todo!(),
423 }
424 } else {
425 let lit_str = ident.to_string();
426 let struct_name_string = struct_name.to_string();
427 let fk_name = format!("{}_{}", struct_name_string, lit_str);
428 quote! {
429 #ident: match row.try_get(#lit_str) {
430 Ok(val) => val,
431 Err(_) => row.try_get(#fk_name)?
432 }
433 }
434 }
435 },
436 )
437 .collect::<Vec<_>>();
438 Ok(quote! {
439 impl Select for #struct_name {
440 fn column_names_sql() -> String {
441 let prefixed_vec: Vec<String> = vec![#(#fk_column_streams,)*];
442 let prefixed: String = prefixed_vec.join(", ");
443 if prefixed.is_empty() {
444 format!("{}", #column_names)
445 } else {
446 format!("{}, {}", #column_names, prefixed)
447 }
448 }
449
450 fn columns_sql(&self) -> String {
451 Self::column_names_sql()
452 }
453 }
454
455 impl<'r> FromRow<'r, sqlite::SqliteRow> for #struct_name {
456 fn from_row(row: &'r sqlite::SqliteRow) -> Result<Self, SqlxError> {
457 Ok(#struct_name {
458 #(#sets,)*
459 })
460 }
461 }
462 })
463}
464
465#[proc_macro_derive(Insert, attributes(rizzle))]
466pub fn insert(s: TokenStream) -> TokenStream {
467 let input = parse_macro_input!(s as DeriveInput);
468 match insert_macro(input) {
469 Ok(s) => s.to_token_stream().into(),
470 Err(e) => e.to_compile_error().into(),
471 }
472}
473
474fn insert_macro(input: DeriveInput) -> Result<TokenStream2> {
475 let struct_name = input.ident;
476 let fields = match input.data {
477 syn::Data::Struct(ref data) => data
478 .fields
479 .iter()
480 .filter(|field| field.attrs.is_empty())
481 .map(|field| {
482 let ident = field
483 .ident
484 .as_ref()
485 .expect("Struct fields should have names");
486 let ty = &field.ty;
487 (ident, ty)
488 })
489 .collect::<Vec<_>>(),
490 _ => unimplemented!(),
491 };
492 let column_names = fields
493 .iter()
494 .map(|(ident, _)| ident.to_string())
495 .collect::<Vec<_>>()
496 .join(", ");
497 let sql_placeholders = &fields.iter().map(|_| "?").collect::<Vec<_>>();
498 let placeholders = &fields.iter().map(|_| "{}").collect::<Vec<_>>().join(", ");
499 let data_values = &fields
500 .iter()
501 .map(|(ident, _)| quote! { self.#ident.clone().into() })
502 .collect::<Vec<_>>();
503 Ok(quote! {
504 impl Insert for #struct_name {
505 fn insert_values(&self) -> Vec<DataValue> {
506 vec![
507 #(#data_values,)*
508 ]
509 }
510
511 fn insert_sql(&self) -> String {
512 let values_sql = format!(#placeholders, #(#sql_placeholders,)*);
513 format!("({}) values ({})", #column_names, values_sql)
514 }
515 }
516 })
517}
518
519#[proc_macro_derive(New, attributes(rizzle))]
520pub fn new(s: TokenStream) -> TokenStream {
521 let input = parse_macro_input!(s as DeriveInput);
522 match new_macro(input) {
523 Ok(s) => s.to_token_stream().into(),
524 Err(e) => e.to_compile_error().into(),
525 }
526}
527
528fn new_macro(input: DeriveInput) -> Result<TokenStream2> {
529 let fields = match input.data {
530 syn::Data::Struct(ref data) => data
531 .fields
532 .iter()
533 .map(|field| {
534 let ident = field
535 .ident
536 .as_ref()
537 .expect("Struct fields should have names");
538 let ty = &field.ty;
539 (ident, ty)
540 })
541 .collect::<Vec<_>>(),
542 _ => unimplemented!(),
543 };
544 let attrs = &fields
545 .iter()
546 .map(|(ident, ty)| {
547 quote! {
548 #ident: #ty::default()
549 }
550 })
551 .collect::<Vec<_>>();
552 let struct_name = input.ident;
553 Ok(quote! {
554 impl New for #struct_name {
555 fn new() -> Self {
556 Self { #(#attrs,)* }
557 }
558 }
559 })
560}
561
562#[proc_macro_derive(Update, attributes(rizzle))]
563pub fn update(s: TokenStream) -> TokenStream {
564 let input = parse_macro_input!(s as DeriveInput);
565 match update_macro(input) {
566 Ok(s) => s.to_token_stream().into(),
567 Err(e) => e.to_compile_error().into(),
568 }
569}
570
571fn update_macro(input: DeriveInput) -> Result<TokenStream2> {
572 let struct_name = input.ident;
573 let fields = match input.data {
574 syn::Data::Struct(ref data) => data
575 .fields
576 .iter()
577 .filter(|field| field.attrs.is_empty())
578 .map(|field| {
579 let ident = field
580 .ident
581 .as_ref()
582 .expect("Struct fields should have names");
583 let ty = &field.ty;
584 (ident, ty)
585 })
586 .collect::<Vec<_>>(),
587 _ => unimplemented!(),
588 };
589 let placeholders = &fields
590 .iter()
591 .map(|(ident, _)| format!("{} = ?", ident))
592 .collect::<Vec<_>>()
593 .join(", ");
594 let data_values = &fields
595 .iter()
596 .map(|(ident, _)| quote! { self.#ident.clone().into() })
597 .collect::<Vec<_>>();
598 Ok(quote! {
599 impl Update for #struct_name {
600 fn update_values(&self) -> Vec<DataValue> {
601 vec![
602 #(#data_values,)*
603 ]
604 }
605
606 fn update_sql(&self) -> String {
607 format!("set {}", #placeholders)
608 }
609 }
610 })
611}
612
613#[proc_macro_derive(Row, attributes(rizzle))]
614pub fn row(s: TokenStream) -> TokenStream {
615 let input = parse_macro_input!(s as DeriveInput);
616 match row_macro(input) {
617 Ok(s) => s.to_token_stream().into(),
618 Err(e) => e.to_compile_error().into(),
619 }
620}
621
622fn row_macro(input: DeriveInput) -> Result<TokenStream2> {
623 let insert_token_stream = insert_macro(input.clone())?;
624 let update_token_stream = update_macro(input.clone())?;
625 let select_token_stream = select_macro(input.clone())?;
626 Ok(quote! {
627 #insert_token_stream
628 #update_token_stream
629 #select_token_stream
630 })
631}
632
633fn rizzle_fields(data: &Data) -> Vec<RizzleField> {
634 match data {
635 syn::Data::Struct(ref data) => data
636 .fields
637 .iter()
638 .map(|field| {
639 let ident = field
640 .ident
641 .as_ref()
642 .expect("Struct fields should have names");
643 RizzleField {
644 ident: ident.clone(),
645 ident_name: ident.to_string(),
646 field: field.clone(),
647 attrs: field
648 .attrs
649 .iter()
650 .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
651 .collect::<Vec<_>>(),
652 type_string: string_type_from_field(&field),
653 }
654 })
655 .collect::<Vec<_>>(),
656 _ => unimplemented!(),
657 }
658}
659
660fn last_segment_from_type(ty: &Type) -> Option<Ident> {
661 match ty {
662 Type::Path(TypePath { path, .. }) => Some(path.segments.last()?.ident.clone()),
663 _ => None,
664 }
665}
666
667#[proc_macro_derive(RizzleSchema, attributes(rizzle))]
668pub fn rizzle_schema(s: TokenStream) -> TokenStream {
669 let input = parse_macro_input!(s as DeriveInput);
670 match rizzle_schema_macro(input) {
671 Ok(s) => s.to_token_stream().into(),
672 Err(e) => e.to_compile_error().into(),
673 }
674}
675
676fn rizzle_schema_macro(input: DeriveInput) -> Result<TokenStream2> {
677 let struct_name = input.ident;
678 let struct_fields = rizzle_fields(&input.data);
679 let new_fields = struct_fields
680 .iter()
681 .map(|field| {
682 let ident = field
683 .field
684 .ident
685 .as_ref()
686 .expect("Struct fields should have ");
687 let ty = &field.field.ty;
688 quote! { #ident: #ty::new() }
689 })
690 .collect::<Vec<_>>();
691 let tables = struct_fields
692 .iter()
693 .map(|field| {
694 let ident = &field.ident;
695 quote! { &self.#ident }
696 })
697 .collect::<Vec<_>>();
698 Ok(quote! {
699 impl RizzleSchema for #struct_name {
700 fn new() -> Self {
701 Self { #(#new_fields,)* }
702 }
703
704 fn sql(&self) -> String {
705 "".to_owned()
706 }
707
708 fn tables<'a>(&'a self) -> Vec<&'a dyn Table> {
709 vec![#(#tables,)*]
710 }
711 }
712 })
713}