1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse_macro_input, Attribute, Data, DeriveInput, Error, Field, Fields, GenericArgument,
5 PathArguments, Type, TypePath,
6};
7
8#[proc_macro_derive(Sensitive, attributes(secure))]
9pub fn derive_sensitive(input: TokenStream) -> TokenStream {
10 match derive_sensitive_impl(parse_macro_input!(input as DeriveInput)) {
11 Ok(tokens) => tokens.into(),
12 Err(err) => err.to_compile_error().into(),
13 }
14}
15
16#[proc_macro_derive(Store, attributes(unique))]
17pub fn derive_store(input: TokenStream) -> TokenStream {
18 match derive_store_impl(parse_macro_input!(input as DeriveInput)) {
19 Ok(tokens) => tokens.into(),
20 Err(err) => err.to_compile_error().into(),
21 }
22}
23
24fn derive_store_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
25 let struct_ident = input.ident;
26
27 let named_fields = match input.data {
28 Data::Struct(data) => match data.fields {
29 Fields::Named(fields) => fields.named,
30 _ => {
31 return Err(Error::new_spanned(
32 struct_ident,
33 "Store can only be derived for structs with named fields",
34 ))
35 }
36 },
37 _ => {
38 return Err(Error::new_spanned(
39 struct_ident,
40 "Store can only be derived for structs",
41 ))
42 }
43 };
44
45 let id_fields = named_fields
46 .iter()
47 .filter(|field| is_id_type(&field.ty))
48 .map(|field| field.ident.clone().expect("named field"))
49 .collect::<Vec<_>>();
50
51 let unique_fields = named_fields
52 .iter()
53 .filter(|field| has_unique_attr(&field.attrs))
54 .map(|field| field.ident.clone().expect("named field"))
55 .collect::<Vec<_>>();
56
57 if id_fields.len() > 1 {
58 return Err(Error::new_spanned(
59 struct_ident,
60 "Store supports at most one `Id` field for automatic HasId generation",
61 ));
62 }
63
64 let auto_has_id_impl = id_fields.first().map(|field| {
65 quote! {
66 impl ::appdb::model::meta::HasId for #struct_ident {
67 fn id(&self) -> ::surrealdb::types::RecordId {
68 ::surrealdb::types::RecordId::new(
69 <Self as ::appdb::model::meta::ModelMeta>::table_name(),
70 self.#field.clone(),
71 )
72 }
73 }
74 }
75 });
76
77 let unique_schema_impls = unique_fields.iter().map(|field| {
78 let field_name = field.to_string();
79 let index_name = format!("{}_{}_unique", to_snake_case(&struct_ident.to_string()), field_name);
80 let ddl = format!(
81 "DEFINE INDEX IF NOT EXISTS {index_name} ON {} FIELDS {field_name} UNIQUE;",
82 to_snake_case(&struct_ident.to_string())
83 );
84
85 quote! {
86 ::inventory::submit! {
87 ::appdb::model::schema::SchemaItem {
88 ddl: #ddl,
89 }
90 }
91 }
92 });
93
94 Ok(quote! {
95 impl ::appdb::model::meta::ModelMeta for #struct_ident {
96 fn table_name() -> &'static str {
97 static TABLE_NAME: ::std::sync::OnceLock<&'static str> = ::std::sync::OnceLock::new();
98 TABLE_NAME.get_or_init(|| {
99 let table = ::appdb::model::meta::default_table_name(stringify!(#struct_ident));
100 ::appdb::model::meta::register_table(stringify!(#struct_ident), table)
101 })
102 }
103 }
104
105 #auto_has_id_impl
106
107 #( #unique_schema_impls )*
108
109 impl ::appdb::repository::Crud for #struct_ident {}
110
111 impl #struct_ident {
112 pub async fn get<T>(id: T) -> ::anyhow::Result<Self>
113 where
114 ::surrealdb::types::RecordIdKey: From<T>,
115 T: Send,
116 {
117 ::appdb::repository::Repo::<Self>::get(id).await
118 }
119
120 pub async fn list() -> ::anyhow::Result<::std::vec::Vec<Self>> {
121 ::appdb::repository::Repo::<Self>::list().await
122 }
123
124 pub async fn list_limit(count: i64) -> ::anyhow::Result<::std::vec::Vec<Self>> {
125 ::appdb::repository::Repo::<Self>::list_limit(count).await
126 }
127
128 pub async fn delete_all() -> ::anyhow::Result<()> {
129 ::appdb::repository::Repo::<Self>::delete_all().await
130 }
131
132 pub async fn find_one_id(
133 k: &str,
134 v: &str,
135 ) -> ::anyhow::Result<::surrealdb::types::RecordId> {
136 ::appdb::repository::Repo::<Self>::find_one_id(k, v).await
137 }
138
139 pub async fn list_record_ids() -> ::anyhow::Result<::std::vec::Vec<::surrealdb::types::RecordId>> {
140 ::appdb::repository::Repo::<Self>::list_record_ids().await
141 }
142
143 pub async fn create_at(
144 id: ::surrealdb::types::RecordId,
145 data: Self,
146 ) -> ::anyhow::Result<Self> {
147 ::appdb::repository::Repo::<Self>::create_at(id, data).await
148 }
149
150 pub async fn upsert_at(
151 id: ::surrealdb::types::RecordId,
152 data: Self,
153 ) -> ::anyhow::Result<Self> {
154 ::appdb::repository::Repo::<Self>::upsert_at(id, data).await
155 }
156
157 pub async fn update_at(
158 self,
159 id: ::surrealdb::types::RecordId,
160 ) -> ::anyhow::Result<Self> {
161 ::appdb::repository::Repo::<Self>::update_at(id, self).await
162 }
163
164 pub async fn delete<T>(id: T) -> ::anyhow::Result<()>
165 where
166 ::surrealdb::types::RecordIdKey: From<T>,
167 T: Send,
168 {
169 ::appdb::repository::Repo::<Self>::delete(id).await
170 }
171 }
172 })
173}
174
175fn derive_sensitive_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
176 let struct_ident = input.ident;
177 let encrypted_ident = format_ident!("Encrypted{}", struct_ident);
178 let vis = input.vis;
179
180 let named_fields = match input.data {
181 Data::Struct(data) => match data.fields {
182 Fields::Named(fields) => fields.named,
183 _ => {
184 return Err(Error::new_spanned(
185 struct_ident,
186 "Sensitive can only be derived for structs with named fields",
187 ))
188 }
189 },
190 _ => {
191 return Err(Error::new_spanned(
192 struct_ident,
193 "Sensitive can only be derived for structs",
194 ))
195 }
196 };
197
198 let mut secure_field_count = 0usize;
199 let mut encrypted_fields = Vec::new();
200 let mut encrypt_assignments = Vec::new();
201 let mut decrypt_assignments = Vec::new();
202
203 for field in named_fields.iter() {
204 let ident = field.ident.clone().expect("named field");
205 let field_vis = field.vis.clone();
206 let secure = has_secure_attr(&field.attrs);
207
208 if secure {
209 secure_field_count += 1;
210 let secure_kind = secure_kind(field)?;
211 let encrypted_ty = secure_kind.encrypted_type();
212 let encrypt_expr = secure_kind.encrypt_expr(&ident);
213 let decrypt_expr = secure_kind.decrypt_expr(&ident);
214 encrypted_fields.push(quote! { #field_vis #ident: #encrypted_ty });
215 encrypt_assignments.push(quote! { #ident: #encrypt_expr });
216 decrypt_assignments.push(quote! { #ident: #decrypt_expr });
217 } else {
218 let ty = field.ty.clone();
219 encrypted_fields.push(quote! { #field_vis #ident: #ty });
220 encrypt_assignments.push(quote! { #ident: self.#ident.clone() });
221 decrypt_assignments.push(quote! { #ident: encrypted.#ident.clone() });
222 }
223 }
224
225 if secure_field_count == 0 {
226 return Err(Error::new_spanned(
227 struct_ident,
228 "Sensitive requires at least one #[secure] field",
229 ));
230 }
231
232 Ok(quote! {
233 #[derive(
234 Debug,
235 Clone,
236 ::serde::Serialize,
237 ::serde::Deserialize,
238 ::surrealdb::types::SurrealValue,
239 )]
240 #vis struct #encrypted_ident {
241 #( #encrypted_fields, )*
242 }
243
244 impl ::appdb::Sensitive for #struct_ident {
245 type Encrypted = #encrypted_ident;
246
247 fn encrypt(
248 &self,
249 context: &::appdb::crypto::CryptoContext,
250 ) -> ::std::result::Result<Self::Encrypted, ::appdb::crypto::CryptoError> {
251 ::std::result::Result::Ok(#encrypted_ident {
252 #( #encrypt_assignments, )*
253 })
254 }
255
256 fn decrypt(
257 encrypted: &Self::Encrypted,
258 context: &::appdb::crypto::CryptoContext,
259 ) -> ::std::result::Result<Self, ::appdb::crypto::CryptoError> {
260 ::std::result::Result::Ok(Self {
261 #( #decrypt_assignments, )*
262 })
263 }
264 }
265
266 impl #struct_ident {
267 pub fn encrypt(
268 &self,
269 context: &::appdb::crypto::CryptoContext,
270 ) -> ::std::result::Result<#encrypted_ident, ::appdb::crypto::CryptoError> {
271 <Self as ::appdb::Sensitive>::encrypt(self, context)
272 }
273 }
274
275 impl #encrypted_ident {
276 pub fn decrypt(
277 &self,
278 context: &::appdb::crypto::CryptoContext,
279 ) -> ::std::result::Result<#struct_ident, ::appdb::crypto::CryptoError> {
280 <#struct_ident as ::appdb::Sensitive>::decrypt(self, context)
281 }
282 }
283 })
284}
285
286fn has_secure_attr(attrs: &[Attribute]) -> bool {
287 attrs.iter().any(|attr| attr.path().is_ident("secure"))
288}
289
290fn has_unique_attr(attrs: &[Attribute]) -> bool {
291 attrs.iter().any(|attr| attr.path().is_ident("unique"))
292}
293
294enum SecureKind {
295 String,
296 OptionString,
297}
298
299impl SecureKind {
300 fn encrypted_type(&self) -> proc_macro2::TokenStream {
301 match self {
302 SecureKind::String => quote! { ::std::vec::Vec<u8> },
303 SecureKind::OptionString => quote! { ::std::option::Option<::std::vec::Vec<u8>> },
304 }
305 }
306
307 fn encrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
308 match self {
309 SecureKind::String => {
310 quote! { ::appdb::crypto::encrypt_string(&self.#ident, context)? }
311 }
312 SecureKind::OptionString => {
313 quote! { ::appdb::crypto::encrypt_optional_string(&self.#ident, context)? }
314 }
315 }
316 }
317
318 fn decrypt_expr(&self, ident: &syn::Ident) -> proc_macro2::TokenStream {
319 match self {
320 SecureKind::String => {
321 quote! { ::appdb::crypto::decrypt_string(&encrypted.#ident, context)? }
322 }
323 SecureKind::OptionString => {
324 quote! { ::appdb::crypto::decrypt_optional_string(&encrypted.#ident, context)? }
325 }
326 }
327 }
328}
329
330fn secure_kind(field: &Field) -> syn::Result<SecureKind> {
331 if is_string_type(&field.ty) {
332 return Ok(SecureKind::String);
333 }
334
335 if let Some(inner) = option_inner_type(&field.ty) {
336 if is_string_type(inner) {
337 return Ok(SecureKind::OptionString);
338 }
339 }
340
341 Err(Error::new_spanned(
342 &field.ty,
343 "#[secure] currently supports only String and Option<String>",
344 ))
345}
346
347fn is_string_type(ty: &Type) -> bool {
348 match ty {
349 Type::Path(TypePath { path, .. }) => path.is_ident("String"),
350 _ => false,
351 }
352}
353
354fn is_id_type(ty: &Type) -> bool {
355 match ty {
356 Type::Path(TypePath { path, .. }) => path.segments.last().is_some_and(|segment| {
357 let ident = segment.ident.to_string();
358 ident == "Id"
359 }),
360 _ => false,
361 }
362}
363
364fn option_inner_type(ty: &Type) -> Option<&Type> {
365 let Type::Path(TypePath { path, .. }) = ty else {
366 return None;
367 };
368 let segment = path.segments.last()?;
369 if segment.ident != "Option" {
370 return None;
371 }
372 let PathArguments::AngleBracketed(args) = &segment.arguments else {
373 return None;
374 };
375 let GenericArgument::Type(inner) = args.args.first()? else {
376 return None;
377 };
378 Some(inner)
379}
380
381fn to_snake_case(input: &str) -> String {
382 let mut out = String::with_capacity(input.len() + 4);
383 let mut prev_is_lower_or_digit = false;
384
385 for ch in input.chars() {
386 if ch.is_ascii_uppercase() {
387 if prev_is_lower_or_digit {
388 out.push('_');
389 }
390 out.push(ch.to_ascii_lowercase());
391 prev_is_lower_or_digit = false;
392 } else {
393 out.push(ch);
394 prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
395 }
396 }
397
398 out
399}