1mod aggregate_type;
19pub(crate) mod entity;
20mod options;
21
22pub use aggregate_type::{AggregateType, AggregateTypeList};
23pub use options::{FinalizeModify, ParallelOption};
24
25use crate::enrich::CodeEnrichment;
26use crate::enrich::ToEntityGraphTokens;
27use crate::enrich::ToRustCodeTokens;
28use convert_case::{Case, Casing};
29use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
30use quote::quote;
31use syn::parse::{Parse, ParseStream};
32use syn::spanned::Spanned;
33use syn::{
34 parse_quote, Expr, ImplItemConst, ImplItemMethod, ImplItemType, ItemFn, ItemImpl, Path, Type,
35};
36
37use crate::ToSqlConfig;
38
39use super::UsedType;
40
41const ARG_NAMES: [&str; 32] = [
43 "arg_one",
44 "arg_two",
45 "arg_three",
46 "arg_four",
47 "arg_five",
48 "arg_six",
49 "arg_seven",
50 "arg_eight",
51 "arg_nine",
52 "arg_ten",
53 "arg_eleven",
54 "arg_twelve",
55 "arg_thirteen",
56 "arg_fourteen",
57 "arg_fifteen",
58 "arg_sixteen",
59 "arg_seventeen",
60 "arg_eighteen",
61 "arg_nineteen",
62 "arg_twenty",
63 "arg_twenty_one",
64 "arg_twenty_two",
65 "arg_twenty_three",
66 "arg_twenty_four",
67 "arg_twenty_five",
68 "arg_twenty_six",
69 "arg_twenty_seven",
70 "arg_twenty_eight",
71 "arg_twenty_nine",
72 "arg_thirty",
73 "arg_thirty_one",
74 "arg_thirty_two",
75];
76
77#[derive(Debug, Clone)]
80pub struct PgAggregate {
81 item_impl: ItemImpl,
82 name: Expr,
83 target_ident: Ident,
84 pg_externs: Vec<ItemFn>,
85 type_args: AggregateTypeList,
87 type_ordered_set_args: Option<AggregateTypeList>,
88 type_moving_state: Option<UsedType>,
89 type_stype: AggregateType,
90 const_ordered_set: bool,
91 const_parallel: Option<syn::Expr>,
92 const_finalize_modify: Option<syn::Expr>,
93 const_moving_finalize_modify: Option<syn::Expr>,
94 const_initial_condition: Option<String>,
95 const_sort_operator: Option<String>,
96 const_moving_intial_condition: Option<String>,
97 fn_state: Ident,
98 fn_finalize: Option<Ident>,
99 fn_combine: Option<Ident>,
100 fn_serial: Option<Ident>,
101 fn_deserial: Option<Ident>,
102 fn_moving_state: Option<Ident>,
103 fn_moving_state_inverse: Option<Ident>,
104 fn_moving_finalize: Option<Ident>,
105 hypothetical: bool,
106 to_sql_config: ToSqlConfig,
107}
108
109impl PgAggregate {
110 pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
111 let to_sql_config =
112 ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
113 let target_path = get_target_path(&item_impl)?;
114 let target_ident = get_target_ident(&target_path)?;
115
116 let snake_case_target_ident =
117 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
118 crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
119
120 let mut pg_externs = Vec::default();
121 let item_impl_snapshot = item_impl.clone();
124
125 if let Some((_, ref path, _)) = item_impl.trait_ {
126 if let Some(last) = path.segments.last() {
128 if last.ident.to_string() != "Aggregate" {
129 return Err(syn::Error::new(
130 last.ident.span(),
131 "`#[pg_aggregate]` only works with the `Aggregate` trait.",
132 ));
133 }
134 }
135 }
136
137 let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
138 Some(item_const) => match &item_const.expr {
139 syn::Expr::Lit(ref expr) => {
140 if let syn::Lit::Str(_) = &expr.lit {
141 item_const.expr.clone()
142 } else {
143 return Err(syn::Error::new(
144 expr.span(),
145 "`NAME` must be a `&'static str` for Aggregate implementations.",
146 ));
147 }
148 }
149 e => {
150 return Err(syn::Error::new(
151 e.span(),
152 "`NAME` must be a `&'static str` for Aggregate implementations.",
153 ));
154 }
155 },
156 None => {
157 item_impl.items.push(parse_quote! {
158 const NAME: &'static str = stringify!(Self);
159 });
160 parse_quote! {
161 stringify!(#target_ident)
162 }
163 }
164 };
165
166 let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
168 let _type_state_value = type_state.map(|v| v.ty.clone());
169
170 let type_state_without_self = if let Some(inner) = type_state {
171 let mut remapped = inner.ty.clone();
172 remap_self_to_target(&mut remapped, &target_ident);
173 remapped
174 } else {
175 item_impl.items.push(parse_quote! {
176 type State = Self;
177 });
178 let mut remapped = parse_quote!(Self);
179 remap_self_to_target(&mut remapped, &target_ident);
180 remapped
181 };
182 let type_stype = AggregateType {
183 used_ty: UsedType::new(type_state_without_self.clone())?,
184 name: Some("state".into()),
185 };
186
187 let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
189 let type_moving_state;
190 let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
191 type_moving_state = impl_type_moving_state.ty.clone();
192 Some(UsedType::new(type_moving_state.clone())?)
193 } else {
194 item_impl.items.push(parse_quote! {
195 type MovingState = ();
196 });
197 type_moving_state = parse_quote! { () };
198 None
199 };
200
201 let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
203 let type_ordered_set_args_value =
204 type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
205 if type_ordered_set_args.is_none() {
206 item_impl.items.push(parse_quote! {
207 type OrderedSetArgs = ();
208 })
209 }
210 let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
211 type_ordered_set_args_value
212 {
213 let direct_args = order_by_direct_args
214 .found
215 .iter()
216 .map(|x| {
217 (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
218 })
219 .collect::<Vec<_>>();
220 let direct_arg_names = ARG_NAMES[0..direct_args.len()]
221 .iter()
222 .zip(direct_args.iter())
223 .map(|(default_name, (custom_name, _ty, _orig))| {
224 Ident::new(
225 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
226 Span::mixed_site(),
227 )
228 })
229 .collect::<Vec<_>>();
230 let direct_args_with_names = direct_args
231 .iter()
232 .zip(direct_arg_names.iter())
233 .map(|(arg, name)| {
234 let arg_ty = &arg.2; parse_quote! {
236 #name: #arg_ty
237 }
238 })
239 .collect::<Vec<syn::FnArg>>();
240 (direct_args_with_names, direct_arg_names)
241 } else {
242 (Vec::default(), Vec::default())
243 };
244
245 let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
247 syn::Error::new(
248 item_impl_snapshot.span(),
249 "`#[pg_aggregate]` requires the `Args` type defined.",
250 )
251 })?;
252 let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
253 let args = type_args_value
254 .found
255 .iter()
256 .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
257 .collect::<Vec<_>>();
258 let arg_names = ARG_NAMES[0..args.len()]
259 .iter()
260 .zip(args.iter())
261 .map(|(default_name, (custom_name, ty))| {
262 Ident::new(
263 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
264 ty.span(),
265 )
266 })
267 .collect::<Vec<_>>();
268 let args_with_names = args
269 .iter()
270 .zip(arg_names.iter())
271 .map(|(arg, name)| {
272 let arg_ty = &arg.1;
273 quote! {
274 #name: #arg_ty
275 }
276 })
277 .collect::<Vec<_>>();
278
279 let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
281 let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
282 type_finalize.ty.clone()
283 } else {
284 item_impl.items.push(parse_quote! {
285 type Finalize = ();
286 });
287 parse_quote! { () }
288 };
289
290 let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
291
292 let fn_state_name = if let Some(found) = fn_state {
293 let fn_name =
294 Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
295 let pg_extern_attr = pg_extern_attr(found);
296
297 pg_externs.push(parse_quote! {
298 #[allow(non_snake_case, clippy::too_many_arguments)]
299 #pg_extern_attr
300 fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
301 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
302 fcinfo,
303 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
304 )
305 }
306 });
307 fn_name
308 } else {
309 return Err(syn::Error::new(
310 item_impl.span(),
311 "Aggregate implementation must include state function.",
312 ));
313 };
314
315 let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
316 let fn_combine_name = if let Some(found) = fn_combine {
317 let fn_name =
318 Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
319 let pg_extern_attr = pg_extern_attr(found);
320 pg_externs.push(parse_quote! {
321 #[allow(non_snake_case, clippy::too_many_arguments)]
322 #pg_extern_attr
323 fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
324 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
325 fcinfo,
326 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::combine(this, v, fcinfo)
327 )
328 }
329 });
330 Some(fn_name)
331 } else {
332 item_impl.items.push(parse_quote! {
333 fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
334 unimplemented!("Call to combine on an aggregate which does not support it.")
335 }
336 });
337 None
338 };
339
340 let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
341 let fn_finalize_name = if let Some(found) = fn_finalize {
342 let fn_name = Ident::new(
343 &format!("{}_finalize", snake_case_target_ident),
344 found.sig.ident.span(),
345 );
346 let pg_extern_attr = pg_extern_attr(found);
347
348 if direct_args_with_names.len() > 0 {
349 pg_externs.push(parse_quote! {
350 #[allow(non_snake_case, clippy::too_many_arguments)]
351 #pg_extern_attr
352 fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
353 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
354 fcinfo,
355 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
356 )
357 }
358 });
359 } else {
360 pg_externs.push(parse_quote! {
361 #[allow(non_snake_case, clippy::too_many_arguments)]
362 #pg_extern_attr
363 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
364 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
365 fcinfo,
366 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::finalize(this, (), fcinfo)
367 )
368 }
369 });
370 };
371 Some(fn_name)
372 } else {
373 item_impl.items.push(parse_quote! {
374 fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
375 unimplemented!("Call to finalize on an aggregate which does not support it.")
376 }
377 });
378 None
379 };
380
381 let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
382 let fn_serial_name = if let Some(found) = fn_serial {
383 let fn_name =
384 Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
385 let pg_extern_attr = pg_extern_attr(found);
386 pg_externs.push(parse_quote! {
387 #[allow(non_snake_case, clippy::too_many_arguments)]
388 #pg_extern_attr
389 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
390 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
391 fcinfo,
392 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::serial(this, fcinfo)
393 )
394 }
395 });
396 Some(fn_name)
397 } else {
398 item_impl.items.push(parse_quote! {
399 fn serial(current: #type_state_without_self, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
400 unimplemented!("Call to serial on an aggregate which does not support it.")
401 }
402 });
403 None
404 };
405
406 let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
407 let fn_deserial_name = if let Some(found) = fn_deserial {
408 let fn_name = Ident::new(
409 &format!("{}_deserial", snake_case_target_ident),
410 found.sig.ident.span(),
411 );
412 let pg_extern_attr = pg_extern_attr(found);
413 pg_externs.push(parse_quote! {
414 #[allow(non_snake_case, clippy::too_many_arguments)]
415 #pg_extern_attr
416 fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pgbox::PgBox<#type_state_without_self> {
417 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
418 fcinfo,
419 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::deserial(this, buf, internal, fcinfo)
420 )
421 }
422 });
423 Some(fn_name)
424 } else {
425 item_impl.items.push(parse_quote! {
426 fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pgbox::PgBox<#type_state_without_self> {
427 unimplemented!("Call to deserial on an aggregate which does not support it.")
428 }
429 });
430 None
431 };
432
433 let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
434 let fn_moving_state_name = if let Some(found) = fn_moving_state {
435 let fn_name = Ident::new(
436 &format!("{}_moving_state", snake_case_target_ident),
437 found.sig.ident.span(),
438 );
439 let pg_extern_attr = pg_extern_attr(found);
440
441 pg_externs.push(parse_quote! {
442 #[allow(non_snake_case, clippy::too_many_arguments)]
443 #pg_extern_attr
444 fn #fn_name(
445 mstate: #type_moving_state,
446 #(#args_with_names),*,
447 fcinfo: ::pgx::pg_sys::FunctionCallInfo,
448 ) -> #type_moving_state {
449 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
450 fcinfo,
451 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
452 )
453 }
454 });
455 Some(fn_name)
456 } else {
457 item_impl.items.push(parse_quote! {
458 fn moving_state(
459 _mstate: <#target_path as ::pgx::aggregate::Aggregate>::MovingState,
460 _v: Self::Args,
461 _fcinfo: ::pgx::pg_sys::FunctionCallInfo,
462 ) -> <#target_path as ::pgx::aggregate::Aggregate>::MovingState {
463 unimplemented!("Call to moving_state on an aggregate which does not support it.")
464 }
465 });
466 None
467 };
468
469 let fn_moving_state_inverse =
470 get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
471 let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
472 let fn_name = Ident::new(
473 &format!("{}_moving_state_inverse", snake_case_target_ident),
474 found.sig.ident.span(),
475 );
476 let pg_extern_attr = pg_extern_attr(found);
477 pg_externs.push(parse_quote! {
478 #[allow(non_snake_case, clippy::too_many_arguments)]
479 #pg_extern_attr
480 fn #fn_name(
481 mstate: #type_moving_state,
482 #(#args_with_names),*,
483 fcinfo: ::pgx::pg_sys::FunctionCallInfo,
484 ) -> #type_moving_state {
485 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
486 fcinfo,
487 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
488 )
489 }
490 });
491 Some(fn_name)
492 } else {
493 item_impl.items.push(parse_quote! {
494 fn moving_state_inverse(
495 _mstate: #type_moving_state,
496 _v: Self::Args,
497 _fcinfo: ::pgx::pg_sys::FunctionCallInfo,
498 ) -> #type_moving_state {
499 unimplemented!("Call to moving_state on an aggregate which does not support it.")
500 }
501 });
502 None
503 };
504
505 let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
506 let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
507 let fn_name = Ident::new(
508 &format!("{}_moving_finalize", snake_case_target_ident),
509 found.sig.ident.span(),
510 );
511 let pg_extern_attr = pg_extern_attr(found);
512 let maybe_comma: Option<syn::Token![,]> =
513 if direct_args_with_names.len() > 0 { Some(parse_quote! {,}) } else { None };
514
515 pg_externs.push(parse_quote! {
516 #[allow(non_snake_case, clippy::too_many_arguments)]
517 #pg_extern_attr
518 fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
519 <#target_path as ::pgx::aggregate::Aggregate>::in_memory_context(
520 fcinfo,
521 move |_context| <#target_path as ::pgx::aggregate::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
522 )
523 }
524 });
525 Some(fn_name)
526 } else {
527 item_impl.items.push(parse_quote! {
528 fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
529 unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
530 }
531 });
532 None
533 };
534
535 Ok(CodeEnrichment(Self {
536 item_impl,
537 target_ident,
538 pg_externs,
539 name,
540 type_args: type_args_value,
541 type_ordered_set_args: type_ordered_set_args_value,
542 type_moving_state: type_moving_state_value,
543 type_stype: type_stype,
544 const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
545 .map(|x| x.expr.clone()),
546 const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
547 .map(|x| x.expr.clone()),
548 const_moving_finalize_modify: get_impl_const_by_name(
549 &item_impl_snapshot,
550 "MOVING_FINALIZE_MODIFY",
551 )
552 .map(|x| x.expr.clone()),
553 const_initial_condition: get_impl_const_by_name(
554 &item_impl_snapshot,
555 "INITIAL_CONDITION",
556 )
557 .and_then(|e| get_const_litstr(e).transpose())
558 .transpose()?,
559 const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
560 .and_then(get_const_litbool)
561 .unwrap_or(false),
562 const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
563 .and_then(|e| get_const_litstr(e).transpose())
564 .transpose()?,
565 const_moving_intial_condition: get_impl_const_by_name(
566 &item_impl_snapshot,
567 "MOVING_INITIAL_CONDITION",
568 )
569 .and_then(|e| get_const_litstr(e).transpose())
570 .transpose()?,
571 fn_state: fn_state_name,
572 fn_finalize: fn_finalize_name,
573 fn_combine: fn_combine_name,
574 fn_serial: fn_serial_name,
575 fn_deserial: fn_deserial_name,
576 fn_moving_state: fn_moving_state_name,
577 fn_moving_state_inverse: fn_moving_state_inverse_name,
578 fn_moving_finalize: fn_moving_finalize_name,
579 hypothetical: if let Some(value) =
580 get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
581 {
582 match &value.expr {
583 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
584 syn::Lit::Bool(lit) => lit.value,
585 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
586 },
587 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
588 }
589 } else {
590 false
591 },
592 to_sql_config,
593 }))
594 }
595}
596
597impl ToEntityGraphTokens for PgAggregate {
598 fn to_entity_graph_tokens(&self) -> TokenStream2 {
599 let target_ident = &self.target_ident;
600 let snake_case_target_ident =
601 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
602 let sql_graph_entity_fn_name = syn::Ident::new(
603 &format!("__pgx_internals_aggregate_{}", snake_case_target_ident),
604 target_ident.span(),
605 );
606
607 let name = &self.name;
608 let type_args_iter = &self.type_args.entity_tokens();
609 let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
610
611 let type_moving_state_entity_tokens =
612 self.type_moving_state.clone().map(|v| v.entity_tokens());
613 let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
614 let type_stype = self.type_stype.entity_tokens();
615 let const_ordered_set = self.const_ordered_set;
616 let const_parallel_iter = self.const_parallel.iter();
617 let const_finalize_modify_iter = self.const_finalize_modify.iter();
618 let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
619 let const_initial_condition_iter = self.const_initial_condition.iter();
620 let const_sort_operator_iter = self.const_sort_operator.iter();
621 let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
622 let hypothetical = self.hypothetical;
623 let fn_state = &self.fn_state;
624 let fn_finalize_iter = self.fn_finalize.iter();
625 let fn_combine_iter = self.fn_combine.iter();
626 let fn_serial_iter = self.fn_serial.iter();
627 let fn_deserial_iter = self.fn_deserial.iter();
628 let fn_moving_state_iter = self.fn_moving_state.iter();
629 let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
630 let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
631 let to_sql_config = &self.to_sql_config;
632
633 quote! {
634 #[no_mangle]
635 #[doc(hidden)]
636 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::pgx_sql_entity_graph::SqlGraphEntity {
637 let submission = ::pgx::pgx_sql_entity_graph::PgAggregateEntity {
638 full_path: ::core::any::type_name::<#target_ident>(),
639 module_path: module_path!(),
640 file: file!(),
641 line: line!(),
642 name: #name,
643 ordered_set: #const_ordered_set,
644 ty_id: ::core::any::TypeId::of::<#target_ident>(),
645 args: #type_args_iter,
646 direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
647 stype: #type_stype,
648 sfunc: stringify!(#fn_state),
649 combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
650 finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
651 finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
652 initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
653 serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
654 deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
655 msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
656 minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
657 mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
658 mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
659 mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
660 minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
661 sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
662 parallel: None #( .unwrap_or(#const_parallel_iter) )*,
663 hypothetical: #hypothetical,
664 to_sql_config: #to_sql_config,
665 };
666 ::pgx::pgx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
667 }
668 }
669 }
670}
671
672impl ToRustCodeTokens for PgAggregate {
673 fn to_rust_code_tokens(&self) -> TokenStream2 {
674 let impl_item = &self.item_impl;
675 let pg_externs = self.pg_externs.iter();
676 quote! {
677 #impl_item
678 #(#pg_externs)*
679 }
680 }
681}
682
683impl Parse for CodeEnrichment<PgAggregate> {
684 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
685 PgAggregate::new(input.parse()?)
686 }
687}
688
689fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
690 let last = path.segments.last().ok_or_else(|| {
691 syn::Error::new(
692 path.span(),
693 "`#[pg_aggregate]` only works with types whose path have a final segment.",
694 )
695 })?;
696 Ok(last.ident.clone())
697}
698
699fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
700 let target_ident = match &*item_impl.self_ty {
701 syn::Type::Path(ref type_path) => {
702 let last_segment = type_path.path.segments.last().ok_or_else(|| {
703 syn::Error::new(
704 type_path.span(),
705 "`#[pg_aggregate]` only works with types whose path have a final segment.",
706 )
707 })?;
708 if last_segment.ident.to_string() == "PgVarlena" {
709 match &last_segment.arguments {
710 syn::PathArguments::AngleBracketed(angled) => {
711 let first = angled.args.first().ok_or_else(|| syn::Error::new(
712 type_path.span(),
713 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
714 ))?;
715 match &first {
716 syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
717 _ => return Err(syn::Error::new(
718 type_path.span(),
719 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
720 )),
721 }
722 },
723 _ => return Err(syn::Error::new(
724 type_path.span(),
725 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
726 )),
727 }
728 } else {
729 type_path.path.clone()
730 }
731 }
732 something_else => {
733 return Err(syn::Error::new(
734 something_else.span(),
735 "`#[pg_aggregate]` only works with types.",
736 ))
737 }
738 };
739 Ok(target_ident)
740}
741
742fn pg_extern_attr(item: &ImplItemMethod) -> syn::Attribute {
743 let mut found = None;
744 for attr in item.attrs.iter() {
745 match attr.path.segments.last() {
746 Some(segment) if segment.ident.to_string() == "pgx" => {
747 found = Some(attr.tokens.clone());
748 break;
749 }
750 _ => (),
751 };
752 }
753 match found {
754 Some(args) => parse_quote! {
755 #[::pgx::pg_extern #args]
756 },
757 None => parse_quote! {
758 #[::pgx::pg_extern]
759 },
760 }
761}
762
763fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
764 let mut needle = None;
765 for impl_item in item_impl.items.iter() {
766 match impl_item {
767 syn::ImplItem::Type(impl_item_type) => {
768 let ident_string = impl_item_type.ident.to_string();
769 if ident_string == name {
770 needle = Some(impl_item_type);
771 }
772 }
773 _ => (),
774 }
775 }
776 needle
777}
778
779fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemMethod> {
780 let mut needle = None;
781 for impl_item in item_impl.items.iter() {
782 match impl_item {
783 syn::ImplItem::Method(impl_item_method) => {
784 let ident_string = impl_item_method.sig.ident.to_string();
785 if ident_string == name {
786 needle = Some(impl_item_method);
787 }
788 }
789 _ => (),
790 }
791 }
792 needle
793}
794
795fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
796 let mut needle = None;
797 for impl_item in item_impl.items.iter() {
798 match impl_item {
799 syn::ImplItem::Const(impl_item_const) => {
800 let ident_string = impl_item_const.ident.to_string();
801 if ident_string == name {
802 needle = Some(impl_item_const);
803 }
804 }
805 _ => (),
806 }
807 }
808 needle
809}
810
811fn get_const_litbool<'a>(item: &'a ImplItemConst) -> Option<bool> {
812 match &item.expr {
813 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
814 syn::Lit::Bool(lit) => Some(lit.value()),
815 _ => None,
816 },
817 _ => None,
818 }
819}
820
821fn get_const_litstr<'a>(item: &'a ImplItemConst) -> syn::Result<Option<String>> {
822 match &item.expr {
823 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
824 syn::Lit::Str(lit) => Ok(Some(lit.value())),
825 _ => Ok(None),
826 },
827 syn::Expr::Call(expr_call) => match &*expr_call.func {
828 syn::Expr::Path(expr_path) => {
829 let Some(last) = expr_path.path.segments.last() else {
830 return Ok(None);
831 };
832 if last.ident.to_string() == "Some" {
833 match expr_call.args.first() {
834 Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
835 syn::Lit::Str(lit) => Ok(Some(lit.value())),
836 _ => Ok(None),
837 },
838 _ => Ok(None),
839 }
840 } else {
841 Ok(None)
842 }
843 }
844 _ => Ok(None),
845 },
846 ex => Err(syn::Error::new(ex.span(), "")),
847 }
848}
849
850fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
851 match ty {
852 Type::Path(ref mut ty_path) => {
853 for segment in ty_path.path.segments.iter_mut() {
854 if segment.ident.to_string() == "Self" {
855 segment.ident = target.clone()
856 }
857 use syn::{GenericArgument, PathArguments};
858 match segment.arguments {
859 PathArguments::AngleBracketed(ref mut angle_args) => {
860 for arg in angle_args.args.iter_mut() {
861 match arg {
862 GenericArgument::Type(inner_ty) => {
863 remap_self_to_target(inner_ty, target)
864 }
865 _ => (),
866 }
867 }
868 }
869 PathArguments::Parenthesized(_) => (),
870 PathArguments::None => (),
871 }
872 }
873 }
874 _ => (),
875 }
876}
877
878fn get_pgx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
879 match &ty {
880 syn::Type::Macro(ty_macro) => {
881 let mut found_pgx = false;
882 let mut found_attr = false;
883 for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
885 match segment.ident.to_string().as_str() {
886 "pgx" if idx == 0 => found_pgx = true,
887 attr if attr == attr_name.as_ref() => found_attr = true,
888 _ => (),
889 }
890 }
891 if (ty_macro.mac.path.segments.len() == 1 && found_attr) || (found_pgx && found_attr) {
892 Some(ty_macro.mac.tokens.clone())
893 } else {
894 None
895 }
896 }
897 _ => None,
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::PgAggregate;
904 use eyre::Result;
905 use quote::ToTokens;
906 use syn::{parse_quote, ItemImpl};
907
908 #[test]
909 fn agg_required_only() -> Result<()> {
910 let tokens: ItemImpl = parse_quote! {
911 #[pg_aggregate]
912 impl Aggregate for DemoAgg {
913 type State = PgVarlena<Self>;
914 type Args = i32;
915 const NAME: &'static str = "DEMO";
916
917 fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
918 todo!()
919 }
920 }
921 };
922 let agg = PgAggregate::new(tokens);
924 assert!(agg.is_ok());
925 let agg = agg.unwrap();
927 assert_eq!(agg.0.pg_externs.len(), 1);
928 let extern_fn = &agg.0.pg_externs[0];
930 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
931 let _ = agg.to_token_stream();
933 Ok(())
934 }
935
936 #[test]
937 fn agg_all_options() -> Result<()> {
938 let tokens: ItemImpl = parse_quote! {
939 #[pg_aggregate]
940 impl Aggregate for DemoAgg {
941 type State = PgVarlena<Self>;
942 type Args = i32;
943 type OrderBy = i32;
944 type MovingState = i32;
945
946 const NAME: &'static str = "DEMO";
947
948 const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
949 const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
950 const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
951 const SORT_OPERATOR: Option<&'static str> = Some("sortop");
952 const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
953 const HYPOTHETICAL: bool = true;
954
955 fn state(current: Self::State, v: Self::Args) -> Self::State {
956 todo!()
957 }
958
959 fn finalize(current: Self::State) -> Self::Finalize {
960 todo!()
961 }
962
963 fn combine(current: Self::State, _other: Self::State) -> Self::State {
964 todo!()
965 }
966
967 fn serial(current: Self::State) -> Vec<u8> {
968 todo!()
969 }
970
971 fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
972 todo!()
973 }
974
975 fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
976 todo!()
977 }
978
979 fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
980 todo!()
981 }
982
983 fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
984 todo!()
985 }
986 }
987 };
988 let agg = PgAggregate::new(tokens);
990 assert!(agg.is_ok());
991 let agg = agg.unwrap();
993 assert_eq!(agg.0.pg_externs.len(), 8);
994 let extern_fn = &agg.0.pg_externs[0];
996 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
997 let _ = agg.to_token_stream();
999 Ok(())
1000 }
1001
1002 #[test]
1003 fn agg_missing_required() -> Result<()> {
1004 let tokens: ItemImpl = parse_quote! {
1006 #[pg_aggregate]
1007 impl Aggregate for IntegerAvgState {
1008 }
1009 };
1010 let agg = PgAggregate::new(tokens);
1011 assert!(agg.is_err());
1012 Ok(())
1013 }
1014}