1use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum ParameterMode {
61 In,
63 Out,
65 InOut,
67}
68
69impl Default for ParameterMode {
70 fn default() -> Self {
71 Self::In
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct Parameter {
78 pub name: String,
80 pub value: Option<FilterValue>,
82 pub mode: ParameterMode,
84 pub type_hint: Option<String>,
86}
87
88impl Parameter {
89 pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
91 Self {
92 name: name.into(),
93 value: Some(value.into()),
94 mode: ParameterMode::In,
95 type_hint: None,
96 }
97 }
98
99 pub fn output(name: impl Into<String>) -> Self {
101 Self {
102 name: name.into(),
103 value: None,
104 mode: ParameterMode::Out,
105 type_hint: None,
106 }
107 }
108
109 pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
111 Self {
112 name: name.into(),
113 value: Some(value.into()),
114 mode: ParameterMode::InOut,
115 type_hint: None,
116 }
117 }
118
119 pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
121 self.type_hint = Some(type_name.into());
122 self
123 }
124}
125
126#[derive(Debug, Clone, Default)]
128pub struct ProcedureResult {
129 pub outputs: HashMap<String, FilterValue>,
131 pub return_value: Option<FilterValue>,
133 pub rows_affected: Option<u64>,
135}
136
137impl ProcedureResult {
138 pub fn get(&self, name: &str) -> Option<&FilterValue> {
140 self.outputs.get(name)
141 }
142
143 pub fn get_as<T>(&self, name: &str) -> Option<T>
145 where
146 T: TryFrom<FilterValue>,
147 {
148 self.outputs.get(name).and_then(|v| T::try_from(v.clone()).ok())
149 }
150
151 pub fn return_value(&self) -> Option<&FilterValue> {
153 self.return_value.as_ref()
154 }
155
156 pub fn return_value_as<T>(&self) -> Option<T>
158 where
159 T: TryFrom<FilterValue>,
160 {
161 self.return_value.clone().and_then(|v| T::try_from(v).ok())
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct ProcedureCall {
168 pub name: String,
170 pub schema: Option<String>,
172 pub parameters: Vec<Parameter>,
174 pub db_type: DatabaseType,
176 pub is_function: bool,
178}
179
180impl ProcedureCall {
181 pub fn new(name: impl Into<String>) -> Self {
183 Self {
184 name: name.into(),
185 schema: None,
186 parameters: Vec::new(),
187 db_type: DatabaseType::PostgreSQL,
188 is_function: false,
189 }
190 }
191
192 pub fn function(name: impl Into<String>) -> Self {
194 Self {
195 name: name.into(),
196 schema: None,
197 parameters: Vec::new(),
198 db_type: DatabaseType::PostgreSQL,
199 is_function: true,
200 }
201 }
202
203 pub fn schema(mut self, schema: impl Into<String>) -> Self {
205 self.schema = Some(schema.into());
206 self
207 }
208
209 pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
211 self.db_type = db_type;
212 self
213 }
214
215 pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
217 self.parameters.push(Parameter::input(name, value));
218 self
219 }
220
221 pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
223 self.param(name, value)
224 }
225
226 pub fn out_param(mut self, name: impl Into<String>) -> Self {
228 self.parameters.push(Parameter::output(name));
229 self
230 }
231
232 pub fn out_param_typed(mut self, name: impl Into<String>, type_hint: impl Into<String>) -> Self {
234 self.parameters
235 .push(Parameter::output(name).with_type_hint(type_hint));
236 self
237 }
238
239 pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
241 self.parameters.push(Parameter::inout(name, value));
242 self
243 }
244
245 pub fn add_parameter(mut self, param: Parameter) -> Self {
247 self.parameters.push(param);
248 self
249 }
250
251 pub fn qualified_name(&self) -> Cow<'_, str> {
253 match &self.schema {
254 Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
255 None => Cow::Borrowed(&self.name),
256 }
257 }
258
259 pub fn has_outputs(&self) -> bool {
261 self.parameters
262 .iter()
263 .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
264 }
265
266 pub fn input_values(&self) -> Vec<FilterValue> {
268 self.parameters
269 .iter()
270 .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
271 .filter_map(|p| p.value.clone())
272 .collect()
273 }
274
275 pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
277 let name = self.qualified_name();
278 let params = self.input_values();
279 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
280
281 let sql = if self.is_function {
282 format!("SELECT {}({})", name, placeholders.join(", "))
283 } else {
284 format!("CALL {}({})", name, placeholders.join(", "))
285 };
286
287 (sql, params)
288 }
289
290 pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
292 let name = self.qualified_name();
293 let params = self.input_values();
294 let placeholders = vec!["?"; params.len()].join(", ");
295
296 let sql = if self.is_function {
297 format!("SELECT {}({})", name, placeholders)
298 } else {
299 format!("CALL {}({})", name, placeholders)
300 };
301
302 (sql, params)
303 }
304
305 pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
307 let name = self.qualified_name();
308 let params = self.input_values();
309 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
310
311 if self.is_function {
312 (format!("SELECT {}({})", name, placeholders.join(", ")), params)
313 } else if self.has_outputs() {
314 let mut parts = vec![String::from("DECLARE ")];
316
317 let out_params: Vec<_> = self
319 .parameters
320 .iter()
321 .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
322 .collect();
323
324 for (i, param) in out_params.iter().enumerate() {
325 if i > 0 {
326 parts.push(String::from(", "));
327 }
328 let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
329 parts.push(format!("@{} {}", param.name, type_name));
330 }
331 parts.push(String::from("; "));
332
333 parts.push(format!("EXEC {} ", name));
335
336 let param_parts: Vec<String> = self
337 .parameters
338 .iter()
339 .enumerate()
340 .map(|(i, p)| match p.mode {
341 ParameterMode::In => format!("@P{}", i + 1),
342 ParameterMode::Out => format!("@{} OUTPUT", p.name),
343 ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
344 })
345 .collect();
346
347 parts.push(param_parts.join(", "));
348 parts.push(String::from("; "));
349
350 let select_parts: Vec<String> = out_params
352 .iter()
353 .map(|p| format!("@{} AS {}", p.name, p.name))
354 .collect();
355 parts.push(format!("SELECT {}", select_parts.join(", ")));
356
357 (parts.join(""), params)
358 } else {
359 (format!("EXEC {} {}", name, placeholders.join(", ")), params)
360 }
361 }
362
363 pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
365 if !self.is_function {
366 return Err(QueryError::unsupported(
367 "SQLite does not support stored procedures. Use Rust UDFs instead.",
368 ));
369 }
370
371 let name = self.qualified_name();
372 let params = self.input_values();
373 let placeholders = vec!["?"; params.len()].join(", ");
374
375 Ok((format!("SELECT {}({})", name, placeholders), params))
376 }
377
378 pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
380 match self.db_type {
381 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
382 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
383 DatabaseType::SQLite => self.to_sqlite_sql(),
384 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
385 }
386 }
387}
388
389pub struct ProcedureCallOperation<E: QueryEngine> {
391 engine: E,
392 call: ProcedureCall,
393}
394
395impl<E: QueryEngine> ProcedureCallOperation<E> {
396 pub fn new(engine: E, call: ProcedureCall) -> Self {
398 Self { engine, call }
399 }
400
401 pub async fn exec(self) -> QueryResult<ProcedureResult> {
403 let (sql, params) = self.call.to_sql()?;
404 let affected = self.engine.execute_raw(&sql, params).await?;
405
406 Ok(ProcedureResult {
407 outputs: HashMap::new(),
408 return_value: None,
409 rows_affected: Some(affected),
410 })
411 }
412
413 pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
415 where
416 T: crate::traits::Model + Send + 'static,
417 {
418 let (sql, params) = self.call.to_sql()?;
419 self.engine.query_many(&sql, params).await
420 }
421
422 pub async fn exec_scalar<T>(self) -> QueryResult<T>
424 where
425 T: TryFrom<FilterValue, Error = String> + Send + 'static,
426 {
427 let (sql, params) = self.call.to_sql()?;
428 let result = self.engine.execute_raw(&sql, params).await?;
429
430 Err(QueryError::internal(format!(
433 "Scalar function result parsing not yet implemented (affected: {})",
434 result
435 )))
436 }
437}
438
439#[allow(dead_code)]
441pub struct FunctionCallOperation<E: QueryEngine, T> {
442 engine: E,
443 call: ProcedureCall,
444 _marker: PhantomData<T>,
445}
446
447impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
448 pub fn new(engine: E, call: ProcedureCall) -> Self {
450 Self {
451 engine,
452 call,
453 _marker: PhantomData,
454 }
455 }
456}
457
458pub trait ProcedureEngine: QueryEngine {
460 fn call(&self, name: impl Into<String>) -> ProcedureCall {
462 ProcedureCall::new(name)
463 }
464
465 fn function(&self, name: impl Into<String>) -> ProcedureCall {
467 ProcedureCall::function(name)
468 }
469
470 fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
472 where
473 Self: Clone + 'static,
474 {
475 let engine = self.clone();
476 Box::pin(async move {
477 let op = ProcedureCallOperation::new(engine, call);
478 op.exec().await
479 })
480 }
481}
482
483impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
485
486pub mod sqlite_udf {
488 #[allow(unused_imports)]
489 use super::*;
490
491 pub trait SqliteFunction: Send + Sync + 'static {
493 fn name(&self) -> &str;
495
496 fn num_args(&self) -> i32;
498
499 fn deterministic(&self) -> bool {
501 true
502 }
503 }
504
505 #[derive(Debug, Clone)]
507 pub struct ScalarUdf {
508 pub name: String,
510 pub num_args: i32,
512 pub deterministic: bool,
514 }
515
516 impl ScalarUdf {
517 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
519 Self {
520 name: name.into(),
521 num_args,
522 deterministic: true,
523 }
524 }
525
526 pub fn deterministic(mut self, deterministic: bool) -> Self {
528 self.deterministic = deterministic;
529 self
530 }
531 }
532
533 #[derive(Debug, Clone)]
535 pub struct AggregateUdf {
536 pub name: String,
538 pub num_args: i32,
540 }
541
542 impl AggregateUdf {
543 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
545 Self {
546 name: name.into(),
547 num_args,
548 }
549 }
550 }
551
552 #[derive(Debug, Clone)]
554 pub struct WindowUdf {
555 pub name: String,
557 pub num_args: i32,
559 }
560
561 impl WindowUdf {
562 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
564 Self {
565 name: name.into(),
566 num_args,
567 }
568 }
569 }
570}
571
572pub mod mongodb_func {
574 use super::*;
575
576 #[derive(Debug, Clone, Serialize, Deserialize)]
578 pub struct MongoFunction {
579 pub body: String,
581 pub args: Vec<String>,
583 pub lang: String,
585 }
586
587 impl MongoFunction {
588 pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
590 Self {
591 body: body.into(),
592 args: args.into_iter().map(Into::into).collect(),
593 lang: "js".to_string(),
594 }
595 }
596
597 #[cfg(feature = "mongodb")]
599 pub fn to_bson(&self) -> bson::Document {
600 use bson::doc;
601 doc! {
602 "$function": {
603 "body": &self.body,
604 "args": &self.args,
605 "lang": &self.lang,
606 }
607 }
608 }
609 }
610
611 #[derive(Debug, Clone, Serialize, Deserialize)]
613 pub struct MongoAccumulator {
614 pub init: String,
616 pub init_args: Vec<String>,
618 pub accumulate: String,
620 pub accumulate_args: Vec<String>,
622 pub merge: String,
624 pub finalize: Option<String>,
626 pub lang: String,
628 }
629
630 impl MongoAccumulator {
631 pub fn new(
633 init: impl Into<String>,
634 accumulate: impl Into<String>,
635 merge: impl Into<String>,
636 ) -> Self {
637 Self {
638 init: init.into(),
639 init_args: Vec::new(),
640 accumulate: accumulate.into(),
641 accumulate_args: Vec::new(),
642 merge: merge.into(),
643 finalize: None,
644 lang: "js".to_string(),
645 }
646 }
647
648 pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
650 self.init_args = args.into_iter().map(Into::into).collect();
651 self
652 }
653
654 pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
656 self.accumulate_args = args.into_iter().map(Into::into).collect();
657 self
658 }
659
660 pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
662 self.finalize = Some(finalize.into());
663 self
664 }
665
666 #[cfg(feature = "mongodb")]
668 pub fn to_bson(&self) -> bson::Document {
669 use bson::doc;
670 let mut doc = doc! {
671 "$accumulator": {
672 "init": &self.init,
673 "accumulate": &self.accumulate,
674 "accumulateArgs": &self.accumulate_args,
675 "merge": &self.merge,
676 "lang": &self.lang,
677 }
678 };
679
680 if !self.init_args.is_empty() {
681 doc.get_document_mut("$accumulator")
682 .unwrap()
683 .insert("initArgs", &self.init_args);
684 }
685
686 if let Some(ref finalize) = self.finalize {
687 doc.get_document_mut("$accumulator")
688 .unwrap()
689 .insert("finalize", finalize);
690 }
691
692 doc
693 }
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 #[test]
702 fn test_procedure_call_basic() {
703 let call = ProcedureCall::new("get_user")
704 .param("id", 42i32)
705 .param("active", true);
706
707 assert_eq!(call.name, "get_user");
708 assert_eq!(call.parameters.len(), 2);
709 assert!(!call.is_function);
710 }
711
712 #[test]
713 fn test_function_call() {
714 let call = ProcedureCall::function("calculate_tax")
715 .param("amount", 100.0f64)
716 .param("rate", 0.08f64);
717
718 assert_eq!(call.name, "calculate_tax");
719 assert!(call.is_function);
720 }
721
722 #[test]
723 fn test_postgres_sql_generation() {
724 let call = ProcedureCall::new("get_orders")
725 .param("user_id", 42i32)
726 .param("status", "pending".to_string());
727
728 let (sql, params) = call.to_postgres_sql();
729 assert_eq!(sql, "CALL get_orders($1, $2)");
730 assert_eq!(params.len(), 2);
731 }
732
733 #[test]
734 fn test_postgres_function_sql() {
735 let call = ProcedureCall::function("calculate_total")
736 .param("order_id", 123i32);
737
738 let (sql, params) = call.to_postgres_sql();
739 assert_eq!(sql, "SELECT calculate_total($1)");
740 assert_eq!(params.len(), 1);
741 }
742
743 #[test]
744 fn test_mysql_sql_generation() {
745 let call = ProcedureCall::new("get_orders")
746 .with_db_type(DatabaseType::MySQL)
747 .param("user_id", 42i32);
748
749 let (sql, params) = call.to_mysql_sql();
750 assert_eq!(sql, "CALL get_orders(?)");
751 assert_eq!(params.len(), 1);
752 }
753
754 #[test]
755 fn test_mssql_sql_generation() {
756 let call = ProcedureCall::new("GetOrders")
757 .schema("dbo")
758 .with_db_type(DatabaseType::MSSQL)
759 .param("UserId", 42i32);
760
761 let (sql, params) = call.to_mssql_sql();
762 assert!(sql.contains("EXEC dbo.GetOrders"));
763 assert_eq!(params.len(), 1);
764 }
765
766 #[test]
767 fn test_mssql_with_output_params() {
768 let call = ProcedureCall::new("CalculateTotals")
769 .with_db_type(DatabaseType::MSSQL)
770 .in_param("OrderId", 123i32)
771 .out_param_typed("TotalAmount", "DECIMAL(18,2)")
772 .out_param_typed("ItemCount", "INT");
773
774 let (sql, _params) = call.to_mssql_sql();
775 assert!(sql.contains("DECLARE"));
776 assert!(sql.contains("OUTPUT"));
777 assert!(sql.contains("SELECT"));
778 }
779
780 #[test]
781 fn test_sqlite_function() {
782 let call = ProcedureCall::function("custom_hash")
783 .with_db_type(DatabaseType::SQLite)
784 .param("input", "test".to_string());
785
786 let result = call.to_sqlite_sql();
787 assert!(result.is_ok());
788
789 let (sql, params) = result.unwrap();
790 assert_eq!(sql, "SELECT custom_hash(?)");
791 assert_eq!(params.len(), 1);
792 }
793
794 #[test]
795 fn test_sqlite_procedure_error() {
796 let call = ProcedureCall::new("some_procedure")
797 .with_db_type(DatabaseType::SQLite)
798 .param("id", 42i32);
799
800 let result = call.to_sqlite_sql();
801 assert!(result.is_err());
802 }
803
804 #[test]
805 fn test_qualified_name() {
806 let call = ProcedureCall::new("get_user").schema("public");
807 assert_eq!(call.qualified_name(), "public.get_user");
808
809 let call = ProcedureCall::new("get_user");
810 assert_eq!(call.qualified_name(), "get_user");
811 }
812
813 #[test]
814 fn test_parameter_modes() {
815 let call = ProcedureCall::new("calculate")
816 .in_param("input", 100i32)
817 .out_param("result")
818 .inout_param("running_total", 50i32);
819
820 assert_eq!(call.parameters.len(), 3);
821 assert_eq!(call.parameters[0].mode, ParameterMode::In);
822 assert_eq!(call.parameters[1].mode, ParameterMode::Out);
823 assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
824 assert!(call.has_outputs());
825 }
826
827 #[test]
828 fn test_procedure_result() {
829 let mut result = ProcedureResult::default();
830 result.outputs.insert("total".to_string(), FilterValue::Int(100));
831 result.return_value = Some(FilterValue::Bool(true));
832
833 assert!(result.get("total").is_some());
834 assert!(result.get("nonexistent").is_none());
835 assert!(result.return_value().is_some());
836 }
837
838 #[test]
839 fn test_mongo_function() {
840 use mongodb_func::MongoFunction;
841
842 let func = MongoFunction::new(
843 "function(x, y) { return x + y; }",
844 vec!["$field1", "$field2"],
845 );
846
847 assert_eq!(func.lang, "js");
848 assert_eq!(func.args.len(), 2);
849 }
850
851 #[test]
852 fn test_mongo_accumulator() {
853 use mongodb_func::MongoAccumulator;
854
855 let acc = MongoAccumulator::new(
856 "function() { return { sum: 0, count: 0 }; }",
857 "function(state, value) { state.sum += value; state.count++; return state; }",
858 "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
859 )
860 .with_finalize("function(state) { return state.sum / state.count; }")
861 .with_accumulate_args(vec!["$value"]);
862
863 assert!(acc.finalize.is_some());
864 assert_eq!(acc.accumulate_args.len(), 1);
865 }
866
867 #[test]
868 fn test_sqlite_udf_definitions() {
869 use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
870
871 let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
872 assert!(scalar.deterministic);
873
874 let aggregate = AggregateUdf::new("my_sum", 1);
875 assert_eq!(aggregate.num_args, 1);
876
877 let window = WindowUdf::new("my_rank", 0);
878 assert_eq!(window.num_args, 0);
879 }
880}
881