1use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum ParameterMode {
61 In,
63 Out,
65 InOut,
67}
68
69impl Default for ParameterMode {
70 fn default() -> Self {
71 Self::In
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct Parameter {
78 pub name: String,
80 pub value: Option<FilterValue>,
82 pub mode: ParameterMode,
84 pub type_hint: Option<String>,
86}
87
88impl Parameter {
89 pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
91 Self {
92 name: name.into(),
93 value: Some(value.into()),
94 mode: ParameterMode::In,
95 type_hint: None,
96 }
97 }
98
99 pub fn output(name: impl Into<String>) -> Self {
101 Self {
102 name: name.into(),
103 value: None,
104 mode: ParameterMode::Out,
105 type_hint: None,
106 }
107 }
108
109 pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
111 Self {
112 name: name.into(),
113 value: Some(value.into()),
114 mode: ParameterMode::InOut,
115 type_hint: None,
116 }
117 }
118
119 pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
121 self.type_hint = Some(type_name.into());
122 self
123 }
124}
125
126#[derive(Debug, Clone, Default)]
128pub struct ProcedureResult {
129 pub outputs: HashMap<String, FilterValue>,
131 pub return_value: Option<FilterValue>,
133 pub rows_affected: Option<u64>,
135}
136
137impl ProcedureResult {
138 pub fn get(&self, name: &str) -> Option<&FilterValue> {
140 self.outputs.get(name)
141 }
142
143 pub fn get_as<T>(&self, name: &str) -> Option<T>
145 where
146 T: TryFrom<FilterValue>,
147 {
148 self.outputs
149 .get(name)
150 .and_then(|v| T::try_from(v.clone()).ok())
151 }
152
153 pub fn return_value(&self) -> Option<&FilterValue> {
155 self.return_value.as_ref()
156 }
157
158 pub fn return_value_as<T>(&self) -> Option<T>
160 where
161 T: TryFrom<FilterValue>,
162 {
163 self.return_value.clone().and_then(|v| T::try_from(v).ok())
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct ProcedureCall {
170 pub name: String,
172 pub schema: Option<String>,
174 pub parameters: Vec<Parameter>,
176 pub db_type: DatabaseType,
178 pub is_function: bool,
180}
181
182impl ProcedureCall {
183 pub fn new(name: impl Into<String>) -> Self {
185 Self {
186 name: name.into(),
187 schema: None,
188 parameters: Vec::new(),
189 db_type: DatabaseType::PostgreSQL,
190 is_function: false,
191 }
192 }
193
194 pub fn function(name: impl Into<String>) -> Self {
196 Self {
197 name: name.into(),
198 schema: None,
199 parameters: Vec::new(),
200 db_type: DatabaseType::PostgreSQL,
201 is_function: true,
202 }
203 }
204
205 pub fn schema(mut self, schema: impl Into<String>) -> Self {
207 self.schema = Some(schema.into());
208 self
209 }
210
211 pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
213 self.db_type = db_type;
214 self
215 }
216
217 pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
219 self.parameters.push(Parameter::input(name, value));
220 self
221 }
222
223 pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
225 self.param(name, value)
226 }
227
228 pub fn out_param(mut self, name: impl Into<String>) -> Self {
230 self.parameters.push(Parameter::output(name));
231 self
232 }
233
234 pub fn out_param_typed(
236 mut self,
237 name: impl Into<String>,
238 type_hint: impl Into<String>,
239 ) -> Self {
240 self.parameters
241 .push(Parameter::output(name).with_type_hint(type_hint));
242 self
243 }
244
245 pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
247 self.parameters.push(Parameter::inout(name, value));
248 self
249 }
250
251 pub fn add_parameter(mut self, param: Parameter) -> Self {
253 self.parameters.push(param);
254 self
255 }
256
257 pub fn qualified_name(&self) -> Cow<'_, str> {
259 match &self.schema {
260 Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
261 None => Cow::Borrowed(&self.name),
262 }
263 }
264
265 pub fn has_outputs(&self) -> bool {
267 self.parameters
268 .iter()
269 .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
270 }
271
272 pub fn input_values(&self) -> Vec<FilterValue> {
274 self.parameters
275 .iter()
276 .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
277 .filter_map(|p| p.value.clone())
278 .collect()
279 }
280
281 pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
283 let name = self.qualified_name();
284 let params = self.input_values();
285 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
286
287 let sql = if self.is_function {
288 format!("SELECT {}({})", name, placeholders.join(", "))
289 } else {
290 format!("CALL {}({})", name, placeholders.join(", "))
291 };
292
293 (sql, params)
294 }
295
296 pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
298 let name = self.qualified_name();
299 let params = self.input_values();
300 let placeholders = vec!["?"; params.len()].join(", ");
301
302 let sql = if self.is_function {
303 format!("SELECT {}({})", name, placeholders)
304 } else {
305 format!("CALL {}({})", name, placeholders)
306 };
307
308 (sql, params)
309 }
310
311 pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
313 let name = self.qualified_name();
314 let params = self.input_values();
315 let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
316
317 if self.is_function {
318 (
319 format!("SELECT {}({})", name, placeholders.join(", ")),
320 params,
321 )
322 } else if self.has_outputs() {
323 let mut parts = vec![String::from("DECLARE ")];
325
326 let out_params: Vec<_> = self
328 .parameters
329 .iter()
330 .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
331 .collect();
332
333 for (i, param) in out_params.iter().enumerate() {
334 if i > 0 {
335 parts.push(String::from(", "));
336 }
337 let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
338 parts.push(format!("@{} {}", param.name, type_name));
339 }
340 parts.push(String::from("; "));
341
342 parts.push(format!("EXEC {} ", name));
344
345 let param_parts: Vec<String> = self
346 .parameters
347 .iter()
348 .enumerate()
349 .map(|(i, p)| match p.mode {
350 ParameterMode::In => format!("@P{}", i + 1),
351 ParameterMode::Out => format!("@{} OUTPUT", p.name),
352 ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
353 })
354 .collect();
355
356 parts.push(param_parts.join(", "));
357 parts.push(String::from("; "));
358
359 let select_parts: Vec<String> = out_params
361 .iter()
362 .map(|p| format!("@{} AS {}", p.name, p.name))
363 .collect();
364 parts.push(format!("SELECT {}", select_parts.join(", ")));
365
366 (parts.join(""), params)
367 } else {
368 (format!("EXEC {} {}", name, placeholders.join(", ")), params)
369 }
370 }
371
372 pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
374 if !self.is_function {
375 return Err(QueryError::unsupported(
376 "SQLite does not support stored procedures. Use Rust UDFs instead.",
377 ));
378 }
379
380 let name = self.qualified_name();
381 let params = self.input_values();
382 let placeholders = vec!["?"; params.len()].join(", ");
383
384 Ok((format!("SELECT {}({})", name, placeholders), params))
385 }
386
387 pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
389 match self.db_type {
390 DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
391 DatabaseType::MySQL => Ok(self.to_mysql_sql()),
392 DatabaseType::SQLite => self.to_sqlite_sql(),
393 DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
394 }
395 }
396}
397
398pub struct ProcedureCallOperation<E: QueryEngine> {
400 engine: E,
401 call: ProcedureCall,
402}
403
404impl<E: QueryEngine> ProcedureCallOperation<E> {
405 pub fn new(engine: E, call: ProcedureCall) -> Self {
407 Self { engine, call }
408 }
409
410 pub async fn exec(self) -> QueryResult<ProcedureResult> {
412 let (sql, params) = self.call.to_sql()?;
413 let affected = self.engine.execute_raw(&sql, params).await?;
414
415 Ok(ProcedureResult {
416 outputs: HashMap::new(),
417 return_value: None,
418 rows_affected: Some(affected),
419 })
420 }
421
422 pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
424 where
425 T: crate::traits::Model + Send + 'static,
426 {
427 let (sql, params) = self.call.to_sql()?;
428 self.engine.query_many(&sql, params).await
429 }
430
431 pub async fn exec_scalar<T>(self) -> QueryResult<T>
433 where
434 T: TryFrom<FilterValue, Error = String> + Send + 'static,
435 {
436 let (sql, params) = self.call.to_sql()?;
437 let result = self.engine.execute_raw(&sql, params).await?;
438
439 Err(QueryError::internal(format!(
442 "Scalar function result parsing not yet implemented (affected: {})",
443 result
444 )))
445 }
446}
447
448#[allow(dead_code)]
450pub struct FunctionCallOperation<E: QueryEngine, T> {
451 engine: E,
452 call: ProcedureCall,
453 _marker: PhantomData<T>,
454}
455
456impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
457 pub fn new(engine: E, call: ProcedureCall) -> Self {
459 Self {
460 engine,
461 call,
462 _marker: PhantomData,
463 }
464 }
465}
466
467pub trait ProcedureEngine: QueryEngine {
469 fn call(&self, name: impl Into<String>) -> ProcedureCall {
471 ProcedureCall::new(name)
472 }
473
474 fn function(&self, name: impl Into<String>) -> ProcedureCall {
476 ProcedureCall::function(name)
477 }
478
479 fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
481 where
482 Self: Clone + 'static,
483 {
484 let engine = self.clone();
485 Box::pin(async move {
486 let op = ProcedureCallOperation::new(engine, call);
487 op.exec().await
488 })
489 }
490}
491
492impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
494
495pub mod sqlite_udf {
497 #[allow(unused_imports)]
498 use super::*;
499
500 pub trait SqliteFunction: Send + Sync + 'static {
502 fn name(&self) -> &str;
504
505 fn num_args(&self) -> i32;
507
508 fn deterministic(&self) -> bool {
510 true
511 }
512 }
513
514 #[derive(Debug, Clone)]
516 pub struct ScalarUdf {
517 pub name: String,
519 pub num_args: i32,
521 pub deterministic: bool,
523 }
524
525 impl ScalarUdf {
526 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
528 Self {
529 name: name.into(),
530 num_args,
531 deterministic: true,
532 }
533 }
534
535 pub fn deterministic(mut self, deterministic: bool) -> Self {
537 self.deterministic = deterministic;
538 self
539 }
540 }
541
542 #[derive(Debug, Clone)]
544 pub struct AggregateUdf {
545 pub name: String,
547 pub num_args: i32,
549 }
550
551 impl AggregateUdf {
552 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
554 Self {
555 name: name.into(),
556 num_args,
557 }
558 }
559 }
560
561 #[derive(Debug, Clone)]
563 pub struct WindowUdf {
564 pub name: String,
566 pub num_args: i32,
568 }
569
570 impl WindowUdf {
571 pub fn new(name: impl Into<String>, num_args: i32) -> Self {
573 Self {
574 name: name.into(),
575 num_args,
576 }
577 }
578 }
579}
580
581pub mod mongodb_func {
583 use super::*;
584
585 #[derive(Debug, Clone, Serialize, Deserialize)]
587 pub struct MongoFunction {
588 pub body: String,
590 pub args: Vec<String>,
592 pub lang: String,
594 }
595
596 impl MongoFunction {
597 pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
599 Self {
600 body: body.into(),
601 args: args.into_iter().map(Into::into).collect(),
602 lang: "js".to_string(),
603 }
604 }
605
606 #[cfg(feature = "mongodb")]
608 pub fn to_bson(&self) -> bson::Document {
609 use bson::doc;
610 doc! {
611 "$function": {
612 "body": &self.body,
613 "args": &self.args,
614 "lang": &self.lang,
615 }
616 }
617 }
618 }
619
620 #[derive(Debug, Clone, Serialize, Deserialize)]
622 pub struct MongoAccumulator {
623 pub init: String,
625 pub init_args: Vec<String>,
627 pub accumulate: String,
629 pub accumulate_args: Vec<String>,
631 pub merge: String,
633 pub finalize: Option<String>,
635 pub lang: String,
637 }
638
639 impl MongoAccumulator {
640 pub fn new(
642 init: impl Into<String>,
643 accumulate: impl Into<String>,
644 merge: impl Into<String>,
645 ) -> Self {
646 Self {
647 init: init.into(),
648 init_args: Vec::new(),
649 accumulate: accumulate.into(),
650 accumulate_args: Vec::new(),
651 merge: merge.into(),
652 finalize: None,
653 lang: "js".to_string(),
654 }
655 }
656
657 pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
659 self.init_args = args.into_iter().map(Into::into).collect();
660 self
661 }
662
663 pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
665 self.accumulate_args = args.into_iter().map(Into::into).collect();
666 self
667 }
668
669 pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
671 self.finalize = Some(finalize.into());
672 self
673 }
674
675 #[cfg(feature = "mongodb")]
677 pub fn to_bson(&self) -> bson::Document {
678 use bson::doc;
679 let mut doc = doc! {
680 "$accumulator": {
681 "init": &self.init,
682 "accumulate": &self.accumulate,
683 "accumulateArgs": &self.accumulate_args,
684 "merge": &self.merge,
685 "lang": &self.lang,
686 }
687 };
688
689 if !self.init_args.is_empty() {
690 doc.get_document_mut("$accumulator")
691 .unwrap()
692 .insert("initArgs", &self.init_args);
693 }
694
695 if let Some(ref finalize) = self.finalize {
696 doc.get_document_mut("$accumulator")
697 .unwrap()
698 .insert("finalize", finalize);
699 }
700
701 doc
702 }
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_procedure_call_basic() {
712 let call = ProcedureCall::new("get_user")
713 .param("id", 42i32)
714 .param("active", true);
715
716 assert_eq!(call.name, "get_user");
717 assert_eq!(call.parameters.len(), 2);
718 assert!(!call.is_function);
719 }
720
721 #[test]
722 fn test_function_call() {
723 let call = ProcedureCall::function("calculate_tax")
724 .param("amount", 100.0f64)
725 .param("rate", 0.08f64);
726
727 assert_eq!(call.name, "calculate_tax");
728 assert!(call.is_function);
729 }
730
731 #[test]
732 fn test_postgres_sql_generation() {
733 let call = ProcedureCall::new("get_orders")
734 .param("user_id", 42i32)
735 .param("status", "pending".to_string());
736
737 let (sql, params) = call.to_postgres_sql();
738 assert_eq!(sql, "CALL get_orders($1, $2)");
739 assert_eq!(params.len(), 2);
740 }
741
742 #[test]
743 fn test_postgres_function_sql() {
744 let call = ProcedureCall::function("calculate_total").param("order_id", 123i32);
745
746 let (sql, params) = call.to_postgres_sql();
747 assert_eq!(sql, "SELECT calculate_total($1)");
748 assert_eq!(params.len(), 1);
749 }
750
751 #[test]
752 fn test_mysql_sql_generation() {
753 let call = ProcedureCall::new("get_orders")
754 .with_db_type(DatabaseType::MySQL)
755 .param("user_id", 42i32);
756
757 let (sql, params) = call.to_mysql_sql();
758 assert_eq!(sql, "CALL get_orders(?)");
759 assert_eq!(params.len(), 1);
760 }
761
762 #[test]
763 fn test_mssql_sql_generation() {
764 let call = ProcedureCall::new("GetOrders")
765 .schema("dbo")
766 .with_db_type(DatabaseType::MSSQL)
767 .param("UserId", 42i32);
768
769 let (sql, params) = call.to_mssql_sql();
770 assert!(sql.contains("EXEC dbo.GetOrders"));
771 assert_eq!(params.len(), 1);
772 }
773
774 #[test]
775 fn test_mssql_with_output_params() {
776 let call = ProcedureCall::new("CalculateTotals")
777 .with_db_type(DatabaseType::MSSQL)
778 .in_param("OrderId", 123i32)
779 .out_param_typed("TotalAmount", "DECIMAL(18,2)")
780 .out_param_typed("ItemCount", "INT");
781
782 let (sql, _params) = call.to_mssql_sql();
783 assert!(sql.contains("DECLARE"));
784 assert!(sql.contains("OUTPUT"));
785 assert!(sql.contains("SELECT"));
786 }
787
788 #[test]
789 fn test_sqlite_function() {
790 let call = ProcedureCall::function("custom_hash")
791 .with_db_type(DatabaseType::SQLite)
792 .param("input", "test".to_string());
793
794 let result = call.to_sqlite_sql();
795 assert!(result.is_ok());
796
797 let (sql, params) = result.unwrap();
798 assert_eq!(sql, "SELECT custom_hash(?)");
799 assert_eq!(params.len(), 1);
800 }
801
802 #[test]
803 fn test_sqlite_procedure_error() {
804 let call = ProcedureCall::new("some_procedure")
805 .with_db_type(DatabaseType::SQLite)
806 .param("id", 42i32);
807
808 let result = call.to_sqlite_sql();
809 assert!(result.is_err());
810 }
811
812 #[test]
813 fn test_qualified_name() {
814 let call = ProcedureCall::new("get_user").schema("public");
815 assert_eq!(call.qualified_name(), "public.get_user");
816
817 let call = ProcedureCall::new("get_user");
818 assert_eq!(call.qualified_name(), "get_user");
819 }
820
821 #[test]
822 fn test_parameter_modes() {
823 let call = ProcedureCall::new("calculate")
824 .in_param("input", 100i32)
825 .out_param("result")
826 .inout_param("running_total", 50i32);
827
828 assert_eq!(call.parameters.len(), 3);
829 assert_eq!(call.parameters[0].mode, ParameterMode::In);
830 assert_eq!(call.parameters[1].mode, ParameterMode::Out);
831 assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
832 assert!(call.has_outputs());
833 }
834
835 #[test]
836 fn test_procedure_result() {
837 let mut result = ProcedureResult::default();
838 result
839 .outputs
840 .insert("total".to_string(), FilterValue::Int(100));
841 result.return_value = Some(FilterValue::Bool(true));
842
843 assert!(result.get("total").is_some());
844 assert!(result.get("nonexistent").is_none());
845 assert!(result.return_value().is_some());
846 }
847
848 #[test]
849 fn test_mongo_function() {
850 use mongodb_func::MongoFunction;
851
852 let func = MongoFunction::new(
853 "function(x, y) { return x + y; }",
854 vec!["$field1", "$field2"],
855 );
856
857 assert_eq!(func.lang, "js");
858 assert_eq!(func.args.len(), 2);
859 }
860
861 #[test]
862 fn test_mongo_accumulator() {
863 use mongodb_func::MongoAccumulator;
864
865 let acc = MongoAccumulator::new(
866 "function() { return { sum: 0, count: 0 }; }",
867 "function(state, value) { state.sum += value; state.count++; return state; }",
868 "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
869 )
870 .with_finalize("function(state) { return state.sum / state.count; }")
871 .with_accumulate_args(vec!["$value"]);
872
873 assert!(acc.finalize.is_some());
874 assert_eq!(acc.accumulate_args.len(), 1);
875 }
876
877 #[test]
878 fn test_sqlite_udf_definitions() {
879 use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
880
881 let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
882 assert!(scalar.deterministic);
883
884 let aggregate = AggregateUdf::new("my_sum", 1);
885 assert_eq!(aggregate.num_args, 1);
886
887 let window = WindowUdf::new("my_rank", 0);
888 assert_eq!(window.num_args, 0);
889 }
890}