Skip to main content

appdb_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse_macro_input, Attribute, Data, DeriveInput, Error, Field, Fields, GenericArgument,
5    PathArguments, Type, TypePath,
6};
7
8#[proc_macro_derive(Sensitive, attributes(secure))]
9pub fn derive_sensitive(input: TokenStream) -> TokenStream {
10    match derive_sensitive_impl(parse_macro_input!(input as DeriveInput)) {
11        Ok(tokens) => tokens.into(),
12        Err(err) => err.to_compile_error().into(),
13    }
14}
15
16#[proc_macro_derive(Store, attributes(unique))]
17pub fn derive_store(input: TokenStream) -> TokenStream {
18    match derive_store_impl(parse_macro_input!(input as DeriveInput)) {
19        Ok(tokens) => tokens.into(),
20        Err(err) => err.to_compile_error().into(),
21    }
22}
23
24fn derive_store_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
25    let struct_ident = input.ident;
26
27    let named_fields = match input.data {
28        Data::Struct(data) => match data.fields {
29            Fields::Named(fields) => fields.named,
30            _ => {
31                return Err(Error::new_spanned(
32                    struct_ident,
33                    "Store can only be derived for structs with named fields",
34                ))
35            }
36        },
37        _ => {
38            return Err(Error::new_spanned(
39                struct_ident,
40                "Store can only be derived for structs",
41            ))
42        }
43    };
44
45    let id_fields = named_fields
46        .iter()
47        .filter(|field| is_id_type(&field.ty))
48        .map(|field| field.ident.clone().expect("named field"))
49        .collect::<Vec<_>>();
50
51    let unique_fields = named_fields
52        .iter()
53        .filter(|field| has_unique_attr(&field.attrs))
54        .map(|field| field.ident.clone().expect("named field"))
55        .collect::<Vec<_>>();
56
57    if id_fields.len() > 1 {
58        return Err(Error::new_spanned(
59            struct_ident,
60            "Store supports at most one `Id` field for automatic HasId generation",
61        ));
62    }
63
64    let auto_has_id_impl = id_fields.first().map(|field| {
65        quote! {
66            impl ::appdb::model::meta::HasId for #struct_ident {
67                fn id(&self) -> ::surrealdb::types::RecordId {
68                    ::surrealdb::types::RecordId::new(
69                        <Self as ::appdb::model::meta::ModelMeta>::table_name(),
70                        self.#field.clone(),
71                    )
72                }
73            }
74        }
75    });
76
77    let unique_schema_impls = unique_fields.iter().map(|field| {
78        let field_name = field.to_string();
79        let index_name = format!("{}_{}_unique", to_snake_case(&struct_ident.to_string()), field_name);
80        let ddl = format!(
81            "DEFINE INDEX IF NOT EXISTS {index_name} ON {} FIELDS {field_name} UNIQUE;",
82            to_snake_case(&struct_ident.to_string())
83        );
84
85        quote! {
86            ::inventory::submit! {
87                ::appdb::model::schema::SchemaItem {
88                    ddl: #ddl,
89                }
90            }
91        }
92    });
93
94    Ok(quote! {
95        impl ::appdb::model::meta::ModelMeta for #struct_ident {
96            fn table_name() -> &'static str {
97                static TABLE_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
98                TABLE_NAME.get_or_init(|| {
99                    let table = ::appdb::model::meta::default_table_name(stringify!(#struct_ident));
100                    ::appdb::model::meta::register_table(stringify!(#struct_ident), table)
101                })
102            }
103        }
104
105        #auto_has_id_impl
106
107        #( #unique_schema_impls )*
108
109        impl ::appdb::repository::Crud for #struct_ident {}
110
111        impl #struct_ident {
112            pub async fn get<T>(id: T) -> ::anyhow::Result<Self>
113            where
114                ::surrealdb::types::RecordIdKey: From<T>,
115                T: Send,
116            {
117                ::appdb::repository::Repo::<Self>::get(id).await
118            }
119
120            pub async fn list() -> ::anyhow::Result<::std::vec::Vec<Self>> {
121                ::appdb::repository::Repo::<Self>::list().await
122            }
123
124            pub async fn list_limit(count: i64) -> ::anyhow::Result<::std::vec::Vec<Self>> {
125                ::appdb::repository::Repo::<Self>::list_limit(count).await
126            }
127
128            pub async fn delete_all() -> ::anyhow::Result<()> {
129                ::appdb::repository::Repo::<Self>::delete_all().await
130            }
131
132            pub async fn find_one_id(
133                k: &str,
134                v: &str,
135            ) -> ::anyhow::Result<::surrealdb::types::RecordId> {
136                ::appdb::repository::Repo::<Self>::find_one_id(k, v).await
137            }
138
139            pub async fn list_record_ids() -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>> {
140                ::appdb::repository::Repo::<Self>::list_record_ids().await
141            }
142
143            pub async fn create_at(
144                id: ::surrealdb::types::RecordId,
145                data: Self,
146            ) -> ::anyhow::Result<Self> {
147                ::appdb::repository::Repo::<Self>::create_at(id, data).await
148            }
149
150            pub async fn upsert_at(
151                id: ::surrealdb::types::RecordId,
152                data: Self,
153            ) -> ::anyhow::Result<Self> {
154                ::appdb::repository::Repo::<Self>::upsert_at(id, data).await
155            }
156
157            pub async fn update_at(
158                self,
159                id: ::surrealdb::types::RecordId,
160            ) -> ::anyhow::Result<Self> {
161                ::appdb::repository::Repo::<Self>::update_at(id, self).await
162            }
163
164            pub async fn delete<T>(id: T) -> ::anyhow::Result<()>
165            where
166                ::surrealdb::types::RecordIdKey: From<T>,
167                T: Send,
168            {
169                ::appdb::repository::Repo::<Self>::delete(id).await
170            }
171        }
172    })
173}
174
175fn derive_sensitive_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
176    let struct_ident = input.ident;
177    let encrypted_ident = format_ident!("Encrypted{}", struct_ident);
178    let vis = input.vis;
179
180    let named_fields = match input.data {
181        Data::Struct(data) => match data.fields {
182            Fields::Named(fields) => fields.named,
183            _ => {
184                return Err(Error::new_spanned(
185                    struct_ident,
186                    "Sensitive can only be derived for structs with named fields",
187                ))
188            }
189        },
190        _ => {
191            return Err(Error::new_spanned(
192                struct_ident,
193                "Sensitive can only be derived for structs",
194            ))
195        }
196    };
197
198    let mut secure_field_count = 0usize;
199    let mut encrypted_fields = Vec::new();
200    let mut encrypt_assignments = Vec::new();
201    let mut decrypt_assignments = Vec::new();
202
203    for field in named_fields.iter() {
204        let ident = field.ident.clone().expect("named field");
205        let field_vis = field.vis.clone();
206        let secure = has_secure_attr(&field.attrs);
207
208        if secure {
209            secure_field_count += 1;
210            let secure_kind = secure_kind(field)?;
211            let encrypted_ty = secure_kind.encrypted_type();
212            let encrypt_expr = secure_kind.encrypt_expr(&ident);
213            let decrypt_expr = secure_kind.decrypt_expr(&ident);
214            encrypted_fields.push(quote! { #field_vis #ident: #encrypted_ty });
215            encrypt_assignments.push(quote! { #ident: #encrypt_expr });
216            decrypt_assignments.push(quote! { #ident: #decrypt_expr });
217        } else {
218            let ty = field.ty.clone();
219            encrypted_fields.push(quote! { #field_vis #ident: #ty });
220            encrypt_assignments.push(quote! { #ident: self.#ident.clone() });
221            decrypt_assignments.push(quote! { #ident: encrypted.#ident.clone() });
222        }
223    }
224
225    if secure_field_count == 0 {
226        return Err(Error::new_spanned(
227            struct_ident,
228            "Sensitive requires at least one #[secure] field",
229        ));
230    }
231
232    Ok(quote! {
233        #[derive(
234            Debug,
235            Clone,
236            ::serde::Serialize,
237            ::serde::Deserialize,
238            ::surrealdb::types::SurrealValue,
239        )]
240        #vis struct #encrypted_ident {
241            #( #encrypted_fields, )*
242        }
243
244        impl ::appdb::Sensitive for #struct_ident {
245            type Encrypted = #encrypted_ident;
246
247            fn encrypt(
248                &self,
249                context: &::appdb::crypto::CryptoContext,
250            ) -> ::std::result::Result<Self::Encrypted, ::appdb::crypto::CryptoError> {
251                ::std::result::Result::Ok(#encrypted_ident {
252                    #( #encrypt_assignments, )*
253                })
254            }
255
256            fn decrypt(
257                encrypted: &Self::Encrypted,
258                context: &::appdb::crypto::CryptoContext,
259            ) -> ::std::result::Result<Self, ::appdb::crypto::CryptoError> {
260                ::std::result::Result::Ok(Self {
261                    #( #decrypt_assignments, )*
262                })
263            }
264        }
265
266        impl #struct_ident {
267            pub fn encrypt(
268                &self,
269                context: &::appdb::crypto::CryptoContext,
270            ) -> ::std::result::Result<#encrypted_ident, ::appdb::crypto::CryptoError> {
271                <Self as ::appdb::Sensitive>::encrypt(self, context)
272            }
273        }
274
275        impl #encrypted_ident {
276            pub fn decrypt(
277                &self,
278                context: &::appdb::crypto::CryptoContext,
279            ) -> ::std::result::Result<#struct_ident, ::appdb::crypto::CryptoError> {
280                <#struct_ident as ::appdb::Sensitive>::decrypt(self, context)
281            }
282        }
283    })
284}
285
286fn has_secure_attr(attrs: &[Attribute]) -> bool {
287    attrs.iter().any(|attr| attr.path().is_ident("secure"))
288}
289
290fn has_unique_attr(attrs: &[Attribute]) -> bool {
291    attrs.iter().any(|attr| attr.path().is_ident("unique"))
292}
293
294enum SecureKind {
295    String,
296    OptionString,
297}
298
299impl SecureKind {
300    fn encrypted_type(&self) -> proc_macro2::TokenStream {
301        match self {
302            SecureKind::String => quote! { ::std::vec::Vec<u8> },
303            SecureKind::OptionString => quote! { ::std::option::Option<::std::vec::Vec<u8>> },
304        }
305    }
306
307    fn encrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
308        match self {
309            SecureKind::String => {
310                quote! { ::appdb::crypto::encrypt_string(&self.#ident, context)? }
311            }
312            SecureKind::OptionString => {
313                quote! { ::appdb::crypto::encrypt_optional_string(&self.#ident, context)? }
314            }
315        }
316    }
317
318    fn decrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
319        match self {
320            SecureKind::String => {
321                quote! { ::appdb::crypto::decrypt_string(&encrypted.#ident, context)? }
322            }
323            SecureKind::OptionString => {
324                quote! { ::appdb::crypto::decrypt_optional_string(&encrypted.#ident, context)? }
325            }
326        }
327    }
328}
329
330fn secure_kind(field: &Field) -> syn::Result<SecureKind> {
331    if is_string_type(&field.ty) {
332        return Ok(SecureKind::String);
333    }
334
335    if let Some(inner) = option_inner_type(&field.ty) {
336        if is_string_type(inner) {
337            return Ok(SecureKind::OptionString);
338        }
339    }
340
341    Err(Error::new_spanned(
342        &field.ty,
343        "#[secure] currently supports only String and Option<String>",
344    ))
345}
346
347fn is_string_type(ty: &Type) -> bool {
348    match ty {
349        Type::Path(TypePath { path, .. }) => path.is_ident("String"),
350        _ => false,
351    }
352}
353
354fn is_id_type(ty: &Type) -> bool {
355    match ty {
356        Type::Path(TypePath { path, .. }) => path.segments.last().is_some_and(|segment| {
357            let ident = segment.ident.to_string();
358            ident == "Id"
359        }),
360        _ => false,
361    }
362}
363
364fn option_inner_type(ty: &Type) -> Option<&Type> {
365    let Type::Path(TypePath { path, .. }) = ty else {
366        return None;
367    };
368    let segment = path.segments.last()?;
369    if segment.ident != "Option" {
370        return None;
371    }
372    let PathArguments::AngleBracketed(args) = &segment.arguments else {
373        return None;
374    };
375    let GenericArgument::Type(inner) = args.args.first()? else {
376        return None;
377    };
378    Some(inner)
379}
380
381fn to_snake_case(input: &str) -> String {
382    let mut out = String::with_capacity(input.len() + 4);
383    let mut prev_is_lower_or_digit = false;
384
385    for ch in input.chars() {
386        if ch.is_ascii_uppercase() {
387            if prev_is_lower_or_digit {
388                out.push('_');
389            }
390            out.push(ch.to_ascii_lowercase());
391            prev_is_lower_or_digit = false;
392        } else {
393            out.push(ch);
394            prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
395        }
396    }
397
398    out
399}