1use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
60pub enum ParameterMode {
61 #[default]
63 In,
64 Out,
66 InOut,
68}
69
70#[derive(Debug, Clone)]
72pub struct Parameter {
73 pub name: String,
75 pub value: Option<FilterValue>,
77 pub mode: ParameterMode,
79 pub type_hint: Option<String>,
81}
82
83impl Parameter {
84 pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
86 Self {
87 name: name.into(),
88 value: Some(value.into()),
89 mode: ParameterMode::In,
90 type_hint: None,
91 }
92 }
93
94 pub fn output(name: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 value: None,
99 mode: ParameterMode::Out,
100 type_hint: None,
101 }
102 }
103
104 pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
106 Self {
107 name: name.into(),
108 value: Some(value.into()),
109 mode: ParameterMode::InOut,
110 type_hint: None,
111 }
112 }
113
114 pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
116 self.type_hint = Some(type_name.into());
117 self
118 }
119}
120
121#[derive(Debug, Clone, Default)]
123pub struct ProcedureResult {
124 pub outputs: HashMap<String, FilterValue>,
126 pub return_value: Option<FilterValue>,
128 pub rows_affected: Option<u64>,
130}
131
132impl ProcedureResult {
133 pub fn get(&self, name: &str) -> Option<&FilterValue> {
135 self.outputs.get(name)
136 }
137
138 pub fn get_as<T>(&self, name: &str) -> Option<T>
140 where
141 T: TryFrom<FilterValue>,
142 {
143 self.outputs
144 .get(name)
145 .and_then(|v| T::try_from(v.clone()).ok())
146 }
147
148 pub fn return_value(&self) -> Option<&FilterValue> {
150 self.return_value.as_ref()
151 }
152
153 pub fn return_value_as<T>(&self) -> Option<T>
155 where
156 T: TryFrom<FilterValue>,
157 {
158 self.return_value.clone().and_then(|v| T::try_from(v).ok())
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct ProcedureCall {
165 pub name: String,
167 pub schema: Option<String>,
169 pub parameters: Vec<Parameter>,
171 pub db_type: DatabaseType,
173 pub is_function: bool,
175}
176
177impl ProcedureCall {
178 pub fn new(name: impl Into<String>) -> Self {
180 Self {
181 name: name.into(),
182 schema: None,
183 parameters: Vec::new(),
184 db_type: DatabaseType::PostgreSQL,
185 is_function: false,
186 }
187 }
188
189 pub fn function(name: impl Into<String>) -> Self {
191 Self {
192 name: name.into(),
193 schema: None,
194 parameters: Vec::new(),
195 db_type: DatabaseType::PostgreSQL,
196 is_function: true,
197 }
198 }
199
200 pub fn schema(mut self, schema: impl Into<String>) -> Self {
202 self.schema = Some(schema.into());
203 self
204 }
205
206 pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
208 self.db_type = db_type;
209 self
210 }
211
212 pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
214 self.parameters.push(Parameter::input(name, value));
215 self
216 }
217
218 pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
220 self.param(name, value)
221 }
222
223 pub fn out_param(mut self, name: impl Into<String>) -> Self {
225 self.parameters.push(Parameter::output(name));
226 self
227 }
228
229 pub fn out_param_typed(
231 mut self,
232 name: impl Into<String>,
233 type_hint: impl Into<String>,
234 ) -> Self {
235 self.parameters
236 .push(Parameter::output(name).with_type_hint(type_hint));
237 self
238 }
239
240 pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
242 self.parameters.push(Parameter::inout(name, value));
243 self
244 }
245
246 pub fn add_parameter(mut self, param: Parameter) -> Self {
248 self.parameters.push(param);
249 self
250 }
251
252 pub fn qualified_name(&self) -> Cow<'_, str> {
254 match &self.schema {
255 Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
256 None => Cow::Borrowed(&self.name),
257 }
258 }
259
260 pub fn has_outputs(&self) -> bool {
262 self.parameters
263 .iter()
264 .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
265 }
266
267 pub fn input_values(&self) -> Vec<FilterValue> {
269 self.parameters
270 .iter()
271 .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
272 .filter_map(|p| p.value.clone())
273 .collect()
274 }
275
276 pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
278 let name = self.qualified_name();
279 let params = self.input_values();
280 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
281
282 let sql = if self.is_function {
283 format!("SELECT {}({})", name, placeholders.join(", "))
284 } else {
285 format!("CALL {}({})", name, placeholders.join(", "))
286 };
287
288 (sql, params)
289 }
290
291 pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
293 let name = self.qualified_name();
294 let params = self.input_values();
295 let placeholders = vec!["?"; params.len()].join(", ");
296
297 let sql = if self.is_function {
298 format!("SELECT {}({})", name, placeholders)
299 } else {
300 format!("CALL {}({})", name, placeholders)
301 };
302
303 (sql, params)
304 }
305
306 pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
308 let name = self.qualified_name();
309 let params = self.input_values();
310 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
311
312 if self.is_function {
313 (
314 format!("SELECT {}({})", name, placeholders.join(", ")),
315 params,
316 )
317 } else if self.has_outputs() {
318 let mut parts = vec![String::from("DECLARE ")];
320
321 let out_params: Vec<_> = self
323 .parameters
324 .iter()
325 .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
326 .collect();
327
328 for (i, param) in out_params.iter().enumerate() {
329 if i > 0 {
330 parts.push(String::from(", "));
331 }
332 let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
333 parts.push(format!("@{} {}", param.name, type_name));
334 }
335 parts.push(String::from("; "));
336
337 parts.push(format!("EXEC {} ", name));
339
340 let param_parts: Vec<String> = self
341 .parameters
342 .iter()
343 .enumerate()
344 .map(|(i, p)| match p.mode {
345 ParameterMode::In => format!("@P{}", i + 1),
346 ParameterMode::Out => format!("@{} OUTPUT", p.name),
347 ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
348 })
349 .collect();
350
351 parts.push(param_parts.join(", "));
352 parts.push(String::from("; "));
353
354 let select_parts: Vec<String> = out_params
356 .iter()
357 .map(|p| format!("@{} AS {}", p.name, p.name))
358 .collect();
359 parts.push(format!("SELECT {}", select_parts.join(", ")));
360
361 (parts.join(""), params)
362 } else {
363 (format!("EXEC {} {}", name, placeholders.join(", ")), params)
364 }
365 }
366
367 pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
369 if !self.is_function {
370 return Err(QueryError::unsupported(
371 "SQLite does not support stored procedures. Use Rust UDFs instead.",
372 ));
373 }
374
375 let name = self.qualified_name();
376 let params = self.input_values();
377 let placeholders = vec!["?"; params.len()].join(", ");
378
379 Ok((format!("SELECT {}({})", name, placeholders), params))
380 }
381
382 pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
384 match self.db_type {
385 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
386 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
387 DatabaseType::SQLite => self.to_sqlite_sql(),
388 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
389 }
390 }
391}
392
393pub struct ProcedureCallOperation<E: QueryEngine> {
395 engine: E,
396 call: ProcedureCall,
397}
398
399impl<E: QueryEngine> ProcedureCallOperation<E> {
400 pub fn new(engine: E, call: ProcedureCall) -> Self {
402 Self { engine, call }
403 }
404
405 pub async fn exec(self) -> QueryResult<ProcedureResult> {
407 let (sql, params) = self.call.to_sql()?;
408 let affected = self.engine.execute_raw(&sql, params).await?;
409
410 Ok(ProcedureResult {
411 outputs: HashMap::new(),
412 return_value: None,
413 rows_affected: Some(affected),
414 })
415 }
416
417 pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
419 where
420 T: crate::traits::Model + crate::row::FromRow + Send + 'static,
421 {
422 let (sql, params) = self.call.to_sql()?;
423 self.engine.query_many(&sql, params).await
424 }
425
426 pub async fn exec_scalar<T>(self) -> QueryResult<T>
434 where
435 T: TryFrom<FilterValue, Error = String> + Send + 'static,
436 {
437 let (sql, params) = self.call.to_sql()?;
438 let mut rows = self.engine.aggregate_query(&sql, params).await?;
439 let first = rows
440 .drain(..)
441 .next()
442 .ok_or_else(|| QueryError::not_found("scalar function returned no row".to_string()))?;
443 let value = first.into_values().next().ok_or_else(|| {
446 QueryError::deserialization(
447 "scalar function returned a row with no columns".to_string(),
448 )
449 })?;
450 T::try_from(value).map_err(QueryError::deserialization)
451 }
452}
453
454#[allow(dead_code)]
456pub struct FunctionCallOperation<E: QueryEngine, T> {
457 engine: E,
458 call: ProcedureCall,
459 _marker: PhantomData<T>,
460}
461
462impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
463 pub fn new(engine: E, call: ProcedureCall) -> Self {
465 Self {
466 engine,
467 call,
468 _marker: PhantomData,
469 }
470 }
471}
472
473pub trait ProcedureEngine: QueryEngine {
475 fn call(&self, name: impl Into<String>) -> ProcedureCall {
477 ProcedureCall::new(name)
478 }
479
480 fn function(&self, name: impl Into<String>) -> ProcedureCall {
482 ProcedureCall::function(name)
483 }
484
485 fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
487 where
488 Self: Clone + 'static,
489 {
490 let engine = self.clone();
491 Box::pin(async move {
492 let op = ProcedureCallOperation::new(engine, call);
493 op.exec().await
494 })
495 }
496}
497
498impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
500
501pub mod sqlite_udf {
503 #[allow(unused_imports)]
504 use super::*;
505
506 pub trait SqliteFunction: Send + Sync + 'static {
508 fn name(&self) -> &str;
510
511 fn num_args(&self) -> i32;
513
514 fn deterministic(&self) -> bool {
516 true
517 }
518 }
519
520 #[derive(Debug, Clone)]
522 pub struct ScalarUdf {
523 pub name: String,
525 pub num_args: i32,
527 pub deterministic: bool,
529 }
530
531 impl ScalarUdf {
532 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
534 Self {
535 name: name.into(),
536 num_args,
537 deterministic: true,
538 }
539 }
540
541 pub fn deterministic(mut self, deterministic: bool) -> Self {
543 self.deterministic = deterministic;
544 self
545 }
546 }
547
548 #[derive(Debug, Clone)]
550 pub struct AggregateUdf {
551 pub name: String,
553 pub num_args: i32,
555 }
556
557 impl AggregateUdf {
558 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
560 Self {
561 name: name.into(),
562 num_args,
563 }
564 }
565 }
566
567 #[derive(Debug, Clone)]
569 pub struct WindowUdf {
570 pub name: String,
572 pub num_args: i32,
574 }
575
576 impl WindowUdf {
577 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
579 Self {
580 name: name.into(),
581 num_args,
582 }
583 }
584 }
585}
586
587pub mod mongodb_func {
589 use super::*;
590
591 #[derive(Debug, Clone, Serialize, Deserialize)]
593 pub struct MongoFunction {
594 pub body: String,
596 pub args: Vec<String>,
598 pub lang: String,
600 }
601
602 impl MongoFunction {
603 pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
605 Self {
606 body: body.into(),
607 args: args.into_iter().map(Into::into).collect(),
608 lang: "js".to_string(),
609 }
610 }
611
612 #[cfg(feature = "mongodb")]
614 pub fn to_bson(&self) -> bson::Document {
615 use bson::doc;
616 doc! {
617 "$function": {
618 "body": &self.body,
619 "args": &self.args,
620 "lang": &self.lang,
621 }
622 }
623 }
624 }
625
626 #[derive(Debug, Clone, Serialize, Deserialize)]
628 pub struct MongoAccumulator {
629 pub init: String,
631 pub init_args: Vec<String>,
633 pub accumulate: String,
635 pub accumulate_args: Vec<String>,
637 pub merge: String,
639 pub finalize: Option<String>,
641 pub lang: String,
643 }
644
645 impl MongoAccumulator {
646 pub fn new(
648 init: impl Into<String>,
649 accumulate: impl Into<String>,
650 merge: impl Into<String>,
651 ) -> Self {
652 Self {
653 init: init.into(),
654 init_args: Vec::new(),
655 accumulate: accumulate.into(),
656 accumulate_args: Vec::new(),
657 merge: merge.into(),
658 finalize: None,
659 lang: "js".to_string(),
660 }
661 }
662
663 pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
665 self.init_args = args.into_iter().map(Into::into).collect();
666 self
667 }
668
669 pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
671 self.accumulate_args = args.into_iter().map(Into::into).collect();
672 self
673 }
674
675 pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
677 self.finalize = Some(finalize.into());
678 self
679 }
680
681 #[cfg(feature = "mongodb")]
683 pub fn to_bson(&self) -> bson::Document {
684 use bson::doc;
685 let mut doc = doc! {
686 "$accumulator": {
687 "init": &self.init,
688 "accumulate": &self.accumulate,
689 "accumulateArgs": &self.accumulate_args,
690 "merge": &self.merge,
691 "lang": &self.lang,
692 }
693 };
694
695 if !self.init_args.is_empty() {
696 doc.get_document_mut("$accumulator")
697 .unwrap()
698 .insert("initArgs", &self.init_args);
699 }
700
701 if let Some(ref finalize) = self.finalize {
702 doc.get_document_mut("$accumulator")
703 .unwrap()
704 .insert("finalize", finalize);
705 }
706
707 doc
708 }
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn test_procedure_call_basic() {
718 let call = ProcedureCall::new("get_user")
719 .param("id", 42i32)
720 .param("active", true);
721
722 assert_eq!(call.name, "get_user");
723 assert_eq!(call.parameters.len(), 2);
724 assert!(!call.is_function);
725 }
726
727 #[test]
728 fn test_function_call() {
729 let call = ProcedureCall::function("calculate_tax")
730 .param("amount", 100.0f64)
731 .param("rate", 0.08f64);
732
733 assert_eq!(call.name, "calculate_tax");
734 assert!(call.is_function);
735 }
736
737 #[test]
738 fn test_postgres_sql_generation() {
739 let call = ProcedureCall::new("get_orders")
740 .param("user_id", 42i32)
741 .param("status", "pending".to_string());
742
743 let (sql, params) = call.to_postgres_sql();
744 assert_eq!(sql, "CALL get_orders($1, $2)");
745 assert_eq!(params.len(), 2);
746 }
747
748 #[test]
749 fn test_postgres_function_sql() {
750 let call = ProcedureCall::function("calculate_total").param("order_id", 123i32);
751
752 let (sql, params) = call.to_postgres_sql();
753 assert_eq!(sql, "SELECT calculate_total($1)");
754 assert_eq!(params.len(), 1);
755 }
756
757 #[test]
758 fn test_mysql_sql_generation() {
759 let call = ProcedureCall::new("get_orders")
760 .with_db_type(DatabaseType::MySQL)
761 .param("user_id", 42i32);
762
763 let (sql, params) = call.to_mysql_sql();
764 assert_eq!(sql, "CALL get_orders(?)");
765 assert_eq!(params.len(), 1);
766 }
767
768 #[test]
769 fn test_mssql_sql_generation() {
770 let call = ProcedureCall::new("GetOrders")
771 .schema("dbo")
772 .with_db_type(DatabaseType::MSSQL)
773 .param("UserId", 42i32);
774
775 let (sql, params) = call.to_mssql_sql();
776 assert!(sql.contains("EXEC dbo.GetOrders"));
777 assert_eq!(params.len(), 1);
778 }
779
780 #[test]
781 fn test_mssql_with_output_params() {
782 let call = ProcedureCall::new("CalculateTotals")
783 .with_db_type(DatabaseType::MSSQL)
784 .in_param("OrderId", 123i32)
785 .out_param_typed("TotalAmount", "DECIMAL(18,2)")
786 .out_param_typed("ItemCount", "INT");
787
788 let (sql, _params) = call.to_mssql_sql();
789 assert!(sql.contains("DECLARE"));
790 assert!(sql.contains("OUTPUT"));
791 assert!(sql.contains("SELECT"));
792 }
793
794 #[test]
795 fn test_sqlite_function() {
796 let call = ProcedureCall::function("custom_hash")
797 .with_db_type(DatabaseType::SQLite)
798 .param("input", "test".to_string());
799
800 let result = call.to_sqlite_sql();
801 assert!(result.is_ok());
802
803 let (sql, params) = result.unwrap();
804 assert_eq!(sql, "SELECT custom_hash(?)");
805 assert_eq!(params.len(), 1);
806 }
807
808 #[test]
809 fn test_sqlite_procedure_error() {
810 let call = ProcedureCall::new("some_procedure")
811 .with_db_type(DatabaseType::SQLite)
812 .param("id", 42i32);
813
814 let result = call.to_sqlite_sql();
815 assert!(result.is_err());
816 }
817
818 #[test]
819 fn test_qualified_name() {
820 let call = ProcedureCall::new("get_user").schema("public");
821 assert_eq!(call.qualified_name(), "public.get_user");
822
823 let call = ProcedureCall::new("get_user");
824 assert_eq!(call.qualified_name(), "get_user");
825 }
826
827 #[test]
828 fn test_parameter_modes() {
829 let call = ProcedureCall::new("calculate")
830 .in_param("input", 100i32)
831 .out_param("result")
832 .inout_param("running_total", 50i32);
833
834 assert_eq!(call.parameters.len(), 3);
835 assert_eq!(call.parameters[0].mode, ParameterMode::In);
836 assert_eq!(call.parameters[1].mode, ParameterMode::Out);
837 assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
838 assert!(call.has_outputs());
839 }
840
841 #[test]
842 fn test_procedure_result() {
843 let mut result = ProcedureResult::default();
844 result
845 .outputs
846 .insert("total".to_string(), FilterValue::Int(100));
847 result.return_value = Some(FilterValue::Bool(true));
848
849 assert!(result.get("total").is_some());
850 assert!(result.get("nonexistent").is_none());
851 assert!(result.return_value().is_some());
852 }
853
854 #[test]
855 fn test_mongo_function() {
856 use mongodb_func::MongoFunction;
857
858 let func = MongoFunction::new(
859 "function(x, y) { return x + y; }",
860 vec!["$field1", "$field2"],
861 );
862
863 assert_eq!(func.lang, "js");
864 assert_eq!(func.args.len(), 2);
865 }
866
867 #[test]
868 fn test_mongo_accumulator() {
869 use mongodb_func::MongoAccumulator;
870
871 let acc = MongoAccumulator::new(
872 "function() { return { sum: 0, count: 0 }; }",
873 "function(state, value) { state.sum += value; state.count++; return state; }",
874 "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
875 )
876 .with_finalize("function(state) { return state.sum / state.count; }")
877 .with_accumulate_args(vec!["$value"]);
878
879 assert!(acc.finalize.is_some());
880 assert_eq!(acc.accumulate_args.len(), 1);
881 }
882
883 #[test]
884 fn test_sqlite_udf_definitions() {
885 use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
886
887 let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
888 assert!(scalar.deterministic);
889
890 let aggregate = AggregateUdf::new("my_sum", 1);
891 assert_eq!(aggregate.num_args, 1);
892
893 let window = WindowUdf::new("my_rank", 0);
894 assert_eq!(window.num_args, 0);
895 }
896}