pgx_sql_entity_graph/aggregate/
mod.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`#[pg_aggregate]` related macro expansion for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16
17*/
18mod aggregate_type;
19pub(crate) mod entity;
20mod options;
21
22pub use aggregate_type::{AggregateType, AggregateTypeList};
23pub use options::{FinalizeModify, ParallelOption};
24
25use crate::enrich::CodeEnrichment;
26use crate::enrich::ToEntityGraphTokens;
27use crate::enrich::ToRustCodeTokens;
28use convert_case::{Case, Casing};
29use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
30use quote::quote;
31use syn::parse::{Parse, ParseStream};
32use syn::spanned::Spanned;
33use syn::{
34    parse_quote, Expr, ImplItemConst, ImplItemMethod, ImplItemType, ItemFn, ItemImpl, Path, Type,
35};
36
37use crate::ToSqlConfig;
38
39use super::UsedType;
40
41// We support only 32 tuples...
42const ARG_NAMES: [&str; 32] = [
43    "arg_one",
44    "arg_two",
45    "arg_three",
46    "arg_four",
47    "arg_five",
48    "arg_six",
49    "arg_seven",
50    "arg_eight",
51    "arg_nine",
52    "arg_ten",
53    "arg_eleven",
54    "arg_twelve",
55    "arg_thirteen",
56    "arg_fourteen",
57    "arg_fifteen",
58    "arg_sixteen",
59    "arg_seventeen",
60    "arg_eighteen",
61    "arg_nineteen",
62    "arg_twenty",
63    "arg_twenty_one",
64    "arg_twenty_two",
65    "arg_twenty_three",
66    "arg_twenty_four",
67    "arg_twenty_five",
68    "arg_twenty_six",
69    "arg_twenty_seven",
70    "arg_twenty_eight",
71    "arg_twenty_nine",
72    "arg_thirty",
73    "arg_thirty_one",
74    "arg_thirty_two",
75];
76
77/** A parsed `#[pg_aggregate]` item.
78*/
79#[derive(Debug, Clone)]
80pub struct PgAggregate {
81    item_impl: ItemImpl,
82    name: Expr,
83    target_ident: Ident,
84    pg_externs: Vec<ItemFn>,
85    // Note these should not be considered *writable*, they're snapshots from construction.
86    type_args: AggregateTypeList,
87    type_ordered_set_args: Option<AggregateTypeList>,
88    type_moving_state: Option<UsedType>,
89    type_stype: AggregateType,
90    const_ordered_set: bool,
91    const_parallel: Option<syn::Expr>,
92    const_finalize_modify: Option<syn::Expr>,
93    const_moving_finalize_modify: Option<syn::Expr>,
94    const_initial_condition: Option<String>,
95    const_sort_operator: Option<String>,
96    const_moving_intial_condition: Option<String>,
97    fn_state: Ident,
98    fn_finalize: Option<Ident>,
99    fn_combine: Option<Ident>,
100    fn_serial: Option<Ident>,
101    fn_deserial: Option<Ident>,
102    fn_moving_state: Option<Ident>,
103    fn_moving_state_inverse: Option<Ident>,
104    fn_moving_finalize: Option<Ident>,
105    hypothetical: bool,
106    to_sql_config: ToSqlConfig,
107}
108
109impl PgAggregate {
110    pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
111        let to_sql_config =
112            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
113        let target_path = get_target_path(&item_impl)?;
114        let target_ident = get_target_ident(&target_path)?;
115
116        let snake_case_target_ident =
117            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
118        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
119
120        let mut pg_externs = Vec::default();
121        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
122        // and mutate the actual one.
123        let item_impl_snapshot = item_impl.clone();
124
125        if let Some((_, ref path, _)) = item_impl.trait_ {
126            // TODO: Consider checking the path if there is more than one segment to make sure it's pgx.
127            if let Some(last) = path.segments.last() {
128                if last.ident.to_string() != "Aggregate" {
129                    return Err(syn::Error::new(
130                        last.ident.span(),
131                        "`#[pg_aggregate]` only works with the `Aggregate` trait.",
132                    ));
133                }
134            }
135        }
136
137        let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
138            Some(item_const) => match &item_const.expr {
139                syn::Expr::Lit(ref expr) => {
140                    if let syn::Lit::Str(_) = &expr.lit {
141                        item_const.expr.clone()
142                    } else {
143                        return Err(syn::Error::new(
144                            expr.span(),
145                            "`NAME` must be a `&'static str` for Aggregate implementations.",
146                        ));
147                    }
148                }
149                e => {
150                    return Err(syn::Error::new(
151                        e.span(),
152                        "`NAME` must be a `&'static str` for Aggregate implementations.",
153                    ));
154                }
155            },
156            None => {
157                item_impl.items.push(parse_quote! {
158                    const NAME: &'static str = stringify!(Self);
159                });
160                parse_quote! {
161                    stringify!(#target_ident)
162                }
163            }
164        };
165
166        // `State` is an optional value, we default to `Self`.
167        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
168        let _type_state_value = type_state.map(|v| v.ty.clone());
169
170        let type_state_without_self = if let Some(inner) = type_state {
171            let mut remapped = inner.ty.clone();
172            remap_self_to_target(&mut remapped, &target_ident);
173            remapped
174        } else {
175            item_impl.items.push(parse_quote! {
176                type State = Self;
177            });
178            let mut remapped = parse_quote!(Self);
179            remap_self_to_target(&mut remapped, &target_ident);
180            remapped
181        };
182        let type_stype = AggregateType {
183            used_ty: UsedType::new(type_state_without_self.clone())?,
184            name: Some("state".into()),
185        };
186
187        // `MovingState` is an optional value, we default to nothing.
188        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
189        let type_moving_state;
190        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
191            type_moving_state = impl_type_moving_state.ty.clone();
192            Some(UsedType::new(type_moving_state.clone())?)
193        } else {
194            item_impl.items.push(parse_quote! {
195                type MovingState = ();
196            });
197            type_moving_state = parse_quote! { () };
198            None
199        };
200
201        // `OrderBy` is an optional value, we default to nothing.
202        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
203        let type_ordered_set_args_value =
204            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
205        if type_ordered_set_args.is_none() {
206            item_impl.items.push(parse_quote! {
207                type OrderedSetArgs = ();
208            })
209        }
210        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
211            type_ordered_set_args_value
212        {
213            let direct_args = order_by_direct_args
214                .found
215                .iter()
216                .map(|x| {
217                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
218                })
219                .collect::<Vec<_>>();
220            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
221                .iter()
222                .zip(direct_args.iter())
223                .map(|(default_name, (custom_name, _ty, _orig))| {
224                    Ident::new(
225                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
226                        Span::mixed_site(),
227                    )
228                })
229                .collect::<Vec<_>>();
230            let direct_args_with_names = direct_args
231                .iter()
232                .zip(direct_arg_names.iter())
233                .map(|(arg, name)| {
234                    let arg_ty = &arg.2; // original_type
235                    parse_quote! {
236                        #name: #arg_ty
237                    }
238                })
239                .collect::<Vec<syn::FnArg>>();
240            (direct_args_with_names, direct_arg_names)
241        } else {
242            (Vec::default(), Vec::default())
243        };
244
245        // `Args` is an optional value, we default to nothing.
246        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
247            syn::Error::new(
248                item_impl_snapshot.span(),
249                "`#[pg_aggregate]` requires the `Args` type defined.",
250            )
251        })?;
252        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
253        let args = type_args_value
254            .found
255            .iter()
256            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
257            .collect::<Vec<_>>();
258        let arg_names = ARG_NAMES[0..args.len()]
259            .iter()
260            .zip(args.iter())
261            .map(|(default_name, (custom_name, ty))| {
262                Ident::new(
263                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
264                    ty.span(),
265                )
266            })
267            .collect::<Vec<_>>();
268        let args_with_names = args
269            .iter()
270            .zip(arg_names.iter())
271            .map(|(arg, name)| {
272                let arg_ty = &arg.1;
273                quote! {
274                    #name: #arg_ty
275                }
276            })
277            .collect::<Vec<_>>();
278
279        // `Finalize` is an optional value, we default to nothing.
280        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
281        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
282            type_finalize.ty.clone()
283        } else {
284            item_impl.items.push(parse_quote! {
285                type Finalize = ();
286            });
287            parse_quote! { () }
288        };
289
290        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
291
292        let fn_state_name = if let Some(found) = fn_state {
293            let fn_name =
294                Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
295            let pg_extern_attr = pg_extern_attr(found);
296
297            pg_externs.push(parse_quote! {
298                #[allow(non_snake_case, clippy::too_many_arguments)]
299                #pg_extern_attr
300                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
301                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
302                        fcinfo,
303                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
304                    )
305                }
306            });
307            fn_name
308        } else {
309            return Err(syn::Error::new(
310                item_impl.span(),
311                "Aggregate implementation must include state function.",
312            ));
313        };
314
315        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
316        let fn_combine_name = if let Some(found) = fn_combine {
317            let fn_name =
318                Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
319            let pg_extern_attr = pg_extern_attr(found);
320            pg_externs.push(parse_quote! {
321                #[allow(non_snake_case, clippy::too_many_arguments)]
322                #pg_extern_attr
323                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
324                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
325                        fcinfo,
326                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::combine(this, v, fcinfo)
327                    )
328                }
329            });
330            Some(fn_name)
331        } else {
332            item_impl.items.push(parse_quote! {
333                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
334                    unimplemented!("Call to combine on an aggregate which does not support it.")
335                }
336            });
337            None
338        };
339
340        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
341        let fn_finalize_name = if let Some(found) = fn_finalize {
342            let fn_name = Ident::new(
343                &format!("{}_finalize", snake_case_target_ident),
344                found.sig.ident.span(),
345            );
346            let pg_extern_attr = pg_extern_attr(found);
347
348            if direct_args_with_names.len() > 0 {
349                pg_externs.push(parse_quote! {
350                    #[allow(non_snake_case, clippy::too_many_arguments)]
351                    #pg_extern_attr
352                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
353                        <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
354                            fcinfo,
355                            move |_context| <#target_path as ::pgx::aggregate::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
356                        )
357                    }
358                });
359            } else {
360                pg_externs.push(parse_quote! {
361                    #[allow(non_snake_case, clippy::too_many_arguments)]
362                    #pg_extern_attr
363                    fn #fn_name(this: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
364                        <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
365                            fcinfo,
366                            move |_context| <#target_path as ::pgx::aggregate::Aggregate>::finalize(this, (), fcinfo)
367                        )
368                    }
369                });
370            };
371            Some(fn_name)
372        } else {
373            item_impl.items.push(parse_quote! {
374                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
375                    unimplemented!("Call to finalize on an aggregate which does not support it.")
376                }
377            });
378            None
379        };
380
381        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
382        let fn_serial_name = if let Some(found) = fn_serial {
383            let fn_name =
384                Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
385            let pg_extern_attr = pg_extern_attr(found);
386            pg_externs.push(parse_quote! {
387                #[allow(non_snake_case, clippy::too_many_arguments)]
388                #pg_extern_attr
389                fn #fn_name(this: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
390                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
391                        fcinfo,
392                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::serial(this, fcinfo)
393                    )
394                }
395            });
396            Some(fn_name)
397        } else {
398            item_impl.items.push(parse_quote! {
399                fn serial(current: #type_state_without_self, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
400                    unimplemented!("Call to serial on an aggregate which does not support it.")
401                }
402            });
403            None
404        };
405
406        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
407        let fn_deserial_name = if let Some(found) = fn_deserial {
408            let fn_name = Ident::new(
409                &format!("{}_deserial", snake_case_target_ident),
410                found.sig.ident.span(),
411            );
412            let pg_extern_attr = pg_extern_attr(found);
413            pg_externs.push(parse_quote! {
414                #[allow(non_snake_case, clippy::too_many_arguments)]
415                #pg_extern_attr
416                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pgbox::PgBox<#type_state_without_self> {
417                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
418                        fcinfo,
419                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::deserial(this, buf, internal, fcinfo)
420                    )
421                }
422            });
423            Some(fn_name)
424        } else {
425            item_impl.items.push(parse_quote! {
426                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pgbox::PgBox<#type_state_without_self> {
427                    unimplemented!("Call to deserial on an aggregate which does not support it.")
428                }
429            });
430            None
431        };
432
433        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
434        let fn_moving_state_name = if let Some(found) = fn_moving_state {
435            let fn_name = Ident::new(
436                &format!("{}_moving_state", snake_case_target_ident),
437                found.sig.ident.span(),
438            );
439            let pg_extern_attr = pg_extern_attr(found);
440
441            pg_externs.push(parse_quote! {
442                #[allow(non_snake_case, clippy::too_many_arguments)]
443                #pg_extern_attr
444                fn #fn_name(
445                    mstate: #type_moving_state,
446                    #(#args_with_names),*,
447                    fcinfo: ::pgx::pg_sys::FunctionCallInfo,
448                ) -> #type_moving_state {
449                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
450                        fcinfo,
451                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
452                    )
453                }
454            });
455            Some(fn_name)
456        } else {
457            item_impl.items.push(parse_quote! {
458                fn moving_state(
459                    _mstate: <#target_path as ::pgx::aggregate::Aggregate>::MovingState,
460                    _v: Self::Args,
461                    _fcinfo: ::pgx::pg_sys::FunctionCallInfo,
462                ) -> <#target_path as ::pgx::aggregate::Aggregate>::MovingState {
463                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
464                }
465            });
466            None
467        };
468
469        let fn_moving_state_inverse =
470            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
471        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
472            let fn_name = Ident::new(
473                &format!("{}_moving_state_inverse", snake_case_target_ident),
474                found.sig.ident.span(),
475            );
476            let pg_extern_attr = pg_extern_attr(found);
477            pg_externs.push(parse_quote! {
478                #[allow(non_snake_case, clippy::too_many_arguments)]
479                #pg_extern_attr
480                fn #fn_name(
481                    mstate: #type_moving_state,
482                    #(#args_with_names),*,
483                    fcinfo: ::pgx::pg_sys::FunctionCallInfo,
484                ) -> #type_moving_state {
485                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
486                        fcinfo,
487                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
488                    )
489                }
490            });
491            Some(fn_name)
492        } else {
493            item_impl.items.push(parse_quote! {
494                fn moving_state_inverse(
495                    _mstate: #type_moving_state,
496                    _v: Self::Args,
497                    _fcinfo: ::pgx::pg_sys::FunctionCallInfo,
498                ) -> #type_moving_state {
499                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
500                }
501            });
502            None
503        };
504
505        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
506        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
507            let fn_name = Ident::new(
508                &format!("{}_moving_finalize", snake_case_target_ident),
509                found.sig.ident.span(),
510            );
511            let pg_extern_attr = pg_extern_attr(found);
512            let maybe_comma: Option<syn::Token![,]> =
513                if direct_args_with_names.len() > 0 { Some(parse_quote! {,}) } else { None };
514
515            pg_externs.push(parse_quote! {
516                #[allow(non_snake_case, clippy::too_many_arguments)]
517                #pg_extern_attr
518                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
519                    <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
520                        fcinfo,
521                        move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
522                    )
523                }
524            });
525            Some(fn_name)
526        } else {
527            item_impl.items.push(parse_quote! {
528                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
529                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
530                }
531            });
532            None
533        };
534
535        Ok(CodeEnrichment(Self {
536            item_impl,
537            target_ident,
538            pg_externs,
539            name,
540            type_args: type_args_value,
541            type_ordered_set_args: type_ordered_set_args_value,
542            type_moving_state: type_moving_state_value,
543            type_stype: type_stype,
544            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
545                .map(|x| x.expr.clone()),
546            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
547                .map(|x| x.expr.clone()),
548            const_moving_finalize_modify: get_impl_const_by_name(
549                &item_impl_snapshot,
550                "MOVING_FINALIZE_MODIFY",
551            )
552            .map(|x| x.expr.clone()),
553            const_initial_condition: get_impl_const_by_name(
554                &item_impl_snapshot,
555                "INITIAL_CONDITION",
556            )
557            .and_then(|e| get_const_litstr(e).transpose())
558            .transpose()?,
559            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
560                .and_then(get_const_litbool)
561                .unwrap_or(false),
562            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
563                .and_then(|e| get_const_litstr(e).transpose())
564                .transpose()?,
565            const_moving_intial_condition: get_impl_const_by_name(
566                &item_impl_snapshot,
567                "MOVING_INITIAL_CONDITION",
568            )
569            .and_then(|e| get_const_litstr(e).transpose())
570            .transpose()?,
571            fn_state: fn_state_name,
572            fn_finalize: fn_finalize_name,
573            fn_combine: fn_combine_name,
574            fn_serial: fn_serial_name,
575            fn_deserial: fn_deserial_name,
576            fn_moving_state: fn_moving_state_name,
577            fn_moving_state_inverse: fn_moving_state_inverse_name,
578            fn_moving_finalize: fn_moving_finalize_name,
579            hypothetical: if let Some(value) =
580                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
581            {
582                match &value.expr {
583                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
584                        syn::Lit::Bool(lit) => lit.value,
585                        _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
586                    },
587                    _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
588                }
589            } else {
590                false
591            },
592            to_sql_config,
593        }))
594    }
595}
596
597impl ToEntityGraphTokens for PgAggregate {
598    fn to_entity_graph_tokens(&self) -> TokenStream2 {
599        let target_ident = &self.target_ident;
600        let snake_case_target_ident =
601            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
602        let sql_graph_entity_fn_name = syn::Ident::new(
603            &format!("__pgx_internals_aggregate_{}", snake_case_target_ident),
604            target_ident.span(),
605        );
606
607        let name = &self.name;
608        let type_args_iter = &self.type_args.entity_tokens();
609        let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
610
611        let type_moving_state_entity_tokens =
612            self.type_moving_state.clone().map(|v| v.entity_tokens());
613        let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
614        let type_stype = self.type_stype.entity_tokens();
615        let const_ordered_set = self.const_ordered_set;
616        let const_parallel_iter = self.const_parallel.iter();
617        let const_finalize_modify_iter = self.const_finalize_modify.iter();
618        let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
619        let const_initial_condition_iter = self.const_initial_condition.iter();
620        let const_sort_operator_iter = self.const_sort_operator.iter();
621        let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
622        let hypothetical = self.hypothetical;
623        let fn_state = &self.fn_state;
624        let fn_finalize_iter = self.fn_finalize.iter();
625        let fn_combine_iter = self.fn_combine.iter();
626        let fn_serial_iter = self.fn_serial.iter();
627        let fn_deserial_iter = self.fn_deserial.iter();
628        let fn_moving_state_iter = self.fn_moving_state.iter();
629        let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
630        let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
631        let to_sql_config = &self.to_sql_config;
632
633        quote! {
634            #[no_mangle]
635            #[doc(hidden)]
636            pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::pgx_sql_entity_graph::SqlGraphEntity {
637                let submission = ::pgx::pgx_sql_entity_graph::PgAggregateEntity {
638                    full_path: ::core::any::type_name::<#target_ident>(),
639                    module_path: module_path!(),
640                    file: file!(),
641                    line: line!(),
642                    name: #name,
643                    ordered_set: #const_ordered_set,
644                    ty_id: ::core::any::TypeId::of::<#target_ident>(),
645                    args: #type_args_iter,
646                    direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
647                    stype: #type_stype,
648                    sfunc: stringify!(#fn_state),
649                    combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
650                    finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
651                    finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
652                    initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
653                    serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
654                    deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
655                    msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
656                    minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
657                    mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
658                    mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
659                    mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
660                    minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
661                    sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
662                    parallel: None #( .unwrap_or(#const_parallel_iter) )*,
663                    hypothetical: #hypothetical,
664                    to_sql_config: #to_sql_config,
665                };
666                ::pgx::pgx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
667            }
668        }
669    }
670}
671
672impl ToRustCodeTokens for PgAggregate {
673    fn to_rust_code_tokens(&self) -> TokenStream2 {
674        let impl_item = &self.item_impl;
675        let pg_externs = self.pg_externs.iter();
676        quote! {
677            #impl_item
678            #(#pg_externs)*
679        }
680    }
681}
682
683impl Parse for CodeEnrichment<PgAggregate> {
684    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
685        PgAggregate::new(input.parse()?)
686    }
687}
688
689fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
690    let last = path.segments.last().ok_or_else(|| {
691        syn::Error::new(
692            path.span(),
693            "`#[pg_aggregate]` only works with types whose path have a final segment.",
694        )
695    })?;
696    Ok(last.ident.clone())
697}
698
699fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
700    let target_ident = match &*item_impl.self_ty {
701        syn::Type::Path(ref type_path) => {
702            let last_segment = type_path.path.segments.last().ok_or_else(|| {
703                syn::Error::new(
704                    type_path.span(),
705                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
706                )
707            })?;
708            if last_segment.ident.to_string() == "PgVarlena" {
709                match &last_segment.arguments {
710                    syn::PathArguments::AngleBracketed(angled) => {
711                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
712                            type_path.span(),
713                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
714                        ))?;
715                        match &first {
716                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
717                            _ => return Err(syn::Error::new(
718                                type_path.span(),
719                                "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
720                            )),
721                        }
722                    },
723                    _ => return Err(syn::Error::new(
724                        type_path.span(),
725                        "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
726                    )),
727                }
728            } else {
729                type_path.path.clone()
730            }
731        }
732        something_else => {
733            return Err(syn::Error::new(
734                something_else.span(),
735                "`#[pg_aggregate]` only works with types.",
736            ))
737        }
738    };
739    Ok(target_ident)
740}
741
742fn pg_extern_attr(item: &ImplItemMethod) -> syn::Attribute {
743    let mut found = None;
744    for attr in item.attrs.iter() {
745        match attr.path.segments.last() {
746            Some(segment) if segment.ident.to_string() == "pgx" => {
747                found = Some(attr.tokens.clone());
748                break;
749            }
750            _ => (),
751        };
752    }
753    match found {
754        Some(args) => parse_quote! {
755            #[::pgx::pg_extern #args]
756        },
757        None => parse_quote! {
758            #[::pgx::pg_extern]
759        },
760    }
761}
762
763fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
764    let mut needle = None;
765    for impl_item in item_impl.items.iter() {
766        match impl_item {
767            syn::ImplItem::Type(impl_item_type) => {
768                let ident_string = impl_item_type.ident.to_string();
769                if ident_string == name {
770                    needle = Some(impl_item_type);
771                }
772            }
773            _ => (),
774        }
775    }
776    needle
777}
778
779fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemMethod> {
780    let mut needle = None;
781    for impl_item in item_impl.items.iter() {
782        match impl_item {
783            syn::ImplItem::Method(impl_item_method) => {
784                let ident_string = impl_item_method.sig.ident.to_string();
785                if ident_string == name {
786                    needle = Some(impl_item_method);
787                }
788            }
789            _ => (),
790        }
791    }
792    needle
793}
794
795fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
796    let mut needle = None;
797    for impl_item in item_impl.items.iter() {
798        match impl_item {
799            syn::ImplItem::Const(impl_item_const) => {
800                let ident_string = impl_item_const.ident.to_string();
801                if ident_string == name {
802                    needle = Some(impl_item_const);
803                }
804            }
805            _ => (),
806        }
807    }
808    needle
809}
810
811fn get_const_litbool<'a>(item: &'a ImplItemConst) -> Option<bool> {
812    match &item.expr {
813        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
814            syn::Lit::Bool(lit) => Some(lit.value()),
815            _ => None,
816        },
817        _ => None,
818    }
819}
820
821fn get_const_litstr<'a>(item: &'a ImplItemConst) -> syn::Result<Option<String>> {
822    match &item.expr {
823        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
824            syn::Lit::Str(lit) => Ok(Some(lit.value())),
825            _ => Ok(None),
826        },
827        syn::Expr::Call(expr_call) => match &*expr_call.func {
828            syn::Expr::Path(expr_path) => {
829                let Some(last) = expr_path.path.segments.last() else {
830                    return Ok(None);
831                };
832                if last.ident.to_string() == "Some" {
833                    match expr_call.args.first() {
834                        Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
835                            syn::Lit::Str(lit) => Ok(Some(lit.value())),
836                            _ => Ok(None),
837                        },
838                        _ => Ok(None),
839                    }
840                } else {
841                    Ok(None)
842                }
843            }
844            _ => Ok(None),
845        },
846        ex => Err(syn::Error::new(ex.span(), "")),
847    }
848}
849
850fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
851    match ty {
852        Type::Path(ref mut ty_path) => {
853            for segment in ty_path.path.segments.iter_mut() {
854                if segment.ident.to_string() == "Self" {
855                    segment.ident = target.clone()
856                }
857                use syn::{GenericArgument, PathArguments};
858                match segment.arguments {
859                    PathArguments::AngleBracketed(ref mut angle_args) => {
860                        for arg in angle_args.args.iter_mut() {
861                            match arg {
862                                GenericArgument::Type(inner_ty) => {
863                                    remap_self_to_target(inner_ty, target)
864                                }
865                                _ => (),
866                            }
867                        }
868                    }
869                    PathArguments::Parenthesized(_) => (),
870                    PathArguments::None => (),
871                }
872            }
873        }
874        _ => (),
875    }
876}
877
878fn get_pgx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
879    match &ty {
880        syn::Type::Macro(ty_macro) => {
881            let mut found_pgx = false;
882            let mut found_attr = false;
883            // We don't actually have type resolution here, this is a "Best guess".
884            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
885                match segment.ident.to_string().as_str() {
886                    "pgx" if idx == 0 => found_pgx = true,
887                    attr if attr == attr_name.as_ref() => found_attr = true,
888                    _ => (),
889                }
890            }
891            if (ty_macro.mac.path.segments.len() == 1 && found_attr) || (found_pgx && found_attr) {
892                Some(ty_macro.mac.tokens.clone())
893            } else {
894                None
895            }
896        }
897        _ => None,
898    }
899}
900
901#[cfg(test)]
902mod tests {
903    use super::PgAggregate;
904    use eyre::Result;
905    use quote::ToTokens;
906    use syn::{parse_quote, ItemImpl};
907
908    #[test]
909    fn agg_required_only() -> Result<()> {
910        let tokens: ItemImpl = parse_quote! {
911            #[pg_aggregate]
912            impl Aggregate for DemoAgg {
913                type State = PgVarlena<Self>;
914                type Args = i32;
915                const NAME: &'static str = "DEMO";
916
917                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
918                    todo!()
919                }
920            }
921        };
922        // It should not error, as it's valid.
923        let agg = PgAggregate::new(tokens);
924        assert!(agg.is_ok());
925        // It should create 1 extern, the state.
926        let agg = agg.unwrap();
927        assert_eq!(agg.0.pg_externs.len(), 1);
928        // That extern should be named specifically:
929        let extern_fn = &agg.0.pg_externs[0];
930        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
931        // It should be possible to generate entity tokens.
932        let _ = agg.to_token_stream();
933        Ok(())
934    }
935
936    #[test]
937    fn agg_all_options() -> Result<()> {
938        let tokens: ItemImpl = parse_quote! {
939            #[pg_aggregate]
940            impl Aggregate for DemoAgg {
941                type State = PgVarlena<Self>;
942                type Args = i32;
943                type OrderBy = i32;
944                type MovingState = i32;
945
946                const NAME: &'static str = "DEMO";
947
948                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
949                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
950                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
951                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
952                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
953                const HYPOTHETICAL: bool = true;
954
955                fn state(current: Self::State, v: Self::Args) -> Self::State {
956                    todo!()
957                }
958
959                fn finalize(current: Self::State) -> Self::Finalize {
960                    todo!()
961                }
962
963                fn combine(current: Self::State, _other: Self::State) -> Self::State {
964                    todo!()
965                }
966
967                fn serial(current: Self::State) -> Vec<u8> {
968                    todo!()
969                }
970
971                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
972                    todo!()
973                }
974
975                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
976                    todo!()
977                }
978
979                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
980                    todo!()
981                }
982
983                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
984                    todo!()
985                }
986            }
987        };
988        // It should not error, as it's valid.
989        let agg = PgAggregate::new(tokens);
990        assert!(agg.is_ok());
991        // It should create 8 externs!
992        let agg = agg.unwrap();
993        assert_eq!(agg.0.pg_externs.len(), 8);
994        // That extern should be named specifically:
995        let extern_fn = &agg.0.pg_externs[0];
996        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
997        // It should be possible to generate entity tokens.
998        let _ = agg.to_token_stream();
999        Ok(())
1000    }
1001
1002    #[test]
1003    fn agg_missing_required() -> Result<()> {
1004        // This is not valid as it is missing required types/consts.
1005        let tokens: ItemImpl = parse_quote! {
1006            #[pg_aggregate]
1007            impl Aggregate for IntegerAvgState {
1008            }
1009        };
1010        let agg = PgAggregate::new(tokens);
1011        assert!(agg.is_err());
1012        Ok(())
1013    }
1014}