pgx_utils/sql_entity_graph/aggregate/
mod.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`#[pg_aggregate]` related macro expansion for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16
17*/
18mod aggregate_type;
19pub(crate) mod entity;
20mod options;
21
22pub use aggregate_type::{AggregateType, AggregateTypeList};
23pub use options::{FinalizeModify, ParallelOption};
24
25use convert_case::{Case, Casing};
26use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
27use quote::{quote, ToTokens, TokenStreamExt};
28use syn::parse::{Parse, ParseStream};
29use syn::spanned::Spanned;
30use syn::{
31    parse_quote, Expr, ImplItemConst, ImplItemMethod, ImplItemType, ItemFn, ItemImpl, Path, Type,
32};
33
34use crate::sql_entity_graph::ToSqlConfig;
35
36use super::UsedType;
37
38// We support only 32 tuples...
39const ARG_NAMES: [&str; 32] = [
40    "arg_one",
41    "arg_two",
42    "arg_three",
43    "arg_four",
44    "arg_five",
45    "arg_six",
46    "arg_seven",
47    "arg_eight",
48    "arg_nine",
49    "arg_ten",
50    "arg_eleven",
51    "arg_twelve",
52    "arg_thirteen",
53    "arg_fourteen",
54    "arg_fifteen",
55    "arg_sixteen",
56    "arg_seventeen",
57    "arg_eighteen",
58    "arg_nineteen",
59    "arg_twenty",
60    "arg_twenty_one",
61    "arg_twenty_two",
62    "arg_twenty_three",
63    "arg_twenty_four",
64    "arg_twenty_five",
65    "arg_twenty_six",
66    "arg_twenty_seven",
67    "arg_twenty_eight",
68    "arg_twenty_nine",
69    "arg_thirty",
70    "arg_thirty_one",
71    "arg_thirty_two",
72];
73
74/** A parsed `#[pg_aggregate]` item.
75*/
76#[derive(Debug, Clone)]
77pub struct PgAggregate {
78    item_impl: ItemImpl,
79    name: Expr,
80    pg_externs: Vec<ItemFn>,
81    // Note these should not be considered *writable*, they're snapshots from construction.
82    type_args: AggregateTypeList,
83    type_ordered_set_args: Option<AggregateTypeList>,
84    type_moving_state: Option<UsedType>,
85    type_stype: AggregateType,
86    const_ordered_set: bool,
87    const_parallel: Option<syn::Expr>,
88    const_finalize_modify: Option<syn::Expr>,
89    const_moving_finalize_modify: Option<syn::Expr>,
90    const_initial_condition: Option<String>,
91    const_sort_operator: Option<String>,
92    const_moving_intial_condition: Option<String>,
93    fn_state: Ident,
94    fn_finalize: Option<Ident>,
95    fn_combine: Option<Ident>,
96    fn_serial: Option<Ident>,
97    fn_deserial: Option<Ident>,
98    fn_moving_state: Option<Ident>,
99    fn_moving_state_inverse: Option<Ident>,
100    fn_moving_finalize: Option<Ident>,
101    hypothetical: bool,
102    to_sql_config: ToSqlConfig,
103}
104
105impl PgAggregate {
106    pub fn new(mut item_impl: ItemImpl) -> Result<Self, syn::Error> {
107        let to_sql_config =
108            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
109        let target_path = get_target_path(&item_impl)?;
110        let target_ident = get_target_ident(&target_path)?;
111
112        let snake_case_target_ident =
113            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
114        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
115
116        let mut pg_externs = Vec::default();
117        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
118        // and mutate the actual one.
119        let item_impl_snapshot = item_impl.clone();
120
121        if let Some((_, ref path, _)) = item_impl.trait_ {
122            // TODO: Consider checking the path if there is more than one segment to make sure it's pgx.
123            if let Some(last) = path.segments.last() {
124                if last.ident.to_string() != "Aggregate" {
125                    return Err(syn::Error::new(
126                        last.ident.span(),
127                        "`#[pg_aggregate]` only works with the `Aggregate` trait.",
128                    ));
129                }
130            }
131        }
132
133        let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
134            Some(item_const) => match item_const.expr {
135                syn::Expr::Lit(ref expr) => {
136                    if let syn::Lit::Str(_) = expr.lit {
137                        item_const.expr.clone()
138                    } else {
139                        panic!("`NAME` must be a `&'static str` for Aggregate implementations.")
140                    }
141                }
142                _ => panic!("`NAME` must be a `&'static str` for Aggregate implementations."),
143            },
144            None => {
145                item_impl.items.push(parse_quote! {
146                    const NAME: &'static str = stringify!(Self);
147                });
148                parse_quote! {
149                    stringify!(#target_ident)
150                }
151            }
152        };
153
154        // `State` is an optional value, we default to `Self`.
155        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
156        let _type_state_value = type_state.map(|v| v.ty.clone());
157
158        let type_state_without_self = if let Some(inner) = type_state {
159            let mut remapped = inner.ty.clone();
160            remap_self_to_target(&mut remapped, &target_ident);
161            remapped
162        } else {
163            item_impl.items.push(parse_quote! {
164                type State = Self;
165            });
166            let mut remapped = parse_quote!(Self);
167            remap_self_to_target(&mut remapped, &target_ident);
168            remapped
169        };
170        let type_stype = AggregateType {
171            used_ty: UsedType::new(type_state_without_self.clone())?,
172            name: Some("state".into()),
173        };
174
175        // `MovingState` is an optional value, we default to nothing.
176        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
177        let type_moving_state;
178        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
179            type_moving_state = impl_type_moving_state.ty.clone();
180            Some(UsedType::new(type_moving_state.clone())?)
181        } else {
182            item_impl.items.push(parse_quote! {
183                type MovingState = ();
184            });
185            type_moving_state = parse_quote! { () };
186            None
187        };
188
189        // `OrderBy` is an optional value, we default to nothing.
190        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
191        let type_ordered_set_args_value =
192            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
193        if type_ordered_set_args.is_none() {
194            item_impl.items.push(parse_quote! {
195                type OrderedSetArgs = ();
196            })
197        }
198        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
199            type_ordered_set_args_value
200        {
201            let direct_args = order_by_direct_args
202                .found
203                .iter()
204                .map(|x| {
205                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
206                })
207                .collect::<Vec<_>>();
208            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
209                .iter()
210                .zip(direct_args.iter())
211                .map(|(default_name, (custom_name, _ty, _orig))| {
212                    Ident::new(
213                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
214                        Span::mixed_site(),
215                    )
216                })
217                .collect::<Vec<_>>();
218            let direct_args_with_names = direct_args
219                .iter()
220                .zip(direct_arg_names.iter())
221                .map(|(arg, name)| {
222                    let arg_ty = &arg.2; // original_type
223                    parse_quote! {
224                        #name: #arg_ty
225                    }
226                })
227                .collect::<Vec<syn::FnArg>>();
228            (direct_args_with_names, direct_arg_names)
229        } else {
230            (Vec::default(), Vec::default())
231        };
232
233        // `Args` is an optional value, we default to nothing.
234        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
235            syn::Error::new(
236                item_impl_snapshot.span(),
237                "`#[pg_aggregate]` requires the `Args` type defined.",
238            )
239        })?;
240        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
241        let args = type_args_value
242            .found
243            .iter()
244            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
245            .collect::<Vec<_>>();
246        let arg_names = ARG_NAMES[0..args.len()]
247            .iter()
248            .zip(args.iter())
249            .map(|(default_name, (custom_name, ty))| {
250                Ident::new(
251                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
252                    ty.span(),
253                )
254            })
255            .collect::<Vec<_>>();
256        let args_with_names = args
257            .iter()
258            .zip(arg_names.iter())
259            .map(|(arg, name)| {
260                let arg_ty = &arg.1;
261                quote! {
262                    #name: #arg_ty
263                }
264            })
265            .collect::<Vec<_>>();
266
267        // `Finalize` is an optional value, we default to nothing.
268        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
269        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
270            type_finalize.ty.clone()
271        } else {
272            item_impl.items.push(parse_quote! {
273                type Finalize = ();
274            });
275            parse_quote! { () }
276        };
277
278        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
279
280        let fn_state_name = if let Some(found) = fn_state {
281            let fn_name =
282                Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
283            let pg_extern_attr = pg_extern_attr(found);
284
285            pg_externs.push(parse_quote! {
286                #[allow(non_snake_case, clippy::too_many_arguments)]
287                #pg_extern_attr
288                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
289                    <#target_path as pgx::Aggregate>::in_memory_context(
290                        fcinfo,
291                        move |_context| <#target_path as pgx::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
292                    )
293                }
294            });
295            fn_name
296        } else {
297            return Err(syn::Error::new(
298                item_impl.span(),
299                "Aggregate implementation must include state function.",
300            ));
301        };
302
303        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
304        let fn_combine_name = if let Some(found) = fn_combine {
305            let fn_name =
306                Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
307            let pg_extern_attr = pg_extern_attr(found);
308            pg_externs.push(parse_quote! {
309                #[allow(non_snake_case, clippy::too_many_arguments)]
310                #pg_extern_attr
311                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
312                    <#target_path as pgx::Aggregate>::in_memory_context(
313                        fcinfo,
314                        move |_context| <#target_path as pgx::Aggregate>::combine(this, v, fcinfo)
315                    )
316                }
317            });
318            Some(fn_name)
319        } else {
320            item_impl.items.push(parse_quote! {
321                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
322                    unimplemented!("Call to combine on an aggregate which does not support it.")
323                }
324            });
325            None
326        };
327
328        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
329        let fn_finalize_name = if let Some(found) = fn_finalize {
330            let fn_name = Ident::new(
331                &format!("{}_finalize", snake_case_target_ident),
332                found.sig.ident.span(),
333            );
334            let pg_extern_attr = pg_extern_attr(found);
335
336            if direct_args_with_names.len() > 0 {
337                pg_externs.push(parse_quote! {
338                    #[allow(non_snake_case, clippy::too_many_arguments)]
339                    #pg_extern_attr
340                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
341                        <#target_path as pgx::Aggregate>::in_memory_context(
342                            fcinfo,
343                            move |_context| <#target_path as pgx::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
344                        )
345                    }
346                });
347            } else {
348                pg_externs.push(parse_quote! {
349                    #[allow(non_snake_case, clippy::too_many_arguments)]
350                    #pg_extern_attr
351                    fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
352                        <#target_path as pgx::Aggregate>::in_memory_context(
353                            fcinfo,
354                            move |_context| <#target_path as pgx::Aggregate>::finalize(this, (), fcinfo)
355                        )
356                    }
357                });
358            };
359            Some(fn_name)
360        } else {
361            item_impl.items.push(parse_quote! {
362                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
363                    unimplemented!("Call to finalize on an aggregate which does not support it.")
364                }
365            });
366            None
367        };
368
369        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
370        let fn_serial_name = if let Some(found) = fn_serial {
371            let fn_name =
372                Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
373            let pg_extern_attr = pg_extern_attr(found);
374            pg_externs.push(parse_quote! {
375                #[allow(non_snake_case, clippy::too_many_arguments)]
376                #pg_extern_attr
377                fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
378                    <#target_path as pgx::Aggregate>::in_memory_context(
379                        fcinfo,
380                        move |_context| <#target_path as pgx::Aggregate>::serial(this, fcinfo)
381                    )
382                }
383            });
384            Some(fn_name)
385        } else {
386            item_impl.items.push(parse_quote! {
387                fn serial(current: #type_state_without_self, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
388                    unimplemented!("Call to serial on an aggregate which does not support it.")
389                }
390            });
391            None
392        };
393
394        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
395        let fn_deserial_name = if let Some(found) = fn_deserial {
396            let fn_name = Ident::new(
397                &format!("{}_deserial", snake_case_target_ident),
398                found.sig.ident.span(),
399            );
400            let pg_extern_attr = pg_extern_attr(found);
401            pg_externs.push(parse_quote! {
402                #[allow(non_snake_case, clippy::too_many_arguments)]
403                #pg_extern_attr
404                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: pgx::PgBox<#type_state_without_self>, fcinfo: pgx::pg_sys::FunctionCallInfo) -> pgx::PgBox<#type_state_without_self> {
405                    <#target_path as pgx::Aggregate>::in_memory_context(
406                        fcinfo,
407                        move |_context| <#target_path as pgx::Aggregate>::deserial(this, buf, internal, fcinfo)
408                    )
409                }
410            });
411            Some(fn_name)
412        } else {
413            item_impl.items.push(parse_quote! {
414                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: pgx::PgBox<#type_state_without_self>, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> pgx::PgBox<#type_state_without_self> {
415                    unimplemented!("Call to deserial on an aggregate which does not support it.")
416                }
417            });
418            None
419        };
420
421        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
422        let fn_moving_state_name = if let Some(found) = fn_moving_state {
423            let fn_name = Ident::new(
424                &format!("{}_moving_state", snake_case_target_ident),
425                found.sig.ident.span(),
426            );
427            let pg_extern_attr = pg_extern_attr(found);
428
429            pg_externs.push(parse_quote! {
430                #[allow(non_snake_case, clippy::too_many_arguments)]
431                #pg_extern_attr
432                fn #fn_name(
433                    mstate: #type_moving_state,
434                    #(#args_with_names),*,
435                    fcinfo: pgx::pg_sys::FunctionCallInfo,
436                ) -> #type_moving_state {
437                    <#target_path as pgx::Aggregate>::in_memory_context(
438                        fcinfo,
439                        move |_context| <#target_path as pgx::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
440                    )
441                }
442            });
443            Some(fn_name)
444        } else {
445            item_impl.items.push(parse_quote! {
446                fn moving_state(
447                    _mstate: <#target_path as pgx::Aggregate>::MovingState,
448                    _v: Self::Args,
449                    _fcinfo: pgx::pg_sys::FunctionCallInfo,
450                ) -> <#target_path as pgx::Aggregate>::MovingState {
451                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
452                }
453            });
454            None
455        };
456
457        let fn_moving_state_inverse =
458            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
459        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
460            let fn_name = Ident::new(
461                &format!("{}_moving_state_inverse", snake_case_target_ident),
462                found.sig.ident.span(),
463            );
464            let pg_extern_attr = pg_extern_attr(found);
465            pg_externs.push(parse_quote! {
466                #[allow(non_snake_case, clippy::too_many_arguments)]
467                #pg_extern_attr
468                fn #fn_name(
469                    mstate: #type_moving_state,
470                    #(#args_with_names),*,
471                    fcinfo: pgx::pg_sys::FunctionCallInfo,
472                ) -> #type_moving_state {
473                    <#target_path as pgx::Aggregate>::in_memory_context(
474                        fcinfo,
475                        move |_context| <#target_path as pgx::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
476                    )
477                }
478            });
479            Some(fn_name)
480        } else {
481            item_impl.items.push(parse_quote! {
482                fn moving_state_inverse(
483                    _mstate: #type_moving_state,
484                    _v: Self::Args,
485                    _fcinfo: pgx::pg_sys::FunctionCallInfo,
486                ) -> #type_moving_state {
487                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
488                }
489            });
490            None
491        };
492
493        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
494        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
495            let fn_name = Ident::new(
496                &format!("{}_moving_finalize", snake_case_target_ident),
497                found.sig.ident.span(),
498            );
499            let pg_extern_attr = pg_extern_attr(found);
500            let maybe_comma: Option<syn::Token![,]> =
501                if direct_args_with_names.len() > 0 { Some(parse_quote! {,}) } else { None };
502
503            pg_externs.push(parse_quote! {
504                #[allow(non_snake_case, clippy::too_many_arguments)]
505                #pg_extern_attr
506                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
507                    <#target_path as pgx::Aggregate>::in_memory_context(
508                        fcinfo,
509                        move |_context| <#target_path as pgx::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
510                    )
511                }
512            });
513            Some(fn_name)
514        } else {
515            item_impl.items.push(parse_quote! {
516                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
517                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
518                }
519            });
520            None
521        };
522
523        Ok(Self {
524            item_impl,
525            pg_externs,
526            name,
527            type_args: type_args_value,
528            type_ordered_set_args: type_ordered_set_args_value,
529            type_moving_state: type_moving_state_value,
530            type_stype: type_stype,
531            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
532                .map(|x| x.expr.clone()),
533            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
534                .map(|x| x.expr.clone()),
535            const_moving_finalize_modify: get_impl_const_by_name(
536                &item_impl_snapshot,
537                "MOVING_FINALIZE_MODIFY",
538            )
539            .map(|x| x.expr.clone()),
540            const_initial_condition: get_impl_const_by_name(
541                &item_impl_snapshot,
542                "INITIAL_CONDITION",
543            )
544            .and_then(get_const_litstr),
545            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
546                .and_then(get_const_litbool)
547                .unwrap_or(false),
548            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
549                .and_then(get_const_litstr),
550            const_moving_intial_condition: get_impl_const_by_name(
551                &item_impl_snapshot,
552                "MOVING_INITIAL_CONDITION",
553            )
554            .and_then(get_const_litstr),
555            fn_state: fn_state_name,
556            fn_finalize: fn_finalize_name,
557            fn_combine: fn_combine_name,
558            fn_serial: fn_serial_name,
559            fn_deserial: fn_deserial_name,
560            fn_moving_state: fn_moving_state_name,
561            fn_moving_state_inverse: fn_moving_state_inverse_name,
562            fn_moving_finalize: fn_moving_finalize_name,
563            hypothetical: if let Some(value) =
564                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
565            {
566                match &value.expr {
567                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
568                        syn::Lit::Bool(lit) => lit.value,
569                        _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
570                    },
571                    _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
572                }
573            } else {
574                false
575            },
576            to_sql_config,
577        })
578    }
579
580    fn entity_tokens(&self) -> ItemFn {
581        let target_path = get_target_path(&self.item_impl)
582            .expect("Expected constructed PgAggregate to have target path.");
583        let target_ident = get_target_ident(&target_path)
584            .expect("Expected constructed PgAggregate to have target ident.");
585        let snake_case_target_ident =
586            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
587        let sql_graph_entity_fn_name = syn::Ident::new(
588            &format!("__pgx_internals_aggregate_{}", snake_case_target_ident),
589            target_ident.span(),
590        );
591
592        let name = &self.name;
593        let type_args_iter = &self.type_args.entity_tokens();
594        let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
595
596        let type_moving_state_entity_tokens =
597            self.type_moving_state.clone().map(|v| v.entity_tokens());
598        let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
599        let type_stype = self.type_stype.entity_tokens();
600        let const_ordered_set = self.const_ordered_set;
601        let const_parallel_iter = self.const_parallel.iter();
602        let const_finalize_modify_iter = self.const_finalize_modify.iter();
603        let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
604        let const_initial_condition_iter = self.const_initial_condition.iter();
605        let const_sort_operator_iter = self.const_sort_operator.iter();
606        let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
607        let hypothetical = self.hypothetical;
608        let fn_state = &self.fn_state;
609        let fn_finalize_iter = self.fn_finalize.iter();
610        let fn_combine_iter = self.fn_combine.iter();
611        let fn_serial_iter = self.fn_serial.iter();
612        let fn_deserial_iter = self.fn_deserial.iter();
613        let fn_moving_state_iter = self.fn_moving_state.iter();
614        let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
615        let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
616        let to_sql_config = &self.to_sql_config;
617
618        let entity_item_fn: ItemFn = parse_quote! {
619            #[no_mangle]
620            #[doc(hidden)]
621            pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::utils::sql_entity_graph::SqlGraphEntity {
622                let submission = ::pgx::utils::sql_entity_graph::PgAggregateEntity {
623                    full_path: ::core::any::type_name::<#target_ident>(),
624                    module_path: module_path!(),
625                    file: file!(),
626                    line: line!(),
627                    name: #name,
628                    ordered_set: #const_ordered_set,
629                    ty_id: ::core::any::TypeId::of::<#target_ident>(),
630                    args: #type_args_iter,
631                    direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
632                    stype: #type_stype,
633                    sfunc: stringify!(#fn_state),
634                    combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
635                    finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
636                    finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
637                    initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
638                    serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
639                    deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
640                    msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
641                    minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
642                    mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
643                    mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
644                    mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
645                    minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
646                    sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
647                    parallel: None #( .unwrap_or(#const_parallel_iter) )*,
648                    hypothetical: #hypothetical,
649                    to_sql_config: #to_sql_config,
650                };
651                ::pgx::utils::sql_entity_graph::SqlGraphEntity::Aggregate(submission)
652            }
653        };
654        entity_item_fn
655    }
656}
657
658impl Parse for PgAggregate {
659    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
660        Self::new(input.parse()?)
661    }
662}
663
664impl ToTokens for PgAggregate {
665    fn to_tokens(&self, tokens: &mut TokenStream2) {
666        let entity_fn = self.entity_tokens();
667        let impl_item = &self.item_impl;
668        let pg_externs = self.pg_externs.iter();
669        let inv = quote! {
670            #impl_item
671
672            #(#pg_externs)*
673
674            #entity_fn
675        };
676        tokens.append_all(inv);
677    }
678}
679
680fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
681    let last = path.segments.last().ok_or_else(|| {
682        syn::Error::new(
683            path.span(),
684            "`#[pg_aggregate]` only works with types whose path have a final segment.",
685        )
686    })?;
687    Ok(last.ident.clone())
688}
689
690fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
691    let target_ident = match &*item_impl.self_ty {
692        syn::Type::Path(ref type_path) => {
693            let last_segment = type_path.path.segments.last().ok_or_else(|| {
694                syn::Error::new(
695                    type_path.span(),
696                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
697                )
698            })?;
699            if last_segment.ident.to_string() == "PgVarlena" {
700                match &last_segment.arguments {
701                    syn::PathArguments::AngleBracketed(angled) => {
702                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
703                            type_path.span(),
704                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
705                        ))?;
706                        match &first {
707                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
708                            _ => return Err(syn::Error::new(
709                                type_path.span(),
710                                "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
711                            )),
712                        }
713                    },
714                    _ => return Err(syn::Error::new(
715                        type_path.span(),
716                        "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
717                    )),
718                }
719            } else {
720                type_path.path.clone()
721            }
722        }
723        something_else => {
724            return Err(syn::Error::new(
725                something_else.span(),
726                "`#[pg_aggregate]` only works with types.",
727            ))
728        }
729    };
730    Ok(target_ident)
731}
732
733fn pg_extern_attr(item: &ImplItemMethod) -> syn::Attribute {
734    let mut found = None;
735    for attr in item.attrs.iter() {
736        match attr.path.segments.last() {
737            Some(segment) if segment.ident.to_string() == "pgx" => {
738                found = Some(attr.tokens.clone());
739                break;
740            }
741            _ => (),
742        };
743    }
744    match found {
745        Some(args) => parse_quote! {
746            #[pg_extern #args]
747        },
748        None => parse_quote! {
749            #[pg_extern]
750        },
751    }
752}
753
754fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
755    let mut needle = None;
756    for impl_item in item_impl.items.iter() {
757        match impl_item {
758            syn::ImplItem::Type(impl_item_type) => {
759                let ident_string = impl_item_type.ident.to_string();
760                if ident_string == name {
761                    needle = Some(impl_item_type);
762                }
763            }
764            _ => (),
765        }
766    }
767    needle
768}
769
770fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemMethod> {
771    let mut needle = None;
772    for impl_item in item_impl.items.iter() {
773        match impl_item {
774            syn::ImplItem::Method(impl_item_method) => {
775                let ident_string = impl_item_method.sig.ident.to_string();
776                if ident_string == name {
777                    needle = Some(impl_item_method);
778                }
779            }
780            _ => (),
781        }
782    }
783    needle
784}
785
786fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
787    let mut needle = None;
788    for impl_item in item_impl.items.iter() {
789        match impl_item {
790            syn::ImplItem::Const(impl_item_const) => {
791                let ident_string = impl_item_const.ident.to_string();
792                if ident_string == name {
793                    needle = Some(impl_item_const);
794                }
795            }
796            _ => (),
797        }
798    }
799    needle
800}
801
802fn get_const_litbool<'a>(item: &'a ImplItemConst) -> Option<bool> {
803    match &item.expr {
804        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
805            syn::Lit::Bool(lit) => Some(lit.value()),
806            _ => None,
807        },
808        _ => None,
809    }
810}
811
812fn get_const_litstr<'a>(item: &'a ImplItemConst) -> Option<String> {
813    match &item.expr {
814        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
815            syn::Lit::Str(lit) => Some(lit.value()),
816            _ => None,
817        },
818        syn::Expr::Call(expr_call) => match &*expr_call.func {
819            syn::Expr::Path(expr_path) => {
820                if expr_path.path.segments.last()?.ident.to_string() == "Some" {
821                    match expr_call.args.first()? {
822                        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
823                            syn::Lit::Str(lit) => Some(lit.value()),
824                            _ => None,
825                        },
826                        _ => None,
827                    }
828                } else {
829                    None
830                }
831            }
832            _ => None,
833        },
834        _ => panic!("Got {:?}", item.expr),
835    }
836}
837
838fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
839    match ty {
840        Type::Path(ref mut ty_path) => {
841            for segment in ty_path.path.segments.iter_mut() {
842                if segment.ident.to_string() == "Self" {
843                    segment.ident = target.clone()
844                }
845                use syn::{GenericArgument, PathArguments};
846                match segment.arguments {
847                    PathArguments::AngleBracketed(ref mut angle_args) => {
848                        for arg in angle_args.args.iter_mut() {
849                            match arg {
850                                GenericArgument::Type(inner_ty) => {
851                                    remap_self_to_target(inner_ty, target)
852                                }
853                                _ => (),
854                            }
855                        }
856                    }
857                    PathArguments::Parenthesized(_) => (),
858                    PathArguments::None => (),
859                }
860            }
861        }
862        _ => (),
863    }
864}
865
866fn get_pgx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
867    match &ty {
868        syn::Type::Macro(ty_macro) => {
869            let mut found_pgx = false;
870            let mut found_attr = false;
871            // We don't actually have type resolution here, this is a "Best guess".
872            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
873                match segment.ident.to_string().as_str() {
874                    "pgx" if idx == 0 => found_pgx = true,
875                    attr if attr == attr_name.as_ref() => found_attr = true,
876                    _ => (),
877                }
878            }
879            if (ty_macro.mac.path.segments.len() == 1 && found_attr) || (found_pgx && found_attr) {
880                Some(ty_macro.mac.tokens.clone())
881            } else {
882                None
883            }
884        }
885        _ => None,
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::PgAggregate;
892    use eyre::Result;
893    use syn::{parse_quote, ItemImpl};
894
895    #[test]
896    fn agg_required_only() -> Result<()> {
897        let tokens: ItemImpl = parse_quote! {
898            #[pg_aggregate]
899            impl Aggregate for DemoAgg {
900                type State = PgVarlena<Self>;
901                type Args = i32;
902                const NAME: &'static str = "DEMO";
903
904                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
905                    todo!()
906                }
907            }
908        };
909        // It should not error, as it's valid.
910        let agg = PgAggregate::new(tokens);
911        assert!(agg.is_ok());
912        // It should create 1 extern, the state.
913        let agg = agg.unwrap();
914        assert_eq!(agg.pg_externs.len(), 1);
915        // That extern should be named specifically:
916        let extern_fn = &agg.pg_externs[0];
917        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
918        // It should be possible to generate entity tokens.
919        let _ = agg.entity_tokens();
920        Ok(())
921    }
922
923    #[test]
924    fn agg_all_options() -> Result<()> {
925        let tokens: ItemImpl = parse_quote! {
926            #[pg_aggregate]
927            impl Aggregate for DemoAgg {
928                type State = PgVarlena<Self>;
929                type Args = i32;
930                type OrderBy = i32;
931                type MovingState = i32;
932
933                const NAME: &'static str = "DEMO";
934
935                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
936                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
937                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
938                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
939                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
940                const HYPOTHETICAL: bool = true;
941
942                fn state(current: Self::State, v: Self::Args) -> Self::State {
943                    todo!()
944                }
945
946                fn finalize(current: Self::State) -> Self::Finalize {
947                    todo!()
948                }
949
950                fn combine(current: Self::State, _other: Self::State) -> Self::State {
951                    todo!()
952                }
953
954                fn serial(current: Self::State) -> Vec<u8> {
955                    todo!()
956                }
957
958                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
959                    todo!()
960                }
961
962                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
963                    todo!()
964                }
965
966                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
967                    todo!()
968                }
969
970                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
971                    todo!()
972                }
973            }
974        };
975        // It should not error, as it's valid.
976        let agg = PgAggregate::new(tokens);
977        assert!(agg.is_ok());
978        // It should create 8 externs!
979        let agg = agg.unwrap();
980        assert_eq!(agg.pg_externs.len(), 8);
981        // That extern should be named specifically:
982        let extern_fn = &agg.pg_externs[0];
983        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
984        // It should be possible to generate entity tokens.
985        let _ = agg.entity_tokens();
986        Ok(())
987    }
988
989    #[test]
990    fn agg_missing_required() -> Result<()> {
991        // This is not valid as it is missing required types/consts.
992        let tokens: ItemImpl = parse_quote! {
993            #[pg_aggregate]
994            impl Aggregate for IntegerAvgState {
995            }
996        };
997        let agg = PgAggregate::new(tokens);
998        assert!(agg.is_err());
999        Ok(())
1000    }
1001}