Skip to main content

reinhardt_query/query/function/
create_function.rs

1//! CREATE FUNCTION statement builder
2//!
3//! This module provides the `CreateFunctionStatement` type for building SQL CREATE FUNCTION queries.
4
5use crate::{
6	backend::QueryBuilder,
7	types::{
8		IntoIden,
9		function::{FunctionBehavior, FunctionDef, FunctionLanguage, FunctionSecurity},
10	},
11};
12
13use crate::query::traits::{QueryBuilderTrait, QueryStatementBuilder, QueryStatementWriter};
14
15/// CREATE FUNCTION statement builder
16///
17/// This struct provides a fluent API for constructing CREATE FUNCTION queries.
18///
19/// # Examples
20///
21/// ```rust
22/// use reinhardt_query::prelude::*;
23/// use reinhardt_query::types::function::{FunctionLanguage, FunctionBehavior};
24///
25/// // CREATE FUNCTION my_func() RETURNS integer LANGUAGE SQL AS 'SELECT 1'
26/// let query = Query::create_function()
27///     .name("my_func")
28///     .returns("integer")
29///     .language(FunctionLanguage::Sql)
30///     .body("SELECT 1");
31///
32/// // CREATE OR REPLACE FUNCTION my_func(a integer) RETURNS integer
33/// // LANGUAGE PLPGSQL IMMUTABLE AS 'BEGIN RETURN a + 1; END;'
34/// let query = Query::create_function()
35///     .name("my_func")
36///     .or_replace()
37///     .add_parameter("a", "integer")
38///     .returns("integer")
39///     .language(FunctionLanguage::PlPgSql)
40///     .behavior(FunctionBehavior::Immutable)
41///     .body("BEGIN RETURN a + 1; END;");
42/// ```
43#[derive(Debug, Clone)]
44pub struct CreateFunctionStatement {
45	pub(crate) function_def: FunctionDef,
46}
47
48impl CreateFunctionStatement {
49	/// Create a new CREATE FUNCTION statement
50	///
51	/// # Examples
52	///
53	/// ```rust
54	/// use reinhardt_query::prelude::*;
55	///
56	/// let query = Query::create_function();
57	/// ```
58	pub fn new() -> Self {
59		// Start with empty name - will be set via .name()
60		Self {
61			function_def: FunctionDef::new(""),
62		}
63	}
64
65	/// Take the ownership of data in the current [`CreateFunctionStatement`]
66	pub fn take(&mut self) -> Self {
67		let taken = Self {
68			function_def: self.function_def.clone(),
69		};
70		// Reset self to empty state
71		self.function_def = FunctionDef::new("");
72		taken
73	}
74
75	/// Set the function name
76	///
77	/// # Examples
78	///
79	/// ```rust
80	/// use reinhardt_query::prelude::*;
81	///
82	/// let query = Query::create_function()
83	///     .name("my_func");
84	/// ```
85	pub fn name<N>(&mut self, name: N) -> &mut Self
86	where
87		N: IntoIden,
88	{
89		self.function_def.name = name.into_iden();
90		self
91	}
92
93	/// Add OR REPLACE clause
94	///
95	/// # Examples
96	///
97	/// ```rust
98	/// use reinhardt_query::prelude::*;
99	///
100	/// let query = Query::create_function()
101	///     .name("my_func")
102	///     .or_replace();
103	/// ```
104	pub fn or_replace(&mut self) -> &mut Self {
105		self.function_def.or_replace = true;
106		self
107	}
108
109	/// Add a function parameter
110	///
111	/// # Examples
112	///
113	/// ```rust
114	/// use reinhardt_query::prelude::*;
115	///
116	/// let query = Query::create_function()
117	///     .name("my_func")
118	///     .add_parameter("param1", "integer")
119	///     .add_parameter("param2", "text");
120	/// ```
121	pub fn add_parameter<N: IntoIden, T: Into<String>>(
122		&mut self,
123		name: N,
124		param_type: T,
125	) -> &mut Self {
126		self.function_def = self.function_def.clone().add_parameter(name, param_type);
127		self
128	}
129
130	/// Set RETURNS type
131	///
132	/// # Examples
133	///
134	/// ```rust
135	/// use reinhardt_query::prelude::*;
136	///
137	/// let query = Query::create_function()
138	///     .name("my_func")
139	///     .returns("integer");
140	/// ```
141	pub fn returns<T: Into<String>>(&mut self, returns: T) -> &mut Self {
142		self.function_def.returns = Some(returns.into());
143		self
144	}
145
146	/// Set LANGUAGE
147	///
148	/// # Examples
149	///
150	/// ```rust
151	/// use reinhardt_query::prelude::*;
152	/// use reinhardt_query::types::function::FunctionLanguage;
153	///
154	/// let query = Query::create_function()
155	///     .name("my_func")
156	///     .language(FunctionLanguage::PlPgSql);
157	/// ```
158	pub fn language(&mut self, language: FunctionLanguage) -> &mut Self {
159		self.function_def.language = Some(language);
160		self
161	}
162
163	/// Set function behavior (IMMUTABLE/STABLE/VOLATILE)
164	///
165	/// # Examples
166	///
167	/// ```rust
168	/// use reinhardt_query::prelude::*;
169	/// use reinhardt_query::types::function::FunctionBehavior;
170	///
171	/// let query = Query::create_function()
172	///     .name("my_func")
173	///     .behavior(FunctionBehavior::Immutable);
174	/// ```
175	pub fn behavior(&mut self, behavior: FunctionBehavior) -> &mut Self {
176		self.function_def.behavior = Some(behavior);
177		self
178	}
179
180	/// Set security context (DEFINER/INVOKER)
181	///
182	/// # Examples
183	///
184	/// ```rust
185	/// use reinhardt_query::prelude::*;
186	/// use reinhardt_query::types::function::FunctionSecurity;
187	///
188	/// let query = Query::create_function()
189	///     .name("my_func")
190	///     .security(FunctionSecurity::Definer);
191	/// ```
192	pub fn security(&mut self, security: FunctionSecurity) -> &mut Self {
193		self.function_def.security = Some(security);
194		self
195	}
196
197	/// Set function body (AS clause)
198	///
199	/// # Security Warning
200	///
201	/// The function body is embedded directly into the CREATE FUNCTION statement
202	/// and will be stored in the database. **DO NOT** pass user input directly to
203	/// this method, as it could lead to arbitrary code execution.
204	///
205	/// Only use with trusted, validated code.
206	///
207	/// # Examples
208	///
209	/// ```rust
210	/// use reinhardt_query::prelude::*;
211	/// use reinhardt_query::types::function::{FunctionLanguage, FunctionBehavior};
212	///
213	/// // ✅ SAFE: Static code
214	/// let query = Query::create_function()
215	///     .name("my_func")
216	///     .returns("integer")
217	///     .language(FunctionLanguage::Sql)
218	///     .body("SELECT 1");
219	///
220	/// // ❌ UNSAFE: User input
221	/// // let query = Query::create_function()
222	/// //     .name("user_func")
223	/// //     .body(&user_code);
224	/// ```
225	pub fn body<B: Into<String>>(&mut self, body: B) -> &mut Self {
226		self.function_def.body = Some(body.into());
227		self
228	}
229}
230
231impl Default for CreateFunctionStatement {
232	fn default() -> Self {
233		Self::new()
234	}
235}
236
237impl QueryStatementBuilder for CreateFunctionStatement {
238	fn build_any(&self, query_builder: &dyn QueryBuilderTrait) -> (String, crate::value::Values) {
239		// Downcast to concrete QueryBuilder type
240		use std::any::Any;
241		if let Some(builder) =
242			(query_builder as &dyn Any).downcast_ref::<crate::backend::PostgresQueryBuilder>()
243		{
244			return builder.build_create_function(self);
245		}
246		if let Some(builder) =
247			(query_builder as &dyn Any).downcast_ref::<crate::backend::MySqlQueryBuilder>()
248		{
249			return builder.build_create_function(self);
250		}
251		if let Some(builder) =
252			(query_builder as &dyn Any).downcast_ref::<crate::backend::SqliteQueryBuilder>()
253		{
254			return builder.build_create_function(self);
255		}
256		if let Some(builder) =
257			(query_builder as &dyn Any).downcast_ref::<crate::backend::CockroachDBQueryBuilder>()
258		{
259			return builder.build_create_function(self);
260		}
261		panic!("Unsupported query builder type");
262	}
263}
264
265impl QueryStatementWriter for CreateFunctionStatement {}
266
267#[cfg(test)]
268mod tests {
269	use super::*;
270	use rstest::*;
271
272	#[rstest]
273	fn test_create_function_new() {
274		let stmt = CreateFunctionStatement::new();
275		assert!(stmt.function_def.name.to_string().is_empty());
276		assert!(!stmt.function_def.or_replace);
277		assert!(stmt.function_def.parameters.is_empty());
278		assert!(stmt.function_def.returns.is_none());
279		assert!(stmt.function_def.language.is_none());
280		assert!(stmt.function_def.behavior.is_none());
281		assert!(stmt.function_def.security.is_none());
282		assert!(stmt.function_def.body.is_none());
283	}
284
285	#[rstest]
286	fn test_create_function_with_name() {
287		let mut stmt = CreateFunctionStatement::new();
288		stmt.name("my_func");
289		assert_eq!(stmt.function_def.name.to_string(), "my_func");
290	}
291
292	#[rstest]
293	fn test_create_function_or_replace() {
294		let mut stmt = CreateFunctionStatement::new();
295		stmt.name("my_func").or_replace();
296		assert!(stmt.function_def.or_replace);
297	}
298
299	#[rstest]
300	fn test_create_function_add_parameter() {
301		let mut stmt = CreateFunctionStatement::new();
302		stmt.name("my_func").add_parameter("param1", "integer");
303		assert_eq!(stmt.function_def.parameters.len(), 1);
304		assert_eq!(
305			stmt.function_def.parameters[0]
306				.name
307				.as_ref()
308				.unwrap()
309				.to_string(),
310			"param1"
311		);
312		assert_eq!(
313			stmt.function_def.parameters[0].param_type.as_ref().unwrap(),
314			"integer"
315		);
316	}
317
318	#[rstest]
319	fn test_create_function_multiple_parameters() {
320		let mut stmt = CreateFunctionStatement::new();
321		stmt.name("my_func")
322			.add_parameter("param1", "integer")
323			.add_parameter("param2", "text");
324		assert_eq!(stmt.function_def.parameters.len(), 2);
325		assert_eq!(
326			stmt.function_def.parameters[0]
327				.name
328				.as_ref()
329				.unwrap()
330				.to_string(),
331			"param1"
332		);
333		assert_eq!(
334			stmt.function_def.parameters[1]
335				.name
336				.as_ref()
337				.unwrap()
338				.to_string(),
339			"param2"
340		);
341	}
342
343	#[rstest]
344	fn test_create_function_returns() {
345		let mut stmt = CreateFunctionStatement::new();
346		stmt.name("my_func").returns("integer");
347		assert_eq!(stmt.function_def.returns.as_ref().unwrap(), "integer");
348	}
349
350	#[rstest]
351	fn test_create_function_language() {
352		let mut stmt = CreateFunctionStatement::new();
353		stmt.name("my_func").language(FunctionLanguage::PlPgSql);
354		assert_eq!(stmt.function_def.language, Some(FunctionLanguage::PlPgSql));
355	}
356
357	#[rstest]
358	fn test_create_function_behavior() {
359		let mut stmt = CreateFunctionStatement::new();
360		stmt.name("my_func").behavior(FunctionBehavior::Immutable);
361		assert_eq!(
362			stmt.function_def.behavior,
363			Some(FunctionBehavior::Immutable)
364		);
365	}
366
367	#[rstest]
368	fn test_create_function_security() {
369		let mut stmt = CreateFunctionStatement::new();
370		stmt.name("my_func").security(FunctionSecurity::Definer);
371		assert_eq!(stmt.function_def.security, Some(FunctionSecurity::Definer));
372	}
373
374	#[rstest]
375	fn test_create_function_body() {
376		let mut stmt = CreateFunctionStatement::new();
377		stmt.name("my_func").body("SELECT 1");
378		assert_eq!(stmt.function_def.body.as_ref().unwrap(), "SELECT 1");
379	}
380
381	#[rstest]
382	fn test_create_function_all_options() {
383		let mut stmt = CreateFunctionStatement::new();
384		stmt.name("my_func")
385			.or_replace()
386			.add_parameter("a", "integer")
387			.add_parameter("b", "text")
388			.returns("integer")
389			.language(FunctionLanguage::PlPgSql)
390			.behavior(FunctionBehavior::Immutable)
391			.security(FunctionSecurity::Definer)
392			.body("BEGIN RETURN a + LENGTH(b); END;");
393
394		assert_eq!(stmt.function_def.name.to_string(), "my_func");
395		assert!(stmt.function_def.or_replace);
396		assert_eq!(stmt.function_def.parameters.len(), 2);
397		assert_eq!(stmt.function_def.returns.as_ref().unwrap(), "integer");
398		assert_eq!(stmt.function_def.language, Some(FunctionLanguage::PlPgSql));
399		assert_eq!(
400			stmt.function_def.behavior,
401			Some(FunctionBehavior::Immutable)
402		);
403		assert_eq!(stmt.function_def.security, Some(FunctionSecurity::Definer));
404		assert_eq!(
405			stmt.function_def.body.as_ref().unwrap(),
406			"BEGIN RETURN a + LENGTH(b); END;"
407		);
408	}
409
410	#[rstest]
411	fn test_create_function_take() {
412		let mut stmt = CreateFunctionStatement::new();
413		stmt.name("my_func");
414		let taken = stmt.take();
415		assert!(stmt.function_def.name.to_string().is_empty());
416		assert_eq!(taken.function_def.name.to_string(), "my_func");
417	}
418}