Skip to main content

prax_query/
procedure.rs

1//! Stored procedure and function call support.
2//!
3//! This module provides a type-safe way to call stored procedures and functions
4//! across different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature                  | PostgreSQL | MySQL | MSSQL | SQLite | MongoDB |
9//! |--------------------------|------------|-------|-------|--------|---------|
10//! | Stored Procedures        | ✅         | ✅    | ✅    | ❌     | ❌      |
11//! | User-Defined Functions   | ✅         | ✅    | ✅    | ✅*    | ✅      |
12//! | Table-Valued Functions   | ✅         | ❌    | ✅    | ❌     | ❌      |
13//! | IN/OUT/INOUT Parameters  | ✅         | ✅    | ✅    | ❌     | ❌      |
14//!
15//! > *SQLite requires Rust UDFs via `rusqlite::functions`
16//!
17//! # Example Usage
18//!
19//! ```rust,ignore
20//! use prax_query::procedure::{ProcedureCall, ParameterMode};
21//!
22//! // Call a stored procedure
23//! let result = client
24//!     .call("get_user_orders")
25//!     .param("user_id", 42)
26//!     .exec::<OrderResult>()
27//!     .await?;
28//!
29//! // Call a procedure with OUT parameters
30//! let result = client
31//!     .call("calculate_totals")
32//!     .in_param("order_id", 123)
33//!     .out_param::<i64>("total_amount")
34//!     .out_param::<i32>("item_count")
35//!     .exec()
36//!     .await?;
37//!
38//! // Call a function
39//! let result = client
40//!     .function("calculate_tax")
41//!     .param("amount", 100.0)
42//!     .param("rate", 0.08)
43//!     .exec::<f64>()
44//!     .await?;
45//! ```
46
47use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58/// Parameter direction mode for stored procedures.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
60pub enum ParameterMode {
61    /// Input parameter (default).
62    #[default]
63    In,
64    /// Output parameter.
65    Out,
66    /// Input/Output parameter.
67    InOut,
68}
69
70/// A parameter for a stored procedure or function call.
71#[derive(Debug, Clone)]
72pub struct Parameter {
73    /// Parameter name.
74    pub name: String,
75    /// Parameter value (None for OUT parameters without initial value).
76    pub value: Option<FilterValue>,
77    /// Parameter mode (IN, OUT, INOUT).
78    pub mode: ParameterMode,
79    /// Expected type name for OUT parameters.
80    pub type_hint: Option<String>,
81}
82
83impl Parameter {
84    /// Create a new input parameter.
85    pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
86        Self {
87            name: name.into(),
88            value: Some(value.into()),
89            mode: ParameterMode::In,
90            type_hint: None,
91        }
92    }
93
94    /// Create a new output parameter.
95    pub fn output(name: impl Into<String>) -> Self {
96        Self {
97            name: name.into(),
98            value: None,
99            mode: ParameterMode::Out,
100            type_hint: None,
101        }
102    }
103
104    /// Create a new input/output parameter.
105    pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
106        Self {
107            name: name.into(),
108            value: Some(value.into()),
109            mode: ParameterMode::InOut,
110            type_hint: None,
111        }
112    }
113
114    /// Set a type hint for the parameter.
115    pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
116        self.type_hint = Some(type_name.into());
117        self
118    }
119}
120
121/// Result from a procedure call with OUT/INOUT parameters.
122#[derive(Debug, Clone, Default)]
123pub struct ProcedureResult {
124    /// Output parameter values by name.
125    pub outputs: HashMap<String, FilterValue>,
126    /// Return value (for functions).
127    pub return_value: Option<FilterValue>,
128    /// Number of rows affected (if applicable).
129    pub rows_affected: Option<u64>,
130}
131
132impl ProcedureResult {
133    /// Get an output parameter value.
134    pub fn get(&self, name: &str) -> Option<&FilterValue> {
135        self.outputs.get(name)
136    }
137
138    /// Get an output parameter as a specific type.
139    pub fn get_as<T>(&self, name: &str) -> Option<T>
140    where
141        T: TryFrom<FilterValue>,
142    {
143        self.outputs
144            .get(name)
145            .and_then(|v| T::try_from(v.clone()).ok())
146    }
147
148    /// Get the return value.
149    pub fn return_value(&self) -> Option<&FilterValue> {
150        self.return_value.as_ref()
151    }
152
153    /// Get the return value as a specific type.
154    pub fn return_value_as<T>(&self) -> Option<T>
155    where
156        T: TryFrom<FilterValue>,
157    {
158        self.return_value.clone().and_then(|v| T::try_from(v).ok())
159    }
160}
161
162/// Builder for stored procedure calls.
163#[derive(Debug, Clone)]
164pub struct ProcedureCall {
165    /// Procedure/function name.
166    pub name: String,
167    /// Schema name (optional).
168    pub schema: Option<String>,
169    /// Parameters.
170    pub parameters: Vec<Parameter>,
171    /// Database type for SQL generation.
172    pub db_type: DatabaseType,
173    /// Whether this is a function call (vs procedure).
174    pub is_function: bool,
175}
176
177impl ProcedureCall {
178    /// Create a new procedure call.
179    pub fn new(name: impl Into<String>) -> Self {
180        Self {
181            name: name.into(),
182            schema: None,
183            parameters: Vec::new(),
184            db_type: DatabaseType::PostgreSQL,
185            is_function: false,
186        }
187    }
188
189    /// Create a new function call.
190    pub fn function(name: impl Into<String>) -> Self {
191        Self {
192            name: name.into(),
193            schema: None,
194            parameters: Vec::new(),
195            db_type: DatabaseType::PostgreSQL,
196            is_function: true,
197        }
198    }
199
200    /// Set the schema name.
201    pub fn schema(mut self, schema: impl Into<String>) -> Self {
202        self.schema = Some(schema.into());
203        self
204    }
205
206    /// Set the database type.
207    pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
208        self.db_type = db_type;
209        self
210    }
211
212    /// Add an input parameter.
213    pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
214        self.parameters.push(Parameter::input(name, value));
215        self
216    }
217
218    /// Add an input parameter (alias for param).
219    pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
220        self.param(name, value)
221    }
222
223    /// Add an output parameter.
224    pub fn out_param(mut self, name: impl Into<String>) -> Self {
225        self.parameters.push(Parameter::output(name));
226        self
227    }
228
229    /// Add an output parameter with type hint.
230    pub fn out_param_typed(
231        mut self,
232        name: impl Into<String>,
233        type_hint: impl Into<String>,
234    ) -> Self {
235        self.parameters
236            .push(Parameter::output(name).with_type_hint(type_hint));
237        self
238    }
239
240    /// Add an input/output parameter.
241    pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
242        self.parameters.push(Parameter::inout(name, value));
243        self
244    }
245
246    /// Add a raw parameter.
247    pub fn add_parameter(mut self, param: Parameter) -> Self {
248        self.parameters.push(param);
249        self
250    }
251
252    /// Get the fully qualified name.
253    pub fn qualified_name(&self) -> Cow<'_, str> {
254        match &self.schema {
255            Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
256            None => Cow::Borrowed(&self.name),
257        }
258    }
259
260    /// Check if any parameters are OUT or INOUT.
261    pub fn has_outputs(&self) -> bool {
262        self.parameters
263            .iter()
264            .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
265    }
266
267    /// Get input parameter values.
268    pub fn input_values(&self) -> Vec<FilterValue> {
269        self.parameters
270            .iter()
271            .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
272            .filter_map(|p| p.value.clone())
273            .collect()
274    }
275
276    /// Generate SQL for PostgreSQL.
277    pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
278        let name = self.qualified_name();
279        let params = self.input_values();
280        let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
281
282        let sql = if self.is_function {
283            format!("SELECT {}({})", name, placeholders.join(", "))
284        } else {
285            format!("CALL {}({})", name, placeholders.join(", "))
286        };
287
288        (sql, params)
289    }
290
291    /// Generate SQL for MySQL.
292    pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
293        let name = self.qualified_name();
294        let params = self.input_values();
295        let placeholders = vec!["?"; params.len()].join(", ");
296
297        let sql = if self.is_function {
298            format!("SELECT {}({})", name, placeholders)
299        } else {
300            format!("CALL {}({})", name, placeholders)
301        };
302
303        (sql, params)
304    }
305
306    /// Generate SQL for MSSQL.
307    pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
308        let name = self.qualified_name();
309        let params = self.input_values();
310        let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
311
312        if self.is_function {
313            (
314                format!("SELECT {}({})", name, placeholders.join(", ")),
315                params,
316            )
317        } else if self.has_outputs() {
318            // For procedures with OUT params, use EXEC with output variable declarations
319            let mut parts = vec![String::from("DECLARE ")];
320
321            // Declare output variables
322            let out_params: Vec<_> = self
323                .parameters
324                .iter()
325                .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
326                .collect();
327
328            for (i, param) in out_params.iter().enumerate() {
329                if i > 0 {
330                    parts.push(String::from(", "));
331                }
332                let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
333                parts.push(format!("@{} {}", param.name, type_name));
334            }
335            parts.push(String::from("; "));
336
337            // Build EXEC statement
338            parts.push(format!("EXEC {} ", name));
339
340            let param_parts: Vec<String> = self
341                .parameters
342                .iter()
343                .enumerate()
344                .map(|(i, p)| match p.mode {
345                    ParameterMode::In => format!("@P{}", i + 1),
346                    ParameterMode::Out => format!("@{} OUTPUT", p.name),
347                    ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
348                })
349                .collect();
350
351            parts.push(param_parts.join(", "));
352            parts.push(String::from("; "));
353
354            // Select output values
355            let select_parts: Vec<String> = out_params
356                .iter()
357                .map(|p| format!("@{} AS {}", p.name, p.name))
358                .collect();
359            parts.push(format!("SELECT {}", select_parts.join(", ")));
360
361            (parts.join(""), params)
362        } else {
363            (format!("EXEC {} {}", name, placeholders.join(", ")), params)
364        }
365    }
366
367    /// Generate SQL for SQLite (only functions supported).
368    pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
369        if !self.is_function {
370            return Err(QueryError::unsupported(
371                "SQLite does not support stored procedures. Use Rust UDFs instead.",
372            ));
373        }
374
375        let name = self.qualified_name();
376        let params = self.input_values();
377        let placeholders = vec!["?"; params.len()].join(", ");
378
379        Ok((format!("SELECT {}({})", name, placeholders), params))
380    }
381
382    /// Generate SQL for the configured database type.
383    pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
384        match self.db_type {
385            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
386            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
387            DatabaseType::SQLite => self.to_sqlite_sql(),
388            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
389        }
390    }
391}
392
393/// Operation for executing a procedure call.
394pub struct ProcedureCallOperation<E: QueryEngine> {
395    engine: E,
396    call: ProcedureCall,
397}
398
399impl<E: QueryEngine> ProcedureCallOperation<E> {
400    /// Create a new procedure call operation.
401    pub fn new(engine: E, call: ProcedureCall) -> Self {
402        Self { engine, call }
403    }
404
405    /// Execute the procedure and return the result.
406    pub async fn exec(self) -> QueryResult<ProcedureResult> {
407        let (sql, params) = self.call.to_sql()?;
408        let affected = self.engine.execute_raw(&sql, params).await?;
409
410        Ok(ProcedureResult {
411            outputs: HashMap::new(),
412            return_value: None,
413            rows_affected: Some(affected),
414        })
415    }
416
417    /// Execute the procedure and return typed results.
418    pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
419    where
420        T: crate::traits::Model + crate::row::FromRow + Send + 'static,
421    {
422        let (sql, params) = self.call.to_sql()?;
423        self.engine.query_many(&sql, params).await
424    }
425
426    /// Execute a function and return a single value.
427    ///
428    /// Routes through [`QueryEngine::aggregate_query`] so the scalar
429    /// return lands in the first column of the first row as a
430    /// [`FilterValue`]. The caller's `T: TryFrom<FilterValue>` impl
431    /// handles the final type coercion — e.g., `T = i64` succeeds on
432    /// `FilterValue::Int`, errors on `FilterValue::String`.
433    pub async fn exec_scalar<T>(self) -> QueryResult<T>
434    where
435        T: TryFrom<FilterValue, Error = String> + Send + 'static,
436    {
437        let (sql, params) = self.call.to_sql()?;
438        let mut rows = self.engine.aggregate_query(&sql, params).await?;
439        let first = rows
440            .drain(..)
441            .next()
442            .ok_or_else(|| QueryError::not_found("scalar function returned no row".to_string()))?;
443        // Take any value from the map — scalar functions produce a
444        // single column, but the column name is driver-dependent.
445        let value = first.into_values().next().ok_or_else(|| {
446            QueryError::deserialization(
447                "scalar function returned a row with no columns".to_string(),
448            )
449        })?;
450        T::try_from(value).map_err(QueryError::deserialization)
451    }
452}
453
454/// Operation for executing a function call that returns a value.
455#[allow(dead_code)]
456pub struct FunctionCallOperation<E: QueryEngine, T> {
457    engine: E,
458    call: ProcedureCall,
459    _marker: PhantomData<T>,
460}
461
462impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
463    /// Create a new function call operation.
464    pub fn new(engine: E, call: ProcedureCall) -> Self {
465        Self {
466            engine,
467            call,
468            _marker: PhantomData,
469        }
470    }
471}
472
473/// Extension trait for query engines to support procedure calls.
474pub trait ProcedureEngine: QueryEngine {
475    /// Call a stored procedure.
476    fn call(&self, name: impl Into<String>) -> ProcedureCall {
477        ProcedureCall::new(name)
478    }
479
480    /// Call a function.
481    fn function(&self, name: impl Into<String>) -> ProcedureCall {
482        ProcedureCall::function(name)
483    }
484
485    /// Execute a procedure call.
486    fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
487    where
488        Self: Clone + 'static,
489    {
490        let engine = self.clone();
491        Box::pin(async move {
492            let op = ProcedureCallOperation::new(engine, call);
493            op.exec().await
494        })
495    }
496}
497
498// Implement ProcedureEngine for all QueryEngine implementations
499impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
500
501/// SQLite-specific UDF registration support.
502pub mod sqlite_udf {
503    #[allow(unused_imports)]
504    use super::*;
505
506    /// A Rust function that can be registered as a SQLite UDF.
507    pub trait SqliteFunction: Send + Sync + 'static {
508        /// The name of the function.
509        fn name(&self) -> &str;
510
511        /// The number of arguments (-1 for variadic).
512        fn num_args(&self) -> i32;
513
514        /// Whether the function is deterministic.
515        fn deterministic(&self) -> bool {
516            true
517        }
518    }
519
520    /// A scalar UDF definition.
521    #[derive(Debug, Clone)]
522    pub struct ScalarUdf {
523        /// Function name.
524        pub name: String,
525        /// Number of arguments.
526        pub num_args: i32,
527        /// Whether deterministic.
528        pub deterministic: bool,
529    }
530
531    impl ScalarUdf {
532        /// Create a new scalar UDF definition.
533        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
534            Self {
535                name: name.into(),
536                num_args,
537                deterministic: true,
538            }
539        }
540
541        /// Set whether the function is deterministic.
542        pub fn deterministic(mut self, deterministic: bool) -> Self {
543            self.deterministic = deterministic;
544            self
545        }
546    }
547
548    /// An aggregate UDF definition.
549    #[derive(Debug, Clone)]
550    pub struct AggregateUdf {
551        /// Function name.
552        pub name: String,
553        /// Number of arguments.
554        pub num_args: i32,
555    }
556
557    impl AggregateUdf {
558        /// Create a new aggregate UDF definition.
559        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
560            Self {
561                name: name.into(),
562                num_args,
563            }
564        }
565    }
566
567    /// A window UDF definition.
568    #[derive(Debug, Clone)]
569    pub struct WindowUdf {
570        /// Function name.
571        pub name: String,
572        /// Number of arguments.
573        pub num_args: i32,
574    }
575
576    impl WindowUdf {
577        /// Create a new window UDF definition.
578        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
579            Self {
580                name: name.into(),
581                num_args,
582            }
583        }
584    }
585}
586
587/// MongoDB-specific function support.
588pub mod mongodb_func {
589    use super::*;
590
591    /// A MongoDB `$function` expression for custom JavaScript functions.
592    #[derive(Debug, Clone, Serialize, Deserialize)]
593    pub struct MongoFunction {
594        /// JavaScript function body.
595        pub body: String,
596        /// Function arguments (field references or values).
597        pub args: Vec<String>,
598        /// Language (always "js" for now).
599        pub lang: String,
600    }
601
602    impl MongoFunction {
603        /// Create a new MongoDB function.
604        pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
605            Self {
606                body: body.into(),
607                args: args.into_iter().map(Into::into).collect(),
608                lang: "js".to_string(),
609            }
610        }
611
612        /// Convert to a BSON document for use in aggregation.
613        #[cfg(feature = "mongodb")]
614        pub fn to_bson(&self) -> bson::Document {
615            use bson::doc;
616            doc! {
617                "$function": {
618                    "body": &self.body,
619                    "args": &self.args,
620                    "lang": &self.lang,
621                }
622            }
623        }
624    }
625
626    /// A MongoDB `$accumulator` expression for custom aggregation.
627    #[derive(Debug, Clone, Serialize, Deserialize)]
628    pub struct MongoAccumulator {
629        /// Initialize the accumulator state.
630        pub init: String,
631        /// Initialize arguments.
632        pub init_args: Vec<String>,
633        /// Accumulate function.
634        pub accumulate: String,
635        /// Accumulate arguments.
636        pub accumulate_args: Vec<String>,
637        /// Merge function.
638        pub merge: String,
639        /// Finalize function (optional).
640        pub finalize: Option<String>,
641        /// Language.
642        pub lang: String,
643    }
644
645    impl MongoAccumulator {
646        /// Create a new MongoDB accumulator.
647        pub fn new(
648            init: impl Into<String>,
649            accumulate: impl Into<String>,
650            merge: impl Into<String>,
651        ) -> Self {
652            Self {
653                init: init.into(),
654                init_args: Vec::new(),
655                accumulate: accumulate.into(),
656                accumulate_args: Vec::new(),
657                merge: merge.into(),
658                finalize: None,
659                lang: "js".to_string(),
660            }
661        }
662
663        /// Set init arguments.
664        pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
665            self.init_args = args.into_iter().map(Into::into).collect();
666            self
667        }
668
669        /// Set accumulate arguments.
670        pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
671            self.accumulate_args = args.into_iter().map(Into::into).collect();
672            self
673        }
674
675        /// Set finalize function.
676        pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
677            self.finalize = Some(finalize.into());
678            self
679        }
680
681        /// Convert to a BSON document for use in aggregation.
682        #[cfg(feature = "mongodb")]
683        pub fn to_bson(&self) -> bson::Document {
684            use bson::doc;
685            let mut doc = doc! {
686                "$accumulator": {
687                    "init": &self.init,
688                    "accumulate": &self.accumulate,
689                    "accumulateArgs": &self.accumulate_args,
690                    "merge": &self.merge,
691                    "lang": &self.lang,
692                }
693            };
694
695            if !self.init_args.is_empty() {
696                doc.get_document_mut("$accumulator")
697                    .unwrap()
698                    .insert("initArgs", &self.init_args);
699            }
700
701            if let Some(ref finalize) = self.finalize {
702                doc.get_document_mut("$accumulator")
703                    .unwrap()
704                    .insert("finalize", finalize);
705            }
706
707            doc
708        }
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn test_procedure_call_basic() {
718        let call = ProcedureCall::new("get_user")
719            .param("id", 42i32)
720            .param("active", true);
721
722        assert_eq!(call.name, "get_user");
723        assert_eq!(call.parameters.len(), 2);
724        assert!(!call.is_function);
725    }
726
727    #[test]
728    fn test_function_call() {
729        let call = ProcedureCall::function("calculate_tax")
730            .param("amount", 100.0f64)
731            .param("rate", 0.08f64);
732
733        assert_eq!(call.name, "calculate_tax");
734        assert!(call.is_function);
735    }
736
737    #[test]
738    fn test_postgres_sql_generation() {
739        let call = ProcedureCall::new("get_orders")
740            .param("user_id", 42i32)
741            .param("status", "pending".to_string());
742
743        let (sql, params) = call.to_postgres_sql();
744        assert_eq!(sql, "CALL get_orders($1, $2)");
745        assert_eq!(params.len(), 2);
746    }
747
748    #[test]
749    fn test_postgres_function_sql() {
750        let call = ProcedureCall::function("calculate_total").param("order_id", 123i32);
751
752        let (sql, params) = call.to_postgres_sql();
753        assert_eq!(sql, "SELECT calculate_total($1)");
754        assert_eq!(params.len(), 1);
755    }
756
757    #[test]
758    fn test_mysql_sql_generation() {
759        let call = ProcedureCall::new("get_orders")
760            .with_db_type(DatabaseType::MySQL)
761            .param("user_id", 42i32);
762
763        let (sql, params) = call.to_mysql_sql();
764        assert_eq!(sql, "CALL get_orders(?)");
765        assert_eq!(params.len(), 1);
766    }
767
768    #[test]
769    fn test_mssql_sql_generation() {
770        let call = ProcedureCall::new("GetOrders")
771            .schema("dbo")
772            .with_db_type(DatabaseType::MSSQL)
773            .param("UserId", 42i32);
774
775        let (sql, params) = call.to_mssql_sql();
776        assert!(sql.contains("EXEC dbo.GetOrders"));
777        assert_eq!(params.len(), 1);
778    }
779
780    #[test]
781    fn test_mssql_with_output_params() {
782        let call = ProcedureCall::new("CalculateTotals")
783            .with_db_type(DatabaseType::MSSQL)
784            .in_param("OrderId", 123i32)
785            .out_param_typed("TotalAmount", "DECIMAL(18,2)")
786            .out_param_typed("ItemCount", "INT");
787
788        let (sql, _params) = call.to_mssql_sql();
789        assert!(sql.contains("DECLARE"));
790        assert!(sql.contains("OUTPUT"));
791        assert!(sql.contains("SELECT"));
792    }
793
794    #[test]
795    fn test_sqlite_function() {
796        let call = ProcedureCall::function("custom_hash")
797            .with_db_type(DatabaseType::SQLite)
798            .param("input", "test".to_string());
799
800        let result = call.to_sqlite_sql();
801        assert!(result.is_ok());
802
803        let (sql, params) = result.unwrap();
804        assert_eq!(sql, "SELECT custom_hash(?)");
805        assert_eq!(params.len(), 1);
806    }
807
808    #[test]
809    fn test_sqlite_procedure_error() {
810        let call = ProcedureCall::new("some_procedure")
811            .with_db_type(DatabaseType::SQLite)
812            .param("id", 42i32);
813
814        let result = call.to_sqlite_sql();
815        assert!(result.is_err());
816    }
817
818    #[test]
819    fn test_qualified_name() {
820        let call = ProcedureCall::new("get_user").schema("public");
821        assert_eq!(call.qualified_name(), "public.get_user");
822
823        let call = ProcedureCall::new("get_user");
824        assert_eq!(call.qualified_name(), "get_user");
825    }
826
827    #[test]
828    fn test_parameter_modes() {
829        let call = ProcedureCall::new("calculate")
830            .in_param("input", 100i32)
831            .out_param("result")
832            .inout_param("running_total", 50i32);
833
834        assert_eq!(call.parameters.len(), 3);
835        assert_eq!(call.parameters[0].mode, ParameterMode::In);
836        assert_eq!(call.parameters[1].mode, ParameterMode::Out);
837        assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
838        assert!(call.has_outputs());
839    }
840
841    #[test]
842    fn test_procedure_result() {
843        let mut result = ProcedureResult::default();
844        result
845            .outputs
846            .insert("total".to_string(), FilterValue::Int(100));
847        result.return_value = Some(FilterValue::Bool(true));
848
849        assert!(result.get("total").is_some());
850        assert!(result.get("nonexistent").is_none());
851        assert!(result.return_value().is_some());
852    }
853
854    #[test]
855    fn test_mongo_function() {
856        use mongodb_func::MongoFunction;
857
858        let func = MongoFunction::new(
859            "function(x, y) { return x + y; }",
860            vec!["$field1", "$field2"],
861        );
862
863        assert_eq!(func.lang, "js");
864        assert_eq!(func.args.len(), 2);
865    }
866
867    #[test]
868    fn test_mongo_accumulator() {
869        use mongodb_func::MongoAccumulator;
870
871        let acc = MongoAccumulator::new(
872            "function() { return { sum: 0, count: 0 }; }",
873            "function(state, value) { state.sum += value; state.count++; return state; }",
874            "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
875        )
876        .with_finalize("function(state) { return state.sum / state.count; }")
877        .with_accumulate_args(vec!["$value"]);
878
879        assert!(acc.finalize.is_some());
880        assert_eq!(acc.accumulate_args.len(), 1);
881    }
882
883    #[test]
884    fn test_sqlite_udf_definitions() {
885        use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
886
887        let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
888        assert!(scalar.deterministic);
889
890        let aggregate = AggregateUdf::new("my_sum", 1);
891        assert_eq!(aggregate.num_args, 1);
892
893        let window = WindowUdf::new("my_rank", 0);
894        assert_eq!(window.num_args, 0);
895    }
896}