1use crate::object_id::ObjectId;
2use crate::postgres_client_wrapper::FromPgChar;
3use crate::quoting::AttemptedKeywordUsage::TypeOrFunctionName;
4use crate::quoting::{quote_value_string, IdentifierQuoter, Quotable};
5use crate::whitespace_ignorant_string::WhitespaceIgnorantString;
6use crate::{ElefantToolsError, PostgresSchema};
7use ordered_float::NotNan;
8use serde::{Deserialize, Serialize};
9use std::fmt::Display;
10
11#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
12pub enum FunctionKind {
13 #[default]
14 Function,
15 Procedure,
16 Aggregate,
17 Window,
18}
19
20impl FromPgChar for FunctionKind {
21 fn from_pg_char(c: char) -> Result<Self, crate::ElefantToolsError> {
22 match c {
23 'f' => Ok(FunctionKind::Function),
24 'p' => Ok(FunctionKind::Procedure),
25 'a' => Ok(FunctionKind::Aggregate),
26 'w' => Ok(FunctionKind::Window),
27 _ => Err(ElefantToolsError::UnknownFunctionKind(c.to_string())),
28 }
29 }
30}
31
32#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
33pub enum Volatility {
34 Immutable,
35 Stable,
36 #[default]
37 Volatile,
38}
39
40impl FromPgChar for Volatility {
41 fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
42 match c {
43 'i' => Ok(Volatility::Immutable),
44 's' => Ok(Volatility::Stable),
45 'v' => Ok(Volatility::Volatile),
46 _ => Err(ElefantToolsError::UnknownVolatility(c.to_string())),
47 }
48 }
49}
50
51#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
52pub enum Parallel {
53 Safe,
54 Restricted,
55 #[default]
56 Unsafe,
57}
58
59impl FromPgChar for Parallel {
60 fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
61 match c {
62 's' => Ok(Parallel::Safe),
63 'r' => Ok(Parallel::Restricted),
64 'u' => Ok(Parallel::Unsafe),
65 _ => Err(ElefantToolsError::UnknownParallel(c.to_string())),
66 }
67 }
68}
69
70#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
71pub struct PostgresFunction {
72 pub function_name: String,
73 pub language: String,
74 pub estimated_cost: NotNan<f32>,
75 pub estimated_rows: NotNan<f32>,
76 pub support_function: Option<String>,
77 pub kind: FunctionKind,
78 pub security_definer: bool,
79 pub leak_proof: bool,
80 pub strict: bool,
81 pub returns_set: bool,
82 pub volatility: Volatility,
83 pub parallel: Parallel,
84 pub sql_body: WhitespaceIgnorantString,
85 pub configuration: Option<Vec<String>>,
86 pub arguments: String,
87 pub result: Option<String>,
88 pub comment: Option<String>,
89 pub object_id: ObjectId,
90 pub depends_on: Vec<ObjectId>,
91}
92
93impl PostgresFunction {
94 pub fn get_create_statement(
95 &self,
96 schema: &PostgresSchema,
97 identifier_quoter: &IdentifierQuoter,
98 ) -> String {
99 let fn_name = format!(
100 "{}.{}",
101 schema.name.quote(identifier_quoter, TypeOrFunctionName),
102 &self
103 .function_name
104 .quote(identifier_quoter, TypeOrFunctionName)
105 );
106
107 let function_keyword = if self.kind == FunctionKind::Procedure {
108 "procedure"
109 } else {
110 "function"
111 };
112
113 let mut sql = format!(
114 "create {} {} ({})",
115 function_keyword, fn_name, self.arguments
116 );
117
118 if let Some(result) = &self.result {
119 sql.push_str(" returns ");
120
121 sql.push_str(result);
122 }
123
124 sql.push_str(" language ");
125 sql.push_str(&self.language);
126
127 if self.kind == FunctionKind::Window {
128 sql.push_str("window ");
129 }
130
131 if self.kind != FunctionKind::Procedure {
132 match self.volatility {
133 Volatility::Immutable => sql.push_str(" immutable "),
134 Volatility::Stable => sql.push_str(" stable "),
135 Volatility::Volatile => sql.push_str(" volatile "),
136 }
137
138 match self.parallel {
139 Parallel::Safe => sql.push_str(" parallel safe "),
140 Parallel::Restricted => sql.push_str(" parallel restricted "),
141 Parallel::Unsafe => sql.push_str(" parallel unsafe "),
142 }
143
144 if self.leak_proof {
145 sql.push_str(" leakproof ");
146 }
147
148 if self.strict {
149 sql.push_str(" strict ");
150 }
151 }
152
153 if self.security_definer {
154 sql.push_str(" security definer ");
155 }
156
157 if let Some(configuration) = &self.configuration {
158 sql.push_str(" set ");
159 for cfg in configuration {
160 sql.push_str(cfg);
161 }
162 }
163
164 if self.kind != FunctionKind::Procedure {
165 sql.push_str("cost ");
166 sql.push_str(&self.estimated_cost.to_string());
167
168 if self.estimated_rows.into_inner() > 0. {
169 sql.push_str(" rows ");
170 sql.push_str(&self.estimated_rows.to_string());
171 }
172
173 if let Some(support_function_name) = &self.support_function {
174 sql.push_str(" support ");
175 sql.push_str(support_function_name);
176 }
177 }
178
179 sql.push_str(" as $$");
180 sql.push_str(&self.sql_body);
181 sql.push_str("$$;");
182
183 if let Some(comment) = &self.comment {
184 sql.push_str("\ncomment on ");
185 sql.push_str(function_keyword);
186 String::push_str(&mut sql, " ");
187 sql.push_str(&fn_name);
188 sql.push_str(" is ");
189 sql.push_str("e_value_string(comment));
190 sql.push(';');
191 }
192
193 sql
194 }
195}
196
197#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
198pub enum FinalModify {
199 #[default]
200 ReadOnly,
201 Shareable,
202 ReadWrite,
203}
204
205impl FromPgChar for FinalModify {
206 fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
207 match c {
208 'r' => Ok(FinalModify::ReadOnly),
209 's' => Ok(FinalModify::Shareable),
210 'w' => Ok(FinalModify::ReadWrite),
211 _ => Err(ElefantToolsError::UnknownAggregateFinalFunctionModify(
212 c.to_string(),
213 )),
214 }
215 }
216}
217
218impl Display for FinalModify {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 let str = match self {
221 FinalModify::ReadOnly => "read_only",
222 FinalModify::Shareable => "shareable",
223 FinalModify::ReadWrite => "read_write",
224 };
225 write!(f, "{}", str)
226 }
227}
228
229#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
230pub struct PostgresAggregateFunction {
231 pub function_name: String,
232 pub arguments: String,
233 pub state_transition_function: String,
234 pub final_function: Option<String>,
235 pub combine_function: Option<String>,
236 pub serial_function: Option<String>,
237 pub deserial_function: Option<String>,
238 pub moving_state_transition_function: Option<String>,
239 pub inverse_moving_state_transition_function: Option<String>,
240 pub moving_final_function: Option<String>,
241 pub final_extra_data: bool,
242 pub moving_final_extra_data: bool,
243 pub final_modify: FinalModify,
244 pub moving_final_modify: FinalModify,
245 pub sort_operator: Option<String>,
246 pub transition_type: String,
247 pub transition_space: Option<i32>,
248 pub moving_transition_type: Option<String>,
249 pub moving_transition_space: Option<i32>,
250 pub initial_value: Option<String>,
251 pub moving_initial_value: Option<String>,
252 pub parallel: Parallel,
253 pub object_id: ObjectId,
254 pub depends_on: Vec<ObjectId>,
255}
256
257impl PostgresAggregateFunction {
258 pub fn get_create_statement(
259 &self,
260 schema: &PostgresSchema,
261 identifier_quoter: &IdentifierQuoter,
262 ) -> String {
263 let fn_name = format!(
264 "{}.{}",
265 schema.name.quote(identifier_quoter, TypeOrFunctionName),
266 &self
267 .function_name
268 .quote(identifier_quoter, TypeOrFunctionName)
269 );
270
271 let mut sql = format!("create aggregate {} ({}) (\n", fn_name, self.arguments);
272
273 sql.push_str("\tsfunc = ");
274 sql.push_str(&self.state_transition_function);
275 sql.push_str(",\n\tstype=");
276 sql.push_str(&self.transition_type);
277
278 if let Some(transition_space) = &self.transition_space {
279 sql.push_str(",\n\tsspace=");
280 sql.push_str(&transition_space.to_string());
281 }
282
283 if let Some(serial_function) = &self.serial_function {
284 sql.push_str(",\n\tsfunc=");
285 sql.push_str(serial_function);
286 }
287
288 if let Some(deserial_function) = &self.deserial_function {
289 sql.push_str(",\n\tdfunc=");
290 sql.push_str(deserial_function);
291 }
292
293 if let Some(initial_value) = &self.initial_value {
294 sql.push_str(",\n\tinitcond=");
295 sql.push_str(initial_value);
296 }
297
298 if let Some(final_function) = &self.final_function {
299 sql.push_str(",\n\tfinalfunc=");
300 sql.push_str(final_function);
301
302 sql.push_str(",\n\tfinalfunc_modify=");
303 sql.push_str(&self.final_modify.to_string());
304
305 if self.final_extra_data {
306 sql.push_str(",\n\tfinalfunc_extra");
307 }
308 }
309
310 if let Some(moving_state_transition_function) = &self.moving_state_transition_function {
311 sql.push_str(",\n\tmsfunc=");
312 sql.push_str(moving_state_transition_function);
313 }
314
315 if let Some(inverse_moving_state_transition_function) =
316 &self.inverse_moving_state_transition_function
317 {
318 sql.push_str(",\n\tminv_sfunc=");
319 sql.push_str(inverse_moving_state_transition_function);
320 }
321
322 if let Some(moving_final_function) = &self.moving_final_function {
323 sql.push_str(",\n\tmfinalfunc=");
324 sql.push_str(moving_final_function);
325
326 sql.push_str(",\n\tmfinalfunc_modify=");
327 sql.push_str(&self.moving_final_modify.to_string());
328
329 if self.moving_final_extra_data {
330 sql.push_str(",\n\tmfinalfunc_extra");
331 }
332 }
333
334 if let Some(moving_transition_type) = &self.moving_transition_type {
335 sql.push_str(",\n\tmstype=");
336 sql.push_str(moving_transition_type);
337
338 if let Some(moving_transition_space) = &self.moving_transition_space {
339 sql.push_str(",\n\tmsspace=");
340 sql.push_str(&moving_transition_space.to_string());
341 }
342 }
343
344 if let Some(moving_initial_value) = &self.moving_initial_value {
345 sql.push_str(",\n\tminitcond=");
346 sql.push_str(moving_initial_value);
347 }
348
349 if let Some(sort_operator) = &self.sort_operator {
350 sql.push_str(",\n\tsortop=");
351 sql.push_str(sort_operator);
352 }
353
354 match self.parallel {
355 Parallel::Safe => sql.push_str(",\n\tparallel=safe"),
356 Parallel::Restricted => sql.push_str(",\n\tparallel=restricted"),
357 Parallel::Unsafe => sql.push_str(",\n\tparallel=unsafe"),
358 }
359
360 sql.push_str("\n);");
361
362 sql
363 }
364}