reinhardt_query/query/function/
create_function.rs1use crate::{
6 backend::QueryBuilder,
7 types::{
8 IntoIden,
9 function::{FunctionBehavior, FunctionDef, FunctionLanguage, FunctionSecurity},
10 },
11};
12
13use crate::query::traits::{QueryBuilderTrait, QueryStatementBuilder, QueryStatementWriter};
14
15#[derive(Debug, Clone)]
44pub struct CreateFunctionStatement {
45 pub(crate) function_def: FunctionDef,
46}
47
48impl CreateFunctionStatement {
49 pub fn new() -> Self {
59 Self {
61 function_def: FunctionDef::new(""),
62 }
63 }
64
65 pub fn take(&mut self) -> Self {
67 let taken = Self {
68 function_def: self.function_def.clone(),
69 };
70 self.function_def = FunctionDef::new("");
72 taken
73 }
74
75 pub fn name<N>(&mut self, name: N) -> &mut Self
86 where
87 N: IntoIden,
88 {
89 self.function_def.name = name.into_iden();
90 self
91 }
92
93 pub fn or_replace(&mut self) -> &mut Self {
105 self.function_def.or_replace = true;
106 self
107 }
108
109 pub fn add_parameter<N: IntoIden, T: Into<String>>(
122 &mut self,
123 name: N,
124 param_type: T,
125 ) -> &mut Self {
126 self.function_def = self.function_def.clone().add_parameter(name, param_type);
127 self
128 }
129
130 pub fn returns<T: Into<String>>(&mut self, returns: T) -> &mut Self {
142 self.function_def.returns = Some(returns.into());
143 self
144 }
145
146 pub fn language(&mut self, language: FunctionLanguage) -> &mut Self {
159 self.function_def.language = Some(language);
160 self
161 }
162
163 pub fn behavior(&mut self, behavior: FunctionBehavior) -> &mut Self {
176 self.function_def.behavior = Some(behavior);
177 self
178 }
179
180 pub fn security(&mut self, security: FunctionSecurity) -> &mut Self {
193 self.function_def.security = Some(security);
194 self
195 }
196
197 pub fn body<B: Into<String>>(&mut self, body: B) -> &mut Self {
226 self.function_def.body = Some(body.into());
227 self
228 }
229}
230
231impl Default for CreateFunctionStatement {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237impl QueryStatementBuilder for CreateFunctionStatement {
238 fn build_any(&self, query_builder: &dyn QueryBuilderTrait) -> (String, crate::value::Values) {
239 use std::any::Any;
241 if let Some(builder) =
242 (query_builder as &dyn Any).downcast_ref::<crate::backend::PostgresQueryBuilder>()
243 {
244 return builder.build_create_function(self);
245 }
246 if let Some(builder) =
247 (query_builder as &dyn Any).downcast_ref::<crate::backend::MySqlQueryBuilder>()
248 {
249 return builder.build_create_function(self);
250 }
251 if let Some(builder) =
252 (query_builder as &dyn Any).downcast_ref::<crate::backend::SqliteQueryBuilder>()
253 {
254 return builder.build_create_function(self);
255 }
256 if let Some(builder) =
257 (query_builder as &dyn Any).downcast_ref::<crate::backend::CockroachDBQueryBuilder>()
258 {
259 return builder.build_create_function(self);
260 }
261 panic!("Unsupported query builder type");
262 }
263}
264
265impl QueryStatementWriter for CreateFunctionStatement {}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use rstest::*;
271
272 #[rstest]
273 fn test_create_function_new() {
274 let stmt = CreateFunctionStatement::new();
275 assert!(stmt.function_def.name.to_string().is_empty());
276 assert!(!stmt.function_def.or_replace);
277 assert!(stmt.function_def.parameters.is_empty());
278 assert!(stmt.function_def.returns.is_none());
279 assert!(stmt.function_def.language.is_none());
280 assert!(stmt.function_def.behavior.is_none());
281 assert!(stmt.function_def.security.is_none());
282 assert!(stmt.function_def.body.is_none());
283 }
284
285 #[rstest]
286 fn test_create_function_with_name() {
287 let mut stmt = CreateFunctionStatement::new();
288 stmt.name("my_func");
289 assert_eq!(stmt.function_def.name.to_string(), "my_func");
290 }
291
292 #[rstest]
293 fn test_create_function_or_replace() {
294 let mut stmt = CreateFunctionStatement::new();
295 stmt.name("my_func").or_replace();
296 assert!(stmt.function_def.or_replace);
297 }
298
299 #[rstest]
300 fn test_create_function_add_parameter() {
301 let mut stmt = CreateFunctionStatement::new();
302 stmt.name("my_func").add_parameter("param1", "integer");
303 assert_eq!(stmt.function_def.parameters.len(), 1);
304 assert_eq!(
305 stmt.function_def.parameters[0]
306 .name
307 .as_ref()
308 .unwrap()
309 .to_string(),
310 "param1"
311 );
312 assert_eq!(
313 stmt.function_def.parameters[0].param_type.as_ref().unwrap(),
314 "integer"
315 );
316 }
317
318 #[rstest]
319 fn test_create_function_multiple_parameters() {
320 let mut stmt = CreateFunctionStatement::new();
321 stmt.name("my_func")
322 .add_parameter("param1", "integer")
323 .add_parameter("param2", "text");
324 assert_eq!(stmt.function_def.parameters.len(), 2);
325 assert_eq!(
326 stmt.function_def.parameters[0]
327 .name
328 .as_ref()
329 .unwrap()
330 .to_string(),
331 "param1"
332 );
333 assert_eq!(
334 stmt.function_def.parameters[1]
335 .name
336 .as_ref()
337 .unwrap()
338 .to_string(),
339 "param2"
340 );
341 }
342
343 #[rstest]
344 fn test_create_function_returns() {
345 let mut stmt = CreateFunctionStatement::new();
346 stmt.name("my_func").returns("integer");
347 assert_eq!(stmt.function_def.returns.as_ref().unwrap(), "integer");
348 }
349
350 #[rstest]
351 fn test_create_function_language() {
352 let mut stmt = CreateFunctionStatement::new();
353 stmt.name("my_func").language(FunctionLanguage::PlPgSql);
354 assert_eq!(stmt.function_def.language, Some(FunctionLanguage::PlPgSql));
355 }
356
357 #[rstest]
358 fn test_create_function_behavior() {
359 let mut stmt = CreateFunctionStatement::new();
360 stmt.name("my_func").behavior(FunctionBehavior::Immutable);
361 assert_eq!(
362 stmt.function_def.behavior,
363 Some(FunctionBehavior::Immutable)
364 );
365 }
366
367 #[rstest]
368 fn test_create_function_security() {
369 let mut stmt = CreateFunctionStatement::new();
370 stmt.name("my_func").security(FunctionSecurity::Definer);
371 assert_eq!(stmt.function_def.security, Some(FunctionSecurity::Definer));
372 }
373
374 #[rstest]
375 fn test_create_function_body() {
376 let mut stmt = CreateFunctionStatement::new();
377 stmt.name("my_func").body("SELECT 1");
378 assert_eq!(stmt.function_def.body.as_ref().unwrap(), "SELECT 1");
379 }
380
381 #[rstest]
382 fn test_create_function_all_options() {
383 let mut stmt = CreateFunctionStatement::new();
384 stmt.name("my_func")
385 .or_replace()
386 .add_parameter("a", "integer")
387 .add_parameter("b", "text")
388 .returns("integer")
389 .language(FunctionLanguage::PlPgSql)
390 .behavior(FunctionBehavior::Immutable)
391 .security(FunctionSecurity::Definer)
392 .body("BEGIN RETURN a + LENGTH(b); END;");
393
394 assert_eq!(stmt.function_def.name.to_string(), "my_func");
395 assert!(stmt.function_def.or_replace);
396 assert_eq!(stmt.function_def.parameters.len(), 2);
397 assert_eq!(stmt.function_def.returns.as_ref().unwrap(), "integer");
398 assert_eq!(stmt.function_def.language, Some(FunctionLanguage::PlPgSql));
399 assert_eq!(
400 stmt.function_def.behavior,
401 Some(FunctionBehavior::Immutable)
402 );
403 assert_eq!(stmt.function_def.security, Some(FunctionSecurity::Definer));
404 assert_eq!(
405 stmt.function_def.body.as_ref().unwrap(),
406 "BEGIN RETURN a + LENGTH(b); END;"
407 );
408 }
409
410 #[rstest]
411 fn test_create_function_take() {
412 let mut stmt = CreateFunctionStatement::new();
413 stmt.name("my_func");
414 let taken = stmt.take();
415 assert!(stmt.function_def.name.to_string().is_empty());
416 assert_eq!(taken.function_def.name.to_string(), "my_func");
417 }
418}