1use crate::aggregate::options::{FinalizeModify, ParallelOption};
20use crate::fmt;
21use crate::metadata::{SqlArrayMapping, SqlMapping};
22use crate::pgrx_sql::PgrxSql;
23use crate::to_sql::ToSql;
24use crate::to_sql::entity::ToSqlConfigEntity;
25use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
26use eyre::{WrapErr, eyre};
27use petgraph::graph::NodeIndex;
28
29#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
30pub struct AggregateTypeEntity<'a> {
31 pub used_ty: UsedTypeEntity<'a>,
32 pub name: Option<&'a str>,
33}
34
35#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
36pub struct PgAggregateEntity<'a> {
37 pub full_path: &'a str,
38 pub module_path: &'a str,
39 pub file: &'a str,
40 pub line: u32,
41
42 pub name: &'a str,
43
44 pub ordered_set: bool,
48
49 pub args: Vec<AggregateTypeEntity<'a>>,
53
54 pub direct_args: Option<Vec<AggregateTypeEntity<'a>>>,
58
59 pub stype: AggregateTypeEntity<'a>,
63
64 pub sfunc: &'a str,
68
69 pub finalfunc: Option<&'a str>,
73
74 pub finalfunc_modify: Option<FinalizeModify>,
78
79 pub combinefunc: Option<&'a str>,
83
84 pub serialfunc: Option<&'a str>,
88
89 pub deserialfunc: Option<&'a str>,
93
94 pub initcond: Option<&'a str>,
98
99 pub msfunc: Option<&'a str>,
103
104 pub minvfunc: Option<&'a str>,
108
109 pub mstype: Option<UsedTypeEntity<'a>>,
113
114 pub mfinalfunc: Option<&'a str>,
122
123 pub mfinalfunc_modify: Option<FinalizeModify>,
127
128 pub minitcond: Option<&'a str>,
132
133 pub sortop: Option<&'a str>,
137
138 pub parallel: Option<ParallelOption>,
142
143 pub hypothetical: bool,
147 pub to_sql_config: ToSqlConfigEntity<'a>,
148}
149
150impl<'a> From<PgAggregateEntity<'a>> for SqlGraphEntity<'a> {
151 fn from(val: PgAggregateEntity<'a>) -> Self {
152 SqlGraphEntity::Aggregate(val)
153 }
154}
155
156impl SqlGraphIdentifier for PgAggregateEntity<'_> {
157 fn dot_identifier(&self) -> String {
158 format!("aggregate {}", self.full_path)
159 }
160 fn rust_identifier(&self) -> String {
161 self.full_path.to_string()
162 }
163 fn file(&self) -> Option<&str> {
164 Some(self.file)
165 }
166 fn line(&self) -> Option<u32> {
167 Some(self.line)
168 }
169}
170
171fn aggregate_sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result<String> {
172 match mapping {
173 SqlMapping::As(sql) => Ok(sql.clone()),
174 SqlMapping::Composite => composite_type
175 .map(ToString::to_string)
176 .ok_or_else(|| eyre!("Composite mapping requires composite_type")),
177 SqlMapping::Array(SqlArrayMapping::As(sql)) => Ok(fmt::with_array_brackets(sql.clone(), 1)),
178 SqlMapping::Array(SqlArrayMapping::Composite) => composite_type
179 .map(ToString::to_string)
180 .map(|sql| fmt::with_array_brackets(sql, 1))
181 .ok_or_else(|| eyre!("Composite mapping requires composite_type")),
182 SqlMapping::Skip => {
183 Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
184 }
185 }
186}
187
188pub(crate) fn render_aggregate_argtypes(
194 context: &PgrxSql,
195 owner: NodeIndex,
196 a: &PgAggregateEntity,
197) -> eyre::Result<String> {
198 let render_slot = |arg: &AggregateTypeEntity| -> eyre::Result<String> {
199 let slot = arg.name.unwrap_or("aggregate argument");
200 let prefix = context.schema_prefix_for_used_type(&owner, slot, &arg.used_ty)?;
201 let sql = match arg.used_ty.metadata.argument_sql {
202 Ok(ref mapping) => aggregate_sql_type(mapping, arg.used_ty.composite_type)?,
203 Err(err) => return Err(err.into()),
204 };
205 let variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" };
206 Ok(format!("{variadic}{prefix}{sql}"))
207 };
208
209 let args = a.args.iter().map(render_slot).collect::<eyre::Result<Vec<_>>>()?.join(", ");
210 let direct = a.direct_args.as_deref().unwrap_or(&[]);
211
212 if a.ordered_set {
213 let direct_rendered =
214 direct.iter().map(render_slot).collect::<eyre::Result<Vec<_>>>()?.join(", ");
215 Ok(format!("({direct_rendered} ORDER BY {args})"))
216 } else {
217 Ok(format!("({args})"))
218 }
219}
220
221impl ToSql for PgAggregateEntity<'_> {
222 fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
223 let self_index = context.aggregates[self];
224 let mut optional_attributes = Vec::new();
225 let schema = context.schema_prefix_for(&self_index);
226
227 if let Some(value) = self.finalfunc {
228 optional_attributes.push((
229 format!("\tFINALFUNC = {schema}\"{value}\""),
230 format!("/* {}::final */", self.full_path),
231 ));
232 }
233 if let Some(value) = self.finalfunc_modify {
234 optional_attributes.push((
235 format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
236 format!("/* {}::FINALIZE_MODIFY */", self.full_path),
237 ));
238 }
239 if let Some(value) = self.combinefunc {
240 optional_attributes.push((
241 format!("\tCOMBINEFUNC = {schema}\"{value}\""),
242 format!("/* {}::combine */", self.full_path),
243 ));
244 }
245 if let Some(value) = self.serialfunc {
246 optional_attributes.push((
247 format!("\tSERIALFUNC = {schema}\"{value}\""),
248 format!("/* {}::serial */", self.full_path),
249 ));
250 }
251 if let Some(value) = self.deserialfunc {
252 optional_attributes.push((
253 format!("\tDESERIALFUNC ={schema} \"{value}\""),
254 format!("/* {}::deserial */", self.full_path),
255 ));
256 }
257 if let Some(value) = self.initcond {
258 optional_attributes.push((
259 format!("\tINITCOND = '{value}'"),
260 format!("/* {}::INITIAL_CONDITION */", self.full_path),
261 ));
262 }
263 if let Some(value) = self.msfunc {
264 optional_attributes.push((
265 format!("\tMSFUNC = {schema}\"{value}\""),
266 format!("/* {}::moving_state */", self.full_path),
267 ));
268 }
269 if let Some(value) = self.minvfunc {
270 optional_attributes.push((
271 format!("\tMINVFUNC = {schema}\"{value}\""),
272 format!("/* {}::moving_state_inverse */", self.full_path),
273 ));
274 }
275 if let Some(value) = self.mfinalfunc {
276 optional_attributes.push((
277 format!("\tMFINALFUNC = {schema}\"{value}\""),
278 format!("/* {}::moving_state_finalize */", self.full_path),
279 ));
280 }
281 if let Some(value) = self.mfinalfunc_modify {
282 optional_attributes.push((
283 format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
284 format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
285 ));
286 }
287 if let Some(value) = self.minitcond {
288 optional_attributes.push((
289 format!("\tMINITCOND = '{value}'"),
290 format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
291 ));
292 }
293 if let Some(value) = self.sortop {
294 optional_attributes.push((
295 format!("\tSORTOP = \"{value}\""),
296 format!("/* {}::SORT_OPERATOR */", self.full_path),
297 ));
298 }
299 if let Some(value) = self.parallel {
300 optional_attributes.push((
301 format!("\tPARALLEL = {}", value.to_sql(context)?),
302 format!("/* {}::PARALLEL */", self.full_path),
303 ));
304 }
305 if self.hypothetical {
306 optional_attributes.push((
307 String::from("\tHYPOTHETICAL"),
308 format!("/* {}::hypothetical */", self.full_path),
309 ))
310 }
311
312 let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
313 match used_ty.metadata.argument_sql {
314 Ok(ref mapping) => aggregate_sql_type(mapping, used_ty.composite_type),
315 Err(err) => Err(err).wrap_err("While mapping argument"),
316 }
317 };
318
319 let sql_type_for_slot = |slot: &str,
320 used_ty: &UsedTypeEntity|
321 -> eyre::Result<(String, String)> {
322 let sql = map_ty(used_ty).wrap_err_with(|| format!("Mapping {slot}"))?;
323 let schema_prefix = context.schema_prefix_for_used_type(&self_index, slot, used_ty)?;
324 Ok((schema_prefix, sql))
325 };
326 let (stype_schema, stype_sql) = sql_type_for_slot("STYPE", &self.stype.used_ty)?;
327
328 if let Some(value) = &self.mstype {
329 let (mstype_schema, mstype_sql) = sql_type_for_slot("MSTYPE", value)?;
330 optional_attributes.push((
331 format!("\tMSTYPE = {mstype_schema}{mstype_sql}"),
332 format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
333 ));
334 }
335
336 let mut optional_attributes_string = String::new();
337 for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
338 let optional_attribute_string = format!(
339 "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
340 optional_attribute = optional_attribute,
341 maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
342 comment = comment,
343 maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
344 );
345 optional_attributes_string += &optional_attribute_string;
346 }
347
348 let args = {
349 let mut args = Vec::new();
350 for (idx, arg) in self.args.iter().enumerate() {
351 let needs_comma = idx < (self.args.len() - 1);
352 let schema_prefix = context.schema_prefix_for_used_type(
353 &self_index,
354 arg.name.unwrap_or("aggregate argument"),
355 &arg.used_ty,
356 )?;
357 let buf = format!(
358 "\
359 \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
360 ",
361 schema_prefix = schema_prefix,
362 sql_type = match arg.used_ty.metadata.argument_sql {
364 Ok(ref mapping) => aggregate_sql_type(mapping, arg.used_ty.composite_type)?,
365 Err(err) => return Err(err).wrap_err("While mapping argument"),
366 },
367 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
368 maybe_comma = if needs_comma { ", " } else { " " },
369 full_path = arg.used_ty.full_path,
370 name = if let Some(name) = arg.name {
371 format!(r#""{name}" "#)
372 } else {
373 "".to_string()
374 },
375 );
376 args.push(buf);
377 }
378 "\n".to_string() + &args.join("\n") + "\n"
379 };
380 let direct_args = if let Some(direct_args) = &self.direct_args {
381 let mut args = Vec::new();
382 for (idx, arg) in direct_args.iter().enumerate() {
383 let schema_prefix = context.schema_prefix_for_used_type(
384 &self_index,
385 arg.name.unwrap_or("aggregate direct argument"),
386 &arg.used_ty,
387 )?;
388 let needs_comma = idx < (direct_args.len() - 1);
389 let buf = format!(
390 "\
391 \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
392 ",
393 schema_prefix = schema_prefix,
394 sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
396 maybe_name = if let Some(name) = arg.name {
397 "\"".to_string() + name + "\" "
398 } else {
399 "".to_string()
400 },
401 maybe_comma = if needs_comma { ", " } else { " " },
402 full_path = arg.used_ty.full_path,
403 );
404 args.push(buf);
405 }
406 "\n".to_string() + &args.join("\n") + "\n"
407 } else {
408 String::default()
409 };
410
411 let PgAggregateEntity { name, full_path, file, line, sfunc, .. } = self;
412
413 let sql = format!(
414 "\n\
415 -- {file}:{line}\n\
416 -- {full_path}\n\
417 CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
418 (\n\
419 \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
420 \tSTYPE = {stype_schema}{stype_sql}{maybe_comma_after_stype} /* {stype_full_path} */\
421 {optional_attributes}\
422 );\
423 ",
424 stype_full_path = self.stype.used_ty.full_path,
425 maybe_comma_after_stype = if optional_attributes.is_empty() { "" } else { "," },
426 maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
427 optional_attributes = String::from("\n")
428 + &optional_attributes_string
429 + if optional_attributes.is_empty() { "" } else { "\n" },
430 );
431 Ok(sql)
432 }
433}