1use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum ParameterMode {
61 In,
63 Out,
65 InOut,
67}
68
69impl Default for ParameterMode {
70 fn default() -> Self {
71 Self::In
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct Parameter {
78 pub name: String,
80 pub value: Option<FilterValue>,
82 pub mode: ParameterMode,
84 pub type_hint: Option<String>,
86}
87
88impl Parameter {
89 pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
91 Self {
92 name: name.into(),
93 value: Some(value.into()),
94 mode: ParameterMode::In,
95 type_hint: None,
96 }
97 }
98
99 pub fn output(name: impl Into<String>) -> Self {
101 Self {
102 name: name.into(),
103 value: None,
104 mode: ParameterMode::Out,
105 type_hint: None,
106 }
107 }
108
109 pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
111 Self {
112 name: name.into(),
113 value: Some(value.into()),
114 mode: ParameterMode::InOut,
115 type_hint: None,
116 }
117 }
118
119 pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
121 self.type_hint = Some(type_name.into());
122 self
123 }
124}
125
126#[derive(Debug, Clone, Default)]
128pub struct ProcedureResult {
129 pub outputs: HashMap<String, FilterValue>,
131 pub return_value: Option<FilterValue>,
133 pub rows_affected: Option<u64>,
135}
136
137impl ProcedureResult {
138 pub fn get(&self, name: &str) -> Option<&FilterValue> {
140 self.outputs.get(name)
141 }
142
143 pub fn get_as<T>(&self, name: &str) -> Option<T>
145 where
146 T: TryFrom<FilterValue>,
147 {
148 self.outputs
149 .get(name)
150 .and_then(|v| T::try_from(v.clone()).ok())
151 }
152
153 pub fn return_value(&self) -> Option<&FilterValue> {
155 self.return_value.as_ref()
156 }
157
158 pub fn return_value_as<T>(&self) -> Option<T>
160 where
161 T: TryFrom<FilterValue>,
162 {
163 self.return_value.clone().and_then(|v| T::try_from(v).ok())
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct ProcedureCall {
170 pub name: String,
172 pub schema: Option<String>,
174 pub parameters: Vec<Parameter>,
176 pub db_type: DatabaseType,
178 pub is_function: bool,
180}
181
182impl ProcedureCall {
183 pub fn new(name: impl Into<String>) -> Self {
185 Self {
186 name: name.into(),
187 schema: None,
188 parameters: Vec::new(),
189 db_type: DatabaseType::PostgreSQL,
190 is_function: false,
191 }
192 }
193
194 pub fn function(name: impl Into<String>) -> Self {
196 Self {
197 name: name.into(),
198 schema: None,
199 parameters: Vec::new(),
200 db_type: DatabaseType::PostgreSQL,
201 is_function: true,
202 }
203 }
204
205 pub fn schema(mut self, schema: impl Into<String>) -> Self {
207 self.schema = Some(schema.into());
208 self
209 }
210
211 pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
213 self.db_type = db_type;
214 self
215 }
216
217 pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
219 self.parameters.push(Parameter::input(name, value));
220 self
221 }
222
223 pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
225 self.param(name, value)
226 }
227
228 pub fn out_param(mut self, name: impl Into<String>) -> Self {
230 self.parameters.push(Parameter::output(name));
231 self
232 }
233
234 pub fn out_param_typed(
236 mut self,
237 name: impl Into<String>,
238 type_hint: impl Into<String>,
239 ) -> Self {
240 self.parameters
241 .push(Parameter::output(name).with_type_hint(type_hint));
242 self
243 }
244
245 pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
247 self.parameters.push(Parameter::inout(name, value));
248 self
249 }
250
251 pub fn add_parameter(mut self, param: Parameter) -> Self {
253 self.parameters.push(param);
254 self
255 }
256
257 pub fn qualified_name(&self) -> Cow<'_, str> {
259 match &self.schema {
260 Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
261 None => Cow::Borrowed(&self.name),
262 }
263 }
264
265 pub fn has_outputs(&self) -> bool {
267 self.parameters
268 .iter()
269 .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
270 }
271
272 pub fn input_values(&self) -> Vec<FilterValue> {
274 self.parameters
275 .iter()
276 .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
277 .filter_map(|p| p.value.clone())
278 .collect()
279 }
280
281 pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
283 let name = self.qualified_name();
284 let params = self.input_values();
285 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
286
287 let sql = if self.is_function {
288 format!("SELECT {}({})", name, placeholders.join(", "))
289 } else {
290 format!("CALL {}({})", name, placeholders.join(", "))
291 };
292
293 (sql, params)
294 }
295
296 pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
298 let name = self.qualified_name();
299 let params = self.input_values();
300 let placeholders = vec!["?"; params.len()].join(", ");
301
302 let sql = if self.is_function {
303 format!("SELECT {}({})", name, placeholders)
304 } else {
305 format!("CALL {}({})", name, placeholders)
306 };
307
308 (sql, params)
309 }
310
311 pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
313 let name = self.qualified_name();
314 let params = self.input_values();
315 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
316
317 if self.is_function {
318 (
319 format!("SELECT {}({})", name, placeholders.join(", ")),
320 params,
321 )
322 } else if self.has_outputs() {
323 let mut parts = vec![String::from("DECLARE ")];
325
326 let out_params: Vec<_> = self
328 .parameters
329 .iter()
330 .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
331 .collect();
332
333 for (i, param) in out_params.iter().enumerate() {
334 if i > 0 {
335 parts.push(String::from(", "));
336 }
337 let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
338 parts.push(format!("@{} {}", param.name, type_name));
339 }
340 parts.push(String::from("; "));
341
342 parts.push(format!("EXEC {} ", name));
344
345 let param_parts: Vec<String> = self
346 .parameters
347 .iter()
348 .enumerate()
349 .map(|(i, p)| match p.mode {
350 ParameterMode::In => format!("@P{}", i + 1),
351 ParameterMode::Out => format!("@{} OUTPUT", p.name),
352 ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
353 })
354 .collect();
355
356 parts.push(param_parts.join(", "));
357 parts.push(String::from("; "));
358
359 let select_parts: Vec<String> = out_params
361 .iter()
362 .map(|p| format!("@{} AS {}", p.name, p.name))
363 .collect();
364 parts.push(format!("SELECT {}", select_parts.join(", ")));
365
366 (parts.join(""), params)
367 } else {
368 (format!("EXEC {} {}", name, placeholders.join(", ")), params)
369 }
370 }
371
372 pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
374 if !self.is_function {
375 return Err(QueryError::unsupported(
376 "SQLite does not support stored procedures. Use Rust UDFs instead.",
377 ));
378 }
379
380 let name = self.qualified_name();
381 let params = self.input_values();
382 let placeholders = vec!["?"; params.len()].join(", ");
383
384 Ok((format!("SELECT {}({})", name, placeholders), params))
385 }
386
387 pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
389 match self.db_type {
390 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
391 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
392 DatabaseType::SQLite => self.to_sqlite_sql(),
393 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
394 }
395 }
396}
397
398pub struct ProcedureCallOperation<E: QueryEngine> {
400 engine: E,
401 call: ProcedureCall,
402}
403
404impl<E: QueryEngine> ProcedureCallOperation<E> {
405 pub fn new(engine: E, call: ProcedureCall) -> Self {
407 Self { engine, call }
408 }
409
410 pub async fn exec(self) -> QueryResult<ProcedureResult> {
412 let (sql, params) = self.call.to_sql()?;
413 let affected = self.engine.execute_raw(&sql, params).await?;
414
415 Ok(ProcedureResult {
416 outputs: HashMap::new(),
417 return_value: None,
418 rows_affected: Some(affected),
419 })
420 }
421
422 pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
424 where
425 T: crate::traits::Model + crate::row::FromRow + Send + 'static,
426 {
427 let (sql, params) = self.call.to_sql()?;
428 self.engine.query_many(&sql, params).await
429 }
430
431 pub async fn exec_scalar<T>(self) -> QueryResult<T>
439 where
440 T: TryFrom<FilterValue, Error = String> + Send + 'static,
441 {
442 let (sql, params) = self.call.to_sql()?;
443 let mut rows = self.engine.aggregate_query(&sql, params).await?;
444 let first = rows
445 .drain(..)
446 .next()
447 .ok_or_else(|| QueryError::not_found("scalar function returned no row".to_string()))?;
448 let value = first.into_values().next().ok_or_else(|| {
451 QueryError::deserialization(
452 "scalar function returned a row with no columns".to_string(),
453 )
454 })?;
455 T::try_from(value).map_err(QueryError::deserialization)
456 }
457}
458
459#[allow(dead_code)]
461pub struct FunctionCallOperation<E: QueryEngine, T> {
462 engine: E,
463 call: ProcedureCall,
464 _marker: PhantomData<T>,
465}
466
467impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
468 pub fn new(engine: E, call: ProcedureCall) -> Self {
470 Self {
471 engine,
472 call,
473 _marker: PhantomData,
474 }
475 }
476}
477
478pub trait ProcedureEngine: QueryEngine {
480 fn call(&self, name: impl Into<String>) -> ProcedureCall {
482 ProcedureCall::new(name)
483 }
484
485 fn function(&self, name: impl Into<String>) -> ProcedureCall {
487 ProcedureCall::function(name)
488 }
489
490 fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
492 where
493 Self: Clone + 'static,
494 {
495 let engine = self.clone();
496 Box::pin(async move {
497 let op = ProcedureCallOperation::new(engine, call);
498 op.exec().await
499 })
500 }
501}
502
503impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
505
506pub mod sqlite_udf {
508 #[allow(unused_imports)]
509 use super::*;
510
511 pub trait SqliteFunction: Send + Sync + 'static {
513 fn name(&self) -> &str;
515
516 fn num_args(&self) -> i32;
518
519 fn deterministic(&self) -> bool {
521 true
522 }
523 }
524
525 #[derive(Debug, Clone)]
527 pub struct ScalarUdf {
528 pub name: String,
530 pub num_args: i32,
532 pub deterministic: bool,
534 }
535
536 impl ScalarUdf {
537 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
539 Self {
540 name: name.into(),
541 num_args,
542 deterministic: true,
543 }
544 }
545
546 pub fn deterministic(mut self, deterministic: bool) -> Self {
548 self.deterministic = deterministic;
549 self
550 }
551 }
552
553 #[derive(Debug, Clone)]
555 pub struct AggregateUdf {
556 pub name: String,
558 pub num_args: i32,
560 }
561
562 impl AggregateUdf {
563 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
565 Self {
566 name: name.into(),
567 num_args,
568 }
569 }
570 }
571
572 #[derive(Debug, Clone)]
574 pub struct WindowUdf {
575 pub name: String,
577 pub num_args: i32,
579 }
580
581 impl WindowUdf {
582 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
584 Self {
585 name: name.into(),
586 num_args,
587 }
588 }
589 }
590}
591
592pub mod mongodb_func {
594 use super::*;
595
596 #[derive(Debug, Clone, Serialize, Deserialize)]
598 pub struct MongoFunction {
599 pub body: String,
601 pub args: Vec<String>,
603 pub lang: String,
605 }
606
607 impl MongoFunction {
608 pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
610 Self {
611 body: body.into(),
612 args: args.into_iter().map(Into::into).collect(),
613 lang: "js".to_string(),
614 }
615 }
616
617 #[cfg(feature = "mongodb")]
619 pub fn to_bson(&self) -> bson::Document {
620 use bson::doc;
621 doc! {
622 "$function": {
623 "body": &self.body,
624 "args": &self.args,
625 "lang": &self.lang,
626 }
627 }
628 }
629 }
630
631 #[derive(Debug, Clone, Serialize, Deserialize)]
633 pub struct MongoAccumulator {
634 pub init: String,
636 pub init_args: Vec<String>,
638 pub accumulate: String,
640 pub accumulate_args: Vec<String>,
642 pub merge: String,
644 pub finalize: Option<String>,
646 pub lang: String,
648 }
649
650 impl MongoAccumulator {
651 pub fn new(
653 init: impl Into<String>,
654 accumulate: impl Into<String>,
655 merge: impl Into<String>,
656 ) -> Self {
657 Self {
658 init: init.into(),
659 init_args: Vec::new(),
660 accumulate: accumulate.into(),
661 accumulate_args: Vec::new(),
662 merge: merge.into(),
663 finalize: None,
664 lang: "js".to_string(),
665 }
666 }
667
668 pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
670 self.init_args = args.into_iter().map(Into::into).collect();
671 self
672 }
673
674 pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
676 self.accumulate_args = args.into_iter().map(Into::into).collect();
677 self
678 }
679
680 pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
682 self.finalize = Some(finalize.into());
683 self
684 }
685
686 #[cfg(feature = "mongodb")]
688 pub fn to_bson(&self) -> bson::Document {
689 use bson::doc;
690 let mut doc = doc! {
691 "$accumulator": {
692 "init": &self.init,
693 "accumulate": &self.accumulate,
694 "accumulateArgs": &self.accumulate_args,
695 "merge": &self.merge,
696 "lang": &self.lang,
697 }
698 };
699
700 if !self.init_args.is_empty() {
701 doc.get_document_mut("$accumulator")
702 .unwrap()
703 .insert("initArgs", &self.init_args);
704 }
705
706 if let Some(ref finalize) = self.finalize {
707 doc.get_document_mut("$accumulator")
708 .unwrap()
709 .insert("finalize", finalize);
710 }
711
712 doc
713 }
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720
721 #[test]
722 fn test_procedure_call_basic() {
723 let call = ProcedureCall::new("get_user")
724 .param("id", 42i32)
725 .param("active", true);
726
727 assert_eq!(call.name, "get_user");
728 assert_eq!(call.parameters.len(), 2);
729 assert!(!call.is_function);
730 }
731
732 #[test]
733 fn test_function_call() {
734 let call = ProcedureCall::function("calculate_tax")
735 .param("amount", 100.0f64)
736 .param("rate", 0.08f64);
737
738 assert_eq!(call.name, "calculate_tax");
739 assert!(call.is_function);
740 }
741
742 #[test]
743 fn test_postgres_sql_generation() {
744 let call = ProcedureCall::new("get_orders")
745 .param("user_id", 42i32)
746 .param("status", "pending".to_string());
747
748 let (sql, params) = call.to_postgres_sql();
749 assert_eq!(sql, "CALL get_orders($1, $2)");
750 assert_eq!(params.len(), 2);
751 }
752
753 #[test]
754 fn test_postgres_function_sql() {
755 let call = ProcedureCall::function("calculate_total").param("order_id", 123i32);
756
757 let (sql, params) = call.to_postgres_sql();
758 assert_eq!(sql, "SELECT calculate_total($1)");
759 assert_eq!(params.len(), 1);
760 }
761
762 #[test]
763 fn test_mysql_sql_generation() {
764 let call = ProcedureCall::new("get_orders")
765 .with_db_type(DatabaseType::MySQL)
766 .param("user_id", 42i32);
767
768 let (sql, params) = call.to_mysql_sql();
769 assert_eq!(sql, "CALL get_orders(?)");
770 assert_eq!(params.len(), 1);
771 }
772
773 #[test]
774 fn test_mssql_sql_generation() {
775 let call = ProcedureCall::new("GetOrders")
776 .schema("dbo")
777 .with_db_type(DatabaseType::MSSQL)
778 .param("UserId", 42i32);
779
780 let (sql, params) = call.to_mssql_sql();
781 assert!(sql.contains("EXEC dbo.GetOrders"));
782 assert_eq!(params.len(), 1);
783 }
784
785 #[test]
786 fn test_mssql_with_output_params() {
787 let call = ProcedureCall::new("CalculateTotals")
788 .with_db_type(DatabaseType::MSSQL)
789 .in_param("OrderId", 123i32)
790 .out_param_typed("TotalAmount", "DECIMAL(18,2)")
791 .out_param_typed("ItemCount", "INT");
792
793 let (sql, _params) = call.to_mssql_sql();
794 assert!(sql.contains("DECLARE"));
795 assert!(sql.contains("OUTPUT"));
796 assert!(sql.contains("SELECT"));
797 }
798
799 #[test]
800 fn test_sqlite_function() {
801 let call = ProcedureCall::function("custom_hash")
802 .with_db_type(DatabaseType::SQLite)
803 .param("input", "test".to_string());
804
805 let result = call.to_sqlite_sql();
806 assert!(result.is_ok());
807
808 let (sql, params) = result.unwrap();
809 assert_eq!(sql, "SELECT custom_hash(?)");
810 assert_eq!(params.len(), 1);
811 }
812
813 #[test]
814 fn test_sqlite_procedure_error() {
815 let call = ProcedureCall::new("some_procedure")
816 .with_db_type(DatabaseType::SQLite)
817 .param("id", 42i32);
818
819 let result = call.to_sqlite_sql();
820 assert!(result.is_err());
821 }
822
823 #[test]
824 fn test_qualified_name() {
825 let call = ProcedureCall::new("get_user").schema("public");
826 assert_eq!(call.qualified_name(), "public.get_user");
827
828 let call = ProcedureCall::new("get_user");
829 assert_eq!(call.qualified_name(), "get_user");
830 }
831
832 #[test]
833 fn test_parameter_modes() {
834 let call = ProcedureCall::new("calculate")
835 .in_param("input", 100i32)
836 .out_param("result")
837 .inout_param("running_total", 50i32);
838
839 assert_eq!(call.parameters.len(), 3);
840 assert_eq!(call.parameters[0].mode, ParameterMode::In);
841 assert_eq!(call.parameters[1].mode, ParameterMode::Out);
842 assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
843 assert!(call.has_outputs());
844 }
845
846 #[test]
847 fn test_procedure_result() {
848 let mut result = ProcedureResult::default();
849 result
850 .outputs
851 .insert("total".to_string(), FilterValue::Int(100));
852 result.return_value = Some(FilterValue::Bool(true));
853
854 assert!(result.get("total").is_some());
855 assert!(result.get("nonexistent").is_none());
856 assert!(result.return_value().is_some());
857 }
858
859 #[test]
860 fn test_mongo_function() {
861 use mongodb_func::MongoFunction;
862
863 let func = MongoFunction::new(
864 "function(x, y) { return x + y; }",
865 vec!["$field1", "$field2"],
866 );
867
868 assert_eq!(func.lang, "js");
869 assert_eq!(func.args.len(), 2);
870 }
871
872 #[test]
873 fn test_mongo_accumulator() {
874 use mongodb_func::MongoAccumulator;
875
876 let acc = MongoAccumulator::new(
877 "function() { return { sum: 0, count: 0 }; }",
878 "function(state, value) { state.sum += value; state.count++; return state; }",
879 "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
880 )
881 .with_finalize("function(state) { return state.sum / state.count; }")
882 .with_accumulate_args(vec!["$value"]);
883
884 assert!(acc.finalize.is_some());
885 assert_eq!(acc.accumulate_args.len(), 1);
886 }
887
888 #[test]
889 fn test_sqlite_udf_definitions() {
890 use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
891
892 let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
893 assert!(scalar.deterministic);
894
895 let aggregate = AggregateUdf::new("my_sum", 1);
896 assert_eq!(aggregate.num_args, 1);
897
898 let window = WindowUdf::new("my_rank", 0);
899 assert_eq!(window.num_args, 0);
900 }
901}