1use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14 DeriveInput, Token,
15 parse::{Parse, ParseStream},
16 parse_macro_input,
17};
18
19#[proc_macro_attribute]
58pub fn persistent(args: TokenStream, input: TokenStream) -> TokenStream {
59 let args = if args.is_empty() {
60 PersistentArgs { indexes: vec![] }
61 } else {
62 parse_macro_input!(args as PersistentArgs)
63 };
64 let input = parse_macro_input!(input as DeriveInput);
65 expand_persistent(args, input)
66 .unwrap_or_else(|e| e.to_compile_error())
67 .into()
68}
69
70fn has_serde_derives(attrs: &[syn::Attribute]) -> (bool, bool) {
71 let mut has_serialize = false;
72 let mut has_deserialize = false;
73
74 for attr in attrs {
75 if attr.path().is_ident("derive") {
76 let Ok(paths) = attr.parse_args_with(
77 syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
78 ) else {
79 continue;
80 };
81 for path in paths {
82 let segs: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
83 match segs.as_slice() {
84 [s] if s == "Serialize" => has_serialize = true,
85 [s] if s == "Deserialize" => has_deserialize = true,
86 [a, b] if a == "serde" && b == "Serialize" => has_serialize = true,
87 [a, b] if a == "serde" && b == "Deserialize" => has_deserialize = true,
88 _ => {}
89 }
90 }
91 }
92 }
93 (has_serialize, has_deserialize)
94}
95
96fn is_stored_json(attr: &syn::Attribute) -> bool {
97 if !attr.path().is_ident("stored") {
98 return false;
99 }
100 match &attr.meta {
101 syn::Meta::List(list) => list
102 .parse_args::<syn::Ident>()
103 .map(|i| i == "json")
104 .unwrap_or(false),
105 _ => false,
106 }
107}
108
109fn expand_persistent(
110 args: PersistentArgs,
111 input: DeriveInput,
112) -> syn::Result<proc_macro2::TokenStream> {
113 let ident = &input.ident;
114 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
115 let vis = &input.vis;
116 let table_name = ident.to_string();
117
118 let named = match &input.data {
120 syn::Data::Struct(s) => match &s.fields {
121 syn::Fields::Named(n) => n,
122 _ => {
123 return Err(syn::Error::new_spanned(
124 ident,
125 "`#[persistent]` only supports structs with named fields",
126 ));
127 }
128 },
129 _ => {
130 return Err(syn::Error::new_spanned(
131 ident,
132 "`#[persistent]` only supports structs",
133 ));
134 }
135 };
136
137 if named
139 .named
140 .iter()
141 .any(|f| f.ident.as_ref().map(|i| i == "id").unwrap_or(false))
142 {
143 return Err(syn::Error::new_spanned(
144 ident,
145 "`#[persistent]` manages its own `id` field — remove the manual `id` field",
146 ));
147 }
148
149 for idx in &args.indexes {
151 let idx_ident = syn::Ident::new(idx, Span::call_site());
152 if !named
153 .named
154 .iter()
155 .any(|f| f.ident.as_ref().map(|i| i == &idx_ident).unwrap_or(false))
156 {
157 return Err(syn::Error::new_spanned(
158 ident,
159 format!("index field `{idx}` not found in `{ident}`"),
160 ));
161 }
162 }
163
164 let mut regular_fields = Vec::new();
166 let mut json_field_idents = Vec::new();
167 let mut index_field_idents = Vec::new();
168 let mut index_column_names = Vec::new();
169
170 for field in &named.named {
171 let field_ident = field.ident.as_ref().unwrap();
172 let field_ty = &field.ty;
173 let mut new_attrs = Vec::new();
174 let mut is_json = false;
175
176 for attr in &field.attrs {
177 if is_stored_json(attr) {
178 is_json = true;
179 } else {
180 new_attrs.push(attr.clone());
181 }
182 }
183
184 if is_json {
185 new_attrs.push(syn::parse_quote! {
186 #[serde(serialize_with = "::airnest::json_ser", deserialize_with = "::airnest::json_de")]
187 });
188 json_field_idents.push(field_ident.clone());
189 }
190
191 let is_index = args.indexes.iter().any(|i| i == &field_ident.to_string());
192
193 if !is_json && is_index {
194 index_field_idents.push(field_ident.clone());
195 index_column_names.push(field_ident.to_string());
196 }
197
198 regular_fields.push(quote! {
199 #(#new_attrs)*
200 pub #field_ident: #field_ty,
201 });
202 }
203
204 let mut all_column_names = index_column_names.clone();
206 let mut all_value_exprs: Vec<proc_macro2::TokenStream> = index_field_idents
207 .iter()
208 .map(|ident| {
209 quote! {
210 ::airnest::ToIndexValue::to_index_value(&self.#ident)
211 }
212 })
213 .collect();
214
215 for json_ident in &json_field_idents {
216 let col_name = format!("{}_json", json_ident);
217 all_column_names.push(col_name);
218 all_value_exprs.push(quote! {
219 ::airnest::json_string(&self.#json_ident)
220 });
221 }
222
223 let field_names: Vec<&syn::Ident> = named
225 .named
226 .iter()
227 .map(|f| f.ident.as_ref().unwrap())
228 .collect();
229 let field_types: Vec<&syn::Type> = named.named.iter().map(|f| &f.ty).collect();
230
231 let (has_serialize, has_deserialize) = has_serde_derives(&input.attrs);
233 let extra_derive = match (has_serialize, has_deserialize) {
234 (false, false) => Some(quote! { #[derive(::serde::Serialize, ::serde::Deserialize)] }),
235 (false, true) => Some(quote! { #[derive(::serde::Serialize)] }),
236 (true, false) => Some(quote! { #[derive(::serde::Deserialize)] }),
237 (true, true) => None,
238 };
239
240 let query_struct_name = syn::Ident::new(&format!("{}Query", ident), Span::call_site());
242 let query_methods: Vec<_> = index_field_idents
243 .iter()
244 .zip(&index_column_names)
245 .map(|(ident, col_name)| {
246 quote! {
247 pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
248 Self { query: self.query.eq(#col_name, value) }
249 }
250 }
251 })
252 .collect();
253
254 let replace_struct_name =
256 syn::Ident::new(&format!("{}ReplaceBuilder", ident), Span::call_site());
257 let replace_methods: Vec<_> = index_field_idents
258 .iter()
259 .zip(&index_column_names)
260 .map(|(ident, col_name)| {
261 quote! {
262 pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
263 Self(self.0.eq(#col_name, value))
264 }
265 }
266 })
267 .collect();
268
269 let upsert_struct_name = syn::Ident::new(&format!("{}UpsertBuilder", ident), Span::call_site());
271 let upsert_methods: Vec<_> = index_field_idents
272 .iter()
273 .zip(&index_column_names)
274 .map(|(ident, col_name)| {
275 quote! {
276 pub fn #ident(self, value: impl ::airnest::ToIndexValue) -> Self {
277 Self(self.0.eq(#col_name, value))
278 }
279 }
280 })
281 .collect();
282
283 let count_by_methods: Vec<_> = index_field_idents
285 .iter()
286 .zip(&index_column_names)
287 .map(|(field_ident, col_name)| {
288 let method_name = syn::Ident::new(&format!("count_by_{}", field_ident), Span::call_site());
289 quote! {
290 pub async fn #method_name(store: &::airnest::Store) -> ::std::result::Result<::std::collections::HashMap<::std::string::String, i64>, ::airnest::StoreError> {
291 store.count_grouped_by::<#ident>(#col_name).await
292 }
293 }
294 })
295 .collect();
296
297 let attrs = &input.attrs;
298
299 Ok(quote! {
300 #(#attrs)*
301 #extra_derive
302 #vis struct #ident #impl_generics #where_clause {
303 pub id: ::airnest::AirId<#ident>,
305 #(#regular_fields)*
306 }
307
308 impl #impl_generics #ident #ty_generics #where_clause {
309 pub fn new(#(#field_names: #field_types),*) -> Self {
311 Self {
312 id: ::airnest::AirId::new(),
313 #(#field_names),*
314 }
315 }
316
317 pub fn id(&self) -> ::airnest::AirId<Self> {
319 self.id
320 }
321
322 pub fn find(store: &::airnest::Store) -> #query_struct_name<'_> {
324 #query_struct_name { query: store.find::<#ident>() }
325 }
326
327 pub fn replace_for(store: &::airnest::Store) -> #replace_struct_name<'_> {
329 #replace_struct_name::new(store)
330 }
331
332 pub fn upsert(store: &::airnest::Store) -> #upsert_struct_name<'_> {
334 #upsert_struct_name::new(store)
335 }
336
337 #(#count_by_methods)*
338 }
339
340 #vis struct #query_struct_name<'a> {
342 query: ::airnest::Query<'a, #ident>,
343 }
344
345 impl<'a> #query_struct_name<'a> {
346 #(#query_methods)*
347
348 pub fn order_by(self, column: &str, order: ::airnest::Order) -> Self {
349 Self { query: self.query.order_by(column, order) }
350 }
351
352 pub fn limit(self, n: usize) -> Self {
353 Self { query: self.query.limit(n) }
354 }
355
356 pub async fn all(self) -> ::std::result::Result<::std::vec::Vec<#ident>, ::airnest::StoreError> {
357 self.query.all().await
358 }
359
360 pub async fn first(self) -> ::std::result::Result<::std::option::Option<#ident>, ::airnest::StoreError> {
361 self.query.first().await
362 }
363
364 pub async fn count(self) -> ::std::result::Result<i64, ::airnest::StoreError> {
365 self.query.count().await
366 }
367 }
368
369 #vis struct #replace_struct_name<'a>(::airnest::ReplaceBuilder<'a, #ident>);
371
372 impl<'a> #replace_struct_name<'a> {
373 fn new(store: &'a ::airnest::Store) -> Self {
374 Self(::airnest::ReplaceBuilder::new(store))
375 }
376
377 #(#replace_methods)*
378
379 pub async fn items(self, items: ::std::vec::Vec<#ident>) -> ::std::result::Result<(), ::airnest::StoreError> {
380 self.0.items(items).await
381 }
382 }
383
384 #vis struct #upsert_struct_name<'a>(::airnest::UpsertBuilder<'a, #ident>);
386
387 impl<'a> #upsert_struct_name<'a> {
388 fn new(store: &'a ::airnest::Store) -> Self {
389 Self(::airnest::UpsertBuilder::new(store))
390 }
391
392 #(#upsert_methods)*
393
394 pub fn modify<F: FnOnce(&mut #ident)>(self, f: F) -> ::airnest::UpsertModifyBuilder<'a, #ident, F> {
395 self.0.modify(f)
396 }
397 }
398
399 impl #impl_generics ::airnest::Persistent for #ident #ty_generics #where_clause {
400 fn id(&self) -> ::airnest::AirId<Self> {
401 self.id
402 }
403
404 const TABLE: &'static str = #table_name;
405
406 fn index_columns() -> &'static [&'static str] {
407 &[#(#all_column_names),*]
408 }
409
410 fn index_values(&self) -> ::std::vec::Vec<::std::string::String> {
411 ::std::vec![#(#all_value_exprs),*]
412 }
413 }
414 })
415}
416
417struct PersistentArgs {
421 indexes: Vec<String>,
422}
423
424impl Parse for PersistentArgs {
425 fn parse(input: ParseStream) -> syn::Result<Self> {
426 let kw: syn::Ident = input.parse()?;
427 if kw != "index" {
428 return Err(syn::Error::new_spanned(
429 kw,
430 "expected `index(field, ...)`. Bare `#[persistent]` needs no arguments.",
431 ));
432 }
433 let content;
434 syn::parenthesized!(content in input);
435 let mut indexes = Vec::new();
436 while !content.is_empty() {
437 let ident: syn::Ident = content.parse()?;
438 indexes.push(ident.to_string());
439 if content.peek(Token![,]) {
440 content.parse::<Token![,]>()?;
441 }
442 }
443 Ok(Self { indexes })
444 }
445}