Skip to main content

heliosdb_nano/sql/
functions.rs

1//! Function Registry and Execution
2//!
3//! This module provides storage and execution of user-defined functions and procedures.
4//! Functions are stored in-memory and can be called from SQL queries.
5
6use crate::{Result, Error, Value, DataType};
7use super::logical_plan::{FunctionParam, ParamMode};
8use super::procedural::{ProceduralParser, ProceduralExecutor, ExecutionContext};
9use super::evaluator::Evaluator;
10use serde::{Serialize, Deserialize};
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14/// Stored function definition
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct StoredFunction {
17    /// Function name
18    pub name: String,
19    /// Whether this can replace existing
20    pub or_replace: bool,
21    /// Function parameters
22    pub params: Vec<FunctionParam>,
23    /// Return type
24    pub return_type: Option<DataType>,
25    /// Function body (raw source)
26    pub body: String,
27    /// Language (plpgsql, sql)
28    pub language: String,
29    /// Volatility (IMMUTABLE, STABLE, VOLATILE)
30    pub volatility: Option<String>,
31    /// Creation timestamp
32    pub created_at: u64,
33}
34
35/// Stored procedure definition
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct StoredProcedure {
38    /// Procedure name
39    pub name: String,
40    /// Whether this can replace existing
41    pub or_replace: bool,
42    /// Procedure parameters
43    pub params: Vec<FunctionParam>,
44    /// Procedure body (raw source)
45    pub body: String,
46    /// Language (plpgsql, sql)
47    pub language: String,
48    /// Creation timestamp
49    pub created_at: u64,
50}
51
52/// Registry for user-defined functions and procedures
53pub struct FunctionRegistry {
54    /// Stored functions
55    functions: Arc<RwLock<HashMap<String, StoredFunction>>>,
56    /// Stored procedures
57    procedures: Arc<RwLock<HashMap<String, StoredProcedure>>>,
58}
59
60impl FunctionRegistry {
61    /// Create a new function registry
62    pub fn new() -> Self {
63        Self {
64            functions: Arc::new(RwLock::new(HashMap::new())),
65            procedures: Arc::new(RwLock::new(HashMap::new())),
66        }
67    }
68
69    /// Register a function
70    pub fn register_function(&self, func: StoredFunction) -> Result<()> {
71        let mut functions = self.functions.write()
72            .map_err(|e| Error::internal(format!("Failed to acquire function lock: {}", e)))?;
73
74        let name = func.name.to_lowercase();
75
76        if functions.contains_key(&name) && !func.or_replace {
77            return Err(Error::query_execution(format!(
78                "Function '{}' already exists",
79                func.name
80            )));
81        }
82
83        functions.insert(name, func);
84        Ok(())
85    }
86
87    /// Register a procedure
88    pub fn register_procedure(&self, proc: StoredProcedure) -> Result<()> {
89        let mut procedures = self.procedures.write()
90            .map_err(|e| Error::internal(format!("Failed to acquire procedure lock: {}", e)))?;
91
92        let name = proc.name.to_lowercase();
93
94        if procedures.contains_key(&name) && !proc.or_replace {
95            return Err(Error::query_execution(format!(
96                "Procedure '{}' already exists",
97                proc.name
98            )));
99        }
100
101        procedures.insert(name, proc);
102        Ok(())
103    }
104
105    /// Get a function by name
106    pub fn get_function(&self, name: &str) -> Option<StoredFunction> {
107        let functions = self.functions.read().ok()?;
108        functions.get(&name.to_lowercase()).cloned()
109    }
110
111    /// Get a procedure by name
112    pub fn get_procedure(&self, name: &str) -> Option<StoredProcedure> {
113        let procedures = self.procedures.read().ok()?;
114        procedures.get(&name.to_lowercase()).cloned()
115    }
116
117    /// Drop a function
118    pub fn drop_function(&self, name: &str, if_exists: bool) -> Result<bool> {
119        let mut functions = self.functions.write()
120            .map_err(|e| Error::internal(format!("Failed to acquire function lock: {}", e)))?;
121
122        let name_lower = name.to_lowercase();
123
124        if functions.remove(&name_lower).is_some() {
125            Ok(true)
126        } else if if_exists {
127            Ok(false)
128        } else {
129            Err(Error::query_execution(format!(
130                "Function '{}' does not exist",
131                name
132            )))
133        }
134    }
135
136    /// Drop a procedure
137    pub fn drop_procedure(&self, name: &str, if_exists: bool) -> Result<bool> {
138        let mut procedures = self.procedures.write()
139            .map_err(|e| Error::internal(format!("Failed to acquire procedure lock: {}", e)))?;
140
141        let name_lower = name.to_lowercase();
142
143        if procedures.remove(&name_lower).is_some() {
144            Ok(true)
145        } else if if_exists {
146            Ok(false)
147        } else {
148            Err(Error::query_execution(format!(
149                "Procedure '{}' does not exist",
150                name
151            )))
152        }
153    }
154
155    /// Check if a function exists
156    pub fn function_exists(&self, name: &str) -> bool {
157        self.functions.read()
158            .map(|f| f.contains_key(&name.to_lowercase()))
159            .unwrap_or(false)
160    }
161
162    /// Check if a procedure exists
163    pub fn procedure_exists(&self, name: &str) -> bool {
164        self.procedures.read()
165            .map(|p| p.contains_key(&name.to_lowercase()))
166            .unwrap_or(false)
167    }
168
169    /// List all function names
170    pub fn list_functions(&self) -> Vec<String> {
171        self.functions.read()
172            .map(|f| f.keys().cloned().collect())
173            .unwrap_or_default()
174    }
175
176    /// List all procedure names
177    pub fn list_procedures(&self) -> Vec<String> {
178        self.procedures.read()
179            .map(|p| p.keys().cloned().collect())
180            .unwrap_or_default()
181    }
182
183    /// Execute a stored function with arguments
184    pub fn execute_function(
185        &self,
186        name: &str,
187        args: &[Value],
188        sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
189    ) -> Result<Value> {
190        let func = self.get_function(name)
191            .ok_or_else(|| Error::query_execution(format!(
192                "Function '{}' does not exist",
193                name
194            )))?;
195
196        // Validate argument count
197        let required_params: Vec<_> = func.params.iter()
198            .filter(|p| p.default.is_none() && p.mode != ParamMode::Out)
199            .collect();
200
201        if args.len() < required_params.len() {
202            return Err(Error::query_execution(format!(
203                "Function '{}' requires at least {} arguments, got {}",
204                name, required_params.len(), args.len()
205            )));
206        }
207
208        let max_in_params = func.params.iter()
209            .filter(|p| p.mode != ParamMode::Out)
210            .count();
211
212        if args.len() > max_in_params {
213            return Err(Error::query_execution(format!(
214                "Function '{}' accepts at most {} arguments, got {}",
215                name, max_in_params, args.len()
216            )));
217        }
218
219        // Execute based on language
220        match func.language.to_lowercase().as_str() {
221            "sql" => self.execute_sql_function(&func, args, sql_executor),
222            "plpgsql" => self.execute_plpgsql_function(&func, args, sql_executor),
223            lang => Err(Error::query_execution(format!(
224                "Unsupported function language: {}",
225                lang
226            ))),
227        }
228    }
229
230    /// Execute a SQL language function
231    // SAFETY: args[i] guarded by i < args.len(); results[0][0] guarded by is_empty() checks.
232    #[allow(clippy::indexing_slicing)]
233    fn execute_sql_function(
234        &self,
235        func: &StoredFunction,
236        args: &[Value],
237        mut sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
238    ) -> Result<Value> {
239        // For SQL functions, the body is raw SQL
240        // Replace $1, $2, etc. with actual argument values
241        let mut body = func.body.clone();
242
243        for (i, arg) in args.iter().enumerate() {
244            let placeholder = format!("${}", i + 1);
245            let value_str = value_to_sql_literal(arg);
246            body = body.replace(&placeholder, &value_str);
247        }
248
249        // Also replace named parameters
250        for (i, param) in func.params.iter().enumerate() {
251            if i < args.len() {
252                let value_str = value_to_sql_literal(&args[i]);
253                // Replace both $name and name patterns
254                body = body.replace(&format!("${}", param.name), &value_str);
255            }
256        }
257
258        // Execute the SQL and get the result
259        let results = sql_executor(&body)?;
260
261        if results.is_empty() || results[0].is_empty() {
262            Ok(Value::Null)
263        } else {
264            Ok(results[0][0].clone())
265        }
266    }
267
268    /// Execute a PL/pgSQL function
269    // SAFETY: args[i] guarded by i < args.len() check.
270    #[allow(clippy::indexing_slicing)]
271    fn execute_plpgsql_function(
272        &self,
273        func: &StoredFunction,
274        args: &[Value],
275        sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
276    ) -> Result<Value> {
277        // Parse the function body into a procedural block
278        let mut parser = ProceduralParser::new(&func.body);
279        let block = parser.parse_block()
280            .map_err(|e| Error::query_execution(format!(
281                "Failed to parse function body: {}",
282                e
283            )))?;
284
285        // Create execution context
286        let schema = Arc::new(crate::Schema { columns: vec![] });
287        let evaluator = Evaluator::new(schema);
288        let mut ctx = ExecutionContext::new(&evaluator, sql_executor);
289
290        // Bind parameters to context
291        for (i, param) in func.params.iter().enumerate() {
292            if param.mode == ParamMode::Out {
293                continue;
294            }
295
296            let value = if i < args.len() {
297                args[i].clone()
298            } else if let Some(ref default) = param.default {
299                evaluator.evaluate(default, &crate::Tuple::new(vec![]))?
300            } else {
301                Value::Null
302            };
303
304            ctx.scope.declare(
305                param.name.clone(),
306                super::procedural::Variable {
307                    value,
308                    data_type: Some(param.data_type.clone()),
309                    is_constant: false,
310                    not_null: false,
311                },
312            )?;
313        }
314
315        // Execute the block
316        ProceduralExecutor::execute_block(&block, &mut ctx)?;
317
318        // Return the result
319        Ok(ctx.return_value.unwrap_or(Value::Null))
320    }
321
322    /// Execute a stored procedure with arguments
323    pub fn execute_procedure(
324        &self,
325        name: &str,
326        args: &[Value],
327        sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
328    ) -> Result<()> {
329        let proc = self.get_procedure(name)
330            .ok_or_else(|| Error::query_execution(format!(
331                "Procedure '{}' does not exist",
332                name
333            )))?;
334
335        // Execute based on language
336        match proc.language.to_lowercase().as_str() {
337            "sql" => self.execute_sql_procedure(&proc, args, sql_executor),
338            "plpgsql" => self.execute_plpgsql_procedure(&proc, args, sql_executor),
339            lang => Err(Error::query_execution(format!(
340                "Unsupported procedure language: {}",
341                lang
342            ))),
343        }
344    }
345
346    /// Execute a SQL language procedure
347    // SAFETY: args[i] guarded by i < args.len() check.
348    #[allow(clippy::indexing_slicing)]
349    fn execute_sql_procedure(
350        &self,
351        proc: &StoredProcedure,
352        args: &[Value],
353        mut sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
354    ) -> Result<()> {
355        let mut body = proc.body.clone();
356
357        for (i, arg) in args.iter().enumerate() {
358            let placeholder = format!("${}", i + 1);
359            let value_str = value_to_sql_literal(arg);
360            body = body.replace(&placeholder, &value_str);
361        }
362
363        for (i, param) in proc.params.iter().enumerate() {
364            if i < args.len() {
365                let value_str = value_to_sql_literal(&args[i]);
366                body = body.replace(&format!("${}", param.name), &value_str);
367            }
368        }
369
370        sql_executor(&body)?;
371        Ok(())
372    }
373
374    /// Execute a PL/pgSQL procedure
375    // SAFETY: args[i] guarded by i < args.len() check.
376    #[allow(clippy::indexing_slicing)]
377    fn execute_plpgsql_procedure(
378        &self,
379        proc: &StoredProcedure,
380        args: &[Value],
381        sql_executor: impl FnMut(&str) -> Result<Vec<Vec<Value>>>,
382    ) -> Result<()> {
383        let mut parser = ProceduralParser::new(&proc.body);
384        let block = parser.parse_block()
385            .map_err(|e| Error::query_execution(format!(
386                "Failed to parse procedure body: {}",
387                e
388            )))?;
389
390        let schema = Arc::new(crate::Schema { columns: vec![] });
391        let evaluator = Evaluator::new(schema);
392        let mut ctx = ExecutionContext::new(&evaluator, sql_executor);
393
394        for (i, param) in proc.params.iter().enumerate() {
395            if param.mode == ParamMode::Out {
396                continue;
397            }
398
399            let value = if i < args.len() {
400                args[i].clone()
401            } else if let Some(ref default) = param.default {
402                evaluator.evaluate(default, &crate::Tuple::new(vec![]))?
403            } else {
404                Value::Null
405            };
406
407            ctx.scope.declare(
408                param.name.clone(),
409                super::procedural::Variable {
410                    value,
411                    data_type: Some(param.data_type.clone()),
412                    is_constant: false,
413                    not_null: false,
414                },
415            )?;
416        }
417
418        ProceduralExecutor::execute_block(&block, &mut ctx)?;
419        Ok(())
420    }
421}
422
423impl Default for FunctionRegistry {
424    fn default() -> Self {
425        Self::new()
426    }
427}
428
429/// Convert a Value to a SQL literal string
430fn value_to_sql_literal(value: &Value) -> String {
431    match value {
432        Value::Null => "NULL".to_string(),
433        Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
434        Value::Int2(v) => v.to_string(),
435        Value::Int4(v) => v.to_string(),
436        Value::Int8(v) => v.to_string(),
437        Value::Float4(v) => v.to_string(),
438        Value::Float8(v) => v.to_string(),
439        Value::String(s) => format!("'{}'", s.replace('\'', "''")),
440        Value::Numeric(d) => d.clone(),
441        Value::Date(d) => format!("'{}'", d),
442        Value::Time(t) => format!("'{}'", t),
443        Value::Timestamp(ts) => format!("'{}'", ts),
444        Value::Uuid(u) => format!("'{}'", u),
445        Value::Json(j) => format!("'{}'", j.replace('\'', "''")),
446        Value::Bytes(b) => format!("E'\\\\x{}'", hex::encode(b)),
447        Value::Vector(v) => format!("[{}]", v.iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",")),
448        Value::Array(arr) => {
449            let elements: Vec<String> = arr.iter().map(value_to_sql_literal).collect();
450            format!("ARRAY[{}]", elements.join(","))
451        }
452        // Storage references (should be resolved before reaching here)
453        Value::DictRef { dict_id } => format!("'dict:{}'", dict_id),
454        Value::CasRef { hash } => format!("E'\\\\x{}'", hex::encode(hash)),
455        Value::ColumnarRef => "NULL".to_string(), // Placeholder
456        Value::Interval(iv) => format!("INTERVAL '{} microseconds'", iv),
457    }
458}
459
460#[cfg(test)]
461#[allow(clippy::unwrap_used)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_register_function() {
467        let registry = FunctionRegistry::new();
468
469        let func = StoredFunction {
470            name: "add_numbers".to_string(),
471            or_replace: false,
472            params: vec![
473                FunctionParam {
474                    name: "a".to_string(),
475                    data_type: DataType::Int4,
476                    mode: ParamMode::In,
477                    default: None,
478                },
479                FunctionParam {
480                    name: "b".to_string(),
481                    data_type: DataType::Int4,
482                    mode: ParamMode::In,
483                    default: None,
484                },
485            ],
486            return_type: Some(DataType::Int4),
487            body: "SELECT $1 + $2".to_string(),
488            language: "sql".to_string(),
489            volatility: Some("IMMUTABLE".to_string()),
490            created_at: 0,
491        };
492
493        registry.register_function(func).unwrap();
494        assert!(registry.function_exists("add_numbers"));
495        assert!(registry.function_exists("ADD_NUMBERS")); // case insensitive
496    }
497
498    #[test]
499    fn test_duplicate_function_error() {
500        let registry = FunctionRegistry::new();
501
502        let func = StoredFunction {
503            name: "my_func".to_string(),
504            or_replace: false,
505            params: vec![],
506            return_type: Some(DataType::Int4),
507            body: "SELECT 1".to_string(),
508            language: "sql".to_string(),
509            volatility: None,
510            created_at: 0,
511        };
512
513        registry.register_function(func.clone()).unwrap();
514
515        // Second registration should fail
516        let result = registry.register_function(func);
517        assert!(result.is_err());
518    }
519
520    #[test]
521    fn test_or_replace() {
522        let registry = FunctionRegistry::new();
523
524        let func1 = StoredFunction {
525            name: "my_func".to_string(),
526            or_replace: false,
527            params: vec![],
528            return_type: Some(DataType::Int4),
529            body: "SELECT 1".to_string(),
530            language: "sql".to_string(),
531            volatility: None,
532            created_at: 0,
533        };
534
535        registry.register_function(func1).unwrap();
536
537        let func2 = StoredFunction {
538            name: "my_func".to_string(),
539            or_replace: true,
540            params: vec![],
541            return_type: Some(DataType::Int4),
542            body: "SELECT 2".to_string(),
543            language: "sql".to_string(),
544            volatility: None,
545            created_at: 0,
546        };
547
548        // Should succeed with or_replace
549        registry.register_function(func2).unwrap();
550
551        let stored = registry.get_function("my_func").unwrap();
552        assert_eq!(stored.body, "SELECT 2");
553    }
554
555    #[test]
556    fn test_drop_function() {
557        let registry = FunctionRegistry::new();
558
559        let func = StoredFunction {
560            name: "to_drop".to_string(),
561            or_replace: false,
562            params: vec![],
563            return_type: Some(DataType::Int4),
564            body: "SELECT 1".to_string(),
565            language: "sql".to_string(),
566            volatility: None,
567            created_at: 0,
568        };
569
570        registry.register_function(func).unwrap();
571        assert!(registry.function_exists("to_drop"));
572
573        registry.drop_function("to_drop", false).unwrap();
574        assert!(!registry.function_exists("to_drop"));
575    }
576
577    #[test]
578    fn test_execute_sql_function() {
579        let registry = FunctionRegistry::new();
580
581        let func = StoredFunction {
582            name: "double_it".to_string(),
583            or_replace: false,
584            params: vec![
585                FunctionParam {
586                    name: "x".to_string(),
587                    data_type: DataType::Int4,
588                    mode: ParamMode::In,
589                    default: None,
590                },
591            ],
592            return_type: Some(DataType::Int4),
593            body: "SELECT $1 * 2".to_string(),
594            language: "sql".to_string(),
595            volatility: Some("IMMUTABLE".to_string()),
596            created_at: 0,
597        };
598
599        registry.register_function(func).unwrap();
600
601        // Mock SQL executor
602        let result = registry.execute_function(
603            "double_it",
604            &[Value::Int4(21)],
605            |sql| {
606                // The SQL should be "SELECT 21 * 2"
607                assert!(sql.contains("21"));
608                Ok(vec![vec![Value::Int4(42)]])
609            },
610        ).unwrap();
611
612        assert_eq!(result, Value::Int4(42));
613    }
614
615    #[test]
616    fn test_value_to_sql_literal() {
617        assert_eq!(value_to_sql_literal(&Value::Null), "NULL");
618        assert_eq!(value_to_sql_literal(&Value::Boolean(true)), "TRUE");
619        assert_eq!(value_to_sql_literal(&Value::Int4(42)), "42");
620        assert_eq!(value_to_sql_literal(&Value::String("hello".to_string())), "'hello'");
621        assert_eq!(value_to_sql_literal(&Value::String("it's".to_string())), "'it''s'");
622    }
623}