Skip to main content

appdb_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse_macro_input, Attribute, Data, DeriveInput, Error, Field, Fields, GenericArgument,
5    PathArguments, Type, TypePath,
6};
7
8#[proc_macro_derive(Sensitive, attributes(secure))]
9pub fn derive_sensitive(input: TokenStream) -> TokenStream {
10    match derive_sensitive_impl(parse_macro_input!(input as DeriveInput)) {
11        Ok(tokens) => tokens.into(),
12        Err(err) => err.to_compile_error().into(),
13    }
14}
15
16#[proc_macro_derive(Store, attributes(unique))]
17pub fn derive_store(input: TokenStream) -> TokenStream {
18    match derive_store_impl(parse_macro_input!(input as DeriveInput)) {
19        Ok(tokens) => tokens.into(),
20        Err(err) => err.to_compile_error().into(),
21    }
22}
23
24#[proc_macro_derive(Relation, attributes(relation))]
25pub fn derive_relation(input: TokenStream) -> TokenStream {
26    match derive_relation_impl(parse_macro_input!(input as DeriveInput)) {
27        Ok(tokens) => tokens.into(),
28        Err(err) => err.to_compile_error().into(),
29    }
30}
31
32fn derive_store_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
33    let struct_ident = input.ident;
34
35    let named_fields = match input.data {
36        Data::Struct(data) => match data.fields {
37            Fields::Named(fields) => fields.named,
38            _ => {
39                return Err(Error::new_spanned(
40                    struct_ident,
41                    "Store can only be derived for structs with named fields",
42                ))
43            }
44        },
45        _ => {
46            return Err(Error::new_spanned(
47                struct_ident,
48                "Store can only be derived for structs",
49            ))
50        }
51    };
52
53    let id_fields = named_fields
54        .iter()
55        .filter(|field| is_id_type(&field.ty))
56        .map(|field| field.ident.clone().expect("named field"))
57        .collect::<Vec<_>>();
58
59    let unique_fields = named_fields
60        .iter()
61        .filter(|field| has_unique_attr(&field.attrs))
62        .map(|field| field.ident.clone().expect("named field"))
63        .collect::<Vec<_>>();
64
65    if id_fields.len() > 1 {
66        return Err(Error::new_spanned(
67            struct_ident,
68            "Store supports at most one `Id` field for automatic HasId generation",
69        ));
70    }
71
72    let auto_has_id_impl = id_fields.first().map(|field| {
73        quote! {
74            impl ::appdb::model::meta::HasId for #struct_ident {
75                fn id(&self) -> ::surrealdb::types::RecordId {
76                    ::surrealdb::types::RecordId::new(
77                        <Self as ::appdb::model::meta::ModelMeta>::table_name(),
78                        self.#field.clone(),
79                    )
80                }
81            }
82        }
83    });
84
85    let resolve_record_id_impl = if let Some(field) = id_fields.first() {
86        quote! {
87            #[::async_trait::async_trait]
88            impl ::appdb::model::meta::ResolveRecordId for #struct_ident {
89                async fn resolve_record_id(&self) -> ::anyhow::Result<::surrealdb::types::RecordId> {
90                    Ok(::surrealdb::types::RecordId::new(
91                        <Self as ::appdb::model::meta::ModelMeta>::table_name(),
92                        self.#field.clone(),
93                    ))
94                }
95            }
96        }
97    } else {
98        quote! {
99            #[::async_trait::async_trait]
100            impl ::appdb::model::meta::ResolveRecordId for #struct_ident {
101                async fn resolve_record_id(&self) -> ::anyhow::Result<::surrealdb::types::RecordId> {
102                    ::appdb::repository::Repo::<Self>::find_unique_id_for(self).await
103                }
104            }
105        }
106    };
107
108    let unique_schema_impls = unique_fields.iter().map(|field| {
109        let field_name = field.to_string();
110        let index_name = format!("{}_{}_unique", to_snake_case(&struct_ident.to_string()), field_name);
111        let ddl = format!(
112            "DEFINE INDEX IF NOT EXISTS {index_name} ON {} FIELDS {field_name} UNIQUE;",
113            to_snake_case(&struct_ident.to_string())
114        );
115
116        quote! {
117            ::inventory::submit! {
118                ::appdb::model::schema::SchemaItem {
119                    ddl: #ddl,
120                }
121            }
122        }
123    });
124
125    let lookup_fields = if unique_fields.is_empty() {
126        named_fields
127            .iter()
128            .filter_map(|field| {
129                let ident = field.ident.as_ref()?;
130                if ident == "id" {
131                    None
132                } else {
133                    Some(ident.to_string())
134                }
135            })
136            .collect::<Vec<_>>()
137    } else {
138        unique_fields.iter().map(|field| field.to_string()).collect::<Vec<_>>()
139    };
140    let lookup_field_literals = lookup_fields.iter().map(|field| quote! { #field });
141
142    Ok(quote! {
143        impl ::appdb::model::meta::ModelMeta for #struct_ident {
144            fn table_name() -> &'static str {
145                static TABLE_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
146                TABLE_NAME.get_or_init(|| {
147                    let table = ::appdb::model::meta::default_table_name(stringify!(#struct_ident));
148                    ::appdb::model::meta::register_table(stringify!(#struct_ident), table)
149                })
150            }
151        }
152
153        impl ::appdb::model::meta::UniqueLookupMeta for #struct_ident {
154            fn lookup_fields() -> &'static [&'static str] {
155                &[ #( #lookup_field_literals ),* ]
156            }
157        }
158
159        #auto_has_id_impl
160        #resolve_record_id_impl
161
162        #( #unique_schema_impls )*
163
164        impl ::appdb::repository::Crud for #struct_ident {}
165
166        impl #struct_ident {
167            pub async fn get<T>(id: T) -> ::anyhow::Result<Self>
168            where
169                ::surrealdb::types::RecordIdKey: From<T>,
170                T: Send,
171            {
172                ::appdb::repository::Repo::<Self>::get(id).await
173            }
174
175            pub async fn list() -> ::anyhow::Result<::std::vec::Vec<Self>> {
176                ::appdb::repository::Repo::<Self>::list().await
177            }
178
179            pub async fn list_limit(count: i64) -> ::anyhow::Result<::std::vec::Vec<Self>> {
180                ::appdb::repository::Repo::<Self>::list_limit(count).await
181            }
182
183            pub async fn delete_all() -> ::anyhow::Result<()> {
184                ::appdb::repository::Repo::<Self>::delete_all().await
185            }
186
187            pub async fn find_one_id(
188                k: &str,
189                v: &str,
190            ) -> ::anyhow::Result<::surrealdb::types::RecordId> {
191                ::appdb::repository::Repo::<Self>::find_one_id(k, v).await
192            }
193
194            pub async fn list_record_ids() -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>> {
195                ::appdb::repository::Repo::<Self>::list_record_ids().await
196            }
197
198            pub async fn create_at(
199                id: ::surrealdb::types::RecordId,
200                data: Self,
201            ) -> ::anyhow::Result<Self> {
202                ::appdb::repository::Repo::<Self>::create_at(id, data).await
203            }
204
205            pub async fn upsert_at(
206                id: ::surrealdb::types::RecordId,
207                data: Self,
208            ) -> ::anyhow::Result<Self> {
209                ::appdb::repository::Repo::<Self>::upsert_at(id, data).await
210            }
211
212            pub async fn update_at(
213                self,
214                id: ::surrealdb::types::RecordId,
215            ) -> ::anyhow::Result<Self> {
216                ::appdb::repository::Repo::<Self>::update_at(id, self).await
217            }
218
219            pub async fn delete<T>(id: T) -> ::anyhow::Result<()>
220            where
221                ::surrealdb::types::RecordIdKey: From<T>,
222                T: Send,
223            {
224                ::appdb::repository::Repo::<Self>::delete(id).await
225            }
226        }
227    })
228}
229
230fn derive_relation_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
231    let struct_ident = input.ident;
232    let relation_name = relation_name_override(&input.attrs)?
233        .unwrap_or_else(|| to_snake_case(&struct_ident.to_string()));
234
235    match input.data {
236        Data::Struct(data) => match data.fields {
237            Fields::Unit | Fields::Named(_) => {}
238            _ => {
239                return Err(Error::new_spanned(
240                    struct_ident,
241                    "Relation can only be derived for unit structs or structs with named fields",
242                ))
243            }
244        },
245        _ => {
246            return Err(Error::new_spanned(
247                struct_ident,
248                "Relation can only be derived for structs",
249            ))
250        }
251    }
252
253    Ok(quote! {
254        impl ::appdb::model::relation::RelationMeta for #struct_ident {
255            fn relation_name() -> &'static str {
256                static REL_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
257                REL_NAME.get_or_init(|| ::appdb::model::relation::register_relation(#relation_name))
258            }
259        }
260
261        impl #struct_ident {
262            pub async fn relate<A, B>(a: &A, b: &B) -> ::anyhow::Result<()>
263            where
264                A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
265                B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
266            {
267                ::appdb::graph::relate_at(a.resolve_record_id().await?, b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name()).await
268            }
269
270            pub async fn unrelate<A, B>(a: &A, b: &B) -> ::anyhow::Result<()>
271            where
272                A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
273                B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
274            {
275                ::appdb::graph::unrelate_at(a.resolve_record_id().await?, b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name()).await
276            }
277
278            pub async fn out_ids<A>(a: &A, out_table: &str) -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>>
279            where
280                A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
281            {
282                ::appdb::graph::out_ids(a.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name(), out_table).await
283            }
284
285            pub async fn in_ids<B>(b: &B, in_table: &str) -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>>
286            where
287                B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
288            {
289                ::appdb::graph::in_ids(b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name(), in_table).await
290            }
291        }
292    })
293}
294
295fn derive_sensitive_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
296    let struct_ident = input.ident;
297    let encrypted_ident = format_ident!("Encrypted{}", struct_ident);
298    let vis = input.vis;
299
300    let named_fields = match input.data {
301        Data::Struct(data) => match data.fields {
302            Fields::Named(fields) => fields.named,
303            _ => {
304                return Err(Error::new_spanned(
305                    struct_ident,
306                    "Sensitive can only be derived for structs with named fields",
307                ))
308            }
309        },
310        _ => {
311            return Err(Error::new_spanned(
312                struct_ident,
313                "Sensitive can only be derived for structs",
314            ))
315        }
316    };
317
318    let mut secure_field_count = 0usize;
319    let mut encrypted_fields = Vec::new();
320    let mut encrypt_assignments = Vec::new();
321    let mut decrypt_assignments = Vec::new();
322
323    for field in named_fields.iter() {
324        let ident = field.ident.clone().expect("named field");
325        let field_vis = field.vis.clone();
326        let secure = has_secure_attr(&field.attrs);
327
328        if secure {
329            secure_field_count += 1;
330            let secure_kind = secure_kind(field)?;
331            let encrypted_ty = secure_kind.encrypted_type();
332            let encrypt_expr = secure_kind.encrypt_expr(&ident);
333            let decrypt_expr = secure_kind.decrypt_expr(&ident);
334            encrypted_fields.push(quote! { #field_vis #ident: #encrypted_ty });
335            encrypt_assignments.push(quote! { #ident: #encrypt_expr });
336            decrypt_assignments.push(quote! { #ident: #decrypt_expr });
337        } else {
338            let ty = field.ty.clone();
339            encrypted_fields.push(quote! { #field_vis #ident: #ty });
340            encrypt_assignments.push(quote! { #ident: self.#ident.clone() });
341            decrypt_assignments.push(quote! { #ident: encrypted.#ident.clone() });
342        }
343    }
344
345    if secure_field_count == 0 {
346        return Err(Error::new_spanned(
347            struct_ident,
348            "Sensitive requires at least one #[secure] field",
349        ));
350    }
351
352    Ok(quote! {
353        #[derive(
354            Debug,
355            Clone,
356            ::serde::Serialize,
357            ::serde::Deserialize,
358            ::surrealdb::types::SurrealValue,
359        )]
360        #vis struct #encrypted_ident {
361            #( #encrypted_fields, )*
362        }
363
364        impl ::appdb::Sensitive for #struct_ident {
365            type Encrypted = #encrypted_ident;
366
367            fn encrypt(
368                &self,
369                context: &::appdb::crypto::CryptoContext,
370            ) -> ::std::result::Result<Self::Encrypted, ::appdb::crypto::CryptoError> {
371                ::std::result::Result::Ok(#encrypted_ident {
372                    #( #encrypt_assignments, )*
373                })
374            }
375
376            fn decrypt(
377                encrypted: &Self::Encrypted,
378                context: &::appdb::crypto::CryptoContext,
379            ) -> ::std::result::Result<Self, ::appdb::crypto::CryptoError> {
380                ::std::result::Result::Ok(Self {
381                    #( #decrypt_assignments, )*
382                })
383            }
384        }
385
386        impl #struct_ident {
387            pub fn encrypt(
388                &self,
389                context: &::appdb::crypto::CryptoContext,
390            ) -> ::std::result::Result<#encrypted_ident, ::appdb::crypto::CryptoError> {
391                <Self as ::appdb::Sensitive>::encrypt(self, context)
392            }
393        }
394
395        impl #encrypted_ident {
396            pub fn decrypt(
397                &self,
398                context: &::appdb::crypto::CryptoContext,
399            ) -> ::std::result::Result<#struct_ident, ::appdb::crypto::CryptoError> {
400                <#struct_ident as ::appdb::Sensitive>::decrypt(self, context)
401            }
402        }
403    })
404}
405
406fn has_secure_attr(attrs: &[Attribute]) -> bool {
407    attrs.iter().any(|attr| attr.path().is_ident("secure"))
408}
409
410fn has_unique_attr(attrs: &[Attribute]) -> bool {
411    attrs.iter().any(|attr| attr.path().is_ident("unique"))
412}
413
414fn relation_name_override(attrs: &[Attribute]) -> syn::Result<Option<String>> {
415    for attr in attrs {
416        if !attr.path().is_ident("relation") {
417            continue;
418        }
419
420        let mut name = None;
421        attr.parse_nested_meta(|meta| {
422            if meta.path.is_ident("name") {
423                let value = meta.value()?;
424                let literal: syn::LitStr = value.parse()?;
425                name = Some(literal.value());
426                Ok(())
427            } else {
428                Err(meta.error("unsupported relation attribute"))
429            }
430        })?;
431        return Ok(name);
432    }
433
434    Ok(None)
435}
436
437enum SecureKind {
438    String,
439    OptionString,
440}
441
442impl SecureKind {
443    fn encrypted_type(&self) -> proc_macro2::TokenStream {
444        match self {
445            SecureKind::String => quote! { ::std::vec::Vec<u8> },
446            SecureKind::OptionString => quote! { ::std::option::Option<::std::vec::Vec<u8>> },
447        }
448    }
449
450    fn encrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
451        match self {
452            SecureKind::String => {
453                quote! { ::appdb::crypto::encrypt_string(&self.#ident, context)? }
454            }
455            SecureKind::OptionString => {
456                quote! { ::appdb::crypto::encrypt_optional_string(&self.#ident, context)? }
457            }
458        }
459    }
460
461    fn decrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
462        match self {
463            SecureKind::String => {
464                quote! { ::appdb::crypto::decrypt_string(&encrypted.#ident, context)? }
465            }
466            SecureKind::OptionString => {
467                quote! { ::appdb::crypto::decrypt_optional_string(&encrypted.#ident, context)? }
468            }
469        }
470    }
471}
472
473fn secure_kind(field: &Field) -> syn::Result<SecureKind> {
474    if is_string_type(&field.ty) {
475        return Ok(SecureKind::String);
476    }
477
478    if let Some(inner) = option_inner_type(&field.ty) {
479        if is_string_type(inner) {
480            return Ok(SecureKind::OptionString);
481        }
482    }
483
484    Err(Error::new_spanned(
485        &field.ty,
486        "#[secure] currently supports only String and Option<String>",
487    ))
488}
489
490fn is_string_type(ty: &Type) -> bool {
491    match ty {
492        Type::Path(TypePath { path, .. }) => path.is_ident("String"),
493        _ => false,
494    }
495}
496
497fn is_id_type(ty: &Type) -> bool {
498    match ty {
499        Type::Path(TypePath { path, .. }) => path.segments.last().is_some_and(|segment| {
500            let ident = segment.ident.to_string();
501            ident == "Id"
502        }),
503        _ => false,
504    }
505}
506
507fn option_inner_type(ty: &Type) -> Option<&Type> {
508    let Type::Path(TypePath { path, .. }) = ty else {
509        return None;
510    };
511    let segment = path.segments.last()?;
512    if segment.ident != "Option" {
513        return None;
514    }
515    let PathArguments::AngleBracketed(args) = &segment.arguments else {
516        return None;
517    };
518    let GenericArgument::Type(inner) = args.args.first()? else {
519        return None;
520    };
521    Some(inner)
522}
523
524fn to_snake_case(input: &str) -> String {
525    let mut out = String::with_capacity(input.len() + 4);
526    let mut prev_is_lower_or_digit = false;
527
528    for ch in input.chars() {
529        if ch.is_ascii_uppercase() {
530            if prev_is_lower_or_digit {
531                out.push('_');
532            }
533            out.push(ch.to_ascii_lowercase());
534            prev_is_lower_or_digit = false;
535        } else {
536            out.push(ch);
537            prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
538        }
539    }
540
541    out
542}