pgrx_sql_entity_graph/aggregate/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25use syn::PathArguments;
26
27use crate::enrich::CodeEnrichment;
28use crate::enrich::ToEntityGraphTokens;
29use crate::enrich::ToRustCodeTokens;
30use convert_case::{Case, Casing};
31use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
32use quote::quote;
33use syn::parse::{Parse, ParseStream};
34use syn::punctuated::Punctuated;
35use syn::spanned::Spanned;
36use syn::{
37    parse_quote, Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type,
38};
39
40use crate::ToSqlConfig;
41
42use super::UsedType;
43
44// We support only 32 tuples...
45const ARG_NAMES: [&str; 32] = [
46    "arg_one",
47    "arg_two",
48    "arg_three",
49    "arg_four",
50    "arg_five",
51    "arg_six",
52    "arg_seven",
53    "arg_eight",
54    "arg_nine",
55    "arg_ten",
56    "arg_eleven",
57    "arg_twelve",
58    "arg_thirteen",
59    "arg_fourteen",
60    "arg_fifteen",
61    "arg_sixteen",
62    "arg_seventeen",
63    "arg_eighteen",
64    "arg_nineteen",
65    "arg_twenty",
66    "arg_twenty_one",
67    "arg_twenty_two",
68    "arg_twenty_three",
69    "arg_twenty_four",
70    "arg_twenty_five",
71    "arg_twenty_six",
72    "arg_twenty_seven",
73    "arg_twenty_eight",
74    "arg_twenty_nine",
75    "arg_thirty",
76    "arg_thirty_one",
77    "arg_thirty_two",
78];
79
80/** A parsed `#[pg_aggregate]` item.
81*/
82#[derive(Debug, Clone)]
83pub struct PgAggregate {
84    item_impl: ItemImpl,
85    name: Expr,
86    target_ident: Ident,
87    snake_case_target_ident: Ident,
88    pg_externs: Vec<ItemFn>,
89    // Note these should not be considered *writable*, they're snapshots from construction.
90    type_args: AggregateTypeList,
91    type_ordered_set_args: Option<AggregateTypeList>,
92    type_moving_state: Option<UsedType>,
93    type_stype: AggregateType,
94    const_ordered_set: bool,
95    const_parallel: Option<syn::Expr>,
96    const_finalize_modify: Option<syn::Expr>,
97    const_moving_finalize_modify: Option<syn::Expr>,
98    const_initial_condition: Option<String>,
99    const_sort_operator: Option<String>,
100    const_moving_intial_condition: Option<String>,
101    fn_state: Ident,
102    fn_finalize: Option<Ident>,
103    fn_combine: Option<Ident>,
104    fn_serial: Option<Ident>,
105    fn_deserial: Option<Ident>,
106    fn_moving_state: Option<Ident>,
107    fn_moving_state_inverse: Option<Ident>,
108    fn_moving_finalize: Option<Ident>,
109    hypothetical: bool,
110    to_sql_config: ToSqlConfig,
111}
112
113fn extract_generic_from_trait(item_impl: &ItemImpl) -> Result<&Type, syn::Error> {
114    let (_, path, _) = item_impl.trait_.as_ref().ok_or_else(|| {
115        syn::Error::new_spanned(
116            item_impl,
117            "`#[pg_aggregate]` can only be used on `impl` blocks for a trait.",
118        )
119    })?;
120
121    let last_segment = path
122        .segments
123        .last()
124        .ok_or_else(|| syn::Error::new_spanned(path, "Trait path is empty or malformed."))?;
125
126    if last_segment.ident != "Aggregate" {
127        return Err(syn::Error::new_spanned(
128            last_segment.ident.clone(),
129            "`#[pg_aggregate]` only works with the `Aggregate` trait.",
130        ));
131    }
132
133    let args = match &last_segment.arguments {
134        PathArguments::AngleBracketed(args) => args,
135        _ => {
136            return Err(syn::Error::new_spanned(
137                last_segment.ident.clone(),
138                "`Aggregate` trait must have angle-bracketed generic arguments (e.g., `Aggregate<T>`). Missing generic argument.",
139            ));
140        }
141    };
142
143    let generic_arg = args.args.first().ok_or_else(|| {
144        syn::Error::new_spanned(
145            args,
146            "`Aggregate` trait requires at least one generic argument (e.g., `Aggregate<T>`).",
147        )
148    })?;
149
150    if let syn::GenericArgument::Type(ty) = generic_arg {
151        Ok(ty)
152    } else {
153        Err(syn::Error::new_spanned(
154            generic_arg,
155            "Expected a type as the generic argument for `Aggregate` (e.g., `Aggregate<MyType>`).",
156        ))
157    }
158}
159
160fn get_generic_type_name(ty: &syn::Type) -> Result<String, syn::Error> {
161    if let Type::Path(type_path) = ty {
162        if let Some(ident) = type_path.path.segments.last().map(|s| &s.ident) {
163            let ident = ident.to_string();
164
165            match ident.as_str() {
166                "!" => Ok("never".to_string()),
167                "()" => Ok("unit".to_string()),
168                _ => Ok(ident),
169            }
170        } else {
171            Err(syn::Error::new_spanned(ty, "Generic type path is empty or malformed."))
172        }
173    } else {
174        Err(syn::Error::new_spanned(ty, "Expected a path type for the generic argument."))
175    }
176}
177
178impl PgAggregate {
179    pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
180        let to_sql_config =
181            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
182        let target_path = get_target_path(&item_impl)?;
183        let target_ident = get_target_ident(&target_path)?;
184
185        let mut pg_externs = Vec::default();
186        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
187        // and mutate the actual one.
188        let item_impl_snapshot = item_impl.clone();
189
190        let generic_type = extract_generic_from_trait(&item_impl)?.clone();
191        let generic_type_name = get_generic_type_name(&generic_type)?;
192
193        let snake_case_target_ident =
194            format!("{target_ident}_{generic_type_name}").to_case(Case::Snake);
195        let snake_case_target_ident = Ident::new(&snake_case_target_ident, target_ident.span());
196        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
197
198        let name = parse_quote! {
199            <#generic_type as ::pgrx::aggregate::ToAggregateName>::NAME
200        };
201
202        // `State` is an optional value, we default to `Self`.
203        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
204        let _type_state_value = type_state.map(|v| v.ty.clone());
205
206        let type_state_without_self = if let Some(inner) = type_state {
207            let mut remapped = inner.ty.clone();
208            remap_self_to_target(&mut remapped, &target_ident);
209            remapped
210        } else {
211            item_impl.items.push(parse_quote! {
212                type State = Self;
213            });
214            let mut remapped = parse_quote!(Self);
215            remap_self_to_target(&mut remapped, &target_ident);
216            remapped
217        };
218        let type_stype = AggregateType {
219            used_ty: UsedType::new(type_state_without_self.clone())?,
220            name: Some("state".into()),
221        };
222
223        // `MovingState` is an optional value, we default to nothing.
224        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
225        let type_moving_state;
226        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
227            type_moving_state = impl_type_moving_state.ty.clone();
228            Some(UsedType::new(type_moving_state.clone())?)
229        } else {
230            item_impl.items.push(parse_quote! {
231                type MovingState = ();
232            });
233            type_moving_state = parse_quote! { () };
234            None
235        };
236
237        // `OrderBy` is an optional value, we default to nothing.
238        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
239        let type_ordered_set_args_value =
240            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
241        if type_ordered_set_args.is_none() {
242            item_impl.items.push(parse_quote! {
243                type OrderedSetArgs = ();
244            })
245        }
246        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
247            type_ordered_set_args_value
248        {
249            let direct_args = order_by_direct_args
250                .found
251                .iter()
252                .map(|x| {
253                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
254                })
255                .collect::<Vec<_>>();
256            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
257                .iter()
258                .zip(direct_args.iter())
259                .map(|(default_name, (custom_name, _ty, _orig))| {
260                    Ident::new(
261                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
262                        Span::mixed_site(),
263                    )
264                })
265                .collect::<Vec<_>>();
266            let direct_args_with_names = direct_args
267                .iter()
268                .zip(direct_arg_names.iter())
269                .map(|(arg, name)| {
270                    let arg_ty = &arg.2; // original_type
271                    parse_quote! {
272                        #name: #arg_ty
273                    }
274                })
275                .collect::<Vec<syn::FnArg>>();
276            (direct_args_with_names, direct_arg_names)
277        } else {
278            (Vec::default(), Vec::default())
279        };
280
281        // `Args` is an optional value, we default to nothing.
282        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
283            syn::Error::new(
284                item_impl_snapshot.span(),
285                "`#[pg_aggregate]` requires the `Args` type defined.",
286            )
287        })?;
288        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
289        let args = type_args_value
290            .found
291            .iter()
292            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
293            .collect::<Vec<_>>();
294        let arg_names = ARG_NAMES[0..args.len()]
295            .iter()
296            .zip(args.iter())
297            .map(|(default_name, (custom_name, ty))| {
298                Ident::new(
299                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
300                    ty.span(),
301                )
302            })
303            .collect::<Vec<_>>();
304        let args_with_names = args
305            .iter()
306            .zip(arg_names.iter())
307            .map(|(arg, name)| {
308                let arg_ty = &arg.1;
309                quote! {
310                    #name: #arg_ty
311                }
312            })
313            .collect::<Vec<_>>();
314
315        // `Finalize` is an optional value, we default to nothing.
316        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
317        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
318            type_finalize.ty.clone()
319        } else {
320            item_impl.items.push(parse_quote! {
321                type Finalize = ();
322            });
323            parse_quote! { () }
324        };
325
326        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
327
328        let fn_state_name = if let Some(found) = fn_state {
329            let fn_name =
330                Ident::new(&format!("{snake_case_target_ident}_state"), found.sig.ident.span());
331            let pg_extern_attr = pg_extern_attr(found);
332
333            pg_externs.push(parse_quote! {
334                #[allow(non_snake_case, clippy::too_many_arguments)]
335                #pg_extern_attr
336                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
337                    unsafe {
338                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
339                            fcinfo,
340                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::state(this, (#(#arg_names),*), fcinfo)
341                        )
342                    }
343                }
344            });
345            fn_name
346        } else {
347            return Err(syn::Error::new(
348                item_impl.span(),
349                "Aggregate implementation must include state function.",
350            ));
351        };
352
353        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
354        let fn_combine_name = if let Some(found) = fn_combine {
355            let fn_name =
356                Ident::new(&format!("{snake_case_target_ident}_combine"), found.sig.ident.span());
357            let pg_extern_attr = pg_extern_attr(found);
358            pg_externs.push(parse_quote! {
359                #[allow(non_snake_case, clippy::too_many_arguments)]
360                #pg_extern_attr
361                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
362                    unsafe {
363                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
364                            fcinfo,
365                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::combine(this, v, fcinfo)
366                        )
367                    }
368                }
369            });
370            Some(fn_name)
371        } else {
372            item_impl.items.push(parse_quote! {
373                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
374                    unimplemented!("Call to combine on an aggregate which does not support it.")
375                }
376            });
377            None
378        };
379
380        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
381        let fn_finalize_name = if let Some(found) = fn_finalize {
382            let fn_name =
383                Ident::new(&format!("{snake_case_target_ident}_finalize"), found.sig.ident.span());
384            let pg_extern_attr = pg_extern_attr(found);
385
386            if !direct_args_with_names.is_empty() {
387                pg_externs.push(parse_quote! {
388                    #[allow(non_snake_case, clippy::too_many_arguments)]
389                    #pg_extern_attr
390                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
391                        unsafe {
392                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
393                                fcinfo,
394                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (#(#direct_arg_names),*), fcinfo)
395                            )
396                        }
397                    }
398                });
399            } else {
400                pg_externs.push(parse_quote! {
401                    #[allow(non_snake_case, clippy::too_many_arguments)]
402                    #pg_extern_attr
403                    fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
404                        unsafe {
405                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
406                                fcinfo,
407                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (), fcinfo)
408                            )
409                        }
410                    }
411                });
412            };
413            Some(fn_name)
414        } else {
415            item_impl.items.push(parse_quote! {
416                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
417                    unimplemented!("Call to finalize on an aggregate which does not support it.")
418                }
419            });
420            None
421        };
422
423        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
424        let fn_serial_name = if let Some(found) = fn_serial {
425            let fn_name =
426                Ident::new(&format!("{snake_case_target_ident}_serial"), found.sig.ident.span());
427            let pg_extern_attr = pg_extern_attr(found);
428            pg_externs.push(parse_quote! {
429                #[allow(non_snake_case, clippy::too_many_arguments)]
430                #pg_extern_attr
431                fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
432                    unsafe {
433                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
434                            fcinfo,
435                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::serial(this, fcinfo)
436                        )
437                    }
438                }
439            });
440            Some(fn_name)
441        } else {
442            item_impl.items.push(parse_quote! {
443                fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
444                    unimplemented!("Call to serial on an aggregate which does not support it.")
445                }
446            });
447            None
448        };
449
450        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
451        let fn_deserial_name = if let Some(found) = fn_deserial {
452            let fn_name =
453                Ident::new(&format!("{snake_case_target_ident}_deserial"), found.sig.ident.span());
454            let pg_extern_attr = pg_extern_attr(found);
455            pg_externs.push(parse_quote! {
456                #[allow(non_snake_case, clippy::too_many_arguments)]
457                #pg_extern_attr
458                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
459                    unsafe {
460                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
461                            fcinfo,
462                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::deserial(this, buf, internal, fcinfo)
463                        )
464                    }
465                }
466            });
467            Some(fn_name)
468        } else {
469            item_impl.items.push(parse_quote! {
470                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
471                    unimplemented!("Call to deserial on an aggregate which does not support it.")
472                }
473            });
474            None
475        };
476
477        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
478        let fn_moving_state_name = if let Some(found) = fn_moving_state {
479            let fn_name = Ident::new(
480                &format!("{snake_case_target_ident}_moving_state"),
481                found.sig.ident.span(),
482            );
483            let pg_extern_attr = pg_extern_attr(found);
484
485            pg_externs.push(parse_quote! {
486                #[allow(non_snake_case, clippy::too_many_arguments)]
487                #pg_extern_attr
488                fn #fn_name(
489                    mstate: #type_moving_state,
490                    #(#args_with_names),*,
491                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
492                ) -> #type_moving_state {
493                    unsafe {
494                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
495                            fcinfo,
496                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state(mstate, (#(#arg_names),*), fcinfo)
497                        )
498                    }
499                }
500            });
501            Some(fn_name)
502        } else {
503            item_impl.items.push(parse_quote! {
504                fn moving_state(
505                    _mstate: <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState,
506                    _v: Self::Args,
507                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
508                ) -> <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState {
509                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
510                }
511            });
512            None
513        };
514
515        let fn_moving_state_inverse =
516            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
517        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
518            let fn_name = Ident::new(
519                &format!("{snake_case_target_ident}_moving_state_inverse"),
520                found.sig.ident.span(),
521            );
522            let pg_extern_attr = pg_extern_attr(found);
523            pg_externs.push(parse_quote! {
524                #[allow(non_snake_case, clippy::too_many_arguments)]
525                #pg_extern_attr
526                fn #fn_name(
527                    mstate: #type_moving_state,
528                    #(#args_with_names),*,
529                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
530                ) -> #type_moving_state {
531                    unsafe {
532                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
533                            fcinfo,
534                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
535                        )
536                    }
537                }
538            });
539            Some(fn_name)
540        } else {
541            item_impl.items.push(parse_quote! {
542                fn moving_state_inverse(
543                    _mstate: #type_moving_state,
544                    _v: Self::Args,
545                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
546                ) -> #type_moving_state {
547                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
548                }
549            });
550            None
551        };
552
553        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
554        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
555            let fn_name = Ident::new(
556                &format!("{snake_case_target_ident}_moving_finalize"),
557                found.sig.ident.span(),
558            );
559            let pg_extern_attr = pg_extern_attr(found);
560            let maybe_comma: Option<syn::Token![,]> =
561                if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
562
563            pg_externs.push(parse_quote! {
564                #[allow(non_snake_case, clippy::too_many_arguments)]
565                #pg_extern_attr
566                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
567                    unsafe {
568                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
569                            fcinfo,
570                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
571                        )
572                    }
573                }
574            });
575            Some(fn_name)
576        } else {
577            item_impl.items.push(parse_quote! {
578                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
579                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
580                }
581            });
582            None
583        };
584
585        Ok(CodeEnrichment(Self {
586            item_impl,
587            target_ident,
588            pg_externs,
589            name,
590            snake_case_target_ident,
591            type_args: type_args_value,
592            type_ordered_set_args: type_ordered_set_args_value,
593            type_moving_state: type_moving_state_value,
594            type_stype,
595            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
596                .map(|x| x.expr.clone()),
597            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
598                .map(|x| x.expr.clone()),
599            const_moving_finalize_modify: get_impl_const_by_name(
600                &item_impl_snapshot,
601                "MOVING_FINALIZE_MODIFY",
602            )
603            .map(|x| x.expr.clone()),
604            const_initial_condition: get_impl_const_by_name(
605                &item_impl_snapshot,
606                "INITIAL_CONDITION",
607            )
608            .and_then(|e| get_const_litstr(e).transpose())
609            .transpose()?,
610            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
611                .and_then(get_const_litbool)
612                .unwrap_or(false),
613            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
614                .and_then(|e| get_const_litstr(e).transpose())
615                .transpose()?,
616            const_moving_intial_condition: get_impl_const_by_name(
617                &item_impl_snapshot,
618                "MOVING_INITIAL_CONDITION",
619            )
620            .and_then(|e| get_const_litstr(e).transpose())
621            .transpose()?,
622            fn_state: fn_state_name,
623            fn_finalize: fn_finalize_name,
624            fn_combine: fn_combine_name,
625            fn_serial: fn_serial_name,
626            fn_deserial: fn_deserial_name,
627            fn_moving_state: fn_moving_state_name,
628            fn_moving_state_inverse: fn_moving_state_inverse_name,
629            fn_moving_finalize: fn_moving_finalize_name,
630            hypothetical: if let Some(value) =
631                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
632            {
633                match &value.expr {
634                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
635                        syn::Lit::Bool(lit) => lit.value,
636                        _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
637                    },
638                    _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
639                }
640            } else {
641                false
642            },
643            to_sql_config,
644        }))
645    }
646}
647
648impl ToEntityGraphTokens for PgAggregate {
649    fn to_entity_graph_tokens(&self) -> TokenStream2 {
650        let target_ident = &self.target_ident;
651        let sql_graph_entity_fn_name = syn::Ident::new(
652            &format!("__pgrx_internals_aggregate_{}", self.snake_case_target_ident),
653            target_ident.span(),
654        );
655
656        let name = &self.name;
657        let type_args_iter = &self.type_args.entity_tokens();
658        let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
659
660        let type_moving_state_entity_tokens =
661            self.type_moving_state.clone().map(|v| v.entity_tokens());
662        let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
663        let type_stype = self.type_stype.entity_tokens();
664        let const_ordered_set = self.const_ordered_set;
665        let const_parallel_iter = self.const_parallel.iter();
666        let const_finalize_modify_iter = self.const_finalize_modify.iter();
667        let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
668        let const_initial_condition_iter = self.const_initial_condition.iter();
669        let const_sort_operator_iter = self.const_sort_operator.iter();
670        let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
671        let hypothetical = self.hypothetical;
672        let fn_state = &self.fn_state;
673        let fn_finalize_iter = self.fn_finalize.iter();
674        let fn_combine_iter = self.fn_combine.iter();
675        let fn_serial_iter = self.fn_serial.iter();
676        let fn_deserial_iter = self.fn_deserial.iter();
677        let fn_moving_state_iter = self.fn_moving_state.iter();
678        let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
679        let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
680        let to_sql_config = &self.to_sql_config;
681
682        quote! {
683            #[no_mangle]
684            #[doc(hidden)]
685            #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
686            pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
687                let submission = ::pgrx::pgrx_sql_entity_graph::PgAggregateEntity {
688                    full_path: ::core::any::type_name::<#target_ident>(),
689                    module_path: module_path!(),
690                    file: file!(),
691                    line: line!(),
692                    name: #name,
693                    ordered_set: #const_ordered_set,
694                    ty_id: ::core::any::TypeId::of::<#target_ident>(),
695                    args: #type_args_iter,
696                    direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
697                    stype: #type_stype,
698                    sfunc: stringify!(#fn_state),
699                    combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
700                    finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
701                    finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
702                    initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
703                    serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
704                    deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
705                    msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
706                    minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
707                    mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
708                    mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
709                    mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
710                    minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
711                    sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
712                    parallel: None #( .unwrap_or(#const_parallel_iter) )*,
713                    hypothetical: #hypothetical,
714                    to_sql_config: #to_sql_config,
715                };
716                ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
717            }
718        }
719    }
720}
721
722impl ToRustCodeTokens for PgAggregate {
723    fn to_rust_code_tokens(&self) -> TokenStream2 {
724        let impl_item = &self.item_impl;
725        let pg_externs = self.pg_externs.iter();
726
727        quote! {
728            #impl_item
729            #(#pg_externs)*
730        }
731    }
732}
733
734impl Parse for CodeEnrichment<PgAggregate> {
735    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
736        PgAggregate::new(input.parse()?)
737    }
738}
739
740fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
741    let last = path.segments.last().ok_or_else(|| {
742        syn::Error::new(
743            path.span(),
744            "`#[pg_aggregate]` only works with types whose path have a final segment.",
745        )
746    })?;
747    Ok(last.ident.clone())
748}
749
750fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
751    let target_ident = match &*item_impl.self_ty {
752        syn::Type::Path(ref type_path) => {
753            let last_segment = type_path.path.segments.last().ok_or_else(|| {
754                syn::Error::new(
755                    type_path.span(),
756                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
757                )
758            })?;
759            if last_segment.ident == "PgVarlena" {
760                match &last_segment.arguments {
761                    syn::PathArguments::AngleBracketed(angled) => {
762                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
763                            type_path.span(),
764                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
765                        ))?;
766                        match &first {
767                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
768                            _ => return Err(syn::Error::new(
769                                type_path.span(),
770                                "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
771                            )),
772                        }
773                    },
774                    _ => return Err(syn::Error::new(
775                        type_path.span(),
776                        "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
777                    )),
778                }
779            } else {
780                type_path.path.clone()
781            }
782        }
783        something_else => {
784            return Err(syn::Error::new(
785                something_else.span(),
786                "`#[pg_aggregate]` only works with types.",
787            ))
788        }
789    };
790    Ok(target_ident)
791}
792
793fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
794    let mut found = None;
795    for attr in item.attrs.iter() {
796        match attr.path().segments.last() {
797            Some(segment) if segment.ident == "pgrx" => {
798                found = Some(attr);
799                break;
800            }
801            _ => (),
802        };
803    }
804
805    let attrs = if let Some(attr) = found {
806        let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
807        let attrs = attr.parse_args_with(parser);
808        attrs.ok()
809    } else {
810        None
811    };
812
813    match attrs {
814        Some(args) => parse_quote! {
815            #[::pgrx::pg_extern(#args)]
816        },
817        None => parse_quote! {
818            #[::pgrx::pg_extern]
819        },
820    }
821}
822
823fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
824    let mut needle = None;
825    for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
826        syn::ImplItem::Type(iitype) => Some(iitype),
827        _ => None,
828    }) {
829        let ident_string = impl_item_type.ident.to_string();
830        if ident_string == name {
831            needle = Some(impl_item_type);
832        }
833    }
834    needle
835}
836
837fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
838    let mut needle = None;
839    for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
840        syn::ImplItem::Fn(iifn) => Some(iifn),
841        _ => None,
842    }) {
843        let ident_string = impl_item_fn.sig.ident.to_string();
844        if ident_string == name {
845            needle = Some(impl_item_fn);
846        }
847    }
848    needle
849}
850
851fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
852    let mut needle = None;
853    for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
854        syn::ImplItem::Const(iiconst) => Some(iiconst),
855        _ => None,
856    }) {
857        let ident_string = impl_item_const.ident.to_string();
858        if ident_string == name {
859            needle = Some(impl_item_const);
860        }
861    }
862    needle
863}
864
865fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
866    match &item.expr {
867        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
868            syn::Lit::Bool(lit) => Some(lit.value()),
869            _ => None,
870        },
871        _ => None,
872    }
873}
874
875fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
876    match &item.expr {
877        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
878            syn::Lit::Str(lit) => Ok(Some(lit.value())),
879            _ => Ok(None),
880        },
881        syn::Expr::Call(expr_call) => match &*expr_call.func {
882            syn::Expr::Path(expr_path) => {
883                let Some(last) = expr_path.path.segments.last() else {
884                    return Ok(None);
885                };
886                if last.ident == "Some" {
887                    match expr_call.args.first() {
888                        Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
889                            syn::Lit::Str(lit) => Ok(Some(lit.value())),
890                            _ => Ok(None),
891                        },
892                        _ => Ok(None),
893                    }
894                } else {
895                    Ok(None)
896                }
897            }
898            _ => Ok(None),
899        },
900        ex => Err(syn::Error::new(ex.span(), "")),
901    }
902}
903
904fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
905    if let Type::Path(ref mut ty_path) = ty {
906        for segment in ty_path.path.segments.iter_mut() {
907            if segment.ident == "Self" {
908                segment.ident = target.clone()
909            }
910            use syn::{GenericArgument, PathArguments};
911            match segment.arguments {
912                PathArguments::AngleBracketed(ref mut angle_args) => {
913                    for arg in angle_args.args.iter_mut() {
914                        if let GenericArgument::Type(inner_ty) = arg {
915                            remap_self_to_target(inner_ty, target)
916                        }
917                    }
918                }
919                PathArguments::Parenthesized(_) => (),
920                PathArguments::None => (),
921            }
922        }
923    }
924}
925
926fn get_pgrx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
927    match &ty {
928        syn::Type::Macro(ty_macro) => {
929            let mut found_pgrx = false;
930            let mut found_attr = false;
931            // We don't actually have type resolution here, this is a "Best guess".
932            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
933                match segment.ident.to_string().as_str() {
934                    "pgrx" if idx == 0 => found_pgrx = true,
935                    attr if attr == attr_name.as_ref() => found_attr = true,
936                    _ => (),
937                }
938            }
939            if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
940                Some(ty_macro.mac.tokens.clone())
941            } else {
942                None
943            }
944        }
945        _ => None,
946    }
947}
948
949#[cfg(test)]
950mod tests {
951    use super::PgAggregate;
952    use eyre::Result;
953    use quote::ToTokens;
954    use syn::{parse_quote, ItemImpl};
955
956    #[test]
957    fn agg_required_only() -> Result<()> {
958        let tokens: ItemImpl = parse_quote! {
959            #[pg_aggregate]
960            impl Aggregate<DemoName> for DemoAgg {
961                type State = PgVarlena<Self>;
962                type Args = i32;
963
964                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
965                    todo!()
966                }
967            }
968        };
969        // It should not error, as it's valid.
970        let agg = PgAggregate::new(tokens);
971        assert!(agg.is_ok());
972        // It should create 1 extern, the state.
973        let agg = agg.unwrap();
974        assert_eq!(agg.0.pg_externs.len(), 1);
975        // That extern should be named specifically:
976        let extern_fn = &agg.0.pg_externs[0];
977        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
978        // It should be possible to generate entity tokens.
979        let _ = agg.to_token_stream();
980        Ok(())
981    }
982
983    #[test]
984    fn agg_all_options() -> Result<()> {
985        let tokens: ItemImpl = parse_quote! {
986            #[pg_aggregate]
987            impl Aggregate<DemoName> for DemoAgg {
988                type State = PgVarlena<Self>;
989                type Args = i32;
990                type OrderBy = i32;
991                type MovingState = i32;
992
993                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
994                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
995                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
996                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
997                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
998                const HYPOTHETICAL: bool = true;
999
1000                fn state(current: Self::State, v: Self::Args) -> Self::State {
1001                    todo!()
1002                }
1003
1004                fn finalize(current: Self::State) -> Self::Finalize {
1005                    todo!()
1006                }
1007
1008                fn combine(current: Self::State, _other: Self::State) -> Self::State {
1009                    todo!()
1010                }
1011
1012                fn serial(current: Self::State) -> Vec<u8> {
1013                    todo!()
1014                }
1015
1016                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
1017                    todo!()
1018                }
1019
1020                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1021                    todo!()
1022                }
1023
1024                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1025                    todo!()
1026                }
1027
1028                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1029                    todo!()
1030                }
1031            }
1032        };
1033        // It should not error, as it's valid.
1034        let agg = PgAggregate::new(tokens);
1035        assert!(agg.is_ok());
1036        // It should create 8 externs!
1037        let agg = agg.unwrap();
1038        assert_eq!(agg.0.pg_externs.len(), 8);
1039        // That extern should be named specifically:
1040        let extern_fn = &agg.0.pg_externs[0];
1041        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1042        // It should be possible to generate entity tokens.
1043        let _ = agg.to_token_stream();
1044        Ok(())
1045    }
1046
1047    #[test]
1048    fn agg_missing_required() -> Result<()> {
1049        // This is not valid as it is missing required types/consts.
1050        let tokens: ItemImpl = parse_quote! {
1051            #[pg_aggregate]
1052            impl Aggregate for IntegerAvgState {
1053            }
1054        };
1055        let agg = PgAggregate::new(tokens);
1056        assert!(agg.is_err());
1057        Ok(())
1058    }
1059}