prolangkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Data, DeriveInput, GenericArgument, PathArguments, PathSegment, Type,
5};
6
7#[proc_macro_derive(DbAccess)]
8pub fn db_access(tokens: TokenStream) -> TokenStream {
9    let input: DeriveInput = parse_macro_input!(tokens);
10
11    let Data::Struct(data) = input.data else {
12        panic!("must be a struct")
13    };
14
15    let name = input.ident;
16
17    let mut impls = vec![];
18
19    for field in data.fields {
20        let Some(field_name) = field.ident else {
21            panic!("must be a non-tuple struct")
22        };
23        let Type::Path(ty) = field.ty else { continue };
24        let Some(end) = ty.path.segments.last() else {
25            continue;
26        };
27        match &end.ident.to_string()[..] {
28            "Arena" => {
29                let Some(ty) = get_gen_ty(end) else { continue };
30
31                impls.push(quote! {
32                    impl DbAccess for Id<#ty> {
33                        type Input<'a> = #ty;
34                        type Output = #ty;
35
36                        fn get(self, db: &#name) -> &#ty {
37                            db.#field_name.get(self)
38                        }
39
40                        fn get_mut(self, db: &mut #name) -> &mut #ty {
41                            db.#field_name.get_mut(self)
42                        }
43
44                        fn insert(value: #ty, db: &mut #name) -> Id<#ty> {
45                            db.#field_name.insert(value)
46                        }
47                    }
48
49                    impl<'a> DbInput<'a> for #ty {
50                        type Id = Id<#ty>;
51                    }
52                });
53            }
54            "Mapping" => {
55                let Some((k, v)) = get_gen_tys(end) else { continue };
56
57                impls.push(quote! {
58                    impl DbAssocAccess for #v {
59                        type Key = #k;
60
61                        fn get_assoc(key: Id<Self::Key>, db: &#name) -> Option<&Self> {
62                            db.#field_name.get(key)
63                        }
64
65                        fn get_assoc_mut(key: Id<Self::Key>, db: &mut #name) -> Option<&mut Self> {
66                            db.#field_name.get_mut(key)
67                        }
68
69                        fn insert(key: Id<Self::Key>, value: Self, db: &mut #name) -> &mut #v {
70                            db.#field_name.insert(key, value)
71                        }
72                    }
73                });
74            }
75            "StringStore" => {
76                impls.push(quote! {
77                    impl DbAccess for Id<str> {
78                        type Input<'a> = &'a str;
79                        type Output = str;
80
81                        fn get(self, db: &#name) -> &str {
82                            db.#field_name.get(self)
83                        }
84
85                        fn get_mut(self, db: &mut #name) -> &mut str {
86                            unimplemented!()
87                        }
88
89                        fn insert(value: &str, db: &mut #name) -> Id<str> {
90                            db.#field_name.insert(value)
91                        }
92                    }
93
94                    impl<'a> DbInput<'a> for &'a str {
95                        type Id = Id<str>;
96                    }
97                })
98            }
99            _ => (),
100        }
101    }
102
103    quote! {
104        impl #name {
105            pub fn get<I: DbAccess>(&self, id: I) -> &I::Output {
106                id.get(self)
107            }
108
109            pub fn get_mut<I: DbAccess>(&mut self, id: I) -> &mut I::Output {
110                id.get_mut(self)
111            }
112
113            pub fn insert<'v, V: DbInput<'v>>(&mut self, value: V) -> Id<<V::Id as DbAccess>::Output> {
114                <V::Id as DbAccess>::insert(value, self)
115            }
116
117            pub fn get_assoc<V: DbAssocAccess>(&self, id: Id<V::Key>) -> Option<&V> {
118                V::get_assoc(id, self)
119            }
120
121            pub fn get_assoc_mut<V: DbAssocAccess>(&mut self, id: Id<V::Key>) -> Option<&mut V> {
122                V::get_assoc_mut(id, self)
123            }
124
125            pub fn assoc<V: DbAssocAccess>(&mut self, id: Id<V::Key>, value: V) -> &mut V {
126                V::insert(id, value, self)
127            }
128        }
129
130        pub trait DbAccess {
131            type Input<'a>;
132            type Output: ?Sized;
133
134            fn get(self, db: &#name) -> &Self::Output;
135
136            fn get_mut(self, db: &mut #name) -> &mut Self::Output;
137
138            fn insert(value: Self::Input<'_>, db: &mut #name) -> Id<Self::Output>;
139        }
140
141        pub trait DbInput<'a> {
142            type Id: DbAccess<Input<'a> = Self>;
143        }
144
145        pub trait DbAssocAccess {
146            type Key: ?Sized;
147
148            fn get_assoc(key: Id<Self::Key>, db: &#name) -> Option<&Self>;
149
150            fn get_assoc_mut(key: Id<Self::Key>, db: &mut #name) -> Option<&mut Self>;
151
152            fn insert(key: Id<Self::Key>, value: Self, db: &mut #name) -> &mut Self;
153        }
154
155        #(#impls)*
156    }
157    .into()
158}
159
160fn get_gen_ty(segment: &PathSegment) -> Option<&Type> {
161    let PathArguments::AngleBracketed(args) = &segment.arguments else {
162        return None;
163    };
164    if args.args.len() != 1 {
165        return None;
166    }
167    let GenericArgument::Type(ty) = args.args.first().unwrap() else {
168        return None;
169    };
170    Some(ty)
171}
172
173fn get_gen_tys(segment: &PathSegment) -> Option<(&Type, &Type)> {
174    let PathArguments::AngleBracketed(args) = &segment.arguments else {
175        return None;
176    };
177    if args.args.len() != 2 {
178        return None;
179    }
180    let GenericArgument::Type(ty1) = args.args.get(0).unwrap() else {
181        return None;
182    };
183    let GenericArgument::Type(ty2) = args.args.get(1).unwrap() else {
184        return None;
185    };
186    Some((ty1, ty2))
187}