1use crate::{Result, Error, Value, DataType};
7use super::logical_plan::{FunctionParam, ParamMode};
8use super::procedural::{ProceduralParser, ProceduralExecutor, ExecutionContext};
9use super::evaluator::Evaluator;
10use serde::{Serialize, Deserialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StoredFunction {
17 pub name: String,
19 pub or_replace: bool,
21 pub params: Vec<FunctionParam>,
23 pub return_type: Option<DataType>,
25 pub body: String,
27 pub language: String,
29 pub volatility: Option<String>,
31 pub created_at: u64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct StoredProcedure {
38 pub name: String,
40 pub or_replace: bool,
42 pub params: Vec<FunctionParam>,
44 pub body: String,
46 pub language: String,
48 pub created_at: u64,
50}
51
52pub struct FunctionRegistry {
54 functions: Arc<RwLock<HashMap<String, StoredFunction>>>,
56 procedures: Arc<RwLock<HashMap<String, StoredProcedure>>>,
58}
59
60impl FunctionRegistry {
61 pub fn new() -> Self {
63 Self {
64 functions: Arc::new(RwLock::new(HashMap::new())),
65 procedures: Arc::new(RwLock::new(HashMap::new())),
66 }
67 }
68
69 pub fn register_function(&self, func: StoredFunction) -> Result<()> {
71 let mut functions = self.functions.write()
72 .map_err(|e| Error::internal(format!("Failed to acquire function lock: {}", e)))?;
73
74 let name = func.name.to_lowercase();
75
76 if functions.contains_key(&name) && !func.or_replace {
77 return Err(Error::query_execution(format!(
78 "Function '{}' already exists",
79 func.name
80 )));
81 }
82
83 functions.insert(name, func);
84 Ok(())
85 }
86
87 pub fn register_procedure(&self, proc: StoredProcedure) -> Result<()> {
89 let mut procedures = self.procedures.write()
90 .map_err(|e| Error::internal(format!("Failed to acquire procedure lock: {}", e)))?;
91
92 let name = proc.name.to_lowercase();
93
94 if procedures.contains_key(&name) && !proc.or_replace {
95 return Err(Error::query_execution(format!(
96 "Procedure '{}' already exists",
97 proc.name
98 )));
99 }
100
101 procedures.insert(name, proc);
102 Ok(())
103 }
104
105 pub fn get_function(&self, name: &str) -> Option<StoredFunction> {
107 let functions = self.functions.read().ok()?;
108 functions.get(&name.to_lowercase()).cloned()
109 }
110
111 pub fn get_procedure(&self, name: &str) -> Option<StoredProcedure> {
113 let procedures = self.procedures.read().ok()?;
114 procedures.get(&name.to_lowercase()).cloned()
115 }
116
117 pub fn drop_function(&self, name: &str, if_exists: bool) -> Result<bool> {
119 let mut functions = self.functions.write()
120 .map_err(|e| Error::internal(format!("Failed to acquire function lock: {}", e)))?;
121
122 let name_lower = name.to_lowercase();
123
124 if functions.remove(&name_lower).is_some() {
125 Ok(true)
126 } else if if_exists {
127 Ok(false)
128 } else {
129 Err(Error::query_execution(format!(
130 "Function '{}' does not exist",
131 name
132 )))
133 }
134 }
135
136 pub fn drop_procedure(&self, name: &str, if_exists: bool) -> Result<bool> {
138 let mut procedures = self.procedures.write()
139 .map_err(|e| Error::internal(format!("Failed to acquire procedure lock: {}", e)))?;
140
141 let name_lower = name.to_lowercase();
142
143 if procedures.remove(&name_lower).is_some() {
144 Ok(true)
145 } else if if_exists {
146 Ok(false)
147 } else {
148 Err(Error::query_execution(format!(
149 "Procedure '{}' does not exist",
150 name
151 )))
152 }
153 }
154
155 pub fn function_exists(&self, name: &str) -> bool {
157 self.functions.read()
158 .map(|f| f.contains_key(&name.to_lowercase()))
159 .unwrap_or(false)
160 }
161
162 pub fn procedure_exists(&self, name: &str) -> bool {
164 self.procedures.read()
165 .map(|p| p.contains_key(&name.to_lowercase()))
166 .unwrap_or(false)
167 }
168
169 pub fn list_functions(&self) -> Vec<String> {
171 self.functions.read()
172 .map(|f| f.keys().cloned().collect())
173 .unwrap_or_default()
174 }
175
176 pub fn list_procedures(&self) -> Vec<String> {
178 self.procedures.read()
179 .map(|p| p.keys().cloned().collect())
180 .unwrap_or_default()
181 }
182
183 pub fn execute_function(
185 &self,
186 name: &str,
187 args: &[Value],
188 sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
189 ) -> Result<Value> {
190 let func = self.get_function(name)
191 .ok_or_else(|| Error::query_execution(format!(
192 "Function '{}' does not exist",
193 name
194 )))?;
195
196 let required_params: Vec<_> = func.params.iter()
198 .filter(|p| p.default.is_none() && p.mode != ParamMode::Out)
199 .collect();
200
201 if args.len() < required_params.len() {
202 return Err(Error::query_execution(format!(
203 "Function '{}' requires at least {} arguments, got {}",
204 name, required_params.len(), args.len()
205 )));
206 }
207
208 let max_in_params = func.params.iter()
209 .filter(|p| p.mode != ParamMode::Out)
210 .count();
211
212 if args.len() > max_in_params {
213 return Err(Error::query_execution(format!(
214 "Function '{}' accepts at most {} arguments, got {}",
215 name, max_in_params, args.len()
216 )));
217 }
218
219 match func.language.to_lowercase().as_str() {
221 "sql" => self.execute_sql_function(&func, args, sql_executor),
222 "plpgsql" => self.execute_plpgsql_function(&func, args, sql_executor),
223 lang => Err(Error::query_execution(format!(
224 "Unsupported function language: {}",
225 lang
226 ))),
227 }
228 }
229
230 #[allow(clippy::indexing_slicing)]
233 fn execute_sql_function(
234 &self,
235 func: &StoredFunction,
236 args: &[Value],
237 mut sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
238 ) -> Result<Value> {
239 let mut body = func.body.clone();
242
243 for (i, arg) in args.iter().enumerate() {
244 let placeholder = format!("${}", i + 1);
245 let value_str = value_to_sql_literal(arg);
246 body = body.replace(&placeholder, &value_str);
247 }
248
249 for (i, param) in func.params.iter().enumerate() {
251 if i < args.len() {
252 let value_str = value_to_sql_literal(&args[i]);
253 body = body.replace(&format!("${}", param.name), &value_str);
255 }
256 }
257
258 let results = sql_executor(&body)?;
260
261 if results.is_empty() || results[0].is_empty() {
262 Ok(Value::Null)
263 } else {
264 Ok(results[0][0].clone())
265 }
266 }
267
268 #[allow(clippy::indexing_slicing)]
271 fn execute_plpgsql_function(
272 &self,
273 func: &StoredFunction,
274 args: &[Value],
275 sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
276 ) -> Result<Value> {
277 let mut parser = ProceduralParser::new(&func.body);
279 let block = parser.parse_block()
280 .map_err(|e| Error::query_execution(format!(
281 "Failed to parse function body: {}",
282 e
283 )))?;
284
285 let schema = Arc::new(crate::Schema { columns: vec![] });
287 let evaluator = Evaluator::new(schema);
288 let mut ctx = ExecutionContext::new(&evaluator, sql_executor);
289
290 for (i, param) in func.params.iter().enumerate() {
292 if param.mode == ParamMode::Out {
293 continue;
294 }
295
296 let value = if i < args.len() {
297 args[i].clone()
298 } else if let Some(ref default) = param.default {
299 evaluator.evaluate(default, &crate::Tuple::new(vec![]))?
300 } else {
301 Value::Null
302 };
303
304 ctx.scope.declare(
305 param.name.clone(),
306 super::procedural::Variable {
307 value,
308 data_type: Some(param.data_type.clone()),
309 is_constant: false,
310 not_null: false,
311 },
312 )?;
313 }
314
315 ProceduralExecutor::execute_block(&block, &mut ctx)?;
317
318 Ok(ctx.return_value.unwrap_or(Value::Null))
320 }
321
322 pub fn execute_procedure(
324 &self,
325 name: &str,
326 args: &[Value],
327 sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
328 ) -> Result<()> {
329 let proc = self.get_procedure(name)
330 .ok_or_else(|| Error::query_execution(format!(
331 "Procedure '{}' does not exist",
332 name
333 )))?;
334
335 match proc.language.to_lowercase().as_str() {
337 "sql" => self.execute_sql_procedure(&proc, args, sql_executor),
338 "plpgsql" => self.execute_plpgsql_procedure(&proc, args, sql_executor),
339 lang => Err(Error::query_execution(format!(
340 "Unsupported procedure language: {}",
341 lang
342 ))),
343 }
344 }
345
346 #[allow(clippy::indexing_slicing)]
349 fn execute_sql_procedure(
350 &self,
351 proc: &StoredProcedure,
352 args: &[Value],
353 mut sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
354 ) -> Result<()> {
355 let mut body = proc.body.clone();
356
357 for (i, arg) in args.iter().enumerate() {
358 let placeholder = format!("${}", i + 1);
359 let value_str = value_to_sql_literal(arg);
360 body = body.replace(&placeholder, &value_str);
361 }
362
363 for (i, param) in proc.params.iter().enumerate() {
364 if i < args.len() {
365 let value_str = value_to_sql_literal(&args[i]);
366 body = body.replace(&format!("${}", param.name), &value_str);
367 }
368 }
369
370 sql_executor(&body)?;
371 Ok(())
372 }
373
374 #[allow(clippy::indexing_slicing)]
377 fn execute_plpgsql_procedure(
378 &self,
379 proc: &StoredProcedure,
380 args: &[Value],
381 sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
382 ) -> Result<()> {
383 let mut parser = ProceduralParser::new(&proc.body);
384 let block = parser.parse_block()
385 .map_err(|e| Error::query_execution(format!(
386 "Failed to parse procedure body: {}",
387 e
388 )))?;
389
390 let schema = Arc::new(crate::Schema { columns: vec![] });
391 let evaluator = Evaluator::new(schema);
392 let mut ctx = ExecutionContext::new(&evaluator, sql_executor);
393
394 for (i, param) in proc.params.iter().enumerate() {
395 if param.mode == ParamMode::Out {
396 continue;
397 }
398
399 let value = if i < args.len() {
400 args[i].clone()
401 } else if let Some(ref default) = param.default {
402 evaluator.evaluate(default, &crate::Tuple::new(vec![]))?
403 } else {
404 Value::Null
405 };
406
407 ctx.scope.declare(
408 param.name.clone(),
409 super::procedural::Variable {
410 value,
411 data_type: Some(param.data_type.clone()),
412 is_constant: false,
413 not_null: false,
414 },
415 )?;
416 }
417
418 ProceduralExecutor::execute_block(&block, &mut ctx)?;
419 Ok(())
420 }
421}
422
423impl Default for FunctionRegistry {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429fn value_to_sql_literal(value: &Value) -> String {
431 match value {
432 Value::Null => "NULL".to_string(),
433 Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
434 Value::Int2(v) => v.to_string(),
435 Value::Int4(v) => v.to_string(),
436 Value::Int8(v) => v.to_string(),
437 Value::Float4(v) => v.to_string(),
438 Value::Float8(v) => v.to_string(),
439 Value::String(s) => format!("'{}'", s.replace('\'', "''")),
440 Value::Numeric(d) => d.clone(),
441 Value::Date(d) => format!("'{}'", d),
442 Value::Time(t) => format!("'{}'", t),
443 Value::Timestamp(ts) => format!("'{}'", ts),
444 Value::Uuid(u) => format!("'{}'", u),
445 Value::Json(j) => format!("'{}'", j.replace('\'', "''")),
446 Value::Bytes(b) => format!("E'\\\\x{}'", hex::encode(b)),
447 Value::Vector(v) => format!("[{}]", v.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")),
448 Value::Array(arr) => {
449 let elements: Vec<String> = arr.iter().map(value_to_sql_literal).collect();
450 format!("ARRAY[{}]", elements.join(","))
451 }
452 Value::DictRef { dict_id } => format!("'dict:{}'", dict_id),
454 Value::CasRef { hash } => format!("E'\\\\x{}'", hex::encode(hash)),
455 Value::ColumnarRef => "NULL".to_string(), Value::Interval(iv) => format!("INTERVAL '{} microseconds'", iv),
457 }
458}
459
460#[cfg(test)]
461#[allow(clippy::unwrap_used)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_register_function() {
467 let registry = FunctionRegistry::new();
468
469 let func = StoredFunction {
470 name: "add_numbers".to_string(),
471 or_replace: false,
472 params: vec![
473 FunctionParam {
474 name: "a".to_string(),
475 data_type: DataType::Int4,
476 mode: ParamMode::In,
477 default: None,
478 },
479 FunctionParam {
480 name: "b".to_string(),
481 data_type: DataType::Int4,
482 mode: ParamMode::In,
483 default: None,
484 },
485 ],
486 return_type: Some(DataType::Int4),
487 body: "SELECT $1 + $2".to_string(),
488 language: "sql".to_string(),
489 volatility: Some("IMMUTABLE".to_string()),
490 created_at: 0,
491 };
492
493 registry.register_function(func).unwrap();
494 assert!(registry.function_exists("add_numbers"));
495 assert!(registry.function_exists("ADD_NUMBERS")); }
497
498 #[test]
499 fn test_duplicate_function_error() {
500 let registry = FunctionRegistry::new();
501
502 let func = StoredFunction {
503 name: "my_func".to_string(),
504 or_replace: false,
505 params: vec![],
506 return_type: Some(DataType::Int4),
507 body: "SELECT 1".to_string(),
508 language: "sql".to_string(),
509 volatility: None,
510 created_at: 0,
511 };
512
513 registry.register_function(func.clone()).unwrap();
514
515 let result = registry.register_function(func);
517 assert!(result.is_err());
518 }
519
520 #[test]
521 fn test_or_replace() {
522 let registry = FunctionRegistry::new();
523
524 let func1 = StoredFunction {
525 name: "my_func".to_string(),
526 or_replace: false,
527 params: vec![],
528 return_type: Some(DataType::Int4),
529 body: "SELECT 1".to_string(),
530 language: "sql".to_string(),
531 volatility: None,
532 created_at: 0,
533 };
534
535 registry.register_function(func1).unwrap();
536
537 let func2 = StoredFunction {
538 name: "my_func".to_string(),
539 or_replace: true,
540 params: vec![],
541 return_type: Some(DataType::Int4),
542 body: "SELECT 2".to_string(),
543 language: "sql".to_string(),
544 volatility: None,
545 created_at: 0,
546 };
547
548 registry.register_function(func2).unwrap();
550
551 let stored = registry.get_function("my_func").unwrap();
552 assert_eq!(stored.body, "SELECT 2");
553 }
554
555 #[test]
556 fn test_drop_function() {
557 let registry = FunctionRegistry::new();
558
559 let func = StoredFunction {
560 name: "to_drop".to_string(),
561 or_replace: false,
562 params: vec![],
563 return_type: Some(DataType::Int4),
564 body: "SELECT 1".to_string(),
565 language: "sql".to_string(),
566 volatility: None,
567 created_at: 0,
568 };
569
570 registry.register_function(func).unwrap();
571 assert!(registry.function_exists("to_drop"));
572
573 registry.drop_function("to_drop", false).unwrap();
574 assert!(!registry.function_exists("to_drop"));
575 }
576
577 #[test]
578 fn test_execute_sql_function() {
579 let registry = FunctionRegistry::new();
580
581 let func = StoredFunction {
582 name: "double_it".to_string(),
583 or_replace: false,
584 params: vec![
585 FunctionParam {
586 name: "x".to_string(),
587 data_type: DataType::Int4,
588 mode: ParamMode::In,
589 default: None,
590 },
591 ],
592 return_type: Some(DataType::Int4),
593 body: "SELECT $1 * 2".to_string(),
594 language: "sql".to_string(),
595 volatility: Some("IMMUTABLE".to_string()),
596 created_at: 0,
597 };
598
599 registry.register_function(func).unwrap();
600
601 let result = registry.execute_function(
603 "double_it",
604 &[Value::Int4(21)],
605 |sql| {
606 assert!(sql.contains("21"));
608 Ok(vec![vec![Value::Int4(42)]])
609 },
610 ).unwrap();
611
612 assert_eq!(result, Value::Int4(42));
613 }
614
615 #[test]
616 fn test_value_to_sql_literal() {
617 assert_eq!(value_to_sql_literal(&Value::Null), "NULL");
618 assert_eq!(value_to_sql_literal(&Value::Boolean(true)), "TRUE");
619 assert_eq!(value_to_sql_literal(&Value::Int4(42)), "42");
620 assert_eq!(value_to_sql_literal(&Value::String("hello".to_string())), "'hello'");
621 assert_eq!(value_to_sql_literal(&Value::String("it's".to_string())), "'it''s'");
622 }
623}