1extern crate proc_macro;
8
9use syn::spanned::Spanned;
10
11#[cfg(test)]
12mod tests {
13 #[test]
14 fn it_works() {
15 assert_eq!(2 + 2, 4);
16 }
17}
18
19enum Item {
20 Struct(syn::ItemStruct),
21 Enum(syn::ItemEnum),
22}
23
24impl syn::parse::Parse for Item {
25 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
26 let mut attrs = input.call(syn::Attribute::parse_outer)?;
27 let ahead = input.fork();
28 let vis: syn::Visibility = ahead.parse()?;
29
30 let lookahead = ahead.lookahead1();
31 let mut item = if lookahead.peek(syn::Token![struct]) {
32 input.parse().map(Item::Struct)
33 } else if lookahead.peek(syn::Token![enum]) {
34 input.parse().map(Item::Enum)
35 } else {
36 Err(lookahead.error())
37 }?;
38
39 {
40 let (item_vis, item_attrs, generics) = match &mut item {
41 Item::Struct(item) => (&mut item.vis, &mut item.attrs, &item.generics),
42 Item::Enum(item) => (&mut item.vis, &mut item.attrs, &item.generics),
43 };
44 if generics.params.len() > 0 {
45 return Err(syn::Error::new_spanned(
46 generics,
47 "schema! does not support generic types.",
48 ));
49 }
50 attrs.extend(item_attrs.drain(..));
51 *item_attrs = attrs;
52 *item_vis = vis;
53 }
54
55 Ok(item)
56 }
57}
58
59#[derive(Debug)]
60struct SchemaInput {
61 name: syn::Ident,
62 structs: Vec<syn::ItemStruct>,
63 enums: Vec<syn::ItemEnum>,
64}
65
66impl syn::parse::Parse for SchemaInput {
67 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
68 input.parse::<syn::Token![type]>()?;
69 let name: syn::Ident = input.parse()?;
70 input.parse::<syn::Token![;]>()?;
71 let mut structs = Vec::new();
72 let mut enums = Vec::new();
73 while !input.is_empty() {
74 match input.parse()? {
75 Item::Struct(i) => structs.push(i),
76 Item::Enum(i) => enums.push(i),
77 }
78 }
79 Ok(SchemaInput {
80 name,
81 structs,
82 enums,
83 })
84 }
85}
86
87#[derive(Debug)]
88struct SchemaOutput {
89 name: syn::Ident,
90 pod_structs: Vec<syn::ItemStruct>,
91 pod_enums: Vec<syn::ItemEnum>,
92 key_structs: Vec<syn::ItemStruct>,
93 key_struct_maps: Vec<std::collections::HashMap<syn::Ident, KeyType>>,
94 key_enums: Vec<syn::ItemEnum>,
95}
96
97#[derive(Debug, Eq, PartialEq)]
117enum KeyType {
118 Key(syn::Ident),
119 OptionKey(syn::Ident),
120 KeySet(syn::Ident),
121}
122
123impl KeyType {
124 fn key_to(&self) -> syn::Ident {
125 match self {
126 KeyType::Key(i) => i.clone(),
127 KeyType::OptionKey(i) => i.clone(),
128 KeyType::KeySet(i) => i.clone(),
129 }
130 }
131}
132
133fn first_of_type(t: &syn::Type) -> Option<(syn::Ident, syn::Type)> {
134 let p = if let syn::Type::Path(p) = t {
135 p
136 } else {
137 return None;
138 };
139 let path_count = p.path.segments.len();
140 if path_count != 1 {
141 return None;
142 }
143 let ident = p.path.segments.last().unwrap().clone().ident;
144 let path_only = p.path.segments.last().unwrap();
145 let args = if let syn::PathArguments::AngleBracketed(args) = &path_only.arguments {
146 args
147 } else {
148 return None;
149 };
150 if args.args.len() != 1 {
151 return None;
152 }
153 use syn::GenericArgument;
154 let t = if let GenericArgument::Type(t) = args.args.first()? {
155 t
156 } else {
157 return None;
158 };
159 Some((ident, t.clone()))
160}
161
162fn type_is_just_ident(t: &syn::Type) -> Option<syn::Ident> {
163 let p = if let syn::Type::Path(p) = t {
164 p
165 } else {
166 return None;
167 };
168 let path_count = p.path.segments.len();
169 if path_count != 1 {
170 return None;
171 }
172 let ident = p.path.segments.last().unwrap().clone().ident;
173 let path_only = p.path.segments.last().unwrap();
174 if path_only.arguments != syn::PathArguments::None {
175 return None;
176 }
177 Some(ident)
178}
179
180fn parse_keytype(t: &syn::Type) -> Result<Option<KeyType>, syn::Error> {
181 if let Some((key, t)) = first_of_type(&t) {
182 if key.to_string() == "Option" {
183 if let Some((key, t)) = first_of_type(&t) {
184 if key.to_string() == "Key" {
185 if let Some(i) = type_is_just_ident(&t) {
186 return Ok(Some(KeyType::OptionKey(i)));
187 } else {
188 return Err(syn::Error::new_spanned(
189 t,
190 "Key type should be a simple table name",
191 ));
192 }
193 }
194 }
195 } else if key.to_string() == "KeySet" {
196 if let Some(i) = type_is_just_ident(&t) {
197 return Ok(Some(KeyType::KeySet(i)));
198 } else {
199 return Err(syn::Error::new_spanned(
200 t,
201 "Key type should be a simple table name",
202 ));
203 }
204 }
205 }
206 if let syn::Type::Path(p) = t {
207 let path_count = p.path.segments.len();
208 if path_count == 1 {
211 let ident = p.path.segments.last().unwrap().clone().ident;
212 let path_only = p.path.segments.last().unwrap();
213 let name = ident.to_string();
214 if name == "Option" {
216 let args = path_only.clone().arguments;
217 println!("args are {:#?}", args);
218 unimplemented!()
219 } else {
220 if name == "Key" {
221 if let syn::PathArguments::AngleBracketed(args) = &path_only.arguments {
222 if args.args.len() != 1 {
223 return Err(syn::Error::new_spanned(
224 t,
225 "Key should have just one type argument",
226 ));
227 }
228 use syn::{GenericArgument, Type};
229 if let GenericArgument::Type(Type::Path(ap)) = args.args.first().unwrap() {
230 if ap.path.segments.len() != 1 {
231 return Err(syn::Error::new_spanned(
232 t,
233 "Key should have a simple type argument",
234 ));
235 }
236 let tp = ap.path.segments.first().unwrap();
237 if !tp.arguments.is_empty() {
238 Err(syn::Error::new_spanned(
239 tp.arguments.clone(),
240 "Key type should be a simple table name",
241 ))
242 } else {
243 let i = tp.ident.clone();
244 Ok(Some(KeyType::Key(i)))
249 }
250 } else {
251 Err(syn::Error::new_spanned(
252 t,
253 "Key should have a simple type argument",
254 ))
255 }
256 } else {
257 Err(syn::Error::new_spanned(t, "Key should be Key<ATableType>"))
258 }
259 } else {
260 Ok(None)
261 }
262 }
263 } else {
264 Ok(None)
265 }
266 } else {
267 Ok(None)
268 }
269}
270
271fn parse_fields(
272 f: &syn::FieldsNamed,
273) -> Result<std::collections::HashMap<syn::Ident, KeyType>, syn::Error> {
274 let mut keymap = std::collections::HashMap::new();
275 for n in f.named.iter() {
276 if let Some(kt) = parse_keytype(&n.ty)? {
277 keymap.insert(n.ident.clone().unwrap(), kt);
278 }
279 }
280 Ok(keymap)
281}
282
283impl SchemaInput {
284 fn process(&self) -> Result<SchemaOutput, syn::Error> {
285 let mut tables = std::collections::HashSet::new();
286 tables.extend(self.structs.iter().map(|x| x.ident.clone()));
287 tables.extend(self.enums.iter().map(|x| x.ident.clone()));
288
289 let mut pod_structs = Vec::new();
290 let mut key_structs = Vec::new();
291 let mut key_struct_maps = Vec::new();
292
293 for x in self.structs.iter().cloned() {
294 match &x.fields {
295 syn::Fields::Named(n) => {
296 let keymap = parse_fields(n)?;
297 if keymap.len() > 0 {
298 key_struct_maps.push(keymap);
299 key_structs.push(x);
300 } else {
301 pod_structs.push(x);
302 }
303 }
304 syn::Fields::Unnamed(_) => {
305 pod_structs.push(x);
306 }
307 syn::Fields::Unit => {
308 pod_structs.push(x);
309 }
310 }
311 }
312
313 let pod_enums: Vec<_> = self
314 .enums
315 .iter()
316 .map(|x| {
317 let mut x = x.clone();
318 x.vis = syn::Visibility::Public(syn::VisPublic {
319 pub_token: syn::Token!(pub)(x.span()),
320 });
321 x
322 })
323 .collect();
324 Ok(SchemaOutput {
325 name: self.name.clone(),
326 pod_structs,
327 key_structs,
328 key_struct_maps,
329 key_enums: Vec::new(),
330 pod_enums,
331 })
332 }
333}
334
335#[proc_macro]
336pub fn schema(raw_input: proc_macro::TokenStream) -> proc_macro::TokenStream {
337 use heck::SnakeCase;
339
340 let input: SchemaInput = syn::parse_macro_input!(raw_input as SchemaInput);
341 let output = match input.process() {
343 Err(e) => {
344 return e.to_compile_error().into();
345 }
346 Ok(v) => v,
347 };
348
349 let pod_structs = &output.pod_structs;
357 let key_structs = &output.key_structs;
358
359 let key_names: Vec<_> = key_structs
360 .iter()
361 .map(|x| quote::format_ident!("{}", x.ident.to_string().to_snake_case()))
362 .collect();
363
364 let mut reverse_references = std::collections::HashMap::new();
365 for (map, t) in output.key_struct_maps.iter().zip(key_structs.iter()) {
366 for (k, v) in map.iter() {
368 let kt = v.key_to();
369 if !reverse_references.contains_key(&kt) {
370 reverse_references.insert(kt.clone(), Vec::new());
371 }
372 reverse_references
373 .get_mut(&kt)
374 .unwrap()
375 .push((t.ident.clone(), k.clone()));
376 }
377 }
378 let mut pod_query_backrefs: Vec<Vec<(syn::Ident, syn::Ident)>> = Vec::new();
381 let pod_query_structs: Vec<syn::ItemStruct> = pod_structs
382 .iter()
383 .cloned()
384 .map(|mut x| {
385 let i = x.ident.clone();
386 let mut backrefs = Vec::new();
387 let mut backrefs_code = Vec::new();
388 if let Some(v) = reverse_references.get(&x.ident) {
389 for r in v.iter() {
390 let field = quote::format_ident!("{}_of", r.1.to_string().to_snake_case());
391 let t = &r.0;
392 backrefs.push((t.clone(), field.clone()));
393 let code = quote::quote! {
394 pub #field: KeySet<#t>,
395 };
396 backrefs_code.push(code);
398 }
399 }
400 pod_query_backrefs.push(backrefs);
401 x.ident = quote::format_ident!("{}Query", x.ident);
402 x.fields = syn::Fields::Named(syn::parse_quote! {{
403 __data: #i,
404 #(#backrefs_code)*
405 }});
406 x
407 })
408 .collect();
409 let pod_query_types: Vec<syn::PathSegment> = pod_query_structs
410 .iter()
411 .map(|x| {
412 let i = x.ident.clone();
413 syn::parse_quote! {#i}
414 })
415 .collect();
416 let pod_query_new: Vec<_> = pod_query_structs
417 .iter()
418 .zip(pod_query_backrefs.iter())
419 .map(|(x, br)| {
420 let i = &x.ident;
421 let backcode = br.iter().map(|(t, f)| {
422 quote::quote! {
423 #f: KeySet::<#t>::new(),
424 }
425 });
426 quote::quote! {
427 #i {
428 __data: value,
429 #(#backcode)*
430 }
431 }
432 })
433 .collect();
434
435 let pod_names: Vec<_> = pod_structs
436 .iter()
437 .map(|x| quote::format_ident!("{}", x.ident.to_string().to_snake_case()))
438 .collect();
439 let pod_inserts: Vec<_> = pod_structs
440 .iter()
441 .map(|x| quote::format_ident!("insert_{}", x.ident.to_string().to_snake_case()))
442 .collect();
443 let pod_lookups: Vec<_> = pod_structs
444 .iter()
445 .filter(|x| x.generics.params.len() == 0)
447 .map(|x| quote::format_ident!("lookup_{}", x.ident.to_string().to_snake_case()))
448 .collect();
449 let pod_lookup_hashes: Vec<_> = pod_structs
450 .iter()
451 .filter(|x| x.generics.params.len() == 0)
453 .map(|x| quote::format_ident!("hash_{}", x.ident.to_string().to_snake_case()))
454 .collect();
455 let pod_types: Vec<syn::PathSegment> = pod_structs
456 .iter()
457 .map(|x| {
458 let i = x.ident.clone();
459 syn::parse_quote! {#i}
460 })
461 .collect();
462
463 let mut key_query_backrefs: Vec<Vec<(syn::Ident, syn::Ident)>> = Vec::new();
464 let key_query_structs: Vec<_> = key_structs
465 .iter()
466 .cloned()
467 .map(|mut x| {
468 let i = x.ident.clone();
469 let mut backrefs = Vec::new();
470 let mut backrefs_code = Vec::new();
471 if let Some(v) = reverse_references.get(&x.ident) {
472 for r in v.iter() {
473 let field = quote::format_ident!("{}_of", r.1.to_string().to_snake_case());
474 let t = &r.0;
475 backrefs.push((t.clone(), field.clone()));
476 let code = quote::quote! {
477 pub #field: KeySet<#t>,
478 };
479 backrefs_code.push(code);
481 }
482 }
483 key_query_backrefs.push(backrefs);
484 x.ident = quote::format_ident!("{}Query", x.ident);
485 x.fields = syn::Fields::Named(syn::parse_quote! {{
486 __data: #i,
487 #(#backrefs_code)*
488 }});
489 x
490 })
491 .collect();
492 let key_query_types: Vec<syn::PathSegment> = key_query_structs
493 .iter()
494 .map(|x| {
495 let i = x.ident.clone();
496 let g = x.generics.clone();
497 syn::parse_quote! {#i#g}
498 })
499 .collect();
500
501 let key_inserts: Vec<_> = key_structs
502 .iter()
503 .map(|x| quote::format_ident!("insert_{}", x.ident.to_string().to_snake_case()))
504 .collect();
505 let key_insert_backrefs: Vec<_> = output
506 .key_struct_maps
507 .iter()
508 .enumerate()
509 .map(|(i, map)| {
510 let myname = &key_names[i];
511 let mut code = Vec::new();
512 let mut keys_and_types = map.iter().collect::<Vec<_>>();
517 keys_and_types.sort_by_key(|a| a.0);
518 for (k, v) in keys_and_types.into_iter() {
519 match v {
521 KeyType::Key(t) => {
522 let field = quote::format_ident!("{}", t.to_string().to_snake_case());
523 let rev = quote::format_ident!("{}_of", k.to_string().to_snake_case());
524 code.push(quote::quote! {
525 self.#field[self.#myname[idx].#k.0].#rev.insert(k);
526 });
527 }
528 KeyType::OptionKey(t) => {
529 let field = quote::format_ident!("{}", t.to_string().to_snake_case());
530 let rev = quote::format_ident!("{}_of", k.to_string().to_snake_case());
531 code.push(quote::quote! {
532 if let Some(idxk) = self.#myname[idx].#k {
533 self.#field[idxk.0].#rev.insert(k);
534 }
535 });
536 }
537 KeyType::KeySet(t) => {
538 let field = quote::format_ident!("{}", t.to_string().to_snake_case());
539 let rev = quote::format_ident!("{}_of", k.to_string().to_snake_case());
540 code.push(quote::quote! {
541 for idxk in self.#myname[idx].#k.iter() {
542 self.#field[idxk.0].#rev.insert(k);
543 }
544 });
545 }
546 }
547 }
548 quote::quote! {
549 #(#code)*
550 }
551 })
552 .collect();
553 let key_sets: Vec<_> = key_structs
554 .iter()
555 .map(|x| quote::format_ident!("set_{}", x.ident.to_string().to_snake_case()))
556 .collect();
557 let key_types: Vec<syn::PathSegment> = key_structs
558 .iter()
559 .map(|x| {
560 let i = x.ident.clone();
561 let g = x.generics.clone();
562 syn::parse_quote! {#i#g}
563 })
564 .collect();
565
566 let table_enums = output.pod_enums.iter();
568 let name = &input.name;
571 let output = quote::quote! {
573 trait Query: std::ops::Deref {
574 fn new(val: Self::Target) -> Self;
575 }
576 trait HasQuery {
577 type Query: Query<Target=Self>;
578 }
579 #(
580 #[repr(C)]
581 #[derive(Eq,PartialEq,Hash,Clone)]
582 #pod_structs
583 #[repr(C)]
584 #[derive(Eq,PartialEq,Hash,Clone)]
585 #pod_query_structs
587
588 impl std::ops::Deref for #pod_query_types {
589 type Target = #pod_types;
590 fn deref(&self) -> &Self::Target {
591 &self.__data
592 }
593 }
594 impl Query for #pod_query_types {
595 fn new(value: Self::Target) -> Self {
596 #pod_query_new
601 }
605 }
606 impl HasQuery for #pod_types {
607 type Query = #pod_query_types;
608 }
609 )*
610 #(
611 #[repr(C)]
612 #[derive(Clone)]
613 #key_structs
614 #[repr(C)]
615 #[derive(Clone)]
616 #key_query_structs
618
619 impl std::ops::Deref for #key_query_types {
620 type Target = #key_types;
621 fn deref(&self) -> &Self::Target {
622 unsafe { &*(self as *const Self as *const Self::Target) }
623 }
624 }
625 impl Query for #key_query_types {
626 fn new(value: Self::Target) -> Self {
627 let x = (value,
633 [0u8; std::mem::size_of::<Self>() - std::mem::size_of::<Self::Target>()]);
634 unsafe { std::mem::transmute(x) }
635 }
636 }
637 impl HasQuery for #key_types {
638 type Query = #key_query_types;
639 }
640 )*
641 #(
642 #[derive(Eq,PartialEq,Hash,Clone)]
643 #table_enums
644 )*
645
646 pub struct #name {
647 #(
648 pub #pod_names: Vec<#pod_query_types>,
649 )*
650 #(
651 pub #key_names: Vec<#key_query_types>,
652 )*
653 #(
654 pub #pod_lookup_hashes: std::collections::HashMap<#pod_types, usize>,
655 )*
656 }
657 impl #name {
658 pub fn new() -> Self {
660 #name {
661 #( #pod_names: Vec::new(), )*
662 #( #key_names: Vec::new(), )*
663 #(
664 #pod_lookup_hashes: std::collections::HashMap::new(),
665 )*
666 }
667 }
668 }
669
670 type Set64<K> = tinyset::Set64<K>;
671 type KeySet<T> = Set64<Key<T>>;
672
673 #[derive(Eq,PartialEq,Hash)]
674 pub struct Key<T>(usize, std::marker::PhantomData<T>);
675 impl<T> Copy for Key<T> {}
676 impl<T> Clone for Key<T> {
677 fn clone(&self) -> Self {
678 Key(self.0, self.1)
679 }
680 }
681 impl<T> tinyset::Fits64 for Key<T> {
682 unsafe fn from_u64(x: u64) -> Self {
683 Key(x as usize, std::marker::PhantomData)
684 }
685 fn to_u64(self) -> u64 {
686 self.0.to_u64()
687 }
688 }
689
690 impl #name {
691 #(
692 pub fn #pod_inserts(&mut self, datum: #pod_types) -> Key<#pod_types> {
693 let idx = self.#pod_names.len();
694 self.#pod_names.push(#pod_query_types::new(datum.clone()));
695 self.#pod_lookup_hashes.insert(datum, idx);
696 Key(idx, std::marker::PhantomData)
697 }
698 )*
699 #(
700 pub fn #key_inserts(&mut self, datum: #key_types) -> Key<#key_types> {
701 let idx = self.#key_names.len();
702 self.#key_names.push(#key_query_types::new(datum.clone()));
703 let k = Key(idx, std::marker::PhantomData);
704 #key_insert_backrefs
705 k
706 }
707 pub fn #key_sets(&mut self, k: Key<#key_types>, datum: #key_types) {
708 let old = std::mem::replace(&mut self.#key_names[k.0], #key_query_types::new(datum));
709 }
711 )*
712 #(
713 pub fn #pod_lookups(&self, datum: &#pod_types) -> Option<Key<#pod_types>> {
714 self.#pod_lookup_hashes.get(datum)
715 .map(|&i| Key(i, std::marker::PhantomData))
716 }
721 )*
722 }
723
724 #(
725 impl Key<#pod_types> {
726 pub fn d<'a,'b>(&'a self, database: &'b #name) -> &'b #pod_query_types {
727 &database.#pod_names[self.0]
728 }
729 }
730 )*
731 #(
732 impl Key<#key_types> {
733 pub fn d<'a,'b>(&'a self, database: &'b #name) -> &'b #key_query_types {
734 &database.#key_names[self.0]
735 }
736 }
737 )*
738 #(
739 impl std::ops::Index<Key<#key_types>> for #name {
740 type Output = #key_query_types;
741 fn index(&self, index: Key<#key_types>) -> &Self::Output {
742 &self.#key_names[index.0]
743 }
744 }
745 )*
746 #(
747 impl std::ops::Index<Key<#pod_types>> for #name {
748 type Output = #pod_query_types;
749 fn index(&self, index: Key<#pod_types>) -> &Self::Output {
750 &self.#pod_names[index.0]
751 }
752 }
753 )*
754 };
755 output.into()
757}