dbx_core/automation/
procedure_executor.rs1use crate::automation::StoredProcedure;
6use crate::engine::Database;
7use crate::error::{DbxError, DbxResult};
8use std::collections::HashMap;
9
10pub struct ProcedureExecutor {
12 procedures: HashMap<String, StoredProcedure>,
14}
15
16impl ProcedureExecutor {
17 pub fn new() -> Self {
19 Self {
20 procedures: HashMap::new(),
21 }
22 }
23
24 pub fn register(&mut self, procedure: StoredProcedure) {
26 self.procedures.insert(procedure.name.clone(), procedure);
27 }
28
29 pub fn register_all(&mut self, procedures: Vec<StoredProcedure>) {
31 for proc in procedures {
32 self.register(proc);
33 }
34 }
35
36 pub fn unregister(&mut self, name: &str) -> bool {
38 self.procedures.remove(name).is_some()
39 }
40
41 pub fn list_procedures(&self) -> Vec<&StoredProcedure> {
43 self.procedures.values().collect()
44 }
45
46 pub fn get(&self, name: &str) -> Option<&StoredProcedure> {
48 self.procedures.get(name)
49 }
50
51 fn validate_parameter_type(sql_type: &str, value: &str) -> Result<(), String> {
55 let sql_type_upper = sql_type.to_uppercase();
56
57 match sql_type_upper.as_str() {
58 "INT" | "INTEGER" | "BIGINT" | "SMALLINT" | "TINYINT" => value
59 .parse::<i64>()
60 .map(|_| ())
61 .map_err(|_| format!("Expected integer, got '{}'", value)),
62 "REAL" | "FLOAT" | "DOUBLE" | "DECIMAL" | "NUMERIC" => value
63 .parse::<f64>()
64 .map(|_| ())
65 .map_err(|_| format!("Expected number, got '{}'", value)),
66 "TEXT" | "VARCHAR" | "CHAR" | "STRING" => {
67 Ok(())
69 }
70 "BOOLEAN" | "BOOL" => {
71 let val_lower = value.to_lowercase();
72 if val_lower == "true"
73 || val_lower == "false"
74 || val_lower == "1"
75 || val_lower == "0"
76 {
77 Ok(())
78 } else {
79 Err(format!(
80 "Expected boolean (true/false/1/0), got '{}'",
81 value
82 ))
83 }
84 }
85 _ => {
86 #[cfg(debug_assertions)]
88 eprintln!(
89 "[Procedure] Unknown parameter type '{}', skipping validation",
90 sql_type
91 );
92 Ok(())
93 }
94 }
95 }
96
97 pub fn execute(&self, db: &Database, name: &str, arguments: &[String]) -> DbxResult<()> {
99 let procedure = self
101 .procedures
102 .get(name)
103 .ok_or_else(|| DbxError::InvalidOperation {
104 message: format!("Procedure '{}' not found", name),
105 context: "CALL PROCEDURE".to_string(),
106 })?;
107
108 if arguments.len() != procedure.parameters.len() {
110 return Err(DbxError::InvalidOperation {
111 message: format!(
112 "Procedure '{}' expects {} arguments, got {}",
113 name,
114 procedure.parameters.len(),
115 arguments.len()
116 ),
117 context: format!("CALL {}", name),
118 });
119 }
120
121 for (i, param) in procedure.parameters.iter().enumerate() {
123 let arg_value = &arguments[i];
124 if let Err(e) = Self::validate_parameter_type(¶m.data_type, arg_value) {
125 return Err(DbxError::InvalidOperation {
126 message: format!(
127 "Procedure '{}' parameter '{}' type mismatch: {}",
128 name, param.name, e
129 ),
130 context: format!("CALL {}", name),
131 });
132 }
133 }
134
135 for sql in &procedure.body {
137 let mut bound_sql = sql.clone();
141 for (i, param) in procedure.parameters.iter().enumerate() {
142 bound_sql = bound_sql.replace(¶m.name, &arguments[i]);
143 }
144
145 match db.execute_sql(&bound_sql) {
147 Ok(_) => {
148 #[cfg(debug_assertions)]
149 println!(
150 "[Procedure] Successfully executed '{}': {}",
151 name, bound_sql
152 );
153 }
154 Err(e) => {
155 return Err(DbxError::InvalidOperation {
157 message: format!("Failed to execute procedure '{}': {}", name, e),
158 context: bound_sql,
159 });
160 }
161 }
162 }
163
164 Ok(())
165 }
166}
167
168impl Default for ProcedureExecutor {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::automation::ProcedureParameter;
178
179 #[test]
180 fn test_procedure_executor_register() {
181 let mut executor = ProcedureExecutor::new();
182
183 let proc = StoredProcedure::new("test_proc", vec![], vec!["SELECT 1".to_string()]);
184
185 executor.register(proc);
186 assert_eq!(executor.list_procedures().len(), 1);
187 }
188
189 #[test]
190 fn test_procedure_executor_unregister() {
191 let mut executor = ProcedureExecutor::new();
192
193 let proc = StoredProcedure::new("test_proc", vec![], vec!["SELECT 1".to_string()]);
194
195 executor.register(proc);
196 assert_eq!(executor.list_procedures().len(), 1);
197
198 let removed = executor.unregister("test_proc");
199 assert!(removed);
200 assert_eq!(executor.list_procedures().len(), 0);
201 }
202
203 #[test]
204 fn test_procedure_executor_get() {
205 let mut executor = ProcedureExecutor::new();
206
207 let proc = StoredProcedure::new(
208 "test_proc",
209 vec![ProcedureParameter {
210 name: "user_id".to_string(),
211 data_type: "INT".to_string(),
212 }],
213 vec!["UPDATE users SET active = 1 WHERE id = user_id".to_string()],
214 );
215
216 executor.register(proc);
217
218 let found = executor.get("test_proc");
219 assert!(found.is_some());
220 assert_eq!(found.unwrap().name, "test_proc");
221 }
222}