1use crate::aggregate::options::{FinalizeModify, ParallelOption};
19use crate::metadata::SqlMapping;
20use crate::pgx_sql::PgxSql;
21use crate::to_sql::entity::ToSqlConfigEntity;
22use crate::to_sql::ToSql;
23use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
24use core::any::TypeId;
25use eyre::{eyre, WrapErr};
26
27#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
28pub struct AggregateTypeEntity {
29 pub used_ty: UsedTypeEntity,
30 pub name: Option<&'static str>,
31}
32
33#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
34pub struct PgAggregateEntity {
35 pub full_path: &'static str,
36 pub module_path: &'static str,
37 pub file: &'static str,
38 pub line: u32,
39 pub ty_id: TypeId,
40
41 pub name: &'static str,
42
43 pub ordered_set: bool,
47
48 pub args: Vec<AggregateTypeEntity>,
52
53 pub direct_args: Option<Vec<AggregateTypeEntity>>,
57
58 pub stype: AggregateTypeEntity,
62
63 pub sfunc: &'static str,
67
68 pub finalfunc: Option<&'static str>,
72
73 pub finalfunc_modify: Option<FinalizeModify>,
77
78 pub combinefunc: Option<&'static str>,
82
83 pub serialfunc: Option<&'static str>,
87
88 pub deserialfunc: Option<&'static str>,
92
93 pub initcond: Option<&'static str>,
97
98 pub msfunc: Option<&'static str>,
102
103 pub minvfunc: Option<&'static str>,
107
108 pub mstype: Option<UsedTypeEntity>,
112
113 pub mfinalfunc: Option<&'static str>,
121
122 pub mfinalfunc_modify: Option<FinalizeModify>,
126
127 pub minitcond: Option<&'static str>,
131
132 pub sortop: Option<&'static str>,
136
137 pub parallel: Option<ParallelOption>,
141
142 pub hypothetical: bool,
146 pub to_sql_config: ToSqlConfigEntity,
147}
148
149impl From<PgAggregateEntity> for SqlGraphEntity {
150 fn from(val: PgAggregateEntity) -> Self {
151 SqlGraphEntity::Aggregate(val)
152 }
153}
154
155impl SqlGraphIdentifier for PgAggregateEntity {
156 fn dot_identifier(&self) -> String {
157 format!("aggregate {}", self.full_path)
158 }
159 fn rust_identifier(&self) -> String {
160 self.full_path.to_string()
161 }
162 fn file(&self) -> Option<&'static str> {
163 Some(self.file)
164 }
165 fn line(&self) -> Option<u32> {
166 Some(self.line)
167 }
168}
169
170impl ToSql for PgAggregateEntity {
171 fn to_sql(&self, context: &PgxSql) -> eyre::Result<String> {
172 let self_index = context.aggregates[self];
173 let mut optional_attributes = Vec::new();
174 let schema = context.schema_prefix_for(&self_index);
175
176 if let Some(value) = self.finalfunc {
177 optional_attributes.push((
178 format!("\tFINALFUNC = {}\"{}\"", schema, value),
179 format!("/* {}::final */", self.full_path),
180 ));
181 }
182 if let Some(value) = self.finalfunc_modify {
183 optional_attributes.push((
184 format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
185 format!("/* {}::FINALIZE_MODIFY */", self.full_path),
186 ));
187 }
188 if let Some(value) = self.combinefunc {
189 optional_attributes.push((
190 format!("\tCOMBINEFUNC = {}\"{}\"", schema, value),
191 format!("/* {}::combine */", self.full_path),
192 ));
193 }
194 if let Some(value) = self.serialfunc {
195 optional_attributes.push((
196 format!("\tSERIALFUNC = {}\"{}\"", schema, value),
197 format!("/* {}::serial */", self.full_path),
198 ));
199 }
200 if let Some(value) = self.deserialfunc {
201 optional_attributes.push((
202 format!("\tDESERIALFUNC ={} \"{}\"", schema, value),
203 format!("/* {}::deserial */", self.full_path),
204 ));
205 }
206 if let Some(value) = self.initcond {
207 optional_attributes.push((
208 format!("\tINITCOND = '{}'", value),
209 format!("/* {}::INITIAL_CONDITION */", self.full_path),
210 ));
211 }
212 if let Some(value) = self.msfunc {
213 optional_attributes.push((
214 format!("\tMSFUNC = {}\"{}\"", schema, value),
215 format!("/* {}::moving_state */", self.full_path),
216 ));
217 }
218 if let Some(value) = self.minvfunc {
219 optional_attributes.push((
220 format!("\tMINVFUNC = {}\"{}\"", schema, value),
221 format!("/* {}::moving_state_inverse */", self.full_path),
222 ));
223 }
224 if let Some(value) = self.mfinalfunc {
225 optional_attributes.push((
226 format!("\tMFINALFUNC = {}\"{}\"", schema, value),
227 format!("/* {}::moving_state_finalize */", self.full_path),
228 ));
229 }
230 if let Some(value) = self.mfinalfunc_modify {
231 optional_attributes.push((
232 format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
233 format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
234 ));
235 }
236 if let Some(value) = self.minitcond {
237 optional_attributes.push((
238 format!("\tMINITCOND = '{}'", value),
239 format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
240 ));
241 }
242 if let Some(value) = self.sortop {
243 optional_attributes.push((
244 format!("\tSORTOP = \"{}\"", value),
245 format!("/* {}::SORT_OPERATOR */", self.full_path),
246 ));
247 }
248 if let Some(value) = self.parallel {
249 optional_attributes.push((
250 format!("\tPARALLEL = {}", value.to_sql(context)?),
251 format!("/* {}::PARALLEL */", self.full_path),
252 ));
253 }
254 if self.hypothetical {
255 optional_attributes.push((
256 String::from("\tHYPOTHETICAL"),
257 format!("/* {}::hypothetical */", self.full_path),
258 ))
259 }
260
261 let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
262 match used_ty.metadata.argument_sql {
263 Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
264 Ok(SqlMapping::Composite { array_brackets }) => used_ty
265 .composite_type
266 .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
267 .ok_or_else(|| {
268 eyre!("Macro expansion time suggested a composite_type!() in return")
269 }),
270 Ok(SqlMapping::Source { array_brackets }) => {
271 let sql = context
272 .source_only_to_sql_type(used_ty.ty_source)
273 .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
274 .ok_or_else(|| {
275 eyre!("Macro expansion time suggested a source only mapping in return")
276 })?;
277 Ok(sql)
278 }
279 Ok(SqlMapping::Skip) => {
280 Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
281 }
282 Err(err) => match context.source_only_to_sql_type(used_ty.ty_source) {
283 Some(source_only_mapping) => Ok(source_only_mapping.to_string()),
284 None => return Err(err).wrap_err("While mapping argument"),
285 },
286 }
287 };
288
289 let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
290 let mut stype_schema = String::from("");
291 for (ty_item, ty_index) in context.types.iter() {
292 if ty_item.id_matches(&self.stype.used_ty.ty_id) {
293 stype_schema = context.schema_prefix_for(ty_index);
294 break;
295 }
296 }
297 if String::is_empty(&stype_schema) {
298 for (ty_item, ty_index) in context.enums.iter() {
299 if ty_item.id_matches(&self.stype.used_ty.ty_id) {
300 stype_schema = context.schema_prefix_for(ty_index);
301 break;
302 }
303 }
304 }
305
306 if let Some(value) = &self.mstype {
307 let mstype_sql = map_ty(&value).wrap_err("Mapping moving state type")?;
308 optional_attributes.push((
309 format!("\tMSTYPE = {}", mstype_sql),
310 format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
311 ));
312 }
313
314 let mut optional_attributes_string = String::new();
315 for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
316 let optional_attribute_string = format!(
317 "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
318 optional_attribute = optional_attribute,
319 maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
320 comment = comment,
321 maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
322 );
323 optional_attributes_string += &optional_attribute_string;
324 }
325
326 let sql = format!(
327 "\n\
328 -- {file}:{line}\n\
329 -- {full_path}\n\
330 CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
331 (\n\
332 \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
333 \tSTYPE = {stype_schema}{stype}{maybe_comma_after_stype} /* {stype_full_path} */\
334 {optional_attributes}\
335 );\
336 ",
337 schema = schema,
338 name = self.name,
339 full_path = self.full_path,
340 file = self.file,
341 line = self.line,
342 sfunc = self.sfunc,
343 stype_schema = stype_schema,
344 stype = stype_sql,
345 stype_full_path = self.stype.used_ty.full_path,
346 maybe_comma_after_stype = if optional_attributes.len() == 0 { "" } else { "," },
347 args = {
348 let mut args = Vec::new();
349 for (idx, arg) in self.args.iter().enumerate() {
350 let graph_index = context
351 .graph
352 .neighbors_undirected(self_index)
353 .find(|neighbor| match &context.graph[*neighbor] {
354 SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
355 SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
356 SqlGraphEntity::BuiltinType(defined) => {
357 defined == &arg.used_ty.full_path
358 }
359 _ => false,
360 })
361 .ok_or_else(|| {
362 eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
363 })?;
364 let needs_comma = idx < (self.args.len() - 1);
365 let buf = format!("\
366 \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
367 ",
368 schema_prefix = context.schema_prefix_for(&graph_index),
369 sql_type = match arg.used_ty.metadata.argument_sql {
371 Ok(SqlMapping::As(ref argument_sql)) => {
372 argument_sql.to_string()
373 }
374 Ok(SqlMapping::Composite {
375 array_brackets,
376 }) => {
377 arg.used_ty
378 .composite_type
379 .map(|v| {
380 if array_brackets {
381 format!("{v}[]")
382 } else {
383 format!("{v}")
384 }
385 })
386 .ok_or_else(|| {
387 eyre!(
388 "Macro expansion time suggested a composite_type!() in return"
389 )
390 })?
391 }
392 Ok(SqlMapping::Source {
393 array_brackets,
394 }) => {
395 let sql = context
396 .source_only_to_sql_type(arg.used_ty.ty_source)
397 .map(|v| {
398 if array_brackets {
399 format!("{v}[]")
400 } else {
401 format!("{v}")
402 }
403 })
404 .ok_or_else(|| {
405 eyre!(
406 "Macro expansion time suggested a source only mapping in return"
407 )
408 })?;
409 sql
410 }
411 Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
412 Err(err) => {
413 match context.source_only_to_sql_type(arg.used_ty.ty_source) {
414 Some(source_only_mapping) => {
415 source_only_mapping.to_string()
416 }
417 None => return Err(err).wrap_err("While mapping argument"),
418 }
419 }
420 },
421 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
422 maybe_comma = if needs_comma { ", " } else { " " },
423 full_path = arg.used_ty.full_path,
424 name = if let Some(name) = arg.name {
425 format!(r#""{}" "#, name)
426 } else { "".to_string() },
427 );
428 args.push(buf);
429 }
430 "\n".to_string() + &args.join("\n") + "\n"
431 },
432 direct_args = if let Some(direct_args) = &self.direct_args {
433 let mut args = Vec::new();
434 for (idx, arg) in direct_args.iter().enumerate() {
435 let graph_index = context
436 .graph
437 .neighbors_undirected(self_index)
438 .find(|neighbor| match &context.graph[*neighbor] {
439 SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
440 SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
441 SqlGraphEntity::BuiltinType(defined) => {
442 defined == &arg.used_ty.full_path
443 }
444 _ => false,
445 })
446 .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
447 let needs_comma = idx < (direct_args.len() - 1);
448 let buf = format!(
449 "\
450 \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
451 ",
452 schema_prefix = context.schema_prefix_for(&graph_index),
453 sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
455 maybe_name = if let Some(name) = arg.name {
456 "\"".to_string() + name + "\" "
457 } else {
458 "".to_string()
459 },
460 maybe_comma = if needs_comma { ", " } else { " " },
461 full_path = arg.used_ty.full_path,
462 );
463 args.push(buf);
464 }
465 "\n".to_string() + &args.join("\n,") + "\n"
466 } else {
467 String::default()
468 },
469 maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
470 optional_attributes = String::from("\n")
471 + &optional_attributes_string
472 + if optional_attributes.len() == 0 { "" } else { "\n" },
473 );
474 Ok(sql)
475 }
476}