Skip to main content

pgrx_sql_entity_graph/aggregate/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25use syn::PathArguments;
26
27use crate::enrich::CodeEnrichment;
28use crate::enrich::ToEntityGraphTokens;
29use crate::enrich::ToRustCodeTokens;
30use convert_case::{Case, Casing};
31use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
32use quote::quote;
33use syn::parse::{Parse, ParseStream};
34use syn::punctuated::Punctuated;
35use syn::spanned::Spanned;
36use syn::{
37    Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type, parse_quote,
38};
39
40use crate::ToSqlConfig;
41
42use super::UsedType;
43
44// We support only 32 tuples...
45const ARG_NAMES: [&str; 32] = [
46    "arg_one",
47    "arg_two",
48    "arg_three",
49    "arg_four",
50    "arg_five",
51    "arg_six",
52    "arg_seven",
53    "arg_eight",
54    "arg_nine",
55    "arg_ten",
56    "arg_eleven",
57    "arg_twelve",
58    "arg_thirteen",
59    "arg_fourteen",
60    "arg_fifteen",
61    "arg_sixteen",
62    "arg_seventeen",
63    "arg_eighteen",
64    "arg_nineteen",
65    "arg_twenty",
66    "arg_twenty_one",
67    "arg_twenty_two",
68    "arg_twenty_three",
69    "arg_twenty_four",
70    "arg_twenty_five",
71    "arg_twenty_six",
72    "arg_twenty_seven",
73    "arg_twenty_eight",
74    "arg_twenty_nine",
75    "arg_thirty",
76    "arg_thirty_one",
77    "arg_thirty_two",
78];
79
80/** A parsed `#[pg_aggregate]` item.
81*/
82#[derive(Debug, Clone)]
83pub struct PgAggregate {
84    item_impl: ItemImpl,
85    name: Expr,
86    target_ident: Ident,
87    snake_case_target_ident: Ident,
88    pg_externs: Vec<ItemFn>,
89    // Note these should not be considered *writable*, they're snapshots from construction.
90    type_args: AggregateTypeList,
91    type_ordered_set_args: Option<AggregateTypeList>,
92    type_moving_state: Option<UsedType>,
93    type_stype: AggregateType,
94    const_ordered_set: bool,
95    const_parallel: Option<syn::Expr>,
96    const_finalize_modify: Option<syn::Expr>,
97    const_moving_finalize_modify: Option<syn::Expr>,
98    const_initial_condition: Option<String>,
99    const_sort_operator: Option<String>,
100    const_moving_intial_condition: Option<String>,
101    fn_state: Ident,
102    fn_finalize: Option<Ident>,
103    fn_combine: Option<Ident>,
104    fn_serial: Option<Ident>,
105    fn_deserial: Option<Ident>,
106    fn_moving_state: Option<Ident>,
107    fn_moving_state_inverse: Option<Ident>,
108    fn_moving_finalize: Option<Ident>,
109    hypothetical: bool,
110    to_sql_config: ToSqlConfig,
111}
112
113fn extract_generic_from_trait(item_impl: &ItemImpl) -> Result<&Type, syn::Error> {
114    let (_, path, _) = item_impl.trait_.as_ref().ok_or_else(|| {
115        syn::Error::new_spanned(
116            item_impl,
117            "`#[pg_aggregate]` can only be used on `impl` blocks for a trait.",
118        )
119    })?;
120
121    let last_segment = path
122        .segments
123        .last()
124        .ok_or_else(|| syn::Error::new_spanned(path, "Trait path is empty or malformed."))?;
125
126    if last_segment.ident != "Aggregate" {
127        return Err(syn::Error::new_spanned(
128            last_segment.ident.clone(),
129            "`#[pg_aggregate]` only works with the `Aggregate` trait.",
130        ));
131    }
132
133    let args = match &last_segment.arguments {
134        PathArguments::AngleBracketed(args) => args,
135        _ => {
136            return Err(syn::Error::new_spanned(
137                last_segment.ident.clone(),
138                "`Aggregate` trait must have angle-bracketed generic arguments (e.g., `Aggregate<T>`). Missing generic argument.",
139            ));
140        }
141    };
142
143    let generic_arg = args.args.first().ok_or_else(|| {
144        syn::Error::new_spanned(
145            args,
146            "`Aggregate` trait requires at least one generic argument (e.g., `Aggregate<T>`).",
147        )
148    })?;
149
150    if let syn::GenericArgument::Type(ty) = generic_arg {
151        Ok(ty)
152    } else {
153        Err(syn::Error::new_spanned(
154            generic_arg,
155            "Expected a type as the generic argument for `Aggregate` (e.g., `Aggregate<MyType>`).",
156        ))
157    }
158}
159
160fn get_generic_type_name(ty: &syn::Type) -> Result<String, syn::Error> {
161    if let Type::Path(type_path) = ty
162        && let Some(ident) = type_path.path.segments.last().map(|s| &s.ident)
163    {
164        let ident = ident.to_string();
165
166        match ident.as_str() {
167            "!" => Ok("never".to_string()),
168            "()" => Ok("unit".to_string()),
169            _ => Ok(ident),
170        }
171    } else {
172        Err(syn::Error::new_spanned(ty, "Generic type path is empty or malformed."))
173    }
174}
175
176impl PgAggregate {
177    pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
178        let to_sql_config =
179            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
180        let target_path = get_target_path(&item_impl)?;
181        let target_ident = get_target_ident(&target_path)?;
182
183        let mut pg_externs = Vec::default();
184        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
185        // and mutate the actual one.
186        let item_impl_snapshot = item_impl.clone();
187
188        let generic_type = extract_generic_from_trait(&item_impl)?.clone();
189        let generic_type_name = get_generic_type_name(&generic_type)?;
190
191        let snake_case_target_ident =
192            format!("{target_ident}_{generic_type_name}").to_case(Case::Snake);
193        let snake_case_target_ident = Ident::new(&snake_case_target_ident, target_ident.span());
194        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
195
196        let name = parse_quote! {
197            <#generic_type as ::pgrx::aggregate::ToAggregateName>::NAME
198        };
199
200        // `State` is an optional value, we default to `Self`.
201        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
202        let _type_state_value = type_state.map(|v| v.ty.clone());
203
204        let type_state_without_self = if let Some(inner) = type_state {
205            let mut remapped = inner.ty.clone();
206            remap_self_to_target(&mut remapped, &target_ident);
207            remapped
208        } else {
209            item_impl.items.push(parse_quote! {
210                type State = Self;
211            });
212            let mut remapped = parse_quote!(Self);
213            remap_self_to_target(&mut remapped, &target_ident);
214            remapped
215        };
216        let type_stype = AggregateType {
217            used_ty: UsedType::new(type_state_without_self.clone())?,
218            name: Some("state".into()),
219        };
220
221        // `MovingState` is an optional value, we default to nothing.
222        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
223        let type_moving_state;
224        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
225            type_moving_state = impl_type_moving_state.ty.clone();
226            Some(UsedType::new(type_moving_state.clone())?)
227        } else {
228            item_impl.items.push(parse_quote! {
229                type MovingState = ();
230            });
231            type_moving_state = parse_quote! { () };
232            None
233        };
234
235        // `OrderBy` is an optional value, we default to nothing.
236        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
237        let type_ordered_set_args_value =
238            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
239        if type_ordered_set_args.is_none() {
240            item_impl.items.push(parse_quote! {
241                type OrderedSetArgs = ();
242            })
243        }
244        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
245            type_ordered_set_args_value
246        {
247            let direct_args = order_by_direct_args
248                .found
249                .iter()
250                .map(|x| {
251                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
252                })
253                .collect::<Vec<_>>();
254            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
255                .iter()
256                .zip(direct_args.iter())
257                .map(|(default_name, (custom_name, _ty, _orig))| {
258                    Ident::new(
259                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
260                        Span::mixed_site(),
261                    )
262                })
263                .collect::<Vec<_>>();
264            let direct_args_with_names = direct_args
265                .iter()
266                .zip(direct_arg_names.iter())
267                .map(|(arg, name)| {
268                    let arg_ty = &arg.2; // original_type
269                    parse_quote! {
270                        #name: #arg_ty
271                    }
272                })
273                .collect::<Vec<syn::FnArg>>();
274            (direct_args_with_names, direct_arg_names)
275        } else {
276            (Vec::default(), Vec::default())
277        };
278
279        // `Args` is an optional value, we default to nothing.
280        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
281            syn::Error::new(
282                item_impl_snapshot.span(),
283                "`#[pg_aggregate]` requires the `Args` type defined.",
284            )
285        })?;
286        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
287        let args = type_args_value
288            .found
289            .iter()
290            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
291            .collect::<Vec<_>>();
292        let arg_names = ARG_NAMES[0..args.len()]
293            .iter()
294            .zip(args.iter())
295            .map(|(default_name, (custom_name, ty))| {
296                Ident::new(
297                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
298                    ty.span(),
299                )
300            })
301            .collect::<Vec<_>>();
302        let args_with_names = args
303            .iter()
304            .zip(arg_names.iter())
305            .map(|(arg, name)| {
306                let arg_ty = &arg.1;
307                quote! {
308                    #name: #arg_ty
309                }
310            })
311            .collect::<Vec<_>>();
312
313        // `Finalize` is an optional value, we default to nothing.
314        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
315        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
316            type_finalize.ty.clone()
317        } else {
318            item_impl.items.push(parse_quote! {
319                type Finalize = ();
320            });
321            parse_quote! { () }
322        };
323
324        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
325
326        let fn_state_name = if let Some(found) = fn_state {
327            let fn_name =
328                Ident::new(&format!("{snake_case_target_ident}_state"), found.sig.ident.span());
329            let pg_extern_attr = pg_extern_attr(found);
330
331            pg_externs.push(parse_quote! {
332                #[allow(non_snake_case, clippy::too_many_arguments)]
333                #pg_extern_attr
334                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
335                    unsafe {
336                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
337                            fcinfo,
338                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::state(this, (#(#arg_names),*), fcinfo)
339                        )
340                    }
341                }
342            });
343            fn_name
344        } else {
345            return Err(syn::Error::new(
346                item_impl.span(),
347                "Aggregate implementation must include state function.",
348            ));
349        };
350
351        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
352        let fn_combine_name = if let Some(found) = fn_combine {
353            let fn_name =
354                Ident::new(&format!("{snake_case_target_ident}_combine"), found.sig.ident.span());
355            let pg_extern_attr = pg_extern_attr(found);
356            pg_externs.push(parse_quote! {
357                #[allow(non_snake_case, clippy::too_many_arguments)]
358                #pg_extern_attr
359                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
360                    unsafe {
361                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
362                            fcinfo,
363                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::combine(this, v, fcinfo)
364                        )
365                    }
366                }
367            });
368            Some(fn_name)
369        } else {
370            item_impl.items.push(parse_quote! {
371                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
372                    unimplemented!("Call to combine on an aggregate which does not support it.")
373                }
374            });
375            None
376        };
377
378        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
379        let fn_finalize_name = if let Some(found) = fn_finalize {
380            let fn_name =
381                Ident::new(&format!("{snake_case_target_ident}_finalize"), found.sig.ident.span());
382            let pg_extern_attr = pg_extern_attr(found);
383
384            if !direct_args_with_names.is_empty() {
385                pg_externs.push(parse_quote! {
386                    #[allow(non_snake_case, clippy::too_many_arguments)]
387                    #pg_extern_attr
388                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
389                        unsafe {
390                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
391                                fcinfo,
392                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (#(#direct_arg_names),*), fcinfo)
393                            )
394                        }
395                    }
396                });
397            } else {
398                pg_externs.push(parse_quote! {
399                    #[allow(non_snake_case, clippy::too_many_arguments)]
400                    #pg_extern_attr
401                    fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
402                        unsafe {
403                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
404                                fcinfo,
405                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (), fcinfo)
406                            )
407                        }
408                    }
409                });
410            };
411            Some(fn_name)
412        } else {
413            item_impl.items.push(parse_quote! {
414                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
415                    unimplemented!("Call to finalize on an aggregate which does not support it.")
416                }
417            });
418            None
419        };
420
421        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
422        let fn_serial_name = if let Some(found) = fn_serial {
423            let fn_name =
424                Ident::new(&format!("{snake_case_target_ident}_serial"), found.sig.ident.span());
425            let pg_extern_attr = pg_extern_attr(found);
426            pg_externs.push(parse_quote! {
427                #[allow(non_snake_case, clippy::too_many_arguments)]
428                #pg_extern_attr
429                fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
430                    unsafe {
431                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
432                            fcinfo,
433                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::serial(this, fcinfo)
434                        )
435                    }
436                }
437            });
438            Some(fn_name)
439        } else {
440            item_impl.items.push(parse_quote! {
441                fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
442                    unimplemented!("Call to serial on an aggregate which does not support it.")
443                }
444            });
445            None
446        };
447
448        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
449        let fn_deserial_name = if let Some(found) = fn_deserial {
450            let fn_name =
451                Ident::new(&format!("{snake_case_target_ident}_deserial"), found.sig.ident.span());
452            let pg_extern_attr = pg_extern_attr(found);
453            pg_externs.push(parse_quote! {
454                #[allow(non_snake_case, clippy::too_many_arguments)]
455                #pg_extern_attr
456                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
457                    unsafe {
458                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
459                            fcinfo,
460                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::deserial(this, buf, internal, fcinfo)
461                        )
462                    }
463                }
464            });
465            Some(fn_name)
466        } else {
467            item_impl.items.push(parse_quote! {
468                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
469                    unimplemented!("Call to deserial on an aggregate which does not support it.")
470                }
471            });
472            None
473        };
474
475        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
476        let fn_moving_state_name = if let Some(found) = fn_moving_state {
477            let fn_name = Ident::new(
478                &format!("{snake_case_target_ident}_moving_state"),
479                found.sig.ident.span(),
480            );
481            let pg_extern_attr = pg_extern_attr(found);
482
483            pg_externs.push(parse_quote! {
484                #[allow(non_snake_case, clippy::too_many_arguments)]
485                #pg_extern_attr
486                fn #fn_name(
487                    mstate: #type_moving_state,
488                    #(#args_with_names),*,
489                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
490                ) -> #type_moving_state {
491                    unsafe {
492                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
493                            fcinfo,
494                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state(mstate, (#(#arg_names),*), fcinfo)
495                        )
496                    }
497                }
498            });
499            Some(fn_name)
500        } else {
501            item_impl.items.push(parse_quote! {
502                fn moving_state(
503                    _mstate: <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState,
504                    _v: Self::Args,
505                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
506                ) -> <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState {
507                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
508                }
509            });
510            None
511        };
512
513        let fn_moving_state_inverse =
514            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
515        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
516            let fn_name = Ident::new(
517                &format!("{snake_case_target_ident}_moving_state_inverse"),
518                found.sig.ident.span(),
519            );
520            let pg_extern_attr = pg_extern_attr(found);
521            pg_externs.push(parse_quote! {
522                #[allow(non_snake_case, clippy::too_many_arguments)]
523                #pg_extern_attr
524                fn #fn_name(
525                    mstate: #type_moving_state,
526                    #(#args_with_names),*,
527                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
528                ) -> #type_moving_state {
529                    unsafe {
530                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
531                            fcinfo,
532                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
533                        )
534                    }
535                }
536            });
537            Some(fn_name)
538        } else {
539            item_impl.items.push(parse_quote! {
540                fn moving_state_inverse(
541                    _mstate: #type_moving_state,
542                    _v: Self::Args,
543                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
544                ) -> #type_moving_state {
545                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
546                }
547            });
548            None
549        };
550
551        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
552        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
553            let fn_name = Ident::new(
554                &format!("{snake_case_target_ident}_moving_finalize"),
555                found.sig.ident.span(),
556            );
557            let pg_extern_attr = pg_extern_attr(found);
558            let maybe_comma: Option<syn::Token![,]> =
559                if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
560
561            pg_externs.push(parse_quote! {
562                #[allow(non_snake_case, clippy::too_many_arguments)]
563                #pg_extern_attr
564                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
565                    unsafe {
566                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
567                            fcinfo,
568                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
569                        )
570                    }
571                }
572            });
573            Some(fn_name)
574        } else {
575            item_impl.items.push(parse_quote! {
576                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
577                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
578                }
579            });
580            None
581        };
582
583        Ok(CodeEnrichment(Self {
584            item_impl,
585            target_ident,
586            pg_externs,
587            name,
588            snake_case_target_ident,
589            type_args: type_args_value,
590            type_ordered_set_args: type_ordered_set_args_value,
591            type_moving_state: type_moving_state_value,
592            type_stype,
593            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
594                .map(|x| x.expr.clone()),
595            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
596                .map(|x| x.expr.clone()),
597            const_moving_finalize_modify: get_impl_const_by_name(
598                &item_impl_snapshot,
599                "MOVING_FINALIZE_MODIFY",
600            )
601            .map(|x| x.expr.clone()),
602            const_initial_condition: get_impl_const_by_name(
603                &item_impl_snapshot,
604                "INITIAL_CONDITION",
605            )
606            .and_then(|e| get_const_litstr(e).transpose())
607            .transpose()?,
608            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
609                .and_then(get_const_litbool)
610                .unwrap_or(false),
611            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
612                .and_then(|e| get_const_litstr(e).transpose())
613                .transpose()?,
614            const_moving_intial_condition: get_impl_const_by_name(
615                &item_impl_snapshot,
616                "MOVING_INITIAL_CONDITION",
617            )
618            .and_then(|e| get_const_litstr(e).transpose())
619            .transpose()?,
620            fn_state: fn_state_name,
621            fn_finalize: fn_finalize_name,
622            fn_combine: fn_combine_name,
623            fn_serial: fn_serial_name,
624            fn_deserial: fn_deserial_name,
625            fn_moving_state: fn_moving_state_name,
626            fn_moving_state_inverse: fn_moving_state_inverse_name,
627            fn_moving_finalize: fn_moving_finalize_name,
628            hypothetical: if let Some(value) =
629                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
630            {
631                match &value.expr {
632                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
633                        syn::Lit::Bool(lit) => lit.value,
634                        _ => {
635                            return Err(syn::Error::new(
636                                value.span(),
637                                "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
638                            ));
639                        }
640                    },
641                    _ => {
642                        return Err(syn::Error::new(
643                            value.span(),
644                            "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
645                        ));
646                    }
647                }
648            } else {
649                false
650            },
651            to_sql_config,
652        }))
653    }
654}
655
656impl ToEntityGraphTokens for PgAggregate {
657    fn to_entity_graph_tokens(&self) -> TokenStream2 {
658        let target_ident = &self.target_ident;
659        let sql_graph_entity_fn_name = syn::Ident::new(
660            &format!("__pgrx_internals_aggregate_{}", self.snake_case_target_ident),
661            target_ident.span(),
662        );
663
664        let name = &self.name;
665        let type_args_iter = &self.type_args.entity_tokens();
666        let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
667
668        let type_moving_state_entity_tokens =
669            self.type_moving_state.clone().map(|v| v.entity_tokens());
670        let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
671        let type_stype = self.type_stype.entity_tokens();
672        let const_ordered_set = self.const_ordered_set;
673        let const_parallel_iter = self.const_parallel.iter();
674        let const_finalize_modify_iter = self.const_finalize_modify.iter();
675        let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
676        let const_initial_condition_iter = self.const_initial_condition.iter();
677        let const_sort_operator_iter = self.const_sort_operator.iter();
678        let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
679        let hypothetical = self.hypothetical;
680        let fn_state = &self.fn_state;
681        let fn_finalize_iter = self.fn_finalize.iter();
682        let fn_combine_iter = self.fn_combine.iter();
683        let fn_serial_iter = self.fn_serial.iter();
684        let fn_deserial_iter = self.fn_deserial.iter();
685        let fn_moving_state_iter = self.fn_moving_state.iter();
686        let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
687        let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
688        let to_sql_config = &self.to_sql_config;
689
690        quote! {
691            #[unsafe(no_mangle)]
692            #[doc(hidden)]
693            #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
694            pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
695                let submission = ::pgrx::pgrx_sql_entity_graph::PgAggregateEntity {
696                    full_path: ::core::any::type_name::<#target_ident>(),
697                    module_path: module_path!(),
698                    file: file!(),
699                    line: line!(),
700                    name: #name,
701                    ordered_set: #const_ordered_set,
702                    ty_id: ::core::any::TypeId::of::<#target_ident>(),
703                    args: #type_args_iter,
704                    direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
705                    stype: #type_stype,
706                    sfunc: stringify!(#fn_state),
707                    combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
708                    finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
709                    finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
710                    initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
711                    serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
712                    deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
713                    msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
714                    minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
715                    mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
716                    mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
717                    mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
718                    minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
719                    sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
720                    parallel: None #( .unwrap_or(#const_parallel_iter) )*,
721                    hypothetical: #hypothetical,
722                    to_sql_config: #to_sql_config,
723                };
724                ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
725            }
726        }
727    }
728}
729
730impl ToRustCodeTokens for PgAggregate {
731    fn to_rust_code_tokens(&self) -> TokenStream2 {
732        let impl_item = &self.item_impl;
733        let pg_externs = self.pg_externs.iter();
734
735        quote! {
736            #impl_item
737            #(#pg_externs)*
738        }
739    }
740}
741
742impl Parse for CodeEnrichment<PgAggregate> {
743    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
744        PgAggregate::new(input.parse()?)
745    }
746}
747
748fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
749    let last = path.segments.last().ok_or_else(|| {
750        syn::Error::new(
751            path.span(),
752            "`#[pg_aggregate]` only works with types whose path have a final segment.",
753        )
754    })?;
755    Ok(last.ident.clone())
756}
757
758fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
759    let target_ident = match &*item_impl.self_ty {
760        syn::Type::Path(type_path) => {
761            let last_segment = type_path.path.segments.last().ok_or_else(|| {
762                syn::Error::new(
763                    type_path.span(),
764                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
765                )
766            })?;
767            if last_segment.ident == "PgVarlena" {
768                match &last_segment.arguments {
769                    syn::PathArguments::AngleBracketed(angled) => {
770                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
771                            type_path.span(),
772                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
773                        ))?;
774                        match &first {
775                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
776                            _ => {
777                                return Err(syn::Error::new(
778                                    type_path.span(),
779                                    "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
780                                ));
781                            }
782                        }
783                    }
784                    _ => {
785                        return Err(syn::Error::new(
786                            type_path.span(),
787                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
788                        ));
789                    }
790                }
791            } else {
792                type_path.path.clone()
793            }
794        }
795        something_else => {
796            return Err(syn::Error::new(
797                something_else.span(),
798                "`#[pg_aggregate]` only works with types.",
799            ));
800        }
801    };
802    Ok(target_ident)
803}
804
805fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
806    let mut found = None;
807    for attr in item.attrs.iter() {
808        match attr.path().segments.last() {
809            Some(segment) if segment.ident == "pgrx" => {
810                found = Some(attr);
811                break;
812            }
813            _ => (),
814        };
815    }
816
817    let attrs = if let Some(attr) = found {
818        let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
819        let attrs = attr.parse_args_with(parser);
820        attrs.ok()
821    } else {
822        None
823    };
824
825    match attrs {
826        Some(args) => parse_quote! {
827            #[::pgrx::pg_extern(#args)]
828        },
829        None => parse_quote! {
830            #[::pgrx::pg_extern]
831        },
832    }
833}
834
835fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
836    let mut needle = None;
837    for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
838        syn::ImplItem::Type(iitype) => Some(iitype),
839        _ => None,
840    }) {
841        let ident_string = impl_item_type.ident.to_string();
842        if ident_string == name {
843            needle = Some(impl_item_type);
844        }
845    }
846    needle
847}
848
849fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
850    let mut needle = None;
851    for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
852        syn::ImplItem::Fn(iifn) => Some(iifn),
853        _ => None,
854    }) {
855        let ident_string = impl_item_fn.sig.ident.to_string();
856        if ident_string == name {
857            needle = Some(impl_item_fn);
858        }
859    }
860    needle
861}
862
863fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
864    let mut needle = None;
865    for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
866        syn::ImplItem::Const(iiconst) => Some(iiconst),
867        _ => None,
868    }) {
869        let ident_string = impl_item_const.ident.to_string();
870        if ident_string == name {
871            needle = Some(impl_item_const);
872        }
873    }
874    needle
875}
876
877fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
878    match &item.expr {
879        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
880            syn::Lit::Bool(lit) => Some(lit.value()),
881            _ => None,
882        },
883        _ => None,
884    }
885}
886
887fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
888    match &item.expr {
889        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
890            syn::Lit::Str(lit) => Ok(Some(lit.value())),
891            _ => Ok(None),
892        },
893        syn::Expr::Call(expr_call) => match &*expr_call.func {
894            syn::Expr::Path(expr_path) => {
895                let Some(last) = expr_path.path.segments.last() else {
896                    return Ok(None);
897                };
898                if last.ident == "Some" {
899                    match expr_call.args.first() {
900                        Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
901                            syn::Lit::Str(lit) => Ok(Some(lit.value())),
902                            _ => Ok(None),
903                        },
904                        _ => Ok(None),
905                    }
906                } else {
907                    Ok(None)
908                }
909            }
910            _ => Ok(None),
911        },
912        ex => Err(syn::Error::new(ex.span(), "")),
913    }
914}
915
916fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
917    if let Type::Path(ty_path) = ty {
918        for segment in ty_path.path.segments.iter_mut() {
919            if segment.ident == "Self" {
920                segment.ident = target.clone()
921            }
922            use syn::{GenericArgument, PathArguments};
923            match segment.arguments {
924                PathArguments::AngleBracketed(ref mut angle_args) => {
925                    for arg in angle_args.args.iter_mut() {
926                        if let GenericArgument::Type(inner_ty) = arg {
927                            remap_self_to_target(inner_ty, target)
928                        }
929                    }
930                }
931                PathArguments::Parenthesized(_) => (),
932                PathArguments::None => (),
933            }
934        }
935    }
936}
937
938fn get_pgrx_attr_macro(attr_name: &str, ty: &syn::Type) -> Option<TokenStream2> {
939    match &ty {
940        syn::Type::Macro(ty_macro) => {
941            let mut found_pgrx = false;
942            let mut found_attr = false;
943            // We don't actually have type resolution here, this is a "Best guess".
944            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
945                match segment.ident.to_string().as_str() {
946                    "pgrx" if idx == 0 => found_pgrx = true,
947                    attr if attr == attr_name => found_attr = true,
948                    _ => (),
949                }
950            }
951            if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
952                Some(ty_macro.mac.tokens.clone())
953            } else {
954                None
955            }
956        }
957        _ => None,
958    }
959}
960
961#[cfg(test)]
962mod tests {
963    use super::PgAggregate;
964    use eyre::Result;
965    use quote::ToTokens;
966    use syn::{ItemImpl, parse_quote};
967
968    #[test]
969    fn agg_required_only() -> Result<()> {
970        let tokens: ItemImpl = parse_quote! {
971            #[pg_aggregate]
972            impl Aggregate<DemoName> for DemoAgg {
973                type State = PgVarlena<Self>;
974                type Args = i32;
975
976                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
977                    todo!()
978                }
979            }
980        };
981        // It should not error, as it's valid.
982        let agg = PgAggregate::new(tokens);
983        assert!(agg.is_ok());
984        // It should create 1 extern, the state.
985        let agg = agg.unwrap();
986        assert_eq!(agg.0.pg_externs.len(), 1);
987        // That extern should be named specifically:
988        let extern_fn = &agg.0.pg_externs[0];
989        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
990        // It should be possible to generate entity tokens.
991        let _ = agg.to_token_stream();
992        Ok(())
993    }
994
995    #[test]
996    fn agg_all_options() -> Result<()> {
997        let tokens: ItemImpl = parse_quote! {
998            #[pg_aggregate]
999            impl Aggregate<DemoName> for DemoAgg {
1000                type State = PgVarlena<Self>;
1001                type Args = i32;
1002                type OrderBy = i32;
1003                type MovingState = i32;
1004
1005                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
1006                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1007                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1008                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
1009                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
1010                const HYPOTHETICAL: bool = true;
1011
1012                fn state(current: Self::State, v: Self::Args) -> Self::State {
1013                    todo!()
1014                }
1015
1016                fn finalize(current: Self::State) -> Self::Finalize {
1017                    todo!()
1018                }
1019
1020                fn combine(current: Self::State, _other: Self::State) -> Self::State {
1021                    todo!()
1022                }
1023
1024                fn serial(current: Self::State) -> Vec<u8> {
1025                    todo!()
1026                }
1027
1028                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
1029                    todo!()
1030                }
1031
1032                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1033                    todo!()
1034                }
1035
1036                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1037                    todo!()
1038                }
1039
1040                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1041                    todo!()
1042                }
1043            }
1044        };
1045        // It should not error, as it's valid.
1046        let agg = PgAggregate::new(tokens);
1047        assert!(agg.is_ok());
1048        // It should create 8 externs!
1049        let agg = agg.unwrap();
1050        assert_eq!(agg.0.pg_externs.len(), 8);
1051        // That extern should be named specifically:
1052        let extern_fn = &agg.0.pg_externs[0];
1053        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1054        // It should be possible to generate entity tokens.
1055        let _ = agg.to_token_stream();
1056        Ok(())
1057    }
1058
1059    #[test]
1060    fn agg_missing_required() -> Result<()> {
1061        // This is not valid as it is missing required types/consts.
1062        let tokens: ItemImpl = parse_quote! {
1063            #[pg_aggregate]
1064            impl Aggregate for IntegerAvgState {
1065            }
1066        };
1067        let agg = PgAggregate::new(tokens);
1068        assert!(agg.is_err());
1069        Ok(())
1070    }
1071}