Skip to main content

airnest_macros/
lib.rs

1//! Attribute macro `#[persistent]` — injects `id`, generates `new()`, typed query
2//! builders, replace/upsert helpers, and implements the trait.
3//!
4//! # Architecture patterns for large codebases
5//!
6//! `#[persistent]` is designed to be used sparingly — only on **aggregates** and
7//! **domain entities** that represent state worth saving. Child structs nested
8//! inside a persistent root should be plain `Serialize + Deserialize` values.
9
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14    DeriveInput, Token,
15    parse::{Parse, ParseStream},
16    parse_macro_input,
17};
18
19/// Marks a struct as persistable.
20///
21/// Injects an `id: AirId<Self>` field, generates a `new(...)` constructor, and
22/// implements `Persistent`. If `Serialize` or `Deserialize` are not already
23/// present in a `#[derive(...)]` attribute, they are auto-injected.
24///
25/// # Example
26///
27/// ```ignore
28/// #[persistent]
29/// pub struct ChatSession {
30///     pub messages: Vec<Message>,
31/// }
32/// ```
33///
34/// # Index columns
35///
36/// ```ignore
37/// #[persistent(index(status, priority))]
38/// pub struct Job {
39///     pub status: String,
40///     pub priority: i32,
41/// }
42/// ```
43///
44/// # JSON columns
45///
46/// ```ignore
47/// #[persistent(index(session_uuid))]
48/// pub struct StoredMessage {
49///     pub session_uuid: String,
50///
51///     #[stored(json)]
52///     pub content: Message,
53/// }
54/// ```
55/// The `content` field is automatically JSON-serialized and stored in a
56/// `content_json` TEXT column.
57#[proc_macro_attribute]
58pub fn persistent(args: TokenStream, input: TokenStream) -> TokenStream {
59    let args = if args.is_empty() {
60        PersistentArgs { indexes: vec![] }
61    } else {
62        parse_macro_input!(args as PersistentArgs)
63    };
64    let input = parse_macro_input!(input as DeriveInput);
65    expand_persistent(args, input)
66        .unwrap_or_else(|e| e.to_compile_error())
67        .into()
68}
69
70fn has_serde_derives(attrs: &[syn::Attribute]) -> (bool, bool) {
71    let mut has_serialize = false;
72    let mut has_deserialize = false;
73
74    for attr in attrs {
75        if attr.path().is_ident("derive") {
76            let Ok(paths) = attr.parse_args_with(
77                syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
78            ) else {
79                continue;
80            };
81            for path in paths {
82                let segs: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
83                match segs.as_slice() {
84                    [s] if s == "Serialize" => has_serialize = true,
85                    [s] if s == "Deserialize" => has_deserialize = true,
86                    [a, b] if a == "serde" && b == "Serialize" => has_serialize = true,
87                    [a, b] if a == "serde" && b == "Deserialize" => has_deserialize = true,
88                    _ => {}
89                }
90            }
91        }
92    }
93    (has_serialize, has_deserialize)
94}
95
96fn is_stored_json(attr: &syn::Attribute) -> bool {
97    if !attr.path().is_ident("stored") {
98        return false;
99    }
100    match &attr.meta {
101        syn::Meta::List(list) => list
102            .parse_args::<syn::Ident>()
103            .map(|i| i == "json")
104            .unwrap_or(false),
105        _ => false,
106    }
107}
108
109fn expand_persistent(
110    args: PersistentArgs,
111    input: DeriveInput,
112) -> syn::Result<proc_macro2::TokenStream> {
113    let ident = &input.ident;
114    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
115    let vis = &input.vis;
116    let table_name = ident.to_string();
117
118    // ── validate struct ───────────────────────────────────────────────────────
119    let named = match &input.data {
120        syn::Data::Struct(s) => match &s.fields {
121            syn::Fields::Named(n) => n,
122            _ => {
123                return Err(syn::Error::new_spanned(
124                    ident,
125                    "`#[persistent]` only supports structs with named fields",
126                ));
127            }
128        },
129        _ => {
130            return Err(syn::Error::new_spanned(
131                ident,
132                "`#[persistent]` only supports structs",
133            ));
134        }
135    };
136
137    // ── reject existing `id` field ────────────────────────────────────────────
138    if named
139        .named
140        .iter()
141        .any(|f| f.ident.as_ref().map(|i| i == "id").unwrap_or(false))
142    {
143        return Err(syn::Error::new_spanned(
144            ident,
145            "`#[persistent]` manages its own `id` field — remove the manual `id` field",
146        ));
147    }
148
149    // ── validate index fields exist ───────────────────────────────────────────
150    for idx in &args.indexes {
151        let idx_ident = syn::Ident::new(idx, Span::call_site());
152        if !named
153            .named
154            .iter()
155            .any(|f| f.ident.as_ref().map(|i| i == &idx_ident).unwrap_or(false))
156        {
157            return Err(syn::Error::new_spanned(
158                ident,
159                format!("index field `{idx}` not found in `{ident}`"),
160            ));
161        }
162    }
163
164    // ── process fields ────────────────────────────────────────────────────────
165    let mut regular_fields = Vec::new();
166    let mut json_field_idents = Vec::new();
167    let mut index_field_idents = Vec::new();
168    let mut index_column_names = Vec::new();
169
170    for field in &named.named {
171        let field_ident = field.ident.as_ref().unwrap();
172        let field_ty = &field.ty;
173        let mut new_attrs = Vec::new();
174        let mut is_json = false;
175
176        for attr in &field.attrs {
177            if is_stored_json(attr) {
178                is_json = true;
179            } else {
180                new_attrs.push(attr.clone());
181            }
182        }
183
184        if is_json {
185            new_attrs.push(syn::parse_quote! {
186                #[serde(serialize_with = "::airnest::json_ser", deserialize_with = "::airnest::json_de")]
187            });
188            json_field_idents.push(field_ident.clone());
189        }
190
191        let is_index = args.indexes.iter().any(|i| i == &field_ident.to_string());
192
193        if !is_json && is_index {
194            index_field_idents.push(field_ident.clone());
195            index_column_names.push(field_ident.to_string());
196        }
197
198        regular_fields.push(quote! {
199            #(#new_attrs)*
200            pub #field_ident: #field_ty,
201        });
202    }
203
204    // Combine index columns and json columns for the schema
205    let mut all_column_names = index_column_names.clone();
206    let mut all_value_exprs: Vec<proc_macro2::TokenStream> = index_field_idents
207        .iter()
208        .map(|ident| {
209            quote! {
210                ::airnest::ToIndexValue::to_index_value(&self.#ident)
211            }
212        })
213        .collect();
214
215    for json_ident in &json_field_idents {
216        let col_name = format!("{}_json", json_ident);
217        all_column_names.push(col_name);
218        all_value_exprs.push(quote! {
219            ::airnest::json_string(&self.#json_ident)
220        });
221    }
222
223    // ── constructor fields ────────────────────────────────────────────────────
224    let field_names: Vec<&syn::Ident> = named
225        .named
226        .iter()
227        .map(|f| f.ident.as_ref().unwrap())
228        .collect();
229    let field_types: Vec<&syn::Type> = named.named.iter().map(|f| &f.ty).collect();
230
231    // ── auto-inject missing serde derives ─────────────────────────────────────
232    let (has_serialize, has_deserialize) = has_serde_derives(&input.attrs);
233    let extra_derive = match (has_serialize, has_deserialize) {
234        (false, false) => Some(quote! { #[derive(::serde::Serialize, ::serde::Deserialize)] }),
235        (false, true) => Some(quote! { #[derive(::serde::Serialize)] }),
236        (true, false) => Some(quote! { #[derive(::serde::Deserialize)] }),
237        (true, true) => None,
238    };
239
240    // ── typed query builder ───────────────────────────────────────────────────
241    let query_struct_name = syn::Ident::new(&format!("{}Query", ident), Span::call_site());
242    let query_methods: Vec<_> = index_field_idents
243        .iter()
244        .zip(&index_column_names)
245        .map(|(ident, col_name)| {
246            quote! {
247                pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
248                    Self { query: self.query.eq(#col_name, value) }
249                }
250            }
251        })
252        .collect();
253
254    // ── replace builder ───────────────────────────────────────────────────────
255    let replace_struct_name =
256        syn::Ident::new(&format!("{}ReplaceBuilder", ident), Span::call_site());
257    let replace_methods: Vec<_> = index_field_idents
258        .iter()
259        .zip(&index_column_names)
260        .map(|(ident, col_name)| {
261            quote! {
262                pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
263                    Self(self.0.eq(#col_name, value))
264                }
265            }
266        })
267        .collect();
268
269    // ── upsert builder ────────────────────────────────────────────────────────
270    let upsert_struct_name = syn::Ident::new(&format!("{}UpsertBuilder", ident), Span::call_site());
271    let upsert_methods: Vec<_> = index_field_idents
272        .iter()
273        .zip(&index_column_names)
274        .map(|(ident, col_name)| {
275            quote! {
276                pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
277                    Self(self.0.eq(#col_name, value))
278                }
279            }
280        })
281        .collect();
282
283    // ── count-by helpers ──────────────────────────────────────────────────────
284    let count_by_methods: Vec<_> = index_field_idents
285        .iter()
286        .zip(&index_column_names)
287        .map(|(field_ident, col_name)| {
288            let method_name = syn::Ident::new(&format!("count_by_{}", field_ident), Span::call_site());
289            quote! {
290                pub async fn #method_name(store: &::airnest::Store) -> ::std::result::Result<::std::collections::HashMap<::std::string::String, i64>, ::airnest::StoreError> {
291                    store.count_grouped_by::<#ident>(#col_name).await
292                }
293            }
294        })
295        .collect();
296
297    let attrs = &input.attrs;
298
299    Ok(quote! {
300        #(#attrs)*
301        #extra_derive
302        #vis struct #ident #impl_generics #where_clause {
303            /// Auto-generated UUIDv7 id. Managed by `#[persistent]`.
304            pub id: ::airnest::AirId<#ident>,
305            #(#regular_fields)*
306        }
307
308        impl #impl_generics #ident #ty_generics #where_clause {
309            /// Create a new instance with a fresh UUIDv7 id.
310            pub fn new(#(#field_names: #field_types),*) -> Self {
311                Self {
312                    id: ::airnest::AirId::new(),
313                    #(#field_names),*
314                }
315            }
316
317            /// Return the auto-generated id.
318            pub fn id(&self) -> ::airnest::AirId<Self> {
319                self.id
320            }
321
322            /// Start a typed query for this type.
323            pub fn find(store: &::airnest::Store) -> #query_struct_name<'_> {
324                #query_struct_name { query: store.find::<#ident>() }
325            }
326
327            /// Start a replace operation for this type.
328            pub fn replace_for(store: &::airnest::Store) -> #replace_struct_name<'_> {
329                #replace_struct_name::new(store)
330            }
331
332            /// Start an upsert operation for this type.
333            pub fn upsert(store: &::airnest::Store) -> #upsert_struct_name<'_> {
334                #upsert_struct_name::new(store)
335            }
336
337            #(#count_by_methods)*
338        }
339
340        // ── typed query builder ───────────────────────────────────────────────
341        #vis struct #query_struct_name<'a> {
342            query: ::airnest::Query<'a, #ident>,
343        }
344
345        impl<'a> #query_struct_name<'a> {
346            #(#query_methods)*
347
348            pub fn order_by(self, column: &str, order: ::airnest::Order) -> Self {
349                Self { query: self.query.order_by(column, order) }
350            }
351
352            pub fn limit(self, n: usize) -> Self {
353                Self { query: self.query.limit(n) }
354            }
355
356            pub async fn all(self) -> ::std::result::Result<::std::vec::Vec<#ident>, ::airnest::StoreError> {
357                self.query.all().await
358            }
359
360            pub async fn first(self) -> ::std::result::Result<::std::option::Option<#ident>, ::airnest::StoreError> {
361                self.query.first().await
362            }
363
364            pub async fn count(self) -> ::std::result::Result<i64, ::airnest::StoreError> {
365                self.query.count().await
366            }
367        }
368
369        // ── replace builder wrapper ───────────────────────────────────────────
370        #vis struct #replace_struct_name<'a>(::airnest::ReplaceBuilder<'a, #ident>);
371
372        impl<'a> #replace_struct_name<'a> {
373            fn new(store: &'a ::airnest::Store) -> Self {
374                Self(::airnest::ReplaceBuilder::new(store))
375            }
376
377            #(#replace_methods)*
378
379            pub async fn items(self, items: ::std::vec::Vec<#ident>) -> ::std::result::Result<(), ::airnest::StoreError> {
380                self.0.items(items).await
381            }
382        }
383
384        // ── upsert builder wrapper ────────────────────────────────────────────
385        #vis struct #upsert_struct_name<'a>(::airnest::UpsertBuilder<'a, #ident>);
386
387        impl<'a> #upsert_struct_name<'a> {
388            fn new(store: &'a ::airnest::Store) -> Self {
389                Self(::airnest::UpsertBuilder::new(store))
390            }
391
392            #(#upsert_methods)*
393
394            pub fn modify<F: FnOnce(&mut #ident)>(self, f: F) -> ::airnest::UpsertModifyBuilder<'a, #ident, F> {
395                self.0.modify(f)
396            }
397        }
398
399        impl #impl_generics ::airnest::Persistent for #ident #ty_generics #where_clause {
400            fn id(&self) -> ::airnest::AirId<Self> {
401                self.id
402            }
403
404            const TABLE: &'static str = #table_name;
405
406            fn index_columns() -> &'static [&'static str] {
407                &[#(#all_column_names),*]
408            }
409
410            fn index_values(&self) -> ::std::vec::Vec<::std::string::String> {
411                ::std::vec![#(#all_value_exprs),*]
412            }
413        }
414    })
415}
416
417// ── attribute parser ──────────────────────────────────────────────────────────
418
419/// Parsed form of `#[persistent(index(a, b, c))]`.
420struct PersistentArgs {
421    indexes: Vec<String>,
422}
423
424impl Parse for PersistentArgs {
425    fn parse(input: ParseStream) -> syn::Result<Self> {
426        let kw: syn::Ident = input.parse()?;
427        if kw != "index" {
428            return Err(syn::Error::new_spanned(
429                kw,
430                "expected `index(field, ...)`. Bare `#[persistent]` needs no arguments.",
431            ));
432        }
433        let content;
434        syn::parenthesized!(content in input);
435        let mut indexes = Vec::new();
436        while !content.is_empty() {
437            let ident: syn::Ident = content.parse()?;
438            indexes.push(ident.to_string());
439            if content.peek(Token![,]) {
440                content.parse::<Token![,]>()?;
441            }
442        }
443        Ok(Self { indexes })
444    }
445}