Skip to main content

pgrx_sql_entity_graph/aggregate/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25use syn::PathArguments;
26
27use crate::enrich::CodeEnrichment;
28use crate::enrich::ToEntityGraphTokens;
29use crate::enrich::ToRustCodeTokens;
30use convert_case::{Case, Casing};
31use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
32use quote::quote;
33use syn::parse::{Parse, ParseStream};
34use syn::punctuated::Punctuated;
35use syn::spanned::Spanned;
36use syn::{
37    Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type, parse_quote,
38};
39
40use crate::ToSqlConfig;
41
42use super::UsedType;
43
44// We support only 32 tuples...
45const ARG_NAMES: [&str; 32] = [
46    "arg_one",
47    "arg_two",
48    "arg_three",
49    "arg_four",
50    "arg_five",
51    "arg_six",
52    "arg_seven",
53    "arg_eight",
54    "arg_nine",
55    "arg_ten",
56    "arg_eleven",
57    "arg_twelve",
58    "arg_thirteen",
59    "arg_fourteen",
60    "arg_fifteen",
61    "arg_sixteen",
62    "arg_seventeen",
63    "arg_eighteen",
64    "arg_nineteen",
65    "arg_twenty",
66    "arg_twenty_one",
67    "arg_twenty_two",
68    "arg_twenty_three",
69    "arg_twenty_four",
70    "arg_twenty_five",
71    "arg_twenty_six",
72    "arg_twenty_seven",
73    "arg_twenty_eight",
74    "arg_twenty_nine",
75    "arg_thirty",
76    "arg_thirty_one",
77    "arg_thirty_two",
78];
79
80/** A parsed `#[pg_aggregate]` item.
81*/
82#[derive(Debug, Clone)]
83pub struct PgAggregate {
84    item_impl: ItemImpl,
85    name: Expr,
86    target_ident: Ident,
87    snake_case_target_ident: Ident,
88    pg_externs: Vec<ItemFn>,
89    // Note these should not be considered *writable*, they're snapshots from construction.
90    type_args: AggregateTypeList,
91    type_ordered_set_args: Option<AggregateTypeList>,
92    type_moving_state: Option<UsedType>,
93    type_stype: AggregateType,
94    const_ordered_set: bool,
95    const_parallel: Option<syn::Expr>,
96    const_finalize_modify: Option<syn::Expr>,
97    const_moving_finalize_modify: Option<syn::Expr>,
98    const_initial_condition: Option<String>,
99    const_sort_operator: Option<String>,
100    const_moving_intial_condition: Option<String>,
101    fn_state: Ident,
102    fn_finalize: Option<Ident>,
103    fn_combine: Option<Ident>,
104    fn_serial: Option<Ident>,
105    fn_deserial: Option<Ident>,
106    fn_moving_state: Option<Ident>,
107    fn_moving_state_inverse: Option<Ident>,
108    fn_moving_finalize: Option<Ident>,
109    hypothetical: bool,
110    to_sql_config: ToSqlConfig,
111}
112
113fn extract_generic_from_trait(item_impl: &ItemImpl) -> Result<&Type, syn::Error> {
114    let (_, path, _) = item_impl.trait_.as_ref().ok_or_else(|| {
115        syn::Error::new_spanned(
116            item_impl,
117            "`#[pg_aggregate]` can only be used on `impl` blocks for a trait.",
118        )
119    })?;
120
121    let last_segment = path
122        .segments
123        .last()
124        .ok_or_else(|| syn::Error::new_spanned(path, "Trait path is empty or malformed."))?;
125
126    if last_segment.ident != "Aggregate" {
127        return Err(syn::Error::new_spanned(
128            last_segment.ident.clone(),
129            "`#[pg_aggregate]` only works with the `Aggregate` trait.",
130        ));
131    }
132
133    let args = match &last_segment.arguments {
134        PathArguments::AngleBracketed(args) => args,
135        _ => {
136            return Err(syn::Error::new_spanned(
137                last_segment.ident.clone(),
138                "`Aggregate` trait must have angle-bracketed generic arguments (e.g., `Aggregate<T>`). Missing generic argument.",
139            ));
140        }
141    };
142
143    let generic_arg = args.args.first().ok_or_else(|| {
144        syn::Error::new_spanned(
145            args,
146            "`Aggregate` trait requires at least one generic argument (e.g., `Aggregate<T>`).",
147        )
148    })?;
149
150    if let syn::GenericArgument::Type(ty) = generic_arg {
151        Ok(ty)
152    } else {
153        Err(syn::Error::new_spanned(
154            generic_arg,
155            "Expected a type as the generic argument for `Aggregate` (e.g., `Aggregate<MyType>`).",
156        ))
157    }
158}
159
160fn get_generic_type_name(ty: &syn::Type) -> Result<String, syn::Error> {
161    if let Type::Path(type_path) = ty
162        && let Some(ident) = type_path.path.segments.last().map(|s| &s.ident)
163    {
164        let ident = ident.to_string();
165
166        match ident.as_str() {
167            "!" => Ok("never".to_string()),
168            "()" => Ok("unit".to_string()),
169            _ => Ok(ident),
170        }
171    } else {
172        Err(syn::Error::new_spanned(ty, "Generic type path is empty or malformed."))
173    }
174}
175
176impl PgAggregate {
177    pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
178        let to_sql_config =
179            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
180        let target_path = get_target_path(&item_impl)?;
181        let target_ident = get_target_ident(&target_path)?;
182
183        let mut pg_externs = Vec::default();
184        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
185        // and mutate the actual one.
186        let item_impl_snapshot = item_impl.clone();
187
188        let generic_type = extract_generic_from_trait(&item_impl)?.clone();
189        let generic_type_name = get_generic_type_name(&generic_type)?;
190
191        let snake_case_target_ident =
192            format!("{target_ident}_{generic_type_name}").to_case(Case::Snake);
193        let snake_case_target_ident = Ident::new(&snake_case_target_ident, target_ident.span());
194        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
195
196        let name = parse_quote! {
197            <#generic_type as ::pgrx::aggregate::ToAggregateName>::NAME
198        };
199
200        // `State` is an optional value, we default to `Self`.
201        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
202        let _type_state_value = type_state.map(|v| v.ty.clone());
203
204        let type_state_without_self = if let Some(inner) = type_state {
205            let mut remapped = inner.ty.clone();
206            remap_self_to_target(&mut remapped, &target_ident);
207            remapped
208        } else {
209            item_impl.items.push(parse_quote! {
210                type State = Self;
211            });
212            let mut remapped = parse_quote!(Self);
213            remap_self_to_target(&mut remapped, &target_ident);
214            remapped
215        };
216        let type_stype = AggregateType {
217            used_ty: UsedType::new(type_state_without_self.clone())?,
218            name: Some("state".into()),
219        };
220
221        // `MovingState` is an optional value, we default to nothing.
222        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
223        let type_moving_state;
224        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
225            type_moving_state = impl_type_moving_state.ty.clone();
226            Some(UsedType::new(type_moving_state.clone())?)
227        } else {
228            item_impl.items.push(parse_quote! {
229                type MovingState = ();
230            });
231            type_moving_state = parse_quote! { () };
232            None
233        };
234
235        // `OrderBy` is an optional value, we default to nothing.
236        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
237        let type_ordered_set_args_value =
238            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
239        if type_ordered_set_args.is_none() {
240            item_impl.items.push(parse_quote! {
241                type OrderedSetArgs = ();
242            })
243        }
244        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
245            type_ordered_set_args_value
246        {
247            let direct_args = order_by_direct_args
248                .found
249                .iter()
250                .map(|x| {
251                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
252                })
253                .collect::<Vec<_>>();
254            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
255                .iter()
256                .zip(direct_args.iter())
257                .map(|(default_name, (custom_name, _ty, _orig))| {
258                    Ident::new(
259                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
260                        Span::mixed_site(),
261                    )
262                })
263                .collect::<Vec<_>>();
264            let direct_args_with_names = direct_args
265                .iter()
266                .zip(direct_arg_names.iter())
267                .map(|(arg, name)| {
268                    let arg_ty = &arg.2; // original_type
269                    parse_quote! {
270                        #name: #arg_ty
271                    }
272                })
273                .collect::<Vec<syn::FnArg>>();
274            (direct_args_with_names, direct_arg_names)
275        } else {
276            (Vec::default(), Vec::default())
277        };
278
279        // `Args` is an optional value, we default to nothing.
280        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
281            syn::Error::new(
282                item_impl_snapshot.span(),
283                "`#[pg_aggregate]` requires the `Args` type defined.",
284            )
285        })?;
286        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
287        let args = type_args_value
288            .found
289            .iter()
290            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
291            .collect::<Vec<_>>();
292        let arg_names = ARG_NAMES[0..args.len()]
293            .iter()
294            .zip(args.iter())
295            .map(|(default_name, (custom_name, ty))| {
296                Ident::new(
297                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
298                    ty.span(),
299                )
300            })
301            .collect::<Vec<_>>();
302        let args_with_names = args
303            .iter()
304            .zip(arg_names.iter())
305            .map(|(arg, name)| {
306                let arg_ty = &arg.1;
307                quote! {
308                    #name: #arg_ty
309                }
310            })
311            .collect::<Vec<_>>();
312
313        // `Finalize` is an optional value, we default to nothing.
314        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
315        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
316            type_finalize.ty.clone()
317        } else {
318            item_impl.items.push(parse_quote! {
319                type Finalize = ();
320            });
321            parse_quote! { () }
322        };
323
324        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
325
326        let fn_state_name = if let Some(found) = fn_state {
327            let fn_name =
328                Ident::new(&format!("{snake_case_target_ident}_state"), found.sig.ident.span());
329            let pg_extern_attr = pg_extern_attr(found);
330
331            pg_externs.push(parse_quote! {
332                #[allow(non_snake_case, clippy::too_many_arguments)]
333                #pg_extern_attr
334                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
335                    unsafe {
336                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
337                            fcinfo,
338                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::state(this, (#(#arg_names),*), fcinfo)
339                        )
340                    }
341                }
342            });
343            fn_name
344        } else {
345            return Err(syn::Error::new(
346                item_impl.span(),
347                "Aggregate implementation must include state function.",
348            ));
349        };
350
351        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
352        let fn_combine_name = if let Some(found) = fn_combine {
353            let fn_name =
354                Ident::new(&format!("{snake_case_target_ident}_combine"), found.sig.ident.span());
355            let pg_extern_attr = pg_extern_attr(found);
356            pg_externs.push(parse_quote! {
357                #[allow(non_snake_case, clippy::too_many_arguments)]
358                #pg_extern_attr
359                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
360                    unsafe {
361                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
362                            fcinfo,
363                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::combine(this, v, fcinfo)
364                        )
365                    }
366                }
367            });
368            Some(fn_name)
369        } else {
370            item_impl.items.push(parse_quote! {
371                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
372                    unimplemented!("Call to combine on an aggregate which does not support it.")
373                }
374            });
375            None
376        };
377
378        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
379        let fn_finalize_name = if let Some(found) = fn_finalize {
380            let fn_name =
381                Ident::new(&format!("{snake_case_target_ident}_finalize"), found.sig.ident.span());
382            let pg_extern_attr = pg_extern_attr(found);
383
384            if !direct_args_with_names.is_empty() {
385                pg_externs.push(parse_quote! {
386                    #[allow(non_snake_case, clippy::too_many_arguments)]
387                    #pg_extern_attr
388                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
389                        unsafe {
390                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
391                                fcinfo,
392                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (#(#direct_arg_names),*), fcinfo)
393                            )
394                        }
395                    }
396                });
397            } else {
398                pg_externs.push(parse_quote! {
399                    #[allow(non_snake_case, clippy::too_many_arguments)]
400                    #pg_extern_attr
401                    fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
402                        unsafe {
403                            <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
404                                fcinfo,
405                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::finalize(this, (), fcinfo)
406                            )
407                        }
408                    }
409                });
410            };
411            Some(fn_name)
412        } else {
413            item_impl.items.push(parse_quote! {
414                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
415                    unimplemented!("Call to finalize on an aggregate which does not support it.")
416                }
417            });
418            None
419        };
420
421        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
422        let fn_serial_name = if let Some(found) = fn_serial {
423            let fn_name =
424                Ident::new(&format!("{snake_case_target_ident}_serial"), found.sig.ident.span());
425            let pg_extern_attr = pg_extern_attr(found);
426            pg_externs.push(parse_quote! {
427                #[allow(non_snake_case, clippy::too_many_arguments)]
428                #pg_extern_attr
429                fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
430                    unsafe {
431                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
432                            fcinfo,
433                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::serial(this, fcinfo)
434                        )
435                    }
436                }
437            });
438            Some(fn_name)
439        } else {
440            item_impl.items.push(parse_quote! {
441                fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
442                    unimplemented!("Call to serial on an aggregate which does not support it.")
443                }
444            });
445            None
446        };
447
448        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
449        let fn_deserial_name = if let Some(found) = fn_deserial {
450            let fn_name =
451                Ident::new(&format!("{snake_case_target_ident}_deserial"), found.sig.ident.span());
452            let pg_extern_attr = pg_extern_attr(found);
453            pg_externs.push(parse_quote! {
454                #[allow(non_snake_case, clippy::too_many_arguments)]
455                #pg_extern_attr
456                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
457                    unsafe {
458                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
459                            fcinfo,
460                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::deserial(this, buf, internal, fcinfo)
461                        )
462                    }
463                }
464            });
465            Some(fn_name)
466        } else {
467            item_impl.items.push(parse_quote! {
468                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
469                    unimplemented!("Call to deserial on an aggregate which does not support it.")
470                }
471            });
472            None
473        };
474
475        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
476        let fn_moving_state_name = if let Some(found) = fn_moving_state {
477            let fn_name = Ident::new(
478                &format!("{snake_case_target_ident}_moving_state"),
479                found.sig.ident.span(),
480            );
481            let pg_extern_attr = pg_extern_attr(found);
482
483            pg_externs.push(parse_quote! {
484                #[allow(non_snake_case, clippy::too_many_arguments)]
485                #pg_extern_attr
486                fn #fn_name(
487                    mstate: #type_moving_state,
488                    #(#args_with_names),*,
489                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
490                ) -> #type_moving_state {
491                    unsafe {
492                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
493                            fcinfo,
494                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state(mstate, (#(#arg_names),*), fcinfo)
495                        )
496                    }
497                }
498            });
499            Some(fn_name)
500        } else {
501            item_impl.items.push(parse_quote! {
502                fn moving_state(
503                    _mstate: <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState,
504                    _v: Self::Args,
505                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
506                ) -> <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::MovingState {
507                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
508                }
509            });
510            None
511        };
512
513        let fn_moving_state_inverse =
514            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
515        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
516            let fn_name = Ident::new(
517                &format!("{snake_case_target_ident}_moving_state_inverse"),
518                found.sig.ident.span(),
519            );
520            let pg_extern_attr = pg_extern_attr(found);
521            pg_externs.push(parse_quote! {
522                #[allow(non_snake_case, clippy::too_many_arguments)]
523                #pg_extern_attr
524                fn #fn_name(
525                    mstate: #type_moving_state,
526                    #(#args_with_names),*,
527                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
528                ) -> #type_moving_state {
529                    unsafe {
530                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
531                            fcinfo,
532                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
533                        )
534                    }
535                }
536            });
537            Some(fn_name)
538        } else {
539            item_impl.items.push(parse_quote! {
540                fn moving_state_inverse(
541                    _mstate: #type_moving_state,
542                    _v: Self::Args,
543                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
544                ) -> #type_moving_state {
545                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
546                }
547            });
548            None
549        };
550
551        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
552        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
553            let fn_name = Ident::new(
554                &format!("{snake_case_target_ident}_moving_finalize"),
555                found.sig.ident.span(),
556            );
557            let pg_extern_attr = pg_extern_attr(found);
558            let maybe_comma: Option<syn::Token![,]> =
559                if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
560
561            pg_externs.push(parse_quote! {
562                #[allow(non_snake_case, clippy::too_many_arguments)]
563                #pg_extern_attr
564                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
565                    unsafe {
566                        <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::in_memory_context(
567                            fcinfo,
568                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate::<#generic_type>>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
569                        )
570                    }
571                }
572            });
573            Some(fn_name)
574        } else {
575            item_impl.items.push(parse_quote! {
576                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
577                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
578                }
579            });
580            None
581        };
582
583        Ok(CodeEnrichment(Self {
584            item_impl,
585            target_ident,
586            pg_externs,
587            name,
588            snake_case_target_ident,
589            type_args: type_args_value,
590            type_ordered_set_args: type_ordered_set_args_value,
591            type_moving_state: type_moving_state_value,
592            type_stype,
593            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
594                .map(|x| x.expr.clone()),
595            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
596                .map(|x| x.expr.clone()),
597            const_moving_finalize_modify: get_impl_const_by_name(
598                &item_impl_snapshot,
599                "MOVING_FINALIZE_MODIFY",
600            )
601            .map(|x| x.expr.clone()),
602            const_initial_condition: get_impl_const_by_name(
603                &item_impl_snapshot,
604                "INITIAL_CONDITION",
605            )
606            .and_then(|e| get_const_litstr(e).transpose())
607            .transpose()?,
608            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
609                .and_then(get_const_litbool)
610                .unwrap_or(false),
611            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
612                .and_then(|e| get_const_litstr(e).transpose())
613                .transpose()?,
614            const_moving_intial_condition: get_impl_const_by_name(
615                &item_impl_snapshot,
616                "MOVING_INITIAL_CONDITION",
617            )
618            .and_then(|e| get_const_litstr(e).transpose())
619            .transpose()?,
620            fn_state: fn_state_name,
621            fn_finalize: fn_finalize_name,
622            fn_combine: fn_combine_name,
623            fn_serial: fn_serial_name,
624            fn_deserial: fn_deserial_name,
625            fn_moving_state: fn_moving_state_name,
626            fn_moving_state_inverse: fn_moving_state_inverse_name,
627            fn_moving_finalize: fn_moving_finalize_name,
628            hypothetical: if let Some(value) =
629                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
630            {
631                match &value.expr {
632                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
633                        syn::Lit::Bool(lit) => lit.value,
634                        _ => {
635                            return Err(syn::Error::new(
636                                value.span(),
637                                "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
638                            ));
639                        }
640                    },
641                    _ => {
642                        return Err(syn::Error::new(
643                            value.span(),
644                            "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.",
645                        ));
646                    }
647                }
648            } else {
649                false
650            },
651            to_sql_config,
652        }))
653    }
654}
655
656impl ToEntityGraphTokens for PgAggregate {
657    fn to_entity_graph_tokens(&self) -> TokenStream2 {
658        let target_ident = &self.target_ident;
659        let sql_graph_entity_fn_name = syn::Ident::new(
660            &format!("__pgrx_schema_aggregate_{}", self.snake_case_target_ident),
661            target_ident.span(),
662        );
663
664        let name = &self.name;
665        let const_ordered_set = self.const_ordered_set;
666        let hypothetical = self.hypothetical;
667        let fn_state = &self.fn_state;
668        let to_sql_config = &self.to_sql_config;
669        let to_sql_config_len = to_sql_config.section_len_tokens();
670        let type_args_len = self.type_args.section_len_tokens();
671        let direct_args_len = self
672            .type_ordered_set_args
673            .as_ref()
674            .map(|value| {
675                let inner = value.section_len_tokens();
676                quote! {
677                    ::pgrx::pgrx_sql_entity_graph::section::bool_len() + (#inner)
678                }
679            })
680            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
681        let stype_len = self.type_stype.section_len_tokens();
682        let moving_state_len = self
683            .type_moving_state
684            .as_ref()
685            .map(|value| {
686                let inner = value.section_len_tokens();
687                quote! {
688                    ::pgrx::pgrx_sql_entity_graph::section::bool_len() + (#inner)
689                }
690            })
691            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
692        let finalfunc_len = self
693            .fn_finalize
694            .as_ref()
695            .map(|value| {
696                quote! {
697                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
698                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
699                }
700            })
701            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
702        let combinefunc_len = self
703            .fn_combine
704            .as_ref()
705            .map(|value| {
706                quote! {
707                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
708                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
709                }
710            })
711            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
712        let serialfunc_len = self
713            .fn_serial
714            .as_ref()
715            .map(|value| {
716                quote! {
717                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
718                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
719                }
720            })
721            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
722        let deserialfunc_len = self
723            .fn_deserial
724            .as_ref()
725            .map(|value| {
726                quote! {
727                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
728                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
729                }
730            })
731            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
732        let initcond_len = self
733            .const_initial_condition
734            .as_ref()
735            .map(|value| {
736                quote! {
737                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
738                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
739                }
740            })
741            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
742        let msfunc_len = self
743            .fn_moving_state
744            .as_ref()
745            .map(|value| {
746                quote! {
747                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
748                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
749                }
750            })
751            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
752        let minvfunc_len = self
753            .fn_moving_state_inverse
754            .as_ref()
755            .map(|value| {
756                quote! {
757                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
758                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
759                }
760            })
761            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
762        let mfinalfunc_len = self
763            .fn_moving_finalize
764            .as_ref()
765            .map(|value| {
766                quote! {
767                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
768                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#value))
769                }
770            })
771            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
772        let minitcond_len = self
773            .const_moving_intial_condition
774            .as_ref()
775            .map(|value| {
776                quote! {
777                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
778                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
779                }
780            })
781            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
782        let sortop_len = self
783            .const_sort_operator
784            .as_ref()
785            .map(|value| {
786                quote! {
787                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
788                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
789                }
790            })
791            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
792        let finalize_modify_len = self
793            .const_finalize_modify
794            .as_ref()
795            .map(|value| {
796                quote! {
797                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
798                        + match #value {
799                            Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
800                            None => 0,
801                        }
802                }
803            })
804            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
805        let moving_finalize_modify_len = self
806            .const_moving_finalize_modify
807            .as_ref()
808            .map(|value| {
809                quote! {
810                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
811                        + match #value {
812                            Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
813                            None => 0,
814                        }
815                }
816            })
817            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
818        let parallel_len = self
819            .const_parallel
820            .as_ref()
821            .map(|value| {
822                quote! {
823                    ::pgrx::pgrx_sql_entity_graph::section::bool_len()
824                        + match #value {
825                            Some(_) => ::pgrx::pgrx_sql_entity_graph::section::u8_len(),
826                            None => 0,
827                        }
828                }
829            })
830            .unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
831        let payload_len = quote! {
832            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
833                + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(module_path!(), "::", stringify!(#target_ident)))
834                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
835                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
836                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
837                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
838                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
839                + (#type_args_len)
840                + (#direct_args_len)
841                + (#stype_len)
842                + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#fn_state))
843                + (#finalfunc_len)
844                + (#finalize_modify_len)
845                + (#combinefunc_len)
846                + (#serialfunc_len)
847                + (#deserialfunc_len)
848                + (#initcond_len)
849                + (#msfunc_len)
850                + (#minvfunc_len)
851                + (#moving_state_len)
852                + (#mfinalfunc_len)
853                + (#moving_finalize_modify_len)
854                + (#minitcond_len)
855                + (#sortop_len)
856                + (#parallel_len)
857                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
858                + (#to_sql_config_len)
859        };
860        let total_len = quote! {
861            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
862        };
863
864        let direct_args_writer = self
865            .type_ordered_set_args
866            .as_ref()
867            .map(|value| value.section_writer_tokens(quote! { writer.bool(true) }))
868            .unwrap_or_else(|| quote! { writer.bool(false) });
869        let moving_state_writer = self
870            .type_moving_state
871            .as_ref()
872            .map(|value| value.section_writer_tokens(quote! { writer.bool(true) }))
873            .unwrap_or_else(|| quote! { writer.bool(false) });
874        let finalfunc_writer = self
875            .fn_finalize
876            .as_ref()
877            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
878            .unwrap_or_else(|| quote! { writer.bool(false) });
879        let combinefunc_writer = self
880            .fn_combine
881            .as_ref()
882            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
883            .unwrap_or_else(|| quote! { writer.bool(false) });
884        let serialfunc_writer = self
885            .fn_serial
886            .as_ref()
887            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
888            .unwrap_or_else(|| quote! { writer.bool(false) });
889        let deserialfunc_writer = self
890            .fn_deserial
891            .as_ref()
892            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
893            .unwrap_or_else(|| quote! { writer.bool(false) });
894        let initcond_writer = self
895            .const_initial_condition
896            .as_ref()
897            .map(|value| quote! { writer.bool(true).str(#value) })
898            .unwrap_or_else(|| quote! { writer.bool(false) });
899        let msfunc_writer = self
900            .fn_moving_state
901            .as_ref()
902            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
903            .unwrap_or_else(|| quote! { writer.bool(false) });
904        let minvfunc_writer = self
905            .fn_moving_state_inverse
906            .as_ref()
907            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
908            .unwrap_or_else(|| quote! { writer.bool(false) });
909        let mfinalfunc_writer = self
910            .fn_moving_finalize
911            .as_ref()
912            .map(|value| quote! { writer.bool(true).str(stringify!(#value)) })
913            .unwrap_or_else(|| quote! { writer.bool(false) });
914        let minitcond_writer = self
915            .const_moving_intial_condition
916            .as_ref()
917            .map(|value| quote! { writer.bool(true).str(#value) })
918            .unwrap_or_else(|| quote! { writer.bool(false) });
919        let sortop_writer = self
920            .const_sort_operator
921            .as_ref()
922            .map(|value| quote! { writer.bool(true).str(#value) })
923            .unwrap_or_else(|| quote! { writer.bool(false) });
924        let finalize_modify_writer = self
925            .const_finalize_modify
926            .as_ref()
927            .map(|value| quote! { match #value {
928                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadOnly) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_ONLY),
929                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::Shareable) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_SHAREABLE),
930                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadWrite) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_WRITE),
931                None => writer.bool(false),
932            } })
933            .unwrap_or_else(|| quote! { writer.bool(false) });
934        let moving_finalize_modify_writer = self
935            .const_moving_finalize_modify
936            .as_ref()
937            .map(|value| quote! { match #value {
938                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadOnly) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_ONLY),
939                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::Shareable) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_SHAREABLE),
940                Some(::pgrx::pgrx_sql_entity_graph::FinalizeModify::ReadWrite) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_FINALIZE_READ_WRITE),
941                None => writer.bool(false),
942            } })
943            .unwrap_or_else(|| quote! { writer.bool(false) });
944        let parallel_writer = self
945            .const_parallel
946            .as_ref()
947            .map(|value| quote! { match #value {
948                Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Safe) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_SAFE),
949                Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Restricted) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_RESTRICTED),
950                Some(::pgrx::pgrx_sql_entity_graph::ParallelOption::Unsafe) => writer.bool(true).u8(::pgrx::pgrx_sql_entity_graph::section::AGGREGATE_PARALLEL_UNSAFE),
951                None => writer.bool(false),
952            } })
953            .unwrap_or_else(|| quote! { writer.bool(false) });
954        let args_writer = self.type_args.section_writer_tokens(quote! { writer });
955        let stype_writer = self.type_stype.section_writer_tokens(quote! { writer });
956        let to_sql_config_writer = to_sql_config.section_writer_tokens(quote! { writer });
957
958        quote! {
959            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
960                #sql_graph_entity_fn_name,
961                #total_len,
962                {
963                    let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
964                        .u32((#payload_len) as u32)
965                        .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_AGGREGATE)
966                        .str(concat!(module_path!(), "::", stringify!(#target_ident)))
967                        .str(module_path!())
968                        .str(file!())
969                        .u32(line!())
970                        .str(#name)
971                        .bool(#const_ordered_set);
972                    let writer = { #args_writer };
973                    let writer = { #direct_args_writer };
974                    let writer = { #stype_writer };
975                    let writer = writer.str(stringify!(#fn_state));
976                    let writer = { #finalfunc_writer };
977                    let writer = { #finalize_modify_writer };
978                    let writer = { #combinefunc_writer };
979                    let writer = { #serialfunc_writer };
980                    let writer = { #deserialfunc_writer };
981                    let writer = { #initcond_writer };
982                    let writer = { #msfunc_writer };
983                    let writer = { #minvfunc_writer };
984                    let writer = { #moving_state_writer };
985                    let writer = { #mfinalfunc_writer };
986                    let writer = { #moving_finalize_modify_writer };
987                    let writer = { #minitcond_writer };
988                    let writer = { #sortop_writer };
989                    let writer = { #parallel_writer };
990                    let writer = writer.bool(#hypothetical);
991                    let writer = { #to_sql_config_writer };
992                    writer.finish()
993                }
994            );
995        }
996    }
997}
998
999impl ToRustCodeTokens for PgAggregate {
1000    fn to_rust_code_tokens(&self) -> TokenStream2 {
1001        let impl_item = &self.item_impl;
1002        let pg_externs = self.pg_externs.iter();
1003
1004        quote! {
1005            #impl_item
1006            #(#pg_externs)*
1007        }
1008    }
1009}
1010
1011impl Parse for CodeEnrichment<PgAggregate> {
1012    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
1013        PgAggregate::new(input.parse()?)
1014    }
1015}
1016
1017fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
1018    let last = path.segments.last().ok_or_else(|| {
1019        syn::Error::new(
1020            path.span(),
1021            "`#[pg_aggregate]` only works with types whose path have a final segment.",
1022        )
1023    })?;
1024    Ok(last.ident.clone())
1025}
1026
1027fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
1028    let target_ident = match &*item_impl.self_ty {
1029        syn::Type::Path(type_path) => {
1030            let last_segment = type_path.path.segments.last().ok_or_else(|| {
1031                syn::Error::new(
1032                    type_path.span(),
1033                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
1034                )
1035            })?;
1036            if last_segment.ident == "PgVarlena" {
1037                match &last_segment.arguments {
1038                    syn::PathArguments::AngleBracketed(angled) => {
1039                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
1040                            type_path.span(),
1041                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
1042                        ))?;
1043                        match &first {
1044                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
1045                            _ => {
1046                                return Err(syn::Error::new(
1047                                    type_path.span(),
1048                                    "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
1049                                ));
1050                            }
1051                        }
1052                    }
1053                    _ => {
1054                        return Err(syn::Error::new(
1055                            type_path.span(),
1056                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
1057                        ));
1058                    }
1059                }
1060            } else {
1061                type_path.path.clone()
1062            }
1063        }
1064        something_else => {
1065            return Err(syn::Error::new(
1066                something_else.span(),
1067                "`#[pg_aggregate]` only works with types.",
1068            ));
1069        }
1070    };
1071    Ok(target_ident)
1072}
1073
1074fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
1075    let mut found = None;
1076    for attr in item.attrs.iter() {
1077        match attr.path().segments.last() {
1078            Some(segment) if segment.ident == "pgrx" => {
1079                found = Some(attr);
1080                break;
1081            }
1082            _ => (),
1083        };
1084    }
1085
1086    let attrs = if let Some(attr) = found {
1087        let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
1088        let attrs = attr.parse_args_with(parser);
1089        attrs.ok()
1090    } else {
1091        None
1092    };
1093
1094    match attrs {
1095        Some(args) => parse_quote! {
1096            #[::pgrx::pg_extern(#args)]
1097        },
1098        None => parse_quote! {
1099            #[::pgrx::pg_extern]
1100        },
1101    }
1102}
1103
1104fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
1105    let mut needle = None;
1106    for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1107        syn::ImplItem::Type(iitype) => Some(iitype),
1108        _ => None,
1109    }) {
1110        let ident_string = impl_item_type.ident.to_string();
1111        if ident_string == name {
1112            needle = Some(impl_item_type);
1113        }
1114    }
1115    needle
1116}
1117
1118fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
1119    let mut needle = None;
1120    for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1121        syn::ImplItem::Fn(iifn) => Some(iifn),
1122        _ => None,
1123    }) {
1124        let ident_string = impl_item_fn.sig.ident.to_string();
1125        if ident_string == name {
1126            needle = Some(impl_item_fn);
1127        }
1128    }
1129    needle
1130}
1131
1132fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
1133    let mut needle = None;
1134    for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
1135        syn::ImplItem::Const(iiconst) => Some(iiconst),
1136        _ => None,
1137    }) {
1138        let ident_string = impl_item_const.ident.to_string();
1139        if ident_string == name {
1140            needle = Some(impl_item_const);
1141        }
1142    }
1143    needle
1144}
1145
1146fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
1147    match &item.expr {
1148        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
1149            syn::Lit::Bool(lit) => Some(lit.value()),
1150            _ => None,
1151        },
1152        _ => None,
1153    }
1154}
1155
1156fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
1157    match &item.expr {
1158        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
1159            syn::Lit::Str(lit) => Ok(Some(lit.value())),
1160            _ => Ok(None),
1161        },
1162        syn::Expr::Call(expr_call) => match &*expr_call.func {
1163            syn::Expr::Path(expr_path) => {
1164                let Some(last) = expr_path.path.segments.last() else {
1165                    return Ok(None);
1166                };
1167                if last.ident == "Some" {
1168                    match expr_call.args.first() {
1169                        Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
1170                            syn::Lit::Str(lit) => Ok(Some(lit.value())),
1171                            _ => Ok(None),
1172                        },
1173                        _ => Ok(None),
1174                    }
1175                } else {
1176                    Ok(None)
1177                }
1178            }
1179            _ => Ok(None),
1180        },
1181        ex => Err(syn::Error::new(ex.span(), "")),
1182    }
1183}
1184
1185fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
1186    if let Type::Path(ty_path) = ty {
1187        for segment in ty_path.path.segments.iter_mut() {
1188            if segment.ident == "Self" {
1189                segment.ident = target.clone()
1190            }
1191            use syn::{GenericArgument, PathArguments};
1192            match segment.arguments {
1193                PathArguments::AngleBracketed(ref mut angle_args) => {
1194                    for arg in angle_args.args.iter_mut() {
1195                        if let GenericArgument::Type(inner_ty) = arg {
1196                            remap_self_to_target(inner_ty, target)
1197                        }
1198                    }
1199                }
1200                PathArguments::Parenthesized(_) => (),
1201                PathArguments::None => (),
1202            }
1203        }
1204    }
1205}
1206
1207fn get_pgrx_attr_macro(attr_name: &str, ty: &syn::Type) -> Option<TokenStream2> {
1208    match &ty {
1209        syn::Type::Macro(ty_macro) => {
1210            let mut found_pgrx = false;
1211            let mut found_attr = false;
1212            // We don't actually have type resolution here, this is a "Best guess".
1213            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
1214                match segment.ident.to_string().as_str() {
1215                    "pgrx" if idx == 0 => found_pgrx = true,
1216                    attr if attr == attr_name => found_attr = true,
1217                    _ => (),
1218                }
1219            }
1220            if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
1221                Some(ty_macro.mac.tokens.clone())
1222            } else {
1223                None
1224            }
1225        }
1226        _ => None,
1227    }
1228}
1229
1230#[cfg(test)]
1231mod tests {
1232    use super::PgAggregate;
1233    use eyre::Result;
1234    use quote::ToTokens;
1235    use syn::{ItemImpl, parse_quote};
1236
1237    #[test]
1238    fn agg_required_only() -> Result<()> {
1239        let tokens: ItemImpl = parse_quote! {
1240            #[pg_aggregate]
1241            impl Aggregate<DemoName> for DemoAgg {
1242                type State = PgVarlena<Self>;
1243                type Args = i32;
1244
1245                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
1246                    todo!()
1247                }
1248            }
1249        };
1250        // It should not error, as it's valid.
1251        let agg = PgAggregate::new(tokens);
1252        assert!(agg.is_ok());
1253        // It should create 1 extern, the state.
1254        let agg = agg.unwrap();
1255        assert_eq!(agg.0.pg_externs.len(), 1);
1256        // That extern should be named specifically:
1257        let extern_fn = &agg.0.pg_externs[0];
1258        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1259        // It should be possible to generate entity tokens.
1260        let _ = agg.to_token_stream();
1261        Ok(())
1262    }
1263
1264    #[test]
1265    fn agg_all_options() -> Result<()> {
1266        let tokens: ItemImpl = parse_quote! {
1267            #[pg_aggregate]
1268            impl Aggregate<DemoName> for DemoAgg {
1269                type State = PgVarlena<Self>;
1270                type Args = i32;
1271                type OrderBy = i32;
1272                type MovingState = i32;
1273
1274                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
1275                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1276                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
1277                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
1278                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
1279                const HYPOTHETICAL: bool = true;
1280
1281                fn state(current: Self::State, v: Self::Args) -> Self::State {
1282                    todo!()
1283                }
1284
1285                fn finalize(current: Self::State) -> Self::Finalize {
1286                    todo!()
1287                }
1288
1289                fn combine(current: Self::State, _other: Self::State) -> Self::State {
1290                    todo!()
1291                }
1292
1293                fn serial(current: Self::State) -> Vec<u8> {
1294                    todo!()
1295                }
1296
1297                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
1298                    todo!()
1299                }
1300
1301                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1302                    todo!()
1303                }
1304
1305                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
1306                    todo!()
1307                }
1308
1309                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1310                    todo!()
1311                }
1312            }
1313        };
1314        // It should not error, as it's valid.
1315        let agg = PgAggregate::new(tokens);
1316        assert!(agg.is_ok());
1317        // It should create 8 externs!
1318        let agg = agg.unwrap();
1319        assert_eq!(agg.0.pg_externs.len(), 8);
1320        // That extern should be named specifically:
1321        let extern_fn = &agg.0.pg_externs[0];
1322        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_demo_name_state");
1323        // It should be possible to generate entity tokens.
1324        let _ = agg.to_token_stream();
1325        Ok(())
1326    }
1327
1328    #[test]
1329    fn agg_missing_required() -> Result<()> {
1330        // This is not valid as it is missing required types/consts.
1331        let tokens: ItemImpl = parse_quote! {
1332            #[pg_aggregate]
1333            impl Aggregate for IntegerAvgState {
1334            }
1335        };
1336        let agg = PgAggregate::new(tokens);
1337        assert!(agg.is_err());
1338        Ok(())
1339    }
1340}