1mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25use syn::PathArguments;
26
27use crate::enrich::CodeEnrichment;
28use crate::enrich::ToEntityGraphTokens;
29use crate::enrich::ToRustCodeTokens;
30use convert_case::{Case, Casing};
31use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
32use quote::quote;
33use syn::parse::{Parse, ParseStream};
34use syn::punctuated::Punctuated;
35use syn::spanned::Spanned;
36use syn::{
37 Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type, parse_quote,
38};
39
40use crate::ToSqlConfig;
41
42use super::UsedType;
43
44const ARG_NAMES: [&str; 32] = [
46 "arg_one",
47 "arg_two",
48 "arg_three",
49 "arg_four",
50 "arg_five",
51 "arg_six",
52 "arg_seven",
53 "arg_eight",
54 "arg_nine",
55 "arg_ten",
56 "arg_eleven",
57 "arg_twelve",
58 "arg_thirteen",
59 "arg_fourteen",
60 "arg_fifteen",
61 "arg_sixteen",
62 "arg_seventeen",
63 "arg_eighteen",
64 "arg_nineteen",
65 "arg_twenty",
66 "arg_twenty_one",
67 "arg_twenty_two",
68 "arg_twenty_three",
69 "arg_twenty_four",
70 "arg_twenty_five",
71 "arg_twenty_six",
72 "arg_twenty_seven",
73 "arg_twenty_eight",
74 "arg_twenty_nine",
75 "arg_thirty",
76 "arg_thirty_one",
77 "arg_thirty_two",
78];
79
80#[derive(Debug, Clone)]
83pub struct PgAggregate {
84 item_impl: ItemImpl,
85 name: Expr,
86 target_ident: Ident,
87 snake_case_target_ident: Ident,
88 pg_externs: Vec<ItemFn>,
89 type_args: AggregateTypeList,
91 type_ordered_set_args: Option<AggregateTypeList>,
92 type_moving_state: Option<UsedType>,
93 type_stype: AggregateType,
94 const_ordered_set: bool,
95 const_parallel: Option<syn::Expr>,
96 const_finalize_modify: Option<syn::Expr>,
97 const_moving_finalize_modify: Option<syn::Expr>,
98 const_initial_condition: Option<String>,
99 const_sort_operator: Option<String>,
100 const_moving_intial_condition: Option<String>,
101 fn_state: Ident,
102 fn_finalize: Option<Ident>,
103 fn_combine: Option<Ident>,
104 fn_serial: Option<Ident>,
105 fn_deserial: Option<Ident>,
106 fn_moving_state: Option<Ident>,
107 fn_moving_state_inverse: Option<Ident>,
108 fn_moving_finalize: Option<Ident>,
109 hypothetical: bool,
110 to_sql_config: ToSqlConfig,
111}
112
113fn extract_generic_from_trait(item_impl: &ItemImpl) -> Result<&Type, syn::Error> {
114 let (_, path, _) = item_impl.trait_.as_ref().ok_or_else(|| {
115 syn::Error::new_spanned(
116 item_impl,
117 "`#[pg_aggregate]` can only be used on `impl` blocks for a trait.",
118 )
119 })?;
120
121 let last_segment = path
122 .segments
123 .last()
124 .ok_or_else(|| syn::Error::new_spanned(path, "Trait path is empty or malformed."))?;
125
126 if last_segment.ident != "Aggregate" {
127 return Err(syn::Error::new_spanned(
128 last_segment.ident.clone(),
129 "`#[pg_aggregate]` only works with the `Aggregate` trait.",
130 ));
131 }
132
133 let args = match &last_segment.arguments {
134 PathArguments::AngleBracketed(args) => args,
135 _ => {
136 return Err(syn::Error::new_spanned(
137 last_segment.ident.clone(),
138 "`Aggregate` trait must have angle-bracketed generic arguments (e.g., `Aggregate<T>`). Missing generic argument.",
139 ));
140 }
141 };
142
143 let generic_arg = args.args.first().ok_or_else(|| {
144 syn::Error::new_spanned(
145 args,
146 "`Aggregate` trait requires at least one generic argument (e.g., `Aggregate<T>`).",
147 )
148 })?;
149
150 if let syn::GenericArgument::Type(ty) = generic_arg {
151 Ok(ty)
152 } else {
153 Err(syn::Error::new_spanned(
154 generic_arg,
155 "Expected a type as the generic argument for `Aggregate` (e.g., `Aggregate<MyType>`).",
156 ))
157 }
158}
159
160fn get_generic_type_name(ty: &syn::Type) -> Result<String, syn::Error> {
161 if let Type::Path(type_path) = ty
162 && let Some(ident) = type_path.path.segments.last().map(|s| &s.ident)
163 {
164 let ident = ident.to_string();
165
166 match ident.as_str() {
167 "!" => Ok("never".to_string()),
168 "()" => Ok("unit".to_string()),
169 _ => Ok(ident),
170 }
171 } else {
172 Err(syn::Error::new_spanned(ty, "Generic type path is empty or malformed."))
173 }
174}
175
176impl PgAggregate {
177 pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
178 let to_sql_config =
179 ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
180 let target_path = get_target_path(&item_impl)?;
181 let target_ident = get_target_ident(&target_path)?;
182
183 let mut pg_externs = Vec::default();
184 let item_impl_snapshot = item_impl.clone();
187
188 let generic_type = extract_generic_from_trait(&item_impl)?.clone();
189 let generic_type_name = get_generic_type_name(&generic_type)?;
190
191 let snake_case_target_ident =
192 format!("{target_ident}_{generic_type_name}").to_case(Case::Snake);
193 let snake_case_target_ident = Ident::new(&snake_case_target_ident, target_ident.span());
194 crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
195
196 let name = parse_quote! {
197 <#generic_type as ::pgrx::aggregate::ToAggregateName>::NAME
198 };
199
200 let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
202 let _type_state_value = type_state.map(|v| v.ty.clone());
203
204 let type_state_without_self = if let Some(inner) = type_state {
205 let mut remapped = inner.ty.clone();
206 remap_self_to_target(&mut remapped, &target_ident);
207 remapped
208 } else {
209 item_impl.items.push(parse_quote! {
210 type State = Self;
211 });
212 let mut remapped = parse_quote!(Self);
213 remap_self_to_target(&mut remapped, &target_ident);
214 remapped
215 };
216 let type_stype = AggregateType {
217 used_ty: UsedType::new(type_state_without_self.clone())?,
218 name: Some("state".into()),
219 };
220
221 let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
223 let type_moving_state;
224 let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
225 type_moving_state = impl_type_moving_state.ty.clone();
226 Some(UsedType::new(type_moving_state.clone())?)
227 } else {
228 item_impl.items.push(parse_quote! {
229 type MovingState = ();
230 });
231 type_moving_state = parse_quote! { () };
232 None
233 };
234
235 let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
237 let type_ordered_set_args_value =
238 type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
239 if type_ordered_set_args.is_none() {
240 item_impl.items.push(parse_quote! {
241 type OrderedSetArgs = ();
242 })
243 }
244 let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
245 type_ordered_set_args_value
246 {
247 let direct_args = order_by_direct_args
248 .found
249 .iter()
250 .map(|x| {
251 (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
252 })
253 .collect::<Vec<_>>();
254 let direct_arg_names = ARG_NAMES[0..direct_args.len()]
255 .iter()
256 .zip(direct_args.iter())
257 .map(|(default_name, (custom_name, _ty, _orig))| {
258 Ident::new(
259 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
260 Span::mixed_site(),
261 )
262 })
263 .collect::<Vec<_>>();
264 let direct_args_with_names = direct_args
265 .iter()
266 .zip(direct_arg_names.iter())
267 .map(|(arg, name)| {
268 let arg_ty = &arg.2; parse_quote! {
270 #name: #arg_ty
271 }
272 })
273 .collect::<Vec<syn::FnArg>>();
274 (direct_args_with_names, direct_arg_names)
275 } else {
276 (Vec::default(), Vec::default())
277 };
278
279 let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
281 syn::Error::new(
282 item_impl_snapshot.span(),
283 "`#[pg_aggregate]` requires the `Args` type defined.",
284 )
285 })?;
286 let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
287 let args = type_args_value
288 .found
289 .iter()
290 .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
291 .collect::<Vec<_>>();
292 let arg_names = ARG_NAMES[0..args.len()]
293 .iter()
294 .zip(args.iter())
295 .map(|(default_name, (custom_name, ty))| {
296 Ident::new(
297 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
298 ty.span(),
299 )
300 })
301 .collect::<Vec<_>>();
302 let args_with_names = args
303 .iter()
304 .zip(arg_names.iter())
305 .map(|(arg, name)| {
306 let arg_ty = &arg.1;
307 quote! {
308 #name: #arg_ty
309 }
310 })
311 .collect::<Vec<_>>();
312
313 let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
315 let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
316 type_finalize.ty.clone()
317 } else {
318 item_impl.items.push(parse_quote! {
319 type Finalize = ();
320 });
321 parse_quote! { () }
322 };
323
324 let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
325
326 let fn_state_name = if let Some(found) = fn_state {
327 let fn_name =
328 Ident::new(&format!("{snake_case_target_ident}_state"), found.sig.ident.span());
329 let pg_extern_attr = pg_extern_attr(found);
330
331 pg_externs.push(parse_quote! {
332 #[allow(non_snake_case, clippy::too_many_arguments)]
333 #pg_extern_attr
334 fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
335 unsafe {
336 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
337 fcinfo,
338 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::state(this, (#(#arg_names),*), fcinfo)
339 )
340 }
341 }
342 });
343 fn_name
344 } else {
345 return Err(syn::Error::new(
346 item_impl.span(),
347 "Aggregate implementation must include state function.",
348 ));
349 };
350
351 let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
352 let fn_combine_name = if let Some(found) = fn_combine {
353 let fn_name =
354 Ident::new(&format!("{snake_case_target_ident}_combine"), found.sig.ident.span());
355 let pg_extern_attr = pg_extern_attr(found);
356 pg_externs.push(parse_quote! {
357 #[allow(non_snake_case, clippy::too_many_arguments)]
358 #pg_extern_attr
359 fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
360 unsafe {
361 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
362 fcinfo,
363 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::combine(this, v, fcinfo)
364 )
365 }
366 }
367 });
368 Some(fn_name)
369 } else {
370 item_impl.items.push(parse_quote! {
371 fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
372 unimplemented!("Call to combine on an aggregate which does not support it.")
373 }
374 });
375 None
376 };
377
378 let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
379 let fn_finalize_name = if let Some(found) = fn_finalize {
380 let fn_name =
381 Ident::new(&format!("{snake_case_target_ident}_finalize"), found.sig.ident.span());
382 let pg_extern_attr = pg_extern_attr(found);
383
384 if !direct_args_with_names.is_empty() {
385 pg_externs.push(parse_quote! {
386 #[allow(non_snake_case, clippy::too_many_arguments)]
387 #pg_extern_attr
388 fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
389 unsafe {
390 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
391 fcinfo,
392 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (#(#direct_arg_names),*), fcinfo)
393 )
394 }
395 }
396 });
397 } else {
398 pg_externs.push(parse_quote! {
399 #[allow(non_snake_case, clippy::too_many_arguments)]
400 #pg_extern_attr
401 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
402 unsafe {
403 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
404 fcinfo,
405 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (), fcinfo)
406 )
407 }
408 }
409 });
410 };
411 Some(fn_name)
412 } else {
413 item_impl.items.push(parse_quote! {
414 fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
415 unimplemented!("Call to finalize on an aggregate which does not support it.")
416 }
417 });
418 None
419 };
420
421 let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
422 let fn_serial_name = if let Some(found) = fn_serial {
423 let fn_name =
424 Ident::new(&format!("{snake_case_target_ident}_serial"), found.sig.ident.span());
425 let pg_extern_attr = pg_extern_attr(found);
426 pg_externs.push(parse_quote! {
427 #[allow(non_snake_case, clippy::too_many_arguments)]
428 #pg_extern_attr
429 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
430 unsafe {
431 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
432 fcinfo,
433 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::serial(this, fcinfo)
434 )
435 }
436 }
437 });
438 Some(fn_name)
439 } else {
440 item_impl.items.push(parse_quote! {
441 fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
442 unimplemented!("Call to serial on an aggregate which does not support it.")
443 }
444 });
445 None
446 };
447
448 let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
449 let fn_deserial_name = if let Some(found) = fn_deserial {
450 let fn_name =
451 Ident::new(&format!("{snake_case_target_ident}_deserial"), found.sig.ident.span());
452 let pg_extern_attr = pg_extern_attr(found);
453 pg_externs.push(parse_quote! {
454 #[allow(non_snake_case, clippy::too_many_arguments)]
455 #pg_extern_attr
456 fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
457 unsafe {
458 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
459 fcinfo,
460 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::deserial(this, buf, internal, fcinfo)
461 )
462 }
463 }
464 });
465 Some(fn_name)
466 } else {
467 item_impl.items.push(parse_quote! {
468 fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
469 unimplemented!("Call to deserial on an aggregate which does not support it.")
470 }
471 });
472 None
473 };
474
475 let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
476 let fn_moving_state_name = if let Some(found) = fn_moving_state {
477 let fn_name = Ident::new(
478 &format!("{snake_case_target_ident}_moving_state"),
479 found.sig.ident.span(),
480 );
481 let pg_extern_attr = pg_extern_attr(found);
482
483 pg_externs.push(parse_quote! {
484 #[allow(non_snake_case, clippy::too_many_arguments)]
485 #pg_extern_attr
486 fn #fn_name(
487 mstate: #type_moving_state,
488 #(#args_with_names),*,
489 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
490 ) -> #type_moving_state {
491 unsafe {
492 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
493 fcinfo,
494 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state(mstate, (#(#arg_names),*), fcinfo)
495 )
496 }
497 }
498 });
499 Some(fn_name)
500 } else {
501 item_impl.items.push(parse_quote! {
502 fn moving_state(
503 _mstate: <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState,
504 _v: Self::Args,
505 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
506 ) -> <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState {
507 unimplemented!("Call to moving_state on an aggregate which does not support it.")
508 }
509 });
510 None
511 };
512
513 let fn_moving_state_inverse =
514 get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
515 let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
516 let fn_name = Ident::new(
517 &format!("{snake_case_target_ident}_moving_state_inverse"),
518 found.sig.ident.span(),
519 );
520 let pg_extern_attr = pg_extern_attr(found);
521 pg_externs.push(parse_quote! {
522 #[allow(non_snake_case, clippy::too_many_arguments)]
523 #pg_extern_attr
524 fn #fn_name(
525 mstate: #type_moving_state,
526 #(#args_with_names),*,
527 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
528 ) -> #type_moving_state {
529 unsafe {
530 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
531 fcinfo,
532 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
533 )
534 }
535 }
536 });
537 Some(fn_name)
538 } else {
539 item_impl.items.push(parse_quote! {
540 fn moving_state_inverse(
541 _mstate: #type_moving_state,
542 _v: Self::Args,
543 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
544 ) -> #type_moving_state {
545 unimplemented!("Call to moving_state on an aggregate which does not support it.")
546 }
547 });
548 None
549 };
550
551 let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
552 let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
553 let fn_name = Ident::new(
554 &format!("{snake_case_target_ident}_moving_finalize"),
555 found.sig.ident.span(),
556 );
557 let pg_extern_attr = pg_extern_attr(found);
558 let maybe_comma: Option<syn::Token![,]> =
559 if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
560
561 pg_externs.push(parse_quote! {
562 #[allow(non_snake_case, clippy::too_many_arguments)]
563 #pg_extern_attr
564 fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
565 unsafe {
566 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
567 fcinfo,
568 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
569 )
570 }
571 }
572 });
573 Some(fn_name)
574 } else {
575 item_impl.items.push(parse_quote! {
576 fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
577 unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
578 }
579 });
580 None
581 };
582
583 Ok(CodeEnrichment(Self {
584 item_impl,
585 target_ident,
586 pg_externs,
587 name,
588 snake_case_target_ident,
589 type_args: type_args_value,
590 type_ordered_set_args: type_ordered_set_args_value,
591 type_moving_state: type_moving_state_value,
592 type_stype,
593 const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
594 .map(|x| x.expr.clone()),
595 const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
596 .map(|x| x.expr.clone()),
597 const_moving_finalize_modify: get_impl_const_by_name(
598 &item_impl_snapshot,
599 "MOVING_FINALIZE_MODIFY",
600 )
601 .map(|x| x.expr.clone()),
602 const_initial_condition: get_impl_const_by_name(
603 &item_impl_snapshot,
604 "INITIAL_CONDITION",
605 )
606 .and_then(|e| get_const_litstr(e).transpose())
607 .transpose()?,
608 const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
609 .and_then(get_const_litbool)
610 .unwrap_or(false),
611 const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
612 .and_then(|e| get_const_litstr(e).transpose())
613 .transpose()?,
614 const_moving_intial_condition: get_impl_const_by_name(
615 &item_impl_snapshot,
616 "MOVING_INITIAL_CONDITION",
617 )
618 .and_then(|e| get_const_litstr(e).transpose())
619 .transpose()?,
620 fn_state: fn_state_name,
621 fn_finalize: fn_finalize_name,
622 fn_combine: fn_combine_name,
623 fn_serial: fn_serial_name,
624 fn_deserial: fn_deserial_name,
625 fn_moving_state: fn_moving_state_name,
626 fn_moving_state_inverse: fn_moving_state_inverse_name,
627 fn_moving_finalize: fn_moving_finalize_name,
628 hypothetical: if let Some(value) =
629 get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
630 {
631 match &value.expr {
632 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
633 syn::Lit::Bool(lit) => lit.value,
634 _ => {
635 return Err(syn::Error::new(
636 value.span(),
637 "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
638 ));
639 }
640 },
641 _ => {
642 return Err(syn::Error::new(
643 value.span(),
644 "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
645 ));
646 }
647 }
648 } else {
649 false
650 },
651 to_sql_config,
652 }))
653 }
654}
655
656impl ToEntityGraphTokens for PgAggregate {
657 fn to_entity_graph_tokens(&self) -> TokenStream2 {
658 let target_ident = &self.target_ident;
659 let sql_graph_entity_fn_name = syn::Ident::new(
660 &format!("__pgrx_schema_aggregate_{}", self.snake_case_target_ident),
661 target_ident.span(),
662 );
663
664 let name = &self.name;
665 let const_ordered_set = self.const_ordered_set;
666 let hypothetical = self.hypothetical;
667 let fn_state = &self.fn_state;
668 let to_sql_config = &self.to_sql_config;
669 let to_sql_config_len = to_sql_config.section_len_tokens();
670 let type_args_len = self.type_args.section_len_tokens();
671 let direct_args_len = self
672 .type_ordered_set_args
673 .as_ref()
674 .map(|value| {
675 let inner = value.section_len_tokens();
676 quote! {
677 ::pgrx::pgrx_sql_entity_graph::section::bool_len() + (#inner)
678 }
679 })
680 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
681 let stype_len = self.type_stype.section_len_tokens();
682 let moving_state_len = self
683 .type_moving_state
684 .as_ref()
685 .map(|value| {
686 let inner = value.section_len_tokens();
687 quote! {
688 ::pgrx::pgrx_sql_entity_graph::section::bool_len() + (#inner)
689 }
690 })
691 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
692 let finalfunc_len = self
693 .fn_finalize
694 .as_ref()
695 .map(|value| {
696 quote! {
697 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
698 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
699 }
700 })
701 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
702 let combinefunc_len = self
703 .fn_combine
704 .as_ref()
705 .map(|value| {
706 quote! {
707 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
708 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
709 }
710 })
711 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
712 let serialfunc_len = self
713 .fn_serial
714 .as_ref()
715 .map(|value| {
716 quote! {
717 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
718 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
719 }
720 })
721 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
722 let deserialfunc_len = self
723 .fn_deserial
724 .as_ref()
725 .map(|value| {
726 quote! {
727 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
728 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
729 }
730 })
731 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
732 let initcond_len = self
733 .const_initial_condition
734 .as_ref()
735 .map(|value| {
736 quote! {
737 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
738 + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
739 }
740 })
741 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
742 let msfunc_len = self
743 .fn_moving_state
744 .as_ref()
745 .map(|value| {
746 quote! {
747 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
748 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
749 }
750 })
751 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
752 let minvfunc_len = self
753 .fn_moving_state_inverse
754 .as_ref()
755 .map(|value| {
756 quote! {
757 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
758 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
759 }
760 })
761 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
762 let mfinalfunc_len = self
763 .fn_moving_finalize
764 .as_ref()
765 .map(|value| {
766 quote! {
767 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
768 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
769 }
770 })
771 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
772 let minitcond_len = self
773 .const_moving_intial_condition
774 .as_ref()
775 .map(|value| {
776 quote! {
777 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
778 + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
779 }
780 })
781 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
782 let sortop_len = self
783 .const_sort_operator
784 .as_ref()
785 .map(|value| {
786 quote! {
787 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
788 + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
789 }
790 })
791 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
792 let finalize_modify_len = self
793 .const_finalize_modify
794 .as_ref()
795 .map(|value| {
796 quote! {
797 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
798 + match #value {
799 Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
800 None => 0,
801 }
802 }
803 })
804 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
805 let moving_finalize_modify_len = self
806 .const_moving_finalize_modify
807 .as_ref()
808 .map(|value| {
809 quote! {
810 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
811 + match #value {
812 Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
813 None => 0,
814 }
815 }
816 })
817 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
818 let parallel_len = self
819 .const_parallel
820 .as_ref()
821 .map(|value| {
822 quote! {
823 ::pgrx::pgrx_sql_entity_graph::section::bool_len()
824 + match #value {
825 Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
826 None => 0,
827 }
828 }
829 })
830 .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
831 let payload_len = quote! {
832 ::pgrx::pgrx_sql_entity_graph::section::u8_len()
833 + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(module_path!(), "::", stringify!(#target_ident)))
834 + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
835 + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
836 + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
837 + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
838 + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
839 + (#type_args_len)
840 + (#direct_args_len)
841 + (#stype_len)
842 + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#fn_state))
843 + (#finalfunc_len)
844 + (#finalize_modify_len)
845 + (#combinefunc_len)
846 + (#serialfunc_len)
847 + (#deserialfunc_len)
848 + (#initcond_len)
849 + (#msfunc_len)
850 + (#minvfunc_len)
851 + (#moving_state_len)
852 + (#mfinalfunc_len)
853 + (#moving_finalize_modify_len)
854 + (#minitcond_len)
855 + (#sortop_len)
856 + (#parallel_len)
857 + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
858 + (#to_sql_config_len)
859 };
860 let total_len = quote! {
861 ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
862 };
863
864 let direct_args_writer = self
865 .type_ordered_set_args
866 .as_ref()
867 .map(|value| value.section_writer_tokens(quote! { writer.bool(true) }))
868 .unwrap_or_else(|| quote! { writer.bool(false) });
869 let moving_state_writer = self
870 .type_moving_state
871 .as_ref()
872 .map(|value| value.section_writer_tokens(quote! { writer.bool(true) }))
873 .unwrap_or_else(|| quote! { writer.bool(false) });
874 let finalfunc_writer = self
875 .fn_finalize
876 .as_ref()
877 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
878 .unwrap_or_else(|| quote! { writer.bool(false) });
879 let combinefunc_writer = self
880 .fn_combine
881 .as_ref()
882 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
883 .unwrap_or_else(|| quote! { writer.bool(false) });
884 let serialfunc_writer = self
885 .fn_serial
886 .as_ref()
887 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
888 .unwrap_or_else(|| quote! { writer.bool(false) });
889 let deserialfunc_writer = self
890 .fn_deserial
891 .as_ref()
892 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
893 .unwrap_or_else(|| quote! { writer.bool(false) });
894 let initcond_writer = self
895 .const_initial_condition
896 .as_ref()
897 .map(|value| quote! { writer.bool(true).str(#value) })
898 .unwrap_or_else(|| quote! { writer.bool(false) });
899 let msfunc_writer = self
900 .fn_moving_state
901 .as_ref()
902 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
903 .unwrap_or_else(|| quote! { writer.bool(false) });
904 let minvfunc_writer = self
905 .fn_moving_state_inverse
906 .as_ref()
907 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
908 .unwrap_or_else(|| quote! { writer.bool(false) });
909 let mfinalfunc_writer = self
910 .fn_moving_finalize
911 .as_ref()
912 .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
913 .unwrap_or_else(|| quote! { writer.bool(false) });
914 let minitcond_writer = self
915 .const_moving_intial_condition
916 .as_ref()
917 .map(|value| quote! { writer.bool(true).str(#value) })
918 .unwrap_or_else(|| quote! { writer.bool(false) });
919 let sortop_writer = self
920 .const_sort_operator
921 .as_ref()
922 .map(|value| quote! { writer.bool(true).str(#value) })
923 .unwrap_or_else(|| quote! { writer.bool(false) });
924 let finalize_modify_writer = self
925 .const_finalize_modify
926 .as_ref()
927 .map(|value| quote! { match #value {
928 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadOnly) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_ONLY),
929 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::Shareable) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_SHAREABLE),
930 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadWrite) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_WRITE),
931 None => writer.bool(false),
932 } })
933 .unwrap_or_else(|| quote! { writer.bool(false) });
934 let moving_finalize_modify_writer = self
935 .const_moving_finalize_modify
936 .as_ref()
937 .map(|value| quote! { match #value {
938 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadOnly) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_ONLY),
939 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::Shareable) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_SHAREABLE),
940 Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadWrite) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_WRITE),
941 None => writer.bool(false),
942 } })
943 .unwrap_or_else(|| quote! { writer.bool(false) });
944 let parallel_writer = self
945 .const_parallel
946 .as_ref()
947 .map(|value| quote! { match #value {
948 Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Safe) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_SAFE),
949 Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Restricted) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_RESTRICTED),
950 Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Unsafe) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_UNSAFE),
951 None => writer.bool(false),
952 } })
953 .unwrap_or_else(|| quote! { writer.bool(false) });
954 let args_writer = self.type_args.section_writer_tokens(quote! { writer });
955 let stype_writer = self.type_stype.section_writer_tokens(quote! { writer });
956 let to_sql_config_writer = to_sql_config.section_writer_tokens(quote! { writer });
957
958 quote! {
959 ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
960 #sql_graph_entity_fn_name,
961 #total_len,
962 {
963 let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
964 .u32((#payload_len) as u32)
965 .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_AGGREGATE)
966 .str(concat!(module_path!(), "::", stringify!(#target_ident)))
967 .str(module_path!())
968 .str(file!())
969 .u32(line!())
970 .str(#name)
971 .bool(#const_ordered_set);
972 let writer = { #args_writer };
973 let writer = { #direct_args_writer };
974 let writer = { #stype_writer };
975 let writer = writer.str(stringify!(#fn_state));
976 let writer = { #finalfunc_writer };
977 let writer = { #finalize_modify_writer };
978 let writer = { #combinefunc_writer };
979 let writer = { #serialfunc_writer };
980 let writer = { #deserialfunc_writer };
981 let writer = { #initcond_writer };
982 let writer = { #msfunc_writer };
983 let writer = { #minvfunc_writer };
984 let writer = { #moving_state_writer };
985 let writer = { #mfinalfunc_writer };
986 let writer = { #moving_finalize_modify_writer };
987 let writer = { #minitcond_writer };
988 let writer = { #sortop_writer };
989 let writer = { #parallel_writer };
990 let writer = writer.bool(#hypothetical);
991 let writer = { #to_sql_config_writer };
992 writer.finish()
993 }
994 );
995 }
996 }
997}
998
999impl ToRustCodeTokens for PgAggregate {
1000 fn to_rust_code_tokens(&self) -> TokenStream2 {
1001 let impl_item = &self.item_impl;
1002 let pg_externs = self.pg_externs.iter();
1003
1004 quote! {
1005 #impl_item
1006 #(#pg_externs)*
1007 }
1008 }
1009}
1010
1011impl Parse for CodeEnrichment<PgAggregate> {
1012 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
1013 PgAggregate::new(input.parse()?)
1014 }
1015}
1016
1017fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
1018 let last = path.segments.last().ok_or_else(|| {
1019 syn::Error::new(
1020 path.span(),
1021 "`#[pg_aggregate]` only works with types whose path have a final segment.",
1022 )
1023 })?;
1024 Ok(last.ident.clone())
1025}
1026
1027fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
1028 let target_ident = match &*item_impl.self_ty {
1029 syn::Type::Path(type_path) => {
1030 let last_segment = type_path.path.segments.last().ok_or_else(|| {
1031 syn::Error::new(
1032 type_path.span(),
1033 "`#[pg_aggregate]` only works with types whose path have a final segment.",
1034 )
1035 })?;
1036 if last_segment.ident == "PgVarlena" {
1037 match &last_segment.arguments {
1038 syn::PathArguments::AngleBracketed(angled) => {
1039 let first = angled.args.first().ok_or_else(|| syn::Error::new(
1040 type_path.span(),
1041 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
1042 ))?;
1043 match &first {
1044 syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
1045 _ => {
1046 return Err(syn::Error::new(
1047 type_path.span(),
1048 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
1049 ));
1050 }
1051 }
1052 }
1053 _ => {
1054 return Err(syn::Error::new(
1055 type_path.span(),
1056 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
1057 ));
1058 }
1059 }
1060 } else {
1061 type_path.path.clone()
1062 }
1063 }
1064 something_else => {
1065 return Err(syn::Error::new(
1066 something_else.span(),
1067 "`#[pg_aggregate]` only works with types.",
1068 ));
1069 }
1070 };
1071 Ok(target_ident)
1072}
1073
1074fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
1075 let mut found = None;
1076 for attr in item.attrs.iter() {
1077 match attr.path().segments.last() {
1078 Some(segment) if segment.ident == "pgrx" => {
1079 found = Some(attr);
1080 break;
1081 }
1082 _ => (),
1083 };
1084 }
1085
1086 let attrs = if let Some(attr) = found {
1087 let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
1088 let attrs = attr.parse_args_with(parser);
1089 attrs.ok()
1090 } else {
1091 None
1092 };
1093
1094 match attrs {
1095 Some(args) => parse_quote! {
1096 #[::pgrx::pg_extern(#args)]
1097 },
1098 None => parse_quote! {
1099 #[::pgrx::pg_extern]
1100 },
1101 }
1102}
1103
1104fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
1105 let mut needle = None;
1106 for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1107 syn::ImplItem::Type(iitype) => Some(iitype),
1108 _ => None,
1109 }) {
1110 let ident_string = impl_item_type.ident.to_string();
1111 if ident_string == name {
1112 needle = Some(impl_item_type);
1113 }
1114 }
1115 needle
1116}
1117
1118fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
1119 let mut needle = None;
1120 for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1121 syn::ImplItem::Fn(iifn) => Some(iifn),
1122 _ => None,
1123 }) {
1124 let ident_string = impl_item_fn.sig.ident.to_string();
1125 if ident_string == name {
1126 needle = Some(impl_item_fn);
1127 }
1128 }
1129 needle
1130}
1131
1132fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
1133 let mut needle = None;
1134 for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1135 syn::ImplItem::Const(iiconst) => Some(iiconst),
1136 _ => None,
1137 }) {
1138 let ident_string = impl_item_const.ident.to_string();
1139 if ident_string == name {
1140 needle = Some(impl_item_const);
1141 }
1142 }
1143 needle
1144}
1145
1146fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
1147 match &item.expr {
1148 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
1149 syn::Lit::Bool(lit) => Some(lit.value()),
1150 _ => None,
1151 },
1152 _ => None,
1153 }
1154}
1155
1156fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
1157 match &item.expr {
1158 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
1159 syn::Lit::Str(lit) => Ok(Some(lit.value())),
1160 _ => Ok(None),
1161 },
1162 syn::Expr::Call(expr_call) => match &*expr_call.func {
1163 syn::Expr::Path(expr_path) => {
1164 let Some(last) = expr_path.path.segments.last() else {
1165 return Ok(None);
1166 };
1167 if last.ident == "Some" {
1168 match expr_call.args.first() {
1169 Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
1170 syn::Lit::Str(lit) => Ok(Some(lit.value())),
1171 _ => Ok(None),
1172 },
1173 _ => Ok(None),
1174 }
1175 } else {
1176 Ok(None)
1177 }
1178 }
1179 _ => Ok(None),
1180 },
1181 ex => Err(syn::Error::new(ex.span(), "")),
1182 }
1183}
1184
1185fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
1186 if let Type::Path(ty_path) = ty {
1187 for segment in ty_path.path.segments.iter_mut() {
1188 if segment.ident == "Self" {
1189 segment.ident = target.clone()
1190 }
1191 use syn::{GenericArgument, PathArguments};
1192 match segment.arguments {
1193 PathArguments::AngleBracketed(ref mut angle_args) => {
1194 for arg in angle_args.args.iter_mut() {
1195 if let GenericArgument::Type(inner_ty) = arg {
1196 remap_self_to_target(inner_ty, target)
1197 }
1198 }
1199 }
1200 PathArguments::Parenthesized(_) => (),
1201 PathArguments::None => (),
1202 }
1203 }
1204 }
1205}
1206
1207fn get_pgrx_attr_macro(attr_name: &str, ty: &syn::Type) -> Option<TokenStream2> {
1208 match &ty {
1209 syn::Type::Macro(ty_macro) => {
1210 let mut found_pgrx = false;
1211 let mut found_attr = false;
1212 for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
1214 match segment.ident.to_string().as_str() {
1215 "pgrx" if idx == 0 => found_pgrx = true,
1216 attr if attr == attr_name => found_attr = true,
1217 _ => (),
1218 }
1219 }
1220 if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
1221 Some(ty_macro.mac.tokens.clone())
1222 } else {
1223 None
1224 }
1225 }
1226 _ => None,
1227 }
1228}
1229
1230#[cfg(test)]
1231mod tests {
1232 use super::PgAggregate;
1233 use eyre::Result;
1234 use quote::ToTokens;
1235 use syn::{ItemImpl, parse_quote};
1236
1237 #[test]
1238 fn agg_required_only() -> Result<()> {
1239 let tokens: ItemImpl = parse_quote! {
1240 #[pg_aggregate]
1241 impl Aggregate<DemoName> for DemoAgg {
1242 type State = PgVarlena<Self>;
1243 type Args = i32;
1244
1245 fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
1246 todo!()
1247 }
1248 }
1249 };
1250 let agg = PgAggregate::new(tokens);
1252 assert!(agg.is_ok());
1253 let agg = agg.unwrap();
1255 assert_eq!(agg.0.pg_externs.len(), 1);
1256 let extern_fn = &agg.0.pg_externs[0];
1258 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1259 let _ = agg.to_token_stream();
1261 Ok(())
1262 }
1263
1264 #[test]
1265 fn agg_all_options() -> Result<()> {
1266 let tokens: ItemImpl = parse_quote! {
1267 #[pg_aggregate]
1268 impl Aggregate<DemoName> for DemoAgg {
1269 type State = PgVarlena<Self>;
1270 type Args = i32;
1271 type OrderBy = i32;
1272 type MovingState = i32;
1273
1274 const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
1275 const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1276 const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1277 const SORT_OPERATOR: Option<&'static str> = Some("sortop");
1278 const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
1279 const HYPOTHETICAL: bool = true;
1280
1281 fn state(current: Self::State, v: Self::Args) -> Self::State {
1282 todo!()
1283 }
1284
1285 fn finalize(current: Self::State) -> Self::Finalize {
1286 todo!()
1287 }
1288
1289 fn combine(current: Self::State, _other: Self::State) -> Self::State {
1290 todo!()
1291 }
1292
1293 fn serial(current: Self::State) -> Vec<u8> {
1294 todo!()
1295 }
1296
1297 fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
1298 todo!()
1299 }
1300
1301 fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1302 todo!()
1303 }
1304
1305 fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1306 todo!()
1307 }
1308
1309 fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1310 todo!()
1311 }
1312 }
1313 };
1314 let agg = PgAggregate::new(tokens);
1316 assert!(agg.is_ok());
1317 let agg = agg.unwrap();
1319 assert_eq!(agg.0.pg_externs.len(), 8);
1320 let extern_fn = &agg.0.pg_externs[0];
1322 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1323 let _ = agg.to_token_stream();
1325 Ok(())
1326 }
1327
1328 #[test]
1329 fn agg_missing_required() -> Result<()> {
1330 let tokens: ItemImpl = parse_quote! {
1332 #[pg_aggregate]
1333 impl Aggregate for IntegerAvgState {
1334 }
1335 };
1336 let agg = PgAggregate::new(tokens);
1337 assert!(agg.is_err());
1338 Ok(())
1339 }
1340}