1use crate::sql_entity_graph::aggregate::options::{FinalizeModify, ParallelOption};
19use crate::sql_entity_graph::metadata::SqlMapping;
20use crate::sql_entity_graph::pgx_sql::PgxSql;
21use crate::sql_entity_graph::to_sql::entity::ToSqlConfigEntity;
22use crate::sql_entity_graph::to_sql::ToSql;
23use crate::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
24use core::any::TypeId;
25use core::cmp::Ordering;
26use eyre::{eyre, WrapErr};
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29pub struct AggregateTypeEntity {
30 pub used_ty: UsedTypeEntity,
31 pub name: Option<&'static str>,
32}
33
34#[derive(Debug, Clone, Hash, PartialEq, Eq)]
35pub struct PgAggregateEntity {
36 pub full_path: &'static str,
37 pub module_path: &'static str,
38 pub file: &'static str,
39 pub line: u32,
40 pub ty_id: TypeId,
41
42 pub name: &'static str,
43
44 pub ordered_set: bool,
48
49 pub args: Vec<AggregateTypeEntity>,
53
54 pub direct_args: Option<Vec<AggregateTypeEntity>>,
58
59 pub stype: AggregateTypeEntity,
63
64 pub sfunc: &'static str,
68
69 pub finalfunc: Option<&'static str>,
73
74 pub finalfunc_modify: Option<FinalizeModify>,
78
79 pub combinefunc: Option<&'static str>,
83
84 pub serialfunc: Option<&'static str>,
88
89 pub deserialfunc: Option<&'static str>,
93
94 pub initcond: Option<&'static str>,
98
99 pub msfunc: Option<&'static str>,
103
104 pub minvfunc: Option<&'static str>,
108
109 pub mstype: Option<UsedTypeEntity>,
113
114 pub mfinalfunc: Option<&'static str>,
122
123 pub mfinalfunc_modify: Option<FinalizeModify>,
127
128 pub minitcond: Option<&'static str>,
132
133 pub sortop: Option<&'static str>,
137
138 pub parallel: Option<ParallelOption>,
142
143 pub hypothetical: bool,
147 pub to_sql_config: ToSqlConfigEntity,
148}
149
150impl Ord for PgAggregateEntity {
151 fn cmp(&self, other: &Self) -> Ordering {
152 self.file.cmp(other.full_path).then_with(|| self.file.cmp(other.full_path))
153 }
154}
155
156impl PartialOrd for PgAggregateEntity {
157 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
158 Some(self.cmp(other))
159 }
160}
161
162impl From<PgAggregateEntity> for SqlGraphEntity {
163 fn from(val: PgAggregateEntity) -> Self {
164 SqlGraphEntity::Aggregate(val)
165 }
166}
167
168impl SqlGraphIdentifier for PgAggregateEntity {
169 fn dot_identifier(&self) -> String {
170 format!("aggregate {}", self.full_path)
171 }
172 fn rust_identifier(&self) -> String {
173 self.full_path.to_string()
174 }
175 fn file(&self) -> Option<&'static str> {
176 Some(self.file)
177 }
178 fn line(&self) -> Option<u32> {
179 Some(self.line)
180 }
181}
182
183impl ToSql for PgAggregateEntity {
184 #[tracing::instrument(level = "debug", err, skip(self, context), fields(identifier = %self.rust_identifier()))]
185 fn to_sql(&self, context: &PgxSql) -> eyre::Result<String> {
186 let self_index = context.aggregates[self];
187 let mut optional_attributes = Vec::new();
188 let schema = context.schema_prefix_for(&self_index);
189
190 if let Some(value) = self.finalfunc {
191 optional_attributes.push((
192 format!("\tFINALFUNC = {}\"{}\"", schema, value),
193 format!("/* {}::final */", self.full_path),
194 ));
195 }
196 if let Some(value) = self.finalfunc_modify {
197 optional_attributes.push((
198 format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
199 format!("/* {}::FINALIZE_MODIFY */", self.full_path),
200 ));
201 }
202 if let Some(value) = self.combinefunc {
203 optional_attributes.push((
204 format!("\tCOMBINEFUNC = {}\"{}\"", schema, value),
205 format!("/* {}::combine */", self.full_path),
206 ));
207 }
208 if let Some(value) = self.serialfunc {
209 optional_attributes.push((
210 format!("\tSERIALFUNC = {}\"{}\"", schema, value),
211 format!("/* {}::serial */", self.full_path),
212 ));
213 }
214 if let Some(value) = self.deserialfunc {
215 optional_attributes.push((
216 format!("\tDESERIALFUNC ={} \"{}\"", schema, value),
217 format!("/* {}::deserial */", self.full_path),
218 ));
219 }
220 if let Some(value) = self.initcond {
221 optional_attributes.push((
222 format!("\tINITCOND = '{}'", value),
223 format!("/* {}::INITIAL_CONDITION */", self.full_path),
224 ));
225 }
226 if let Some(value) = self.msfunc {
227 optional_attributes.push((
228 format!("\tMSFUNC = {}\"{}\"", schema, value),
229 format!("/* {}::moving_state */", self.full_path),
230 ));
231 }
232 if let Some(value) = self.minvfunc {
233 optional_attributes.push((
234 format!("\tMINVFUNC = {}\"{}\"", schema, value),
235 format!("/* {}::moving_state_inverse */", self.full_path),
236 ));
237 }
238 if let Some(value) = self.mfinalfunc {
239 optional_attributes.push((
240 format!("\tMFINALFUNC = {}\"{}\"", schema, value),
241 format!("/* {}::moving_state_finalize */", self.full_path),
242 ));
243 }
244 if let Some(value) = self.mfinalfunc_modify {
245 optional_attributes.push((
246 format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
247 format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
248 ));
249 }
250 if let Some(value) = self.minitcond {
251 optional_attributes.push((
252 format!("\tMINITCOND = '{}'", value),
253 format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
254 ));
255 }
256 if let Some(value) = self.sortop {
257 optional_attributes.push((
258 format!("\tSORTOP = \"{}\"", value),
259 format!("/* {}::SORT_OPERATOR */", self.full_path),
260 ));
261 }
262 if let Some(value) = self.parallel {
263 optional_attributes.push((
264 format!("\tPARALLEL = {}", value.to_sql(context)?),
265 format!("/* {}::PARALLEL */", self.full_path),
266 ));
267 }
268 if self.hypothetical {
269 optional_attributes.push((
270 String::from("\tHYPOTHETICAL"),
271 format!("/* {}::hypothetical */", self.full_path),
272 ))
273 }
274
275 let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
276 match used_ty.metadata.argument_sql {
277 Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
278 Ok(SqlMapping::Composite { array_brackets }) => used_ty
279 .composite_type
280 .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
281 .ok_or_else(|| {
282 eyre!("Macro expansion time suggested a composite_type!() in return")
283 }),
284 Ok(SqlMapping::Source { array_brackets }) => {
285 let sql = context
286 .source_only_to_sql_type(used_ty.ty_source)
287 .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
288 .ok_or_else(|| {
289 eyre!("Macro expansion time suggested a source only mapping in return")
290 })?;
291 Ok(sql)
292 }
293 Ok(SqlMapping::Skip) => {
294 Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
295 }
296 Err(err) => match context.source_only_to_sql_type(used_ty.ty_source) {
297 Some(source_only_mapping) => Ok(source_only_mapping.to_string()),
298 None => return Err(err).wrap_err("While mapping argument"),
299 },
300 }
301 };
302
303 let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
304
305 if let Some(value) = &self.mstype {
306 let mstype_sql = map_ty(&value).wrap_err("Mapping moving state type")?;
307 optional_attributes.push((
308 format!("\tMSTYPE = {}", mstype_sql),
309 format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
310 ));
311 }
312
313 let mut optional_attributes_string = String::new();
314 for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
315 let optional_attribute_string = format!(
316 "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
317 optional_attribute = optional_attribute,
318 maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
319 comment = comment,
320 maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
321 );
322 optional_attributes_string += &optional_attribute_string;
323 }
324
325 let sql = format!(
326 "\n\
327 -- {file}:{line}\n\
328 -- {full_path}\n\
329 CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
330 (\n\
331 \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
332 \tSTYPE = {schema}{stype}{maybe_comma_after_stype} /* {stype_full_path} */\
333 {optional_attributes}\
334 );\
335 ",
336 schema = schema,
337 name = self.name,
338 full_path = self.full_path,
339 file = self.file,
340 line = self.line,
341 sfunc = self.sfunc,
342 stype = stype_sql,
343 stype_full_path = self.stype.used_ty.full_path,
344 maybe_comma_after_stype = if optional_attributes.len() == 0 { "" } else { "," },
345 args = {
346 let mut args = Vec::new();
347 for (idx, arg) in self.args.iter().enumerate() {
348 let graph_index = context
349 .graph
350 .neighbors_undirected(self_index)
351 .find(|neighbor| match &context.graph[*neighbor] {
352 SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
353 SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
354 SqlGraphEntity::BuiltinType(defined) => {
355 defined == &arg.used_ty.full_path
356 }
357 _ => false,
358 })
359 .ok_or_else(|| {
360 eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
361 })?;
362 let needs_comma = idx < (self.args.len() - 1);
363 let buf = format!("\
364 \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
365 ",
366 schema_prefix = context.schema_prefix_for(&graph_index),
367 sql_type = match arg.used_ty.metadata.argument_sql {
369 Ok(SqlMapping::As(ref argument_sql)) => {
370 argument_sql.to_string()
371 }
372 Ok(SqlMapping::Composite {
373 array_brackets,
374 }) => {
375 arg.used_ty
376 .composite_type
377 .map(|v| {
378 if array_brackets {
379 format!("{v}[]")
380 } else {
381 format!("{v}")
382 }
383 })
384 .ok_or_else(|| {
385 eyre!(
386 "Macro expansion time suggested a composite_type!() in return"
387 )
388 })?
389 }
390 Ok(SqlMapping::Source {
391 array_brackets,
392 }) => {
393 let sql = context
394 .source_only_to_sql_type(arg.used_ty.ty_source)
395 .map(|v| {
396 if array_brackets {
397 format!("{v}[]")
398 } else {
399 format!("{v}")
400 }
401 })
402 .ok_or_else(|| {
403 eyre!(
404 "Macro expansion time suggested a source only mapping in return"
405 )
406 })?;
407 sql
408 }
409 Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
410 Err(err) => {
411 match context.source_only_to_sql_type(arg.used_ty.ty_source) {
412 Some(source_only_mapping) => {
413 source_only_mapping.to_string()
414 }
415 None => return Err(err).wrap_err("While mapping argument"),
416 }
417 }
418 },
419 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
420 maybe_comma = if needs_comma { ", " } else { " " },
421 full_path = arg.used_ty.full_path,
422 name = if let Some(name) = arg.name {
423 format!(r#""{}" "#, name)
424 } else { "".to_string() },
425 );
426 args.push(buf);
427 }
428 "\n".to_string() + &args.join("\n") + "\n"
429 },
430 direct_args = if let Some(direct_args) = &self.direct_args {
431 let mut args = Vec::new();
432 for (idx, arg) in direct_args.iter().enumerate() {
433 let graph_index = context
434 .graph
435 .neighbors_undirected(self_index)
436 .find(|neighbor| match &context.graph[*neighbor] {
437 SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
438 SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
439 SqlGraphEntity::BuiltinType(defined) => {
440 defined == &arg.used_ty.full_path
441 }
442 _ => false,
443 })
444 .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
445 let needs_comma = idx < (direct_args.len() - 1);
446 let buf = format!(
447 "\
448 \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
449 ",
450 schema_prefix = context.schema_prefix_for(&graph_index),
451 sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
453 maybe_name = if let Some(name) = arg.name {
454 "\"".to_string() + name + "\" "
455 } else {
456 "".to_string()
457 },
458 maybe_comma = if needs_comma { ", " } else { " " },
459 full_path = arg.used_ty.full_path,
460 );
461 args.push(buf);
462 }
463 "\n".to_string() + &args.join("\n,") + "\n"
464 } else {
465 String::default()
466 },
467 maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
468 optional_attributes = if optional_attributes.len() == 0 {
469 String::from("\n")
470 } else {
471 String::from("\n")
472 } + &optional_attributes_string
473 + if optional_attributes.len() == 0 { "" } else { "\n" },
474 );
475 tracing::trace!(%sql);
476 Ok(sql)
477 }
478}