1mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25use syn::PathArguments;
26
27use crate::enrich::CodeEnrichment;
28use crate::enrich::ToEntityGraphTokens;
29use crate::enrich::ToRustCodeTokens;
30use convert_case::{Case, Casing};
31use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
32use quote::quote;
33use syn::parse::{Parse, ParseStream};
34use syn::punctuated::Punctuated;
35use syn::spanned::Spanned;
36use syn::{
37 parse_quote, Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type,
38};
39
40use crate::ToSqlConfig;
41
42use super::UsedType;
43
44const ARG_NAMES: [&str; 32] = [
46 "arg_one",
47 "arg_two",
48 "arg_three",
49 "arg_four",
50 "arg_five",
51 "arg_six",
52 "arg_seven",
53 "arg_eight",
54 "arg_nine",
55 "arg_ten",
56 "arg_eleven",
57 "arg_twelve",
58 "arg_thirteen",
59 "arg_fourteen",
60 "arg_fifteen",
61 "arg_sixteen",
62 "arg_seventeen",
63 "arg_eighteen",
64 "arg_nineteen",
65 "arg_twenty",
66 "arg_twenty_one",
67 "arg_twenty_two",
68 "arg_twenty_three",
69 "arg_twenty_four",
70 "arg_twenty_five",
71 "arg_twenty_six",
72 "arg_twenty_seven",
73 "arg_twenty_eight",
74 "arg_twenty_nine",
75 "arg_thirty",
76 "arg_thirty_one",
77 "arg_thirty_two",
78];
79
80#[derive(Debug, Clone)]
83pub struct PgAggregate {
84 item_impl: ItemImpl,
85 name: Expr,
86 target_ident: Ident,
87 snake_case_target_ident: Ident,
88 pg_externs: Vec<ItemFn>,
89 type_args: AggregateTypeList,
91 type_ordered_set_args: Option<AggregateTypeList>,
92 type_moving_state: Option<UsedType>,
93 type_stype: AggregateType,
94 const_ordered_set: bool,
95 const_parallel: Option<syn::Expr>,
96 const_finalize_modify: Option<syn::Expr>,
97 const_moving_finalize_modify: Option<syn::Expr>,
98 const_initial_condition: Option<String>,
99 const_sort_operator: Option<String>,
100 const_moving_intial_condition: Option<String>,
101 fn_state: Ident,
102 fn_finalize: Option<Ident>,
103 fn_combine: Option<Ident>,
104 fn_serial: Option<Ident>,
105 fn_deserial: Option<Ident>,
106 fn_moving_state: Option<Ident>,
107 fn_moving_state_inverse: Option<Ident>,
108 fn_moving_finalize: Option<Ident>,
109 hypothetical: bool,
110 to_sql_config: ToSqlConfig,
111}
112
113fn extract_generic_from_trait(item_impl: &ItemImpl) -> Result<&Type, syn::Error> {
114 let (_, path, _) = item_impl.trait_.as_ref().ok_or_else(|| {
115 syn::Error::new_spanned(
116 item_impl,
117 "`#[pg_aggregate]` can only be used on `impl` blocks for a trait.",
118 )
119 })?;
120
121 let last_segment = path
122 .segments
123 .last()
124 .ok_or_else(|| syn::Error::new_spanned(path, "Trait path is empty or malformed."))?;
125
126 if last_segment.ident != "Aggregate" {
127 return Err(syn::Error::new_spanned(
128 last_segment.ident.clone(),
129 "`#[pg_aggregate]` only works with the `Aggregate` trait.",
130 ));
131 }
132
133 let args = match &last_segment.arguments {
134 PathArguments::AngleBracketed(args) => args,
135 _ => {
136 return Err(syn::Error::new_spanned(
137 last_segment.ident.clone(),
138 "`Aggregate` trait must have angle-bracketed generic arguments (e.g., `Aggregate<T>`). Missing generic argument.",
139 ));
140 }
141 };
142
143 let generic_arg = args.args.first().ok_or_else(|| {
144 syn::Error::new_spanned(
145 args,
146 "`Aggregate` trait requires at least one generic argument (e.g., `Aggregate<T>`).",
147 )
148 })?;
149
150 if let syn::GenericArgument::Type(ty) = generic_arg {
151 Ok(ty)
152 } else {
153 Err(syn::Error::new_spanned(
154 generic_arg,
155 "Expected a type as the generic argument for `Aggregate` (e.g., `Aggregate<MyType>`).",
156 ))
157 }
158}
159
160fn get_generic_type_name(ty: &syn::Type) -> Result<String, syn::Error> {
161 if let Type::Path(type_path) = ty {
162 if let Some(ident) = type_path.path.segments.last().map(|s| &s.ident) {
163 let ident = ident.to_string();
164
165 match ident.as_str() {
166 "!" => Ok("never".to_string()),
167 "()" => Ok("unit".to_string()),
168 _ => Ok(ident),
169 }
170 } else {
171 Err(syn::Error::new_spanned(ty, "Generic type path is empty or malformed."))
172 }
173 } else {
174 Err(syn::Error::new_spanned(ty, "Expected a path type for the generic argument."))
175 }
176}
177
178impl PgAggregate {
179 pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
180 let to_sql_config =
181 ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
182 let target_path = get_target_path(&item_impl)?;
183 let target_ident = get_target_ident(&target_path)?;
184
185 let mut pg_externs = Vec::default();
186 let item_impl_snapshot = item_impl.clone();
189
190 let generic_type = extract_generic_from_trait(&item_impl)?.clone();
191 let generic_type_name = get_generic_type_name(&generic_type)?;
192
193 let snake_case_target_ident =
194 format!("{target_ident}_{generic_type_name}").to_case(Case::Snake);
195 let snake_case_target_ident = Ident::new(&snake_case_target_ident, target_ident.span());
196 crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
197
198 let name = parse_quote! {
199 <#generic_type as ::pgrx::aggregate::ToAggregateName>::NAME
200 };
201
202 let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
204 let _type_state_value = type_state.map(|v| v.ty.clone());
205
206 let type_state_without_self = if let Some(inner) = type_state {
207 let mut remapped = inner.ty.clone();
208 remap_self_to_target(&mut remapped, &target_ident);
209 remapped
210 } else {
211 item_impl.items.push(parse_quote! {
212 type State = Self;
213 });
214 let mut remapped = parse_quote!(Self);
215 remap_self_to_target(&mut remapped, &target_ident);
216 remapped
217 };
218 let type_stype = AggregateType {
219 used_ty: UsedType::new(type_state_without_self.clone())?,
220 name: Some("state".into()),
221 };
222
223 let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
225 let type_moving_state;
226 let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
227 type_moving_state = impl_type_moving_state.ty.clone();
228 Some(UsedType::new(type_moving_state.clone())?)
229 } else {
230 item_impl.items.push(parse_quote! {
231 type MovingState = ();
232 });
233 type_moving_state = parse_quote! { () };
234 None
235 };
236
237 let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
239 let type_ordered_set_args_value =
240 type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
241 if type_ordered_set_args.is_none() {
242 item_impl.items.push(parse_quote! {
243 type OrderedSetArgs = ();
244 })
245 }
246 let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
247 type_ordered_set_args_value
248 {
249 let direct_args = order_by_direct_args
250 .found
251 .iter()
252 .map(|x| {
253 (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
254 })
255 .collect::<Vec<_>>();
256 let direct_arg_names = ARG_NAMES[0..direct_args.len()]
257 .iter()
258 .zip(direct_args.iter())
259 .map(|(default_name, (custom_name, _ty, _orig))| {
260 Ident::new(
261 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
262 Span::mixed_site(),
263 )
264 })
265 .collect::<Vec<_>>();
266 let direct_args_with_names = direct_args
267 .iter()
268 .zip(direct_arg_names.iter())
269 .map(|(arg, name)| {
270 let arg_ty = &arg.2; parse_quote! {
272 #name: #arg_ty
273 }
274 })
275 .collect::<Vec<syn::FnArg>>();
276 (direct_args_with_names, direct_arg_names)
277 } else {
278 (Vec::default(), Vec::default())
279 };
280
281 let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
283 syn::Error::new(
284 item_impl_snapshot.span(),
285 "`#[pg_aggregate]` requires the `Args` type defined.",
286 )
287 })?;
288 let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
289 let args = type_args_value
290 .found
291 .iter()
292 .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
293 .collect::<Vec<_>>();
294 let arg_names = ARG_NAMES[0..args.len()]
295 .iter()
296 .zip(args.iter())
297 .map(|(default_name, (custom_name, ty))| {
298 Ident::new(
299 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
300 ty.span(),
301 )
302 })
303 .collect::<Vec<_>>();
304 let args_with_names = args
305 .iter()
306 .zip(arg_names.iter())
307 .map(|(arg, name)| {
308 let arg_ty = &arg.1;
309 quote! {
310 #name: #arg_ty
311 }
312 })
313 .collect::<Vec<_>>();
314
315 let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
317 let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
318 type_finalize.ty.clone()
319 } else {
320 item_impl.items.push(parse_quote! {
321 type Finalize = ();
322 });
323 parse_quote! { () }
324 };
325
326 let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
327
328 let fn_state_name = if let Some(found) = fn_state {
329 let fn_name =
330 Ident::new(&format!("{snake_case_target_ident}_state"), found.sig.ident.span());
331 let pg_extern_attr = pg_extern_attr(found);
332
333 pg_externs.push(parse_quote! {
334 #[allow(non_snake_case, clippy::too_many_arguments)]
335 #pg_extern_attr
336 fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
337 unsafe {
338 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
339 fcinfo,
340 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::state(this, (#(#arg_names),*), fcinfo)
341 )
342 }
343 }
344 });
345 fn_name
346 } else {
347 return Err(syn::Error::new(
348 item_impl.span(),
349 "Aggregate implementation must include state function.",
350 ));
351 };
352
353 let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
354 let fn_combine_name = if let Some(found) = fn_combine {
355 let fn_name =
356 Ident::new(&format!("{snake_case_target_ident}_combine"), found.sig.ident.span());
357 let pg_extern_attr = pg_extern_attr(found);
358 pg_externs.push(parse_quote! {
359 #[allow(non_snake_case, clippy::too_many_arguments)]
360 #pg_extern_attr
361 fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
362 unsafe {
363 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
364 fcinfo,
365 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::combine(this, v, fcinfo)
366 )
367 }
368 }
369 });
370 Some(fn_name)
371 } else {
372 item_impl.items.push(parse_quote! {
373 fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
374 unimplemented!("Call to combine on an aggregate which does not support it.")
375 }
376 });
377 None
378 };
379
380 let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
381 let fn_finalize_name = if let Some(found) = fn_finalize {
382 let fn_name =
383 Ident::new(&format!("{snake_case_target_ident}_finalize"), found.sig.ident.span());
384 let pg_extern_attr = pg_extern_attr(found);
385
386 if !direct_args_with_names.is_empty() {
387 pg_externs.push(parse_quote! {
388 #[allow(non_snake_case, clippy::too_many_arguments)]
389 #pg_extern_attr
390 fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
391 unsafe {
392 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
393 fcinfo,
394 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (#(#direct_arg_names),*), fcinfo)
395 )
396 }
397 }
398 });
399 } else {
400 pg_externs.push(parse_quote! {
401 #[allow(non_snake_case, clippy::too_many_arguments)]
402 #pg_extern_attr
403 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
404 unsafe {
405 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
406 fcinfo,
407 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (), fcinfo)
408 )
409 }
410 }
411 });
412 };
413 Some(fn_name)
414 } else {
415 item_impl.items.push(parse_quote! {
416 fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
417 unimplemented!("Call to finalize on an aggregate which does not support it.")
418 }
419 });
420 None
421 };
422
423 let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
424 let fn_serial_name = if let Some(found) = fn_serial {
425 let fn_name =
426 Ident::new(&format!("{snake_case_target_ident}_serial"), found.sig.ident.span());
427 let pg_extern_attr = pg_extern_attr(found);
428 pg_externs.push(parse_quote! {
429 #[allow(non_snake_case, clippy::too_many_arguments)]
430 #pg_extern_attr
431 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
432 unsafe {
433 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
434 fcinfo,
435 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::serial(this, fcinfo)
436 )
437 }
438 }
439 });
440 Some(fn_name)
441 } else {
442 item_impl.items.push(parse_quote! {
443 fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
444 unimplemented!("Call to serial on an aggregate which does not support it.")
445 }
446 });
447 None
448 };
449
450 let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
451 let fn_deserial_name = if let Some(found) = fn_deserial {
452 let fn_name =
453 Ident::new(&format!("{snake_case_target_ident}_deserial"), found.sig.ident.span());
454 let pg_extern_attr = pg_extern_attr(found);
455 pg_externs.push(parse_quote! {
456 #[allow(non_snake_case, clippy::too_many_arguments)]
457 #pg_extern_attr
458 fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
459 unsafe {
460 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
461 fcinfo,
462 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::deserial(this, buf, internal, fcinfo)
463 )
464 }
465 }
466 });
467 Some(fn_name)
468 } else {
469 item_impl.items.push(parse_quote! {
470 fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
471 unimplemented!("Call to deserial on an aggregate which does not support it.")
472 }
473 });
474 None
475 };
476
477 let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
478 let fn_moving_state_name = if let Some(found) = fn_moving_state {
479 let fn_name = Ident::new(
480 &format!("{snake_case_target_ident}_moving_state"),
481 found.sig.ident.span(),
482 );
483 let pg_extern_attr = pg_extern_attr(found);
484
485 pg_externs.push(parse_quote! {
486 #[allow(non_snake_case, clippy::too_many_arguments)]
487 #pg_extern_attr
488 fn #fn_name(
489 mstate: #type_moving_state,
490 #(#args_with_names),*,
491 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
492 ) -> #type_moving_state {
493 unsafe {
494 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
495 fcinfo,
496 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state(mstate, (#(#arg_names),*), fcinfo)
497 )
498 }
499 }
500 });
501 Some(fn_name)
502 } else {
503 item_impl.items.push(parse_quote! {
504 fn moving_state(
505 _mstate: <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState,
506 _v: Self::Args,
507 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
508 ) -> <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState {
509 unimplemented!("Call to moving_state on an aggregate which does not support it.")
510 }
511 });
512 None
513 };
514
515 let fn_moving_state_inverse =
516 get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
517 let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
518 let fn_name = Ident::new(
519 &format!("{snake_case_target_ident}_moving_state_inverse"),
520 found.sig.ident.span(),
521 );
522 let pg_extern_attr = pg_extern_attr(found);
523 pg_externs.push(parse_quote! {
524 #[allow(non_snake_case, clippy::too_many_arguments)]
525 #pg_extern_attr
526 fn #fn_name(
527 mstate: #type_moving_state,
528 #(#args_with_names),*,
529 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
530 ) -> #type_moving_state {
531 unsafe {
532 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
533 fcinfo,
534 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
535 )
536 }
537 }
538 });
539 Some(fn_name)
540 } else {
541 item_impl.items.push(parse_quote! {
542 fn moving_state_inverse(
543 _mstate: #type_moving_state,
544 _v: Self::Args,
545 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
546 ) -> #type_moving_state {
547 unimplemented!("Call to moving_state on an aggregate which does not support it.")
548 }
549 });
550 None
551 };
552
553 let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
554 let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
555 let fn_name = Ident::new(
556 &format!("{snake_case_target_ident}_moving_finalize"),
557 found.sig.ident.span(),
558 );
559 let pg_extern_attr = pg_extern_attr(found);
560 let maybe_comma: Option<syn::Token![,]> =
561 if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
562
563 pg_externs.push(parse_quote! {
564 #[allow(non_snake_case, clippy::too_many_arguments)]
565 #pg_extern_attr
566 fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
567 unsafe {
568 <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
569 fcinfo,
570 move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
571 )
572 }
573 }
574 });
575 Some(fn_name)
576 } else {
577 item_impl.items.push(parse_quote! {
578 fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
579 unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
580 }
581 });
582 None
583 };
584
585 Ok(CodeEnrichment(Self {
586 item_impl,
587 target_ident,
588 pg_externs,
589 name,
590 snake_case_target_ident,
591 type_args: type_args_value,
592 type_ordered_set_args: type_ordered_set_args_value,
593 type_moving_state: type_moving_state_value,
594 type_stype,
595 const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
596 .map(|x| x.expr.clone()),
597 const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
598 .map(|x| x.expr.clone()),
599 const_moving_finalize_modify: get_impl_const_by_name(
600 &item_impl_snapshot,
601 "MOVING_FINALIZE_MODIFY",
602 )
603 .map(|x| x.expr.clone()),
604 const_initial_condition: get_impl_const_by_name(
605 &item_impl_snapshot,
606 "INITIAL_CONDITION",
607 )
608 .and_then(|e| get_const_litstr(e).transpose())
609 .transpose()?,
610 const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
611 .and_then(get_const_litbool)
612 .unwrap_or(false),
613 const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
614 .and_then(|e| get_const_litstr(e).transpose())
615 .transpose()?,
616 const_moving_intial_condition: get_impl_const_by_name(
617 &item_impl_snapshot,
618 "MOVING_INITIAL_CONDITION",
619 )
620 .and_then(|e| get_const_litstr(e).transpose())
621 .transpose()?,
622 fn_state: fn_state_name,
623 fn_finalize: fn_finalize_name,
624 fn_combine: fn_combine_name,
625 fn_serial: fn_serial_name,
626 fn_deserial: fn_deserial_name,
627 fn_moving_state: fn_moving_state_name,
628 fn_moving_state_inverse: fn_moving_state_inverse_name,
629 fn_moving_finalize: fn_moving_finalize_name,
630 hypothetical: if let Some(value) =
631 get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
632 {
633 match &value.expr {
634 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
635 syn::Lit::Bool(lit) => lit.value,
636 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
637 },
638 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
639 }
640 } else {
641 false
642 },
643 to_sql_config,
644 }))
645 }
646}
647
648impl ToEntityGraphTokens for PgAggregate {
649 fn to_entity_graph_tokens(&self) -> TokenStream2 {
650 let target_ident = &self.target_ident;
651 let sql_graph_entity_fn_name = syn::Ident::new(
652 &format!("__pgrx_internals_aggregate_{}", self.snake_case_target_ident),
653 target_ident.span(),
654 );
655
656 let name = &self.name;
657 let type_args_iter = &self.type_args.entity_tokens();
658 let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
659
660 let type_moving_state_entity_tokens =
661 self.type_moving_state.clone().map(|v| v.entity_tokens());
662 let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
663 let type_stype = self.type_stype.entity_tokens();
664 let const_ordered_set = self.const_ordered_set;
665 let const_parallel_iter = self.const_parallel.iter();
666 let const_finalize_modify_iter = self.const_finalize_modify.iter();
667 let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
668 let const_initial_condition_iter = self.const_initial_condition.iter();
669 let const_sort_operator_iter = self.const_sort_operator.iter();
670 let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
671 let hypothetical = self.hypothetical;
672 let fn_state = &self.fn_state;
673 let fn_finalize_iter = self.fn_finalize.iter();
674 let fn_combine_iter = self.fn_combine.iter();
675 let fn_serial_iter = self.fn_serial.iter();
676 let fn_deserial_iter = self.fn_deserial.iter();
677 let fn_moving_state_iter = self.fn_moving_state.iter();
678 let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
679 let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
680 let to_sql_config = &self.to_sql_config;
681
682 quote! {
683 #[no_mangle]
684 #[doc(hidden)]
685 #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
686 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
687 let submission = ::pgrx::pgrx_sql_entity_graph::PgAggregateEntity {
688 full_path: ::core::any::type_name::<#target_ident>(),
689 module_path: module_path!(),
690 file: file!(),
691 line: line!(),
692 name: #name,
693 ordered_set: #const_ordered_set,
694 ty_id: ::core::any::TypeId::of::<#target_ident>(),
695 args: #type_args_iter,
696 direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
697 stype: #type_stype,
698 sfunc: stringify!(#fn_state),
699 combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
700 finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
701 finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
702 initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
703 serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
704 deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
705 msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
706 minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
707 mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
708 mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
709 mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
710 minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
711 sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
712 parallel: None #( .unwrap_or(#const_parallel_iter) )*,
713 hypothetical: #hypothetical,
714 to_sql_config: #to_sql_config,
715 };
716 ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
717 }
718 }
719 }
720}
721
722impl ToRustCodeTokens for PgAggregate {
723 fn to_rust_code_tokens(&self) -> TokenStream2 {
724 let impl_item = &self.item_impl;
725 let pg_externs = self.pg_externs.iter();
726
727 quote! {
728 #impl_item
729 #(#pg_externs)*
730 }
731 }
732}
733
734impl Parse for CodeEnrichment<PgAggregate> {
735 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
736 PgAggregate::new(input.parse()?)
737 }
738}
739
740fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
741 let last = path.segments.last().ok_or_else(|| {
742 syn::Error::new(
743 path.span(),
744 "`#[pg_aggregate]` only works with types whose path have a final segment.",
745 )
746 })?;
747 Ok(last.ident.clone())
748}
749
750fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
751 let target_ident = match &*item_impl.self_ty {
752 syn::Type::Path(ref type_path) => {
753 let last_segment = type_path.path.segments.last().ok_or_else(|| {
754 syn::Error::new(
755 type_path.span(),
756 "`#[pg_aggregate]` only works with types whose path have a final segment.",
757 )
758 })?;
759 if last_segment.ident == "PgVarlena" {
760 match &last_segment.arguments {
761 syn::PathArguments::AngleBracketed(angled) => {
762 let first = angled.args.first().ok_or_else(|| syn::Error::new(
763 type_path.span(),
764 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
765 ))?;
766 match &first {
767 syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
768 _ => return Err(syn::Error::new(
769 type_path.span(),
770 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
771 )),
772 }
773 },
774 _ => return Err(syn::Error::new(
775 type_path.span(),
776 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
777 )),
778 }
779 } else {
780 type_path.path.clone()
781 }
782 }
783 something_else => {
784 return Err(syn::Error::new(
785 something_else.span(),
786 "`#[pg_aggregate]` only works with types.",
787 ))
788 }
789 };
790 Ok(target_ident)
791}
792
793fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
794 let mut found = None;
795 for attr in item.attrs.iter() {
796 match attr.path().segments.last() {
797 Some(segment) if segment.ident == "pgrx" => {
798 found = Some(attr);
799 break;
800 }
801 _ => (),
802 };
803 }
804
805 let attrs = if let Some(attr) = found {
806 let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
807 let attrs = attr.parse_args_with(parser);
808 attrs.ok()
809 } else {
810 None
811 };
812
813 match attrs {
814 Some(args) => parse_quote! {
815 #[::pgrx::pg_extern(#args)]
816 },
817 None => parse_quote! {
818 #[::pgrx::pg_extern]
819 },
820 }
821}
822
823fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
824 let mut needle = None;
825 for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
826 syn::ImplItem::Type(iitype) => Some(iitype),
827 _ => None,
828 }) {
829 let ident_string = impl_item_type.ident.to_string();
830 if ident_string == name {
831 needle = Some(impl_item_type);
832 }
833 }
834 needle
835}
836
837fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
838 let mut needle = None;
839 for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
840 syn::ImplItem::Fn(iifn) => Some(iifn),
841 _ => None,
842 }) {
843 let ident_string = impl_item_fn.sig.ident.to_string();
844 if ident_string == name {
845 needle = Some(impl_item_fn);
846 }
847 }
848 needle
849}
850
851fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
852 let mut needle = None;
853 for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
854 syn::ImplItem::Const(iiconst) => Some(iiconst),
855 _ => None,
856 }) {
857 let ident_string = impl_item_const.ident.to_string();
858 if ident_string == name {
859 needle = Some(impl_item_const);
860 }
861 }
862 needle
863}
864
865fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
866 match &item.expr {
867 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
868 syn::Lit::Bool(lit) => Some(lit.value()),
869 _ => None,
870 },
871 _ => None,
872 }
873}
874
875fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
876 match &item.expr {
877 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
878 syn::Lit::Str(lit) => Ok(Some(lit.value())),
879 _ => Ok(None),
880 },
881 syn::Expr::Call(expr_call) => match &*expr_call.func {
882 syn::Expr::Path(expr_path) => {
883 let Some(last) = expr_path.path.segments.last() else {
884 return Ok(None);
885 };
886 if last.ident == "Some" {
887 match expr_call.args.first() {
888 Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
889 syn::Lit::Str(lit) => Ok(Some(lit.value())),
890 _ => Ok(None),
891 },
892 _ => Ok(None),
893 }
894 } else {
895 Ok(None)
896 }
897 }
898 _ => Ok(None),
899 },
900 ex => Err(syn::Error::new(ex.span(), "")),
901 }
902}
903
904fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
905 if let Type::Path(ref mut ty_path) = ty {
906 for segment in ty_path.path.segments.iter_mut() {
907 if segment.ident == "Self" {
908 segment.ident = target.clone()
909 }
910 use syn::{GenericArgument, PathArguments};
911 match segment.arguments {
912 PathArguments::AngleBracketed(ref mut angle_args) => {
913 for arg in angle_args.args.iter_mut() {
914 if let GenericArgument::Type(inner_ty) = arg {
915 remap_self_to_target(inner_ty, target)
916 }
917 }
918 }
919 PathArguments::Parenthesized(_) => (),
920 PathArguments::None => (),
921 }
922 }
923 }
924}
925
926fn get_pgrx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
927 match &ty {
928 syn::Type::Macro(ty_macro) => {
929 let mut found_pgrx = false;
930 let mut found_attr = false;
931 for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
933 match segment.ident.to_string().as_str() {
934 "pgrx" if idx == 0 => found_pgrx = true,
935 attr if attr == attr_name.as_ref() => found_attr = true,
936 _ => (),
937 }
938 }
939 if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
940 Some(ty_macro.mac.tokens.clone())
941 } else {
942 None
943 }
944 }
945 _ => None,
946 }
947}
948
949#[cfg(test)]
950mod tests {
951 use super::PgAggregate;
952 use eyre::Result;
953 use quote::ToTokens;
954 use syn::{parse_quote, ItemImpl};
955
956 #[test]
957 fn agg_required_only() -> Result<()> {
958 let tokens: ItemImpl = parse_quote! {
959 #[pg_aggregate]
960 impl Aggregate<DemoName> for DemoAgg {
961 type State = PgVarlena<Self>;
962 type Args = i32;
963
964 fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
965 todo!()
966 }
967 }
968 };
969 let agg = PgAggregate::new(tokens);
971 assert!(agg.is_ok());
972 let agg = agg.unwrap();
974 assert_eq!(agg.0.pg_externs.len(), 1);
975 let extern_fn = &agg.0.pg_externs[0];
977 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
978 let _ = agg.to_token_stream();
980 Ok(())
981 }
982
983 #[test]
984 fn agg_all_options() -> Result<()> {
985 let tokens: ItemImpl = parse_quote! {
986 #[pg_aggregate]
987 impl Aggregate<DemoName> for DemoAgg {
988 type State = PgVarlena<Self>;
989 type Args = i32;
990 type OrderBy = i32;
991 type MovingState = i32;
992
993 const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
994 const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
995 const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
996 const SORT_OPERATOR: Option<&'static str> = Some("sortop");
997 const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
998 const HYPOTHETICAL: bool = true;
999
1000 fn state(current: Self::State, v: Self::Args) -> Self::State {
1001 todo!()
1002 }
1003
1004 fn finalize(current: Self::State) -> Self::Finalize {
1005 todo!()
1006 }
1007
1008 fn combine(current: Self::State, _other: Self::State) -> Self::State {
1009 todo!()
1010 }
1011
1012 fn serial(current: Self::State) -> Vec<u8> {
1013 todo!()
1014 }
1015
1016 fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
1017 todo!()
1018 }
1019
1020 fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1021 todo!()
1022 }
1023
1024 fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1025 todo!()
1026 }
1027
1028 fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1029 todo!()
1030 }
1031 }
1032 };
1033 let agg = PgAggregate::new(tokens);
1035 assert!(agg.is_ok());
1036 let agg = agg.unwrap();
1038 assert_eq!(agg.0.pg_externs.len(), 8);
1039 let extern_fn = &agg.0.pg_externs[0];
1041 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1042 let _ = agg.to_token_stream();
1044 Ok(())
1045 }
1046
1047 #[test]
1048 fn agg_missing_required() -> Result<()> {
1049 let tokens: ItemImpl = parse_quote! {
1051 #[pg_aggregate]
1052 impl Aggregate for IntegerAvgState {
1053 }
1054 };
1055 let agg = PgAggregate::new(tokens);
1056 assert!(agg.is_err());
1057 Ok(())
1058 }
1059}