1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Error, Field, Fields, GenericArgument,
5 PathArguments, Type, TypePath,
6};
7
8#[proc_macro_derive(Sensitive, attributes(secure))]
9pub fn derive_sensitive(input: TokenStream) -> TokenStream {
10 match derive_sensitive_impl(parse_macro_input!(input as DeriveInput)) {
11 Ok(tokens) => tokens.into(),
12 Err(err) => err.to_compile_error().into(),
13 }
14}
15
16#[proc_macro_derive(Store, attributes(unique))]
17pub fn derive_store(input: TokenStream) -> TokenStream {
18 match derive_store_impl(parse_macro_input!(input as DeriveInput)) {
19 Ok(tokens) => tokens.into(),
20 Err(err) => err.to_compile_error().into(),
21 }
22}
23
24#[proc_macro_derive(Relation, attributes(relation))]
25pub fn derive_relation(input: TokenStream) -> TokenStream {
26 match derive_relation_impl(parse_macro_input!(input as DeriveInput)) {
27 Ok(tokens) => tokens.into(),
28 Err(err) => err.to_compile_error().into(),
29 }
30}
31
32fn derive_store_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
33 let struct_ident = input.ident;
34
35 let named_fields = match input.data {
36 Data::Struct(data) => match data.fields {
37 Fields::Named(fields) => fields.named,
38 _ => {
39 return Err(Error::new_spanned(
40 struct_ident,
41 "Store can only be derived for structs with named fields",
42 ))
43 }
44 },
45 _ => {
46 return Err(Error::new_spanned(
47 struct_ident,
48 "Store can only be derived for structs",
49 ))
50 }
51 };
52
53 let id_fields = named_fields
54 .iter()
55 .filter(|field| is_id_type(&field.ty))
56 .map(|field| field.ident.clone().expect("named field"))
57 .collect::<Vec<_>>();
58
59 let unique_fields = named_fields
60 .iter()
61 .filter(|field| has_unique_attr(&field.attrs))
62 .map(|field| field.ident.clone().expect("named field"))
63 .collect::<Vec<_>>();
64
65 if id_fields.len() > 1 {
66 return Err(Error::new_spanned(
67 struct_ident,
68 "Store supports at most one `Id` field for automatic HasId generation",
69 ));
70 }
71
72 let auto_has_id_impl = id_fields.first().map(|field| {
73 quote! {
74 impl ::appdb::model::meta::HasId for #struct_ident {
75 fn id(&self) -> ::surrealdb::types::RecordId {
76 ::surrealdb::types::RecordId::new(
77 <Self as ::appdb::model::meta::ModelMeta>::table_name(),
78 self.#field.clone(),
79 )
80 }
81 }
82 }
83 });
84
85 let resolve_record_id_impl = if let Some(field) = id_fields.first() {
86 quote! {
87 #[::async_trait::async_trait]
88 impl ::appdb::model::meta::ResolveRecordId for #struct_ident {
89 async fn resolve_record_id(&self) -> ::anyhow::Result<::surrealdb::types::RecordId> {
90 Ok(::surrealdb::types::RecordId::new(
91 <Self as ::appdb::model::meta::ModelMeta>::table_name(),
92 self.#field.clone(),
93 ))
94 }
95 }
96 }
97 } else {
98 quote! {
99 #[::async_trait::async_trait]
100 impl ::appdb::model::meta::ResolveRecordId for #struct_ident {
101 async fn resolve_record_id(&self) -> ::anyhow::Result<::surrealdb::types::RecordId> {
102 ::appdb::repository::Repo::<Self>::find_unique_id_for(self).await
103 }
104 }
105 }
106 };
107
108 let unique_schema_impls = unique_fields.iter().map(|field| {
109 let field_name = field.to_string();
110 let index_name = format!("{}_{}_unique", to_snake_case(&struct_ident.to_string()), field_name);
111 let ddl = format!(
112 "DEFINE INDEX IF NOT EXISTS {index_name} ON {} FIELDS {field_name} UNIQUE;",
113 to_snake_case(&struct_ident.to_string())
114 );
115
116 quote! {
117 ::inventory::submit! {
118 ::appdb::model::schema::SchemaItem {
119 ddl: #ddl,
120 }
121 }
122 }
123 });
124
125 let lookup_fields = if unique_fields.is_empty() {
126 named_fields
127 .iter()
128 .filter_map(|field| {
129 let ident = field.ident.as_ref()?;
130 if ident == "id" {
131 None
132 } else {
133 Some(ident.to_string())
134 }
135 })
136 .collect::<Vec<_>>()
137 } else {
138 unique_fields.iter().map(|field| field.to_string()).collect::<Vec<_>>()
139 };
140 let lookup_field_literals = lookup_fields.iter().map(|field| quote! { #field });
141
142 Ok(quote! {
143 impl ::appdb::model::meta::ModelMeta for #struct_ident {
144 fn table_name() -> &'static str {
145 static TABLE_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
146 TABLE_NAME.get_or_init(|| {
147 let table = ::appdb::model::meta::default_table_name(stringify!(#struct_ident));
148 ::appdb::model::meta::register_table(stringify!(#struct_ident), table)
149 })
150 }
151 }
152
153 impl ::appdb::model::meta::UniqueLookupMeta for #struct_ident {
154 fn lookup_fields() -> &'static [&'static str] {
155 &[ #( #lookup_field_literals ),* ]
156 }
157 }
158
159 #auto_has_id_impl
160 #resolve_record_id_impl
161
162 #( #unique_schema_impls )*
163
164 impl ::appdb::repository::Crud for #struct_ident {}
165
166 impl #struct_ident {
167 pub async fn get<T>(id: T) -> ::anyhow::Result<Self>
168 where
169 ::surrealdb::types::RecordIdKey: From<T>,
170 T: Send,
171 {
172 ::appdb::repository::Repo::<Self>::get(id).await
173 }
174
175 pub async fn list() -> ::anyhow::Result<::std::vec::Vec<Self>> {
176 ::appdb::repository::Repo::<Self>::list().await
177 }
178
179 pub async fn list_limit(count: i64) -> ::anyhow::Result<::std::vec::Vec<Self>> {
180 ::appdb::repository::Repo::<Self>::list_limit(count).await
181 }
182
183 pub async fn delete_all() -> ::anyhow::Result<()> {
184 ::appdb::repository::Repo::<Self>::delete_all().await
185 }
186
187 pub async fn find_one_id(
188 k: &str,
189 v: &str,
190 ) -> ::anyhow::Result<::surrealdb::types::RecordId> {
191 ::appdb::repository::Repo::<Self>::find_one_id(k, v).await
192 }
193
194 pub async fn list_record_ids() -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>> {
195 ::appdb::repository::Repo::<Self>::list_record_ids().await
196 }
197
198 pub async fn create_at(
199 id: ::surrealdb::types::RecordId,
200 data: Self,
201 ) -> ::anyhow::Result<Self> {
202 ::appdb::repository::Repo::<Self>::create_at(id, data).await
203 }
204
205 pub async fn upsert_at(
206 id: ::surrealdb::types::RecordId,
207 data: Self,
208 ) -> ::anyhow::Result<Self> {
209 ::appdb::repository::Repo::<Self>::upsert_at(id, data).await
210 }
211
212 pub async fn update_at(
213 self,
214 id: ::surrealdb::types::RecordId,
215 ) -> ::anyhow::Result<Self> {
216 ::appdb::repository::Repo::<Self>::update_at(id, self).await
217 }
218
219 pub async fn delete<T>(id: T) -> ::anyhow::Result<()>
220 where
221 ::surrealdb::types::RecordIdKey: From<T>,
222 T: Send,
223 {
224 ::appdb::repository::Repo::<Self>::delete(id).await
225 }
226 }
227 })
228}
229
230fn derive_relation_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
231 let struct_ident = input.ident;
232 let relation_name = relation_name_override(&input.attrs)?
233 .unwrap_or_else(|| to_snake_case(&struct_ident.to_string()));
234
235 match input.data {
236 Data::Struct(data) => match data.fields {
237 Fields::Unit | Fields::Named(_) => {}
238 _ => {
239 return Err(Error::new_spanned(
240 struct_ident,
241 "Relation can only be derived for unit structs or structs with named fields",
242 ))
243 }
244 },
245 _ => {
246 return Err(Error::new_spanned(
247 struct_ident,
248 "Relation can only be derived for structs",
249 ))
250 }
251 }
252
253 Ok(quote! {
254 impl ::appdb::model::relation::RelationMeta for #struct_ident {
255 fn relation_name() -> &'static str {
256 static REL_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
257 REL_NAME.get_or_init(|| ::appdb::model::relation::register_relation(#relation_name))
258 }
259 }
260
261 impl #struct_ident {
262 pub async fn relate<A, B>(a: &A, b: &B) -> ::anyhow::Result<()>
263 where
264 A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
265 B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
266 {
267 ::appdb::graph::relate_at(a.resolve_record_id().await?, b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name()).await
268 }
269
270 pub async fn unrelate<A, B>(a: &A, b: &B) -> ::anyhow::Result<()>
271 where
272 A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
273 B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
274 {
275 ::appdb::graph::unrelate_at(a.resolve_record_id().await?, b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name()).await
276 }
277
278 pub async fn out_ids<A>(a: &A, out_table: &str) -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>>
279 where
280 A: ::appdb::model::meta::ResolveRecordId + Send + Sync,
281 {
282 ::appdb::graph::out_ids(a.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name(), out_table).await
283 }
284
285 pub async fn in_ids<B>(b: &B, in_table: &str) -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>>
286 where
287 B: ::appdb::model::meta::ResolveRecordId + Send + Sync,
288 {
289 ::appdb::graph::in_ids(b.resolve_record_id().await?, <Self as ::appdb::model::relation::RelationMeta>::relation_name(), in_table).await
290 }
291 }
292 })
293}
294
295fn derive_sensitive_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
296 let struct_ident = input.ident;
297 let encrypted_ident = format_ident!("Encrypted{}", struct_ident);
298 let vis = input.vis;
299
300 let named_fields = match input.data {
301 Data::Struct(data) => match data.fields {
302 Fields::Named(fields) => fields.named,
303 _ => {
304 return Err(Error::new_spanned(
305 struct_ident,
306 "Sensitive can only be derived for structs with named fields",
307 ))
308 }
309 },
310 _ => {
311 return Err(Error::new_spanned(
312 struct_ident,
313 "Sensitive can only be derived for structs",
314 ))
315 }
316 };
317
318 let mut secure_field_count = 0usize;
319 let mut encrypted_fields = Vec::new();
320 let mut encrypt_assignments = Vec::new();
321 let mut decrypt_assignments = Vec::new();
322
323 for field in named_fields.iter() {
324 let ident = field.ident.clone().expect("named field");
325 let field_vis = field.vis.clone();
326 let secure = has_secure_attr(&field.attrs);
327
328 if secure {
329 secure_field_count += 1;
330 let secure_kind = secure_kind(field)?;
331 let encrypted_ty = secure_kind.encrypted_type();
332 let encrypt_expr = secure_kind.encrypt_expr(&ident);
333 let decrypt_expr = secure_kind.decrypt_expr(&ident);
334 encrypted_fields.push(quote! { #field_vis #ident: #encrypted_ty });
335 encrypt_assignments.push(quote! { #ident: #encrypt_expr });
336 decrypt_assignments.push(quote! { #ident: #decrypt_expr });
337 } else {
338 let ty = field.ty.clone();
339 encrypted_fields.push(quote! { #field_vis #ident: #ty });
340 encrypt_assignments.push(quote! { #ident: self.#ident.clone() });
341 decrypt_assignments.push(quote! { #ident: encrypted.#ident.clone() });
342 }
343 }
344
345 if secure_field_count == 0 {
346 return Err(Error::new_spanned(
347 struct_ident,
348 "Sensitive requires at least one #[secure] field",
349 ));
350 }
351
352 Ok(quote! {
353 #[derive(
354 Debug,
355 Clone,
356 ::serde::Serialize,
357 ::serde::Deserialize,
358 ::surrealdb::types::SurrealValue,
359 )]
360 #vis struct #encrypted_ident {
361 #( #encrypted_fields, )*
362 }
363
364 impl ::appdb::Sensitive for #struct_ident {
365 type Encrypted = #encrypted_ident;
366
367 fn encrypt(
368 &self,
369 context: &::appdb::crypto::CryptoContext,
370 ) -> ::std::result::Result<Self::Encrypted, ::appdb::crypto::CryptoError> {
371 ::std::result::Result::Ok(#encrypted_ident {
372 #( #encrypt_assignments, )*
373 })
374 }
375
376 fn decrypt(
377 encrypted: &Self::Encrypted,
378 context: &::appdb::crypto::CryptoContext,
379 ) -> ::std::result::Result<Self, ::appdb::crypto::CryptoError> {
380 ::std::result::Result::Ok(Self {
381 #( #decrypt_assignments, )*
382 })
383 }
384 }
385
386 impl #struct_ident {
387 pub fn encrypt(
388 &self,
389 context: &::appdb::crypto::CryptoContext,
390 ) -> ::std::result::Result<#encrypted_ident, ::appdb::crypto::CryptoError> {
391 <Self as ::appdb::Sensitive>::encrypt(self, context)
392 }
393 }
394
395 impl #encrypted_ident {
396 pub fn decrypt(
397 &self,
398 context: &::appdb::crypto::CryptoContext,
399 ) -> ::std::result::Result<#struct_ident, ::appdb::crypto::CryptoError> {
400 <#struct_ident as ::appdb::Sensitive>::decrypt(self, context)
401 }
402 }
403 })
404}
405
406fn has_secure_attr(attrs: &[Attribute]) -> bool {
407 attrs.iter().any(|attr| attr.path().is_ident("secure"))
408}
409
410fn has_unique_attr(attrs: &[Attribute]) -> bool {
411 attrs.iter().any(|attr| attr.path().is_ident("unique"))
412}
413
414fn relation_name_override(attrs: &[Attribute]) -> syn::Result<Option<String>> {
415 for attr in attrs {
416 if !attr.path().is_ident("relation") {
417 continue;
418 }
419
420 let mut name = None;
421 attr.parse_nested_meta(|meta| {
422 if meta.path.is_ident("name") {
423 let value = meta.value()?;
424 let literal: syn::LitStr = value.parse()?;
425 name = Some(literal.value());
426 Ok(())
427 } else {
428 Err(meta.error("unsupported relation attribute"))
429 }
430 })?;
431 return Ok(name);
432 }
433
434 Ok(None)
435}
436
437enum SecureKind {
438 String,
439 OptionString,
440}
441
442impl SecureKind {
443 fn encrypted_type(&self) -> proc_macro2::TokenStream {
444 match self {
445 SecureKind::String => quote! { ::std::vec::Vec<u8> },
446 SecureKind::OptionString => quote! { ::std::option::Option<::std::vec::Vec<u8>> },
447 }
448 }
449
450 fn encrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
451 match self {
452 SecureKind::String => {
453 quote! { ::appdb::crypto::encrypt_string(&self.#ident, context)? }
454 }
455 SecureKind::OptionString => {
456 quote! { ::appdb::crypto::encrypt_optional_string(&self.#ident, context)? }
457 }
458 }
459 }
460
461 fn decrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
462 match self {
463 SecureKind::String => {
464 quote! { ::appdb::crypto::decrypt_string(&encrypted.#ident, context)? }
465 }
466 SecureKind::OptionString => {
467 quote! { ::appdb::crypto::decrypt_optional_string(&encrypted.#ident, context)? }
468 }
469 }
470 }
471}
472
473fn secure_kind(field: &Field) -> syn::Result<SecureKind> {
474 if is_string_type(&field.ty) {
475 return Ok(SecureKind::String);
476 }
477
478 if let Some(inner) = option_inner_type(&field.ty) {
479 if is_string_type(inner) {
480 return Ok(SecureKind::OptionString);
481 }
482 }
483
484 Err(Error::new_spanned(
485 &field.ty,
486 "#[secure] currently supports only String and Option<String>",
487 ))
488}
489
490fn is_string_type(ty: &Type) -> bool {
491 match ty {
492 Type::Path(TypePath { path, .. }) => path.is_ident("String"),
493 _ => false,
494 }
495}
496
497fn is_id_type(ty: &Type) -> bool {
498 match ty {
499 Type::Path(TypePath { path, .. }) => path.segments.last().is_some_and(|segment| {
500 let ident = segment.ident.to_string();
501 ident == "Id"
502 }),
503 _ => false,
504 }
505}
506
507fn option_inner_type(ty: &Type) -> Option<&Type> {
508 let Type::Path(TypePath { path, .. }) = ty else {
509 return None;
510 };
511 let segment = path.segments.last()?;
512 if segment.ident != "Option" {
513 return None;
514 }
515 let PathArguments::AngleBracketed(args) = &segment.arguments else {
516 return None;
517 };
518 let GenericArgument::Type(inner) = args.args.first()? else {
519 return None;
520 };
521 Some(inner)
522}
523
524fn to_snake_case(input: &str) -> String {
525 let mut out = String::with_capacity(input.len() + 4);
526 let mut prev_is_lower_or_digit = false;
527
528 for ch in input.chars() {
529 if ch.is_ascii_uppercase() {
530 if prev_is_lower_or_digit {
531 out.push('_');
532 }
533 out.push(ch.to_ascii_lowercase());
534 prev_is_lower_or_digit = false;
535 } else {
536 out.push(ch);
537 prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
538 }
539 }
540
541 out
542}