1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse_macro_input, Data, DeriveInput, GenericArgument, PathArguments, PathSegment, Type,
5};
6
7#[proc_macro_derive(DbAccess)]
8pub fn db_access(tokens: TokenStream) -> TokenStream {
9 let input: DeriveInput = parse_macro_input!(tokens);
10
11 let Data::Struct(data) = input.data else {
12 panic!("must be a struct")
13 };
14
15 let name = input.ident;
16
17 let mut impls = vec![];
18
19 for field in data.fields {
20 let Some(field_name) = field.ident else {
21 panic!("must be a non-tuple struct")
22 };
23 let Type::Path(ty) = field.ty else { continue };
24 let Some(end) = ty.path.segments.last() else {
25 continue;
26 };
27 match &end.ident.to_string()[..] {
28 "Arena" => {
29 let Some(ty) = get_gen_ty(end) else { continue };
30
31 impls.push(quote! {
32 impl DbAccess for Id<#ty> {
33 type Input<'a> = #ty;
34 type Output = #ty;
35
36 fn get(self, db: &#name) -> &#ty {
37 db.#field_name.get(self)
38 }
39
40 fn get_mut(self, db: &mut #name) -> &mut #ty {
41 db.#field_name.get_mut(self)
42 }
43
44 fn insert(value: #ty, db: &mut #name) -> Id<#ty> {
45 db.#field_name.insert(value)
46 }
47 }
48
49 impl<'a> DbInput<'a> for #ty {
50 type Id = Id<#ty>;
51 }
52 });
53 }
54 "Mapping" => {
55 let Some((k, v)) = get_gen_tys(end) else { continue };
56
57 impls.push(quote! {
58 impl DbAssocAccess for #v {
59 type Key = #k;
60
61 fn get_assoc(key: Id<Self::Key>, db: &#name) -> Option<&Self> {
62 db.#field_name.get(key)
63 }
64
65 fn get_assoc_mut(key: Id<Self::Key>, db: &mut #name) -> Option<&mut Self> {
66 db.#field_name.get_mut(key)
67 }
68
69 fn insert(key: Id<Self::Key>, value: Self, db: &mut #name) -> &mut #v {
70 db.#field_name.insert(key, value)
71 }
72 }
73 });
74 }
75 "StringStore" => {
76 impls.push(quote! {
77 impl DbAccess for Id<str> {
78 type Input<'a> = &'a str;
79 type Output = str;
80
81 fn get(self, db: &#name) -> &str {
82 db.#field_name.get(self)
83 }
84
85 fn get_mut(self, db: &mut #name) -> &mut str {
86 unimplemented!()
87 }
88
89 fn insert(value: &str, db: &mut #name) -> Id<str> {
90 db.#field_name.insert(value)
91 }
92 }
93
94 impl<'a> DbInput<'a> for &'a str {
95 type Id = Id<str>;
96 }
97 })
98 }
99 _ => (),
100 }
101 }
102
103 quote! {
104 impl #name {
105 pub fn get<I: DbAccess>(&self, id: I) -> &I::Output {
106 id.get(self)
107 }
108
109 pub fn get_mut<I: DbAccess>(&mut self, id: I) -> &mut I::Output {
110 id.get_mut(self)
111 }
112
113 pub fn insert<'v, V: DbInput<'v>>(&mut self, value: V) -> Id<<V::Id as DbAccess>::Output> {
114 <V::Id as DbAccess>::insert(value, self)
115 }
116
117 pub fn get_assoc<V: DbAssocAccess>(&self, id: Id<V::Key>) -> Option<&V> {
118 V::get_assoc(id, self)
119 }
120
121 pub fn get_assoc_mut<V: DbAssocAccess>(&mut self, id: Id<V::Key>) -> Option<&mut V> {
122 V::get_assoc_mut(id, self)
123 }
124
125 pub fn assoc<V: DbAssocAccess>(&mut self, id: Id<V::Key>, value: V) -> &mut V {
126 V::insert(id, value, self)
127 }
128 }
129
130 pub trait DbAccess {
131 type Input<'a>;
132 type Output: ?Sized;
133
134 fn get(self, db: &#name) -> &Self::Output;
135
136 fn get_mut(self, db: &mut #name) -> &mut Self::Output;
137
138 fn insert(value: Self::Input<'_>, db: &mut #name) -> Id<Self::Output>;
139 }
140
141 pub trait DbInput<'a> {
142 type Id: DbAccess<Input<'a> = Self>;
143 }
144
145 pub trait DbAssocAccess {
146 type Key: ?Sized;
147
148 fn get_assoc(key: Id<Self::Key>, db: &#name) -> Option<&Self>;
149
150 fn get_assoc_mut(key: Id<Self::Key>, db: &mut #name) -> Option<&mut Self>;
151
152 fn insert(key: Id<Self::Key>, value: Self, db: &mut #name) -> &mut Self;
153 }
154
155 #(#impls)*
156 }
157 .into()
158}
159
160fn get_gen_ty(segment: &PathSegment) -> Option<&Type> {
161 let PathArguments::AngleBracketed(args) = &segment.arguments else {
162 return None;
163 };
164 if args.args.len() != 1 {
165 return None;
166 }
167 let GenericArgument::Type(ty) = args.args.first().unwrap() else {
168 return None;
169 };
170 Some(ty)
171}
172
173fn get_gen_tys(segment: &PathSegment) -> Option<(&Type, &Type)> {
174 let PathArguments::AngleBracketed(args) = &segment.arguments else {
175 return None;
176 };
177 if args.args.len() != 2 {
178 return None;
179 }
180 let GenericArgument::Type(ty1) = args.args.get(0).unwrap() else {
181 return None;
182 };
183 let GenericArgument::Type(ty2) = args.args.get(1).unwrap() else {
184 return None;
185 };
186 Some((ty1, ty2))
187}