elefant_tools/models/
function.rs

1use crate::object_id::ObjectId;
2use crate::postgres_client_wrapper::FromPgChar;
3use crate::quoting::AttemptedKeywordUsage::TypeOrFunctionName;
4use crate::quoting::{quote_value_string, IdentifierQuoter, Quotable};
5use crate::whitespace_ignorant_string::WhitespaceIgnorantString;
6use crate::{ElefantToolsError, PostgresSchema};
7use ordered_float::NotNan;
8use serde::{Deserialize, Serialize};
9use std::fmt::Display;
10
11#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
12pub enum FunctionKind {
13    #[default]
14    Function,
15    Procedure,
16    Aggregate,
17    Window,
18}
19
20impl FromPgChar for FunctionKind {
21    fn from_pg_char(c: char) -> Result<Self, crate::ElefantToolsError> {
22        match c {
23            'f' => Ok(FunctionKind::Function),
24            'p' => Ok(FunctionKind::Procedure),
25            'a' => Ok(FunctionKind::Aggregate),
26            'w' => Ok(FunctionKind::Window),
27            _ => Err(ElefantToolsError::UnknownFunctionKind(c.to_string())),
28        }
29    }
30}
31
32#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
33pub enum Volatility {
34    Immutable,
35    Stable,
36    #[default]
37    Volatile,
38}
39
40impl FromPgChar for Volatility {
41    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
42        match c {
43            'i' => Ok(Volatility::Immutable),
44            's' => Ok(Volatility::Stable),
45            'v' => Ok(Volatility::Volatile),
46            _ => Err(ElefantToolsError::UnknownVolatility(c.to_string())),
47        }
48    }
49}
50
51#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
52pub enum Parallel {
53    Safe,
54    Restricted,
55    #[default]
56    Unsafe,
57}
58
59impl FromPgChar for Parallel {
60    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
61        match c {
62            's' => Ok(Parallel::Safe),
63            'r' => Ok(Parallel::Restricted),
64            'u' => Ok(Parallel::Unsafe),
65            _ => Err(ElefantToolsError::UnknownParallel(c.to_string())),
66        }
67    }
68}
69
70#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
71pub struct PostgresFunction {
72    pub function_name: String,
73    pub language: String,
74    pub estimated_cost: NotNan<f32>,
75    pub estimated_rows: NotNan<f32>,
76    pub support_function: Option<String>,
77    pub kind: FunctionKind,
78    pub security_definer: bool,
79    pub leak_proof: bool,
80    pub strict: bool,
81    pub returns_set: bool,
82    pub volatility: Volatility,
83    pub parallel: Parallel,
84    pub sql_body: WhitespaceIgnorantString,
85    pub configuration: Option<Vec<String>>,
86    pub arguments: String,
87    pub result: Option<String>,
88    pub comment: Option<String>,
89    pub object_id: ObjectId,
90    pub depends_on: Vec<ObjectId>,
91}
92
93impl PostgresFunction {
94    pub fn get_create_statement(
95        &self,
96        schema: &PostgresSchema,
97        identifier_quoter: &IdentifierQuoter,
98    ) -> String {
99        let fn_name = format!(
100            "{}.{}",
101            schema.name.quote(identifier_quoter, TypeOrFunctionName),
102            &self
103                .function_name
104                .quote(identifier_quoter, TypeOrFunctionName)
105        );
106
107        let function_keyword = if self.kind == FunctionKind::Procedure {
108            "procedure"
109        } else {
110            "function"
111        };
112
113        let mut sql = format!(
114            "create {} {} ({})",
115            function_keyword, fn_name, self.arguments
116        );
117
118        if let Some(result) = &self.result {
119            sql.push_str(" returns ");
120
121            sql.push_str(result);
122        }
123
124        sql.push_str(" language ");
125        sql.push_str(&self.language);
126
127        if self.kind == FunctionKind::Window {
128            sql.push_str("window ");
129        }
130
131        if self.kind != FunctionKind::Procedure {
132            match self.volatility {
133                Volatility::Immutable => sql.push_str(" immutable "),
134                Volatility::Stable => sql.push_str(" stable "),
135                Volatility::Volatile => sql.push_str(" volatile "),
136            }
137
138            match self.parallel {
139                Parallel::Safe => sql.push_str(" parallel safe "),
140                Parallel::Restricted => sql.push_str(" parallel restricted "),
141                Parallel::Unsafe => sql.push_str(" parallel unsafe "),
142            }
143
144            if self.leak_proof {
145                sql.push_str(" leakproof ");
146            }
147
148            if self.strict {
149                sql.push_str(" strict ");
150            }
151        }
152
153        if self.security_definer {
154            sql.push_str(" security definer ");
155        }
156
157        if let Some(configuration) = &self.configuration {
158            sql.push_str(" set ");
159            for cfg in configuration {
160                sql.push_str(cfg);
161            }
162        }
163
164        if self.kind != FunctionKind::Procedure {
165            sql.push_str("cost ");
166            sql.push_str(&self.estimated_cost.to_string());
167
168            if self.estimated_rows.into_inner() > 0. {
169                sql.push_str(" rows ");
170                sql.push_str(&self.estimated_rows.to_string());
171            }
172
173            if let Some(support_function_name) = &self.support_function {
174                sql.push_str(" support ");
175                sql.push_str(support_function_name);
176            }
177        }
178
179        sql.push_str(" as $$");
180        sql.push_str(&self.sql_body);
181        sql.push_str("$$;");
182
183        if let Some(comment) = &self.comment {
184            sql.push_str("\ncomment on ");
185            sql.push_str(function_keyword);
186            String::push_str(&mut sql, " ");
187            sql.push_str(&fn_name);
188            sql.push_str(" is ");
189            sql.push_str(&quote_value_string(comment));
190            sql.push(';');
191        }
192
193        sql
194    }
195}
196
197#[derive(Debug, Eq, PartialEq, Copy, Clone, Default, Serialize, Deserialize)]
198pub enum FinalModify {
199    #[default]
200    ReadOnly,
201    Shareable,
202    ReadWrite,
203}
204
205impl FromPgChar for FinalModify {
206    fn from_pg_char(c: char) -> Result<Self, ElefantToolsError> {
207        match c {
208            'r' => Ok(FinalModify::ReadOnly),
209            's' => Ok(FinalModify::Shareable),
210            'w' => Ok(FinalModify::ReadWrite),
211            _ => Err(ElefantToolsError::UnknownAggregateFinalFunctionModify(
212                c.to_string(),
213            )),
214        }
215    }
216}
217
218impl Display for FinalModify {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        let str = match self {
221            FinalModify::ReadOnly => "read_only",
222            FinalModify::Shareable => "shareable",
223            FinalModify::ReadWrite => "read_write",
224        };
225        write!(f, "{}", str)
226    }
227}
228
229#[derive(Debug, Eq, PartialEq, Default, Clone, Serialize, Deserialize)]
230pub struct PostgresAggregateFunction {
231    pub function_name: String,
232    pub arguments: String,
233    pub state_transition_function: String,
234    pub final_function: Option<String>,
235    pub combine_function: Option<String>,
236    pub serial_function: Option<String>,
237    pub deserial_function: Option<String>,
238    pub moving_state_transition_function: Option<String>,
239    pub inverse_moving_state_transition_function: Option<String>,
240    pub moving_final_function: Option<String>,
241    pub final_extra_data: bool,
242    pub moving_final_extra_data: bool,
243    pub final_modify: FinalModify,
244    pub moving_final_modify: FinalModify,
245    pub sort_operator: Option<String>,
246    pub transition_type: String,
247    pub transition_space: Option<i32>,
248    pub moving_transition_type: Option<String>,
249    pub moving_transition_space: Option<i32>,
250    pub initial_value: Option<String>,
251    pub moving_initial_value: Option<String>,
252    pub parallel: Parallel,
253    pub object_id: ObjectId,
254    pub depends_on: Vec<ObjectId>,
255}
256
257impl PostgresAggregateFunction {
258    pub fn get_create_statement(
259        &self,
260        schema: &PostgresSchema,
261        identifier_quoter: &IdentifierQuoter,
262    ) -> String {
263        let fn_name = format!(
264            "{}.{}",
265            schema.name.quote(identifier_quoter, TypeOrFunctionName),
266            &self
267                .function_name
268                .quote(identifier_quoter, TypeOrFunctionName)
269        );
270
271        let mut sql = format!("create aggregate {} ({}) (\n", fn_name, self.arguments);
272
273        sql.push_str("\tsfunc = ");
274        sql.push_str(&self.state_transition_function);
275        sql.push_str(",\n\tstype=");
276        sql.push_str(&self.transition_type);
277
278        if let Some(transition_space) = &self.transition_space {
279            sql.push_str(",\n\tsspace=");
280            sql.push_str(&transition_space.to_string());
281        }
282
283        if let Some(serial_function) = &self.serial_function {
284            sql.push_str(",\n\tsfunc=");
285            sql.push_str(serial_function);
286        }
287
288        if let Some(deserial_function) = &self.deserial_function {
289            sql.push_str(",\n\tdfunc=");
290            sql.push_str(deserial_function);
291        }
292
293        if let Some(initial_value) = &self.initial_value {
294            sql.push_str(",\n\tinitcond=");
295            sql.push_str(initial_value);
296        }
297
298        if let Some(final_function) = &self.final_function {
299            sql.push_str(",\n\tfinalfunc=");
300            sql.push_str(final_function);
301
302            sql.push_str(",\n\tfinalfunc_modify=");
303            sql.push_str(&self.final_modify.to_string());
304
305            if self.final_extra_data {
306                sql.push_str(",\n\tfinalfunc_extra");
307            }
308        }
309
310        if let Some(moving_state_transition_function) = &self.moving_state_transition_function {
311            sql.push_str(",\n\tmsfunc=");
312            sql.push_str(moving_state_transition_function);
313        }
314
315        if let Some(inverse_moving_state_transition_function) =
316            &self.inverse_moving_state_transition_function
317        {
318            sql.push_str(",\n\tminv_sfunc=");
319            sql.push_str(inverse_moving_state_transition_function);
320        }
321
322        if let Some(moving_final_function) = &self.moving_final_function {
323            sql.push_str(",\n\tmfinalfunc=");
324            sql.push_str(moving_final_function);
325
326            sql.push_str(",\n\tmfinalfunc_modify=");
327            sql.push_str(&self.moving_final_modify.to_string());
328
329            if self.moving_final_extra_data {
330                sql.push_str(",\n\tmfinalfunc_extra");
331            }
332        }
333
334        if let Some(moving_transition_type) = &self.moving_transition_type {
335            sql.push_str(",\n\tmstype=");
336            sql.push_str(moving_transition_type);
337
338            if let Some(moving_transition_space) = &self.moving_transition_space {
339                sql.push_str(",\n\tmsspace=");
340                sql.push_str(&moving_transition_space.to_string());
341            }
342        }
343
344        if let Some(moving_initial_value) = &self.moving_initial_value {
345            sql.push_str(",\n\tminitcond=");
346            sql.push_str(moving_initial_value);
347        }
348
349        if let Some(sort_operator) = &self.sort_operator {
350            sql.push_str(",\n\tsortop=");
351            sql.push_str(sort_operator);
352        }
353
354        match self.parallel {
355            Parallel::Safe => sql.push_str(",\n\tparallel=safe"),
356            Parallel::Restricted => sql.push_str(",\n\tparallel=restricted"),
357            Parallel::Unsafe => sql.push_str(",\n\tparallel=unsafe"),
358        }
359
360        sql.push_str("\n);");
361
362        sql
363    }
364}