1mod aggregate_type;
19pub(crate) mod entity;
20mod options;
21
22pub use aggregate_type::{AggregateType, AggregateTypeList};
23pub use options::{FinalizeModify, ParallelOption};
24
25use convert_case::{Case, Casing};
26use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
27use quote::{quote, ToTokens, TokenStreamExt};
28use syn::parse::{Parse, ParseStream};
29use syn::spanned::Spanned;
30use syn::{
31 parse_quote, Expr, ImplItemConst, ImplItemMethod, ImplItemType, ItemFn, ItemImpl, Path, Type,
32};
33
34use crate::sql_entity_graph::ToSqlConfig;
35
36use super::UsedType;
37
38const ARG_NAMES: [&str; 32] = [
40 "arg_one",
41 "arg_two",
42 "arg_three",
43 "arg_four",
44 "arg_five",
45 "arg_six",
46 "arg_seven",
47 "arg_eight",
48 "arg_nine",
49 "arg_ten",
50 "arg_eleven",
51 "arg_twelve",
52 "arg_thirteen",
53 "arg_fourteen",
54 "arg_fifteen",
55 "arg_sixteen",
56 "arg_seventeen",
57 "arg_eighteen",
58 "arg_nineteen",
59 "arg_twenty",
60 "arg_twenty_one",
61 "arg_twenty_two",
62 "arg_twenty_three",
63 "arg_twenty_four",
64 "arg_twenty_five",
65 "arg_twenty_six",
66 "arg_twenty_seven",
67 "arg_twenty_eight",
68 "arg_twenty_nine",
69 "arg_thirty",
70 "arg_thirty_one",
71 "arg_thirty_two",
72];
73
74#[derive(Debug, Clone)]
77pub struct PgAggregate {
78 item_impl: ItemImpl,
79 name: Expr,
80 pg_externs: Vec<ItemFn>,
81 type_args: AggregateTypeList,
83 type_ordered_set_args: Option<AggregateTypeList>,
84 type_moving_state: Option<UsedType>,
85 type_stype: AggregateType,
86 const_ordered_set: bool,
87 const_parallel: Option<syn::Expr>,
88 const_finalize_modify: Option<syn::Expr>,
89 const_moving_finalize_modify: Option<syn::Expr>,
90 const_initial_condition: Option<String>,
91 const_sort_operator: Option<String>,
92 const_moving_intial_condition: Option<String>,
93 fn_state: Ident,
94 fn_finalize: Option<Ident>,
95 fn_combine: Option<Ident>,
96 fn_serial: Option<Ident>,
97 fn_deserial: Option<Ident>,
98 fn_moving_state: Option<Ident>,
99 fn_moving_state_inverse: Option<Ident>,
100 fn_moving_finalize: Option<Ident>,
101 hypothetical: bool,
102 to_sql_config: ToSqlConfig,
103}
104
105impl PgAggregate {
106 pub fn new(mut item_impl: ItemImpl) -> Result<Self, syn::Error> {
107 let to_sql_config =
108 ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
109 let target_path = get_target_path(&item_impl)?;
110 let target_ident = get_target_ident(&target_path)?;
111
112 let snake_case_target_ident =
113 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
114 crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
115
116 let mut pg_externs = Vec::default();
117 let item_impl_snapshot = item_impl.clone();
120
121 if let Some((_, ref path, _)) = item_impl.trait_ {
122 if let Some(last) = path.segments.last() {
124 if last.ident.to_string() != "Aggregate" {
125 return Err(syn::Error::new(
126 last.ident.span(),
127 "`#[pg_aggregate]` only works with the `Aggregate` trait.",
128 ));
129 }
130 }
131 }
132
133 let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
134 Some(item_const) => match item_const.expr {
135 syn::Expr::Lit(ref expr) => {
136 if let syn::Lit::Str(_) = expr.lit {
137 item_const.expr.clone()
138 } else {
139 panic!("`NAME` must be a `&'static str` for Aggregate implementations.")
140 }
141 }
142 _ => panic!("`NAME` must be a `&'static str` for Aggregate implementations."),
143 },
144 None => {
145 item_impl.items.push(parse_quote! {
146 const NAME: &'static str = stringify!(Self);
147 });
148 parse_quote! {
149 stringify!(#target_ident)
150 }
151 }
152 };
153
154 let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
156 let _type_state_value = type_state.map(|v| v.ty.clone());
157
158 let type_state_without_self = if let Some(inner) = type_state {
159 let mut remapped = inner.ty.clone();
160 remap_self_to_target(&mut remapped, &target_ident);
161 remapped
162 } else {
163 item_impl.items.push(parse_quote! {
164 type State = Self;
165 });
166 let mut remapped = parse_quote!(Self);
167 remap_self_to_target(&mut remapped, &target_ident);
168 remapped
169 };
170 let type_stype = AggregateType {
171 used_ty: UsedType::new(type_state_without_self.clone())?,
172 name: Some("state".into()),
173 };
174
175 let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
177 let type_moving_state;
178 let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
179 type_moving_state = impl_type_moving_state.ty.clone();
180 Some(UsedType::new(type_moving_state.clone())?)
181 } else {
182 item_impl.items.push(parse_quote! {
183 type MovingState = ();
184 });
185 type_moving_state = parse_quote! { () };
186 None
187 };
188
189 let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
191 let type_ordered_set_args_value =
192 type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
193 if type_ordered_set_args.is_none() {
194 item_impl.items.push(parse_quote! {
195 type OrderedSetArgs = ();
196 })
197 }
198 let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
199 type_ordered_set_args_value
200 {
201 let direct_args = order_by_direct_args
202 .found
203 .iter()
204 .map(|x| {
205 (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
206 })
207 .collect::<Vec<_>>();
208 let direct_arg_names = ARG_NAMES[0..direct_args.len()]
209 .iter()
210 .zip(direct_args.iter())
211 .map(|(default_name, (custom_name, _ty, _orig))| {
212 Ident::new(
213 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
214 Span::mixed_site(),
215 )
216 })
217 .collect::<Vec<_>>();
218 let direct_args_with_names = direct_args
219 .iter()
220 .zip(direct_arg_names.iter())
221 .map(|(arg, name)| {
222 let arg_ty = &arg.2; parse_quote! {
224 #name: #arg_ty
225 }
226 })
227 .collect::<Vec<syn::FnArg>>();
228 (direct_args_with_names, direct_arg_names)
229 } else {
230 (Vec::default(), Vec::default())
231 };
232
233 let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
235 syn::Error::new(
236 item_impl_snapshot.span(),
237 "`#[pg_aggregate]` requires the `Args` type defined.",
238 )
239 })?;
240 let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
241 let args = type_args_value
242 .found
243 .iter()
244 .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
245 .collect::<Vec<_>>();
246 let arg_names = ARG_NAMES[0..args.len()]
247 .iter()
248 .zip(args.iter())
249 .map(|(default_name, (custom_name, ty))| {
250 Ident::new(
251 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
252 ty.span(),
253 )
254 })
255 .collect::<Vec<_>>();
256 let args_with_names = args
257 .iter()
258 .zip(arg_names.iter())
259 .map(|(arg, name)| {
260 let arg_ty = &arg.1;
261 quote! {
262 #name: #arg_ty
263 }
264 })
265 .collect::<Vec<_>>();
266
267 let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
269 let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
270 type_finalize.ty.clone()
271 } else {
272 item_impl.items.push(parse_quote! {
273 type Finalize = ();
274 });
275 parse_quote! { () }
276 };
277
278 let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
279
280 let fn_state_name = if let Some(found) = fn_state {
281 let fn_name =
282 Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
283 let pg_extern_attr = pg_extern_attr(found);
284
285 pg_externs.push(parse_quote! {
286 #[allow(non_snake_case, clippy::too_many_arguments)]
287 #pg_extern_attr
288 fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
289 <#target_path as pgx::Aggregate>::in_memory_context(
290 fcinfo,
291 move |_context| <#target_path as pgx::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
292 )
293 }
294 });
295 fn_name
296 } else {
297 return Err(syn::Error::new(
298 item_impl.span(),
299 "Aggregate implementation must include state function.",
300 ));
301 };
302
303 let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
304 let fn_combine_name = if let Some(found) = fn_combine {
305 let fn_name =
306 Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
307 let pg_extern_attr = pg_extern_attr(found);
308 pg_externs.push(parse_quote! {
309 #[allow(non_snake_case, clippy::too_many_arguments)]
310 #pg_extern_attr
311 fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
312 <#target_path as pgx::Aggregate>::in_memory_context(
313 fcinfo,
314 move |_context| <#target_path as pgx::Aggregate>::combine(this, v, fcinfo)
315 )
316 }
317 });
318 Some(fn_name)
319 } else {
320 item_impl.items.push(parse_quote! {
321 fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
322 unimplemented!("Call to combine on an aggregate which does not support it.")
323 }
324 });
325 None
326 };
327
328 let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
329 let fn_finalize_name = if let Some(found) = fn_finalize {
330 let fn_name = Ident::new(
331 &format!("{}_finalize", snake_case_target_ident),
332 found.sig.ident.span(),
333 );
334 let pg_extern_attr = pg_extern_attr(found);
335
336 if direct_args_with_names.len() > 0 {
337 pg_externs.push(parse_quote! {
338 #[allow(non_snake_case, clippy::too_many_arguments)]
339 #pg_extern_attr
340 fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
341 <#target_path as pgx::Aggregate>::in_memory_context(
342 fcinfo,
343 move |_context| <#target_path as pgx::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
344 )
345 }
346 });
347 } else {
348 pg_externs.push(parse_quote! {
349 #[allow(non_snake_case, clippy::too_many_arguments)]
350 #pg_extern_attr
351 fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
352 <#target_path as pgx::Aggregate>::in_memory_context(
353 fcinfo,
354 move |_context| <#target_path as pgx::Aggregate>::finalize(this, (), fcinfo)
355 )
356 }
357 });
358 };
359 Some(fn_name)
360 } else {
361 item_impl.items.push(parse_quote! {
362 fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
363 unimplemented!("Call to finalize on an aggregate which does not support it.")
364 }
365 });
366 None
367 };
368
369 let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
370 let fn_serial_name = if let Some(found) = fn_serial {
371 let fn_name =
372 Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
373 let pg_extern_attr = pg_extern_attr(found);
374 pg_externs.push(parse_quote! {
375 #[allow(non_snake_case, clippy::too_many_arguments)]
376 #pg_extern_attr
377 fn #fn_name(this: #type_state_without_self, fcinfo: pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
378 <#target_path as pgx::Aggregate>::in_memory_context(
379 fcinfo,
380 move |_context| <#target_path as pgx::Aggregate>::serial(this, fcinfo)
381 )
382 }
383 });
384 Some(fn_name)
385 } else {
386 item_impl.items.push(parse_quote! {
387 fn serial(current: #type_state_without_self, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Vec<u8> {
388 unimplemented!("Call to serial on an aggregate which does not support it.")
389 }
390 });
391 None
392 };
393
394 let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
395 let fn_deserial_name = if let Some(found) = fn_deserial {
396 let fn_name = Ident::new(
397 &format!("{}_deserial", snake_case_target_ident),
398 found.sig.ident.span(),
399 );
400 let pg_extern_attr = pg_extern_attr(found);
401 pg_externs.push(parse_quote! {
402 #[allow(non_snake_case, clippy::too_many_arguments)]
403 #pg_extern_attr
404 fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: pgx::PgBox<#type_state_without_self>, fcinfo: pgx::pg_sys::FunctionCallInfo) -> pgx::PgBox<#type_state_without_self> {
405 <#target_path as pgx::Aggregate>::in_memory_context(
406 fcinfo,
407 move |_context| <#target_path as pgx::Aggregate>::deserial(this, buf, internal, fcinfo)
408 )
409 }
410 });
411 Some(fn_name)
412 } else {
413 item_impl.items.push(parse_quote! {
414 fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: pgx::PgBox<#type_state_without_self>, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> pgx::PgBox<#type_state_without_self> {
415 unimplemented!("Call to deserial on an aggregate which does not support it.")
416 }
417 });
418 None
419 };
420
421 let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
422 let fn_moving_state_name = if let Some(found) = fn_moving_state {
423 let fn_name = Ident::new(
424 &format!("{}_moving_state", snake_case_target_ident),
425 found.sig.ident.span(),
426 );
427 let pg_extern_attr = pg_extern_attr(found);
428
429 pg_externs.push(parse_quote! {
430 #[allow(non_snake_case, clippy::too_many_arguments)]
431 #pg_extern_attr
432 fn #fn_name(
433 mstate: #type_moving_state,
434 #(#args_with_names),*,
435 fcinfo: pgx::pg_sys::FunctionCallInfo,
436 ) -> #type_moving_state {
437 <#target_path as pgx::Aggregate>::in_memory_context(
438 fcinfo,
439 move |_context| <#target_path as pgx::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
440 )
441 }
442 });
443 Some(fn_name)
444 } else {
445 item_impl.items.push(parse_quote! {
446 fn moving_state(
447 _mstate: <#target_path as pgx::Aggregate>::MovingState,
448 _v: Self::Args,
449 _fcinfo: pgx::pg_sys::FunctionCallInfo,
450 ) -> <#target_path as pgx::Aggregate>::MovingState {
451 unimplemented!("Call to moving_state on an aggregate which does not support it.")
452 }
453 });
454 None
455 };
456
457 let fn_moving_state_inverse =
458 get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
459 let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
460 let fn_name = Ident::new(
461 &format!("{}_moving_state_inverse", snake_case_target_ident),
462 found.sig.ident.span(),
463 );
464 let pg_extern_attr = pg_extern_attr(found);
465 pg_externs.push(parse_quote! {
466 #[allow(non_snake_case, clippy::too_many_arguments)]
467 #pg_extern_attr
468 fn #fn_name(
469 mstate: #type_moving_state,
470 #(#args_with_names),*,
471 fcinfo: pgx::pg_sys::FunctionCallInfo,
472 ) -> #type_moving_state {
473 <#target_path as pgx::Aggregate>::in_memory_context(
474 fcinfo,
475 move |_context| <#target_path as pgx::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
476 )
477 }
478 });
479 Some(fn_name)
480 } else {
481 item_impl.items.push(parse_quote! {
482 fn moving_state_inverse(
483 _mstate: #type_moving_state,
484 _v: Self::Args,
485 _fcinfo: pgx::pg_sys::FunctionCallInfo,
486 ) -> #type_moving_state {
487 unimplemented!("Call to moving_state on an aggregate which does not support it.")
488 }
489 });
490 None
491 };
492
493 let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
494 let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
495 let fn_name = Ident::new(
496 &format!("{}_moving_finalize", snake_case_target_ident),
497 found.sig.ident.span(),
498 );
499 let pg_extern_attr = pg_extern_attr(found);
500 let maybe_comma: Option<syn::Token![,]> =
501 if direct_args_with_names.len() > 0 { Some(parse_quote! {,}) } else { None };
502
503 pg_externs.push(parse_quote! {
504 #[allow(non_snake_case, clippy::too_many_arguments)]
505 #pg_extern_attr
506 fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: pgx::pg_sys::FunctionCallInfo) -> #type_finalize {
507 <#target_path as pgx::Aggregate>::in_memory_context(
508 fcinfo,
509 move |_context| <#target_path as pgx::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
510 )
511 }
512 });
513 Some(fn_name)
514 } else {
515 item_impl.items.push(parse_quote! {
516 fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: pgx::pg_sys::FunctionCallInfo) -> Self::Finalize {
517 unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
518 }
519 });
520 None
521 };
522
523 Ok(Self {
524 item_impl,
525 pg_externs,
526 name,
527 type_args: type_args_value,
528 type_ordered_set_args: type_ordered_set_args_value,
529 type_moving_state: type_moving_state_value,
530 type_stype: type_stype,
531 const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
532 .map(|x| x.expr.clone()),
533 const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
534 .map(|x| x.expr.clone()),
535 const_moving_finalize_modify: get_impl_const_by_name(
536 &item_impl_snapshot,
537 "MOVING_FINALIZE_MODIFY",
538 )
539 .map(|x| x.expr.clone()),
540 const_initial_condition: get_impl_const_by_name(
541 &item_impl_snapshot,
542 "INITIAL_CONDITION",
543 )
544 .and_then(get_const_litstr),
545 const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
546 .and_then(get_const_litbool)
547 .unwrap_or(false),
548 const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
549 .and_then(get_const_litstr),
550 const_moving_intial_condition: get_impl_const_by_name(
551 &item_impl_snapshot,
552 "MOVING_INITIAL_CONDITION",
553 )
554 .and_then(get_const_litstr),
555 fn_state: fn_state_name,
556 fn_finalize: fn_finalize_name,
557 fn_combine: fn_combine_name,
558 fn_serial: fn_serial_name,
559 fn_deserial: fn_deserial_name,
560 fn_moving_state: fn_moving_state_name,
561 fn_moving_state_inverse: fn_moving_state_inverse_name,
562 fn_moving_finalize: fn_moving_finalize_name,
563 hypothetical: if let Some(value) =
564 get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
565 {
566 match &value.expr {
567 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
568 syn::Lit::Bool(lit) => lit.value,
569 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
570 },
571 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
572 }
573 } else {
574 false
575 },
576 to_sql_config,
577 })
578 }
579
580 fn entity_tokens(&self) -> ItemFn {
581 let target_path = get_target_path(&self.item_impl)
582 .expect("Expected constructed PgAggregate to have target path.");
583 let target_ident = get_target_ident(&target_path)
584 .expect("Expected constructed PgAggregate to have target ident.");
585 let snake_case_target_ident =
586 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
587 let sql_graph_entity_fn_name = syn::Ident::new(
588 &format!("__pgx_internals_aggregate_{}", snake_case_target_ident),
589 target_ident.span(),
590 );
591
592 let name = &self.name;
593 let type_args_iter = &self.type_args.entity_tokens();
594 let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
595
596 let type_moving_state_entity_tokens =
597 self.type_moving_state.clone().map(|v| v.entity_tokens());
598 let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
599 let type_stype = self.type_stype.entity_tokens();
600 let const_ordered_set = self.const_ordered_set;
601 let const_parallel_iter = self.const_parallel.iter();
602 let const_finalize_modify_iter = self.const_finalize_modify.iter();
603 let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
604 let const_initial_condition_iter = self.const_initial_condition.iter();
605 let const_sort_operator_iter = self.const_sort_operator.iter();
606 let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
607 let hypothetical = self.hypothetical;
608 let fn_state = &self.fn_state;
609 let fn_finalize_iter = self.fn_finalize.iter();
610 let fn_combine_iter = self.fn_combine.iter();
611 let fn_serial_iter = self.fn_serial.iter();
612 let fn_deserial_iter = self.fn_deserial.iter();
613 let fn_moving_state_iter = self.fn_moving_state.iter();
614 let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
615 let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
616 let to_sql_config = &self.to_sql_config;
617
618 let entity_item_fn: ItemFn = parse_quote! {
619 #[no_mangle]
620 #[doc(hidden)]
621 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::utils::sql_entity_graph::SqlGraphEntity {
622 let submission = ::pgx::utils::sql_entity_graph::PgAggregateEntity {
623 full_path: ::core::any::type_name::<#target_ident>(),
624 module_path: module_path!(),
625 file: file!(),
626 line: line!(),
627 name: #name,
628 ordered_set: #const_ordered_set,
629 ty_id: ::core::any::TypeId::of::<#target_ident>(),
630 args: #type_args_iter,
631 direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
632 stype: #type_stype,
633 sfunc: stringify!(#fn_state),
634 combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
635 finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
636 finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
637 initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
638 serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
639 deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
640 msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
641 minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
642 mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
643 mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
644 mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
645 minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
646 sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
647 parallel: None #( .unwrap_or(#const_parallel_iter) )*,
648 hypothetical: #hypothetical,
649 to_sql_config: #to_sql_config,
650 };
651 ::pgx::utils::sql_entity_graph::SqlGraphEntity::Aggregate(submission)
652 }
653 };
654 entity_item_fn
655 }
656}
657
658impl Parse for PgAggregate {
659 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
660 Self::new(input.parse()?)
661 }
662}
663
664impl ToTokens for PgAggregate {
665 fn to_tokens(&self, tokens: &mut TokenStream2) {
666 let entity_fn = self.entity_tokens();
667 let impl_item = &self.item_impl;
668 let pg_externs = self.pg_externs.iter();
669 let inv = quote! {
670 #impl_item
671
672 #(#pg_externs)*
673
674 #entity_fn
675 };
676 tokens.append_all(inv);
677 }
678}
679
680fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
681 let last = path.segments.last().ok_or_else(|| {
682 syn::Error::new(
683 path.span(),
684 "`#[pg_aggregate]` only works with types whose path have a final segment.",
685 )
686 })?;
687 Ok(last.ident.clone())
688}
689
690fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
691 let target_ident = match &*item_impl.self_ty {
692 syn::Type::Path(ref type_path) => {
693 let last_segment = type_path.path.segments.last().ok_or_else(|| {
694 syn::Error::new(
695 type_path.span(),
696 "`#[pg_aggregate]` only works with types whose path have a final segment.",
697 )
698 })?;
699 if last_segment.ident.to_string() == "PgVarlena" {
700 match &last_segment.arguments {
701 syn::PathArguments::AngleBracketed(angled) => {
702 let first = angled.args.first().ok_or_else(|| syn::Error::new(
703 type_path.span(),
704 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
705 ))?;
706 match &first {
707 syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
708 _ => return Err(syn::Error::new(
709 type_path.span(),
710 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
711 )),
712 }
713 },
714 _ => return Err(syn::Error::new(
715 type_path.span(),
716 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
717 )),
718 }
719 } else {
720 type_path.path.clone()
721 }
722 }
723 something_else => {
724 return Err(syn::Error::new(
725 something_else.span(),
726 "`#[pg_aggregate]` only works with types.",
727 ))
728 }
729 };
730 Ok(target_ident)
731}
732
733fn pg_extern_attr(item: &ImplItemMethod) -> syn::Attribute {
734 let mut found = None;
735 for attr in item.attrs.iter() {
736 match attr.path.segments.last() {
737 Some(segment) if segment.ident.to_string() == "pgx" => {
738 found = Some(attr.tokens.clone());
739 break;
740 }
741 _ => (),
742 };
743 }
744 match found {
745 Some(args) => parse_quote! {
746 #[pg_extern #args]
747 },
748 None => parse_quote! {
749 #[pg_extern]
750 },
751 }
752}
753
754fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
755 let mut needle = None;
756 for impl_item in item_impl.items.iter() {
757 match impl_item {
758 syn::ImplItem::Type(impl_item_type) => {
759 let ident_string = impl_item_type.ident.to_string();
760 if ident_string == name {
761 needle = Some(impl_item_type);
762 }
763 }
764 _ => (),
765 }
766 }
767 needle
768}
769
770fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemMethod> {
771 let mut needle = None;
772 for impl_item in item_impl.items.iter() {
773 match impl_item {
774 syn::ImplItem::Method(impl_item_method) => {
775 let ident_string = impl_item_method.sig.ident.to_string();
776 if ident_string == name {
777 needle = Some(impl_item_method);
778 }
779 }
780 _ => (),
781 }
782 }
783 needle
784}
785
786fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
787 let mut needle = None;
788 for impl_item in item_impl.items.iter() {
789 match impl_item {
790 syn::ImplItem::Const(impl_item_const) => {
791 let ident_string = impl_item_const.ident.to_string();
792 if ident_string == name {
793 needle = Some(impl_item_const);
794 }
795 }
796 _ => (),
797 }
798 }
799 needle
800}
801
802fn get_const_litbool<'a>(item: &'a ImplItemConst) -> Option<bool> {
803 match &item.expr {
804 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
805 syn::Lit::Bool(lit) => Some(lit.value()),
806 _ => None,
807 },
808 _ => None,
809 }
810}
811
812fn get_const_litstr<'a>(item: &'a ImplItemConst) -> Option<String> {
813 match &item.expr {
814 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
815 syn::Lit::Str(lit) => Some(lit.value()),
816 _ => None,
817 },
818 syn::Expr::Call(expr_call) => match &*expr_call.func {
819 syn::Expr::Path(expr_path) => {
820 if expr_path.path.segments.last()?.ident.to_string() == "Some" {
821 match expr_call.args.first()? {
822 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
823 syn::Lit::Str(lit) => Some(lit.value()),
824 _ => None,
825 },
826 _ => None,
827 }
828 } else {
829 None
830 }
831 }
832 _ => None,
833 },
834 _ => panic!("Got {:?}", item.expr),
835 }
836}
837
838fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
839 match ty {
840 Type::Path(ref mut ty_path) => {
841 for segment in ty_path.path.segments.iter_mut() {
842 if segment.ident.to_string() == "Self" {
843 segment.ident = target.clone()
844 }
845 use syn::{GenericArgument, PathArguments};
846 match segment.arguments {
847 PathArguments::AngleBracketed(ref mut angle_args) => {
848 for arg in angle_args.args.iter_mut() {
849 match arg {
850 GenericArgument::Type(inner_ty) => {
851 remap_self_to_target(inner_ty, target)
852 }
853 _ => (),
854 }
855 }
856 }
857 PathArguments::Parenthesized(_) => (),
858 PathArguments::None => (),
859 }
860 }
861 }
862 _ => (),
863 }
864}
865
866fn get_pgx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
867 match &ty {
868 syn::Type::Macro(ty_macro) => {
869 let mut found_pgx = false;
870 let mut found_attr = false;
871 for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
873 match segment.ident.to_string().as_str() {
874 "pgx" if idx == 0 => found_pgx = true,
875 attr if attr == attr_name.as_ref() => found_attr = true,
876 _ => (),
877 }
878 }
879 if (ty_macro.mac.path.segments.len() == 1 && found_attr) || (found_pgx && found_attr) {
880 Some(ty_macro.mac.tokens.clone())
881 } else {
882 None
883 }
884 }
885 _ => None,
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use super::PgAggregate;
892 use eyre::Result;
893 use syn::{parse_quote, ItemImpl};
894
895 #[test]
896 fn agg_required_only() -> Result<()> {
897 let tokens: ItemImpl = parse_quote! {
898 #[pg_aggregate]
899 impl Aggregate for DemoAgg {
900 type State = PgVarlena<Self>;
901 type Args = i32;
902 const NAME: &'static str = "DEMO";
903
904 fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
905 todo!()
906 }
907 }
908 };
909 let agg = PgAggregate::new(tokens);
911 assert!(agg.is_ok());
912 let agg = agg.unwrap();
914 assert_eq!(agg.pg_externs.len(), 1);
915 let extern_fn = &agg.pg_externs[0];
917 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
918 let _ = agg.entity_tokens();
920 Ok(())
921 }
922
923 #[test]
924 fn agg_all_options() -> Result<()> {
925 let tokens: ItemImpl = parse_quote! {
926 #[pg_aggregate]
927 impl Aggregate for DemoAgg {
928 type State = PgVarlena<Self>;
929 type Args = i32;
930 type OrderBy = i32;
931 type MovingState = i32;
932
933 const NAME: &'static str = "DEMO";
934
935 const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
936 const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
937 const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
938 const SORT_OPERATOR: Option<&'static str> = Some("sortop");
939 const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
940 const HYPOTHETICAL: bool = true;
941
942 fn state(current: Self::State, v: Self::Args) -> Self::State {
943 todo!()
944 }
945
946 fn finalize(current: Self::State) -> Self::Finalize {
947 todo!()
948 }
949
950 fn combine(current: Self::State, _other: Self::State) -> Self::State {
951 todo!()
952 }
953
954 fn serial(current: Self::State) -> Vec<u8> {
955 todo!()
956 }
957
958 fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
959 todo!()
960 }
961
962 fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
963 todo!()
964 }
965
966 fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
967 todo!()
968 }
969
970 fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
971 todo!()
972 }
973 }
974 };
975 let agg = PgAggregate::new(tokens);
977 assert!(agg.is_ok());
978 let agg = agg.unwrap();
980 assert_eq!(agg.pg_externs.len(), 8);
981 let extern_fn = &agg.pg_externs[0];
983 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
984 let _ = agg.entity_tokens();
986 Ok(())
987 }
988
989 #[test]
990 fn agg_missing_required() -> Result<()> {
991 let tokens: ItemImpl = parse_quote! {
993 #[pg_aggregate]
994 impl Aggregate for IntegerAvgState {
995 }
996 };
997 let agg = PgAggregate::new(tokens);
998 assert!(agg.is_err());
999 Ok(())
1000 }
1001}