Skip to main content

prax_query/
procedure.rs

1//! Stored procedure and function call support.
2//!
3//! This module provides a type-safe way to call stored procedures and functions
4//! across different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature                  | PostgreSQL | MySQL | MSSQL | SQLite | MongoDB |
9//! |--------------------------|------------|-------|-------|--------|---------|
10//! | Stored Procedures        | ✅         | ✅    | ✅    | ❌     | ❌      |
11//! | User-Defined Functions   | ✅         | ✅    | ✅    | ✅*    | ✅      |
12//! | Table-Valued Functions   | ✅         | ❌    | ✅    | ❌     | ❌      |
13//! | IN/OUT/INOUT Parameters  | ✅         | ✅    | ✅    | ❌     | ❌      |
14//!
15//! > *SQLite requires Rust UDFs via `rusqlite::functions`
16//!
17//! # Example Usage
18//!
19//! ```rust,ignore
20//! use prax_query::procedure::{ProcedureCall, ParameterMode};
21//!
22//! // Call a stored procedure
23//! let result = client
24//!     .call("get_user_orders")
25//!     .param("user_id", 42)
26//!     .exec::<OrderResult>()
27//!     .await?;
28//!
29//! // Call a procedure with OUT parameters
30//! let result = client
31//!     .call("calculate_totals")
32//!     .in_param("order_id", 123)
33//!     .out_param::<i64>("total_amount")
34//!     .out_param::<i32>("item_count")
35//!     .exec()
36//!     .await?;
37//!
38//! // Call a function
39//! let result = client
40//!     .function("calculate_tax")
41//!     .param("amount", 100.0)
42//!     .param("rate", 0.08)
43//!     .exec::<f64>()
44//!     .await?;
45//! ```
46
47use std::borrow::Cow;
48use std::collections::HashMap;
49use std::marker::PhantomData;
50
51use serde::{Deserialize, Serialize};
52
53use crate::error::{QueryError, QueryResult};
54use crate::filter::FilterValue;
55use crate::sql::DatabaseType;
56use crate::traits::{BoxFuture, QueryEngine};
57
58/// Parameter direction mode for stored procedures.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum ParameterMode {
61    /// Input parameter (default).
62    In,
63    /// Output parameter.
64    Out,
65    /// Input/Output parameter.
66    InOut,
67}
68
69impl Default for ParameterMode {
70    fn default() -> Self {
71        Self::In
72    }
73}
74
75/// A parameter for a stored procedure or function call.
76#[derive(Debug, Clone)]
77pub struct Parameter {
78    /// Parameter name.
79    pub name: String,
80    /// Parameter value (None for OUT parameters without initial value).
81    pub value: Option<FilterValue>,
82    /// Parameter mode (IN, OUT, INOUT).
83    pub mode: ParameterMode,
84    /// Expected type name for OUT parameters.
85    pub type_hint: Option<String>,
86}
87
88impl Parameter {
89    /// Create a new input parameter.
90    pub fn input(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
91        Self {
92            name: name.into(),
93            value: Some(value.into()),
94            mode: ParameterMode::In,
95            type_hint: None,
96        }
97    }
98
99    /// Create a new output parameter.
100    pub fn output(name: impl Into<String>) -> Self {
101        Self {
102            name: name.into(),
103            value: None,
104            mode: ParameterMode::Out,
105            type_hint: None,
106        }
107    }
108
109    /// Create a new input/output parameter.
110    pub fn inout(name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
111        Self {
112            name: name.into(),
113            value: Some(value.into()),
114            mode: ParameterMode::InOut,
115            type_hint: None,
116        }
117    }
118
119    /// Set a type hint for the parameter.
120    pub fn with_type_hint(mut self, type_name: impl Into<String>) -> Self {
121        self.type_hint = Some(type_name.into());
122        self
123    }
124}
125
126/// Result from a procedure call with OUT/INOUT parameters.
127#[derive(Debug, Clone, Default)]
128pub struct ProcedureResult {
129    /// Output parameter values by name.
130    pub outputs: HashMap<String, FilterValue>,
131    /// Return value (for functions).
132    pub return_value: Option<FilterValue>,
133    /// Number of rows affected (if applicable).
134    pub rows_affected: Option<u64>,
135}
136
137impl ProcedureResult {
138    /// Get an output parameter value.
139    pub fn get(&self, name: &str) -> Option<&FilterValue> {
140        self.outputs.get(name)
141    }
142
143    /// Get an output parameter as a specific type.
144    pub fn get_as<T>(&self, name: &str) -> Option<T>
145    where
146        T: TryFrom<FilterValue>,
147    {
148        self.outputs
149            .get(name)
150            .and_then(|v| T::try_from(v.clone()).ok())
151    }
152
153    /// Get the return value.
154    pub fn return_value(&self) -> Option<&FilterValue> {
155        self.return_value.as_ref()
156    }
157
158    /// Get the return value as a specific type.
159    pub fn return_value_as<T>(&self) -> Option<T>
160    where
161        T: TryFrom<FilterValue>,
162    {
163        self.return_value.clone().and_then(|v| T::try_from(v).ok())
164    }
165}
166
167/// Builder for stored procedure calls.
168#[derive(Debug, Clone)]
169pub struct ProcedureCall {
170    /// Procedure/function name.
171    pub name: String,
172    /// Schema name (optional).
173    pub schema: Option<String>,
174    /// Parameters.
175    pub parameters: Vec<Parameter>,
176    /// Database type for SQL generation.
177    pub db_type: DatabaseType,
178    /// Whether this is a function call (vs procedure).
179    pub is_function: bool,
180}
181
182impl ProcedureCall {
183    /// Create a new procedure call.
184    pub fn new(name: impl Into<String>) -> Self {
185        Self {
186            name: name.into(),
187            schema: None,
188            parameters: Vec::new(),
189            db_type: DatabaseType::PostgreSQL,
190            is_function: false,
191        }
192    }
193
194    /// Create a new function call.
195    pub fn function(name: impl Into<String>) -> Self {
196        Self {
197            name: name.into(),
198            schema: None,
199            parameters: Vec::new(),
200            db_type: DatabaseType::PostgreSQL,
201            is_function: true,
202        }
203    }
204
205    /// Set the schema name.
206    pub fn schema(mut self, schema: impl Into<String>) -> Self {
207        self.schema = Some(schema.into());
208        self
209    }
210
211    /// Set the database type.
212    pub fn with_db_type(mut self, db_type: DatabaseType) -> Self {
213        self.db_type = db_type;
214        self
215    }
216
217    /// Add an input parameter.
218    pub fn param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
219        self.parameters.push(Parameter::input(name, value));
220        self
221    }
222
223    /// Add an input parameter (alias for param).
224    pub fn in_param(self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
225        self.param(name, value)
226    }
227
228    /// Add an output parameter.
229    pub fn out_param(mut self, name: impl Into<String>) -> Self {
230        self.parameters.push(Parameter::output(name));
231        self
232    }
233
234    /// Add an output parameter with type hint.
235    pub fn out_param_typed(
236        mut self,
237        name: impl Into<String>,
238        type_hint: impl Into<String>,
239    ) -> Self {
240        self.parameters
241            .push(Parameter::output(name).with_type_hint(type_hint));
242        self
243    }
244
245    /// Add an input/output parameter.
246    pub fn inout_param(mut self, name: impl Into<String>, value: impl Into<FilterValue>) -> Self {
247        self.parameters.push(Parameter::inout(name, value));
248        self
249    }
250
251    /// Add a raw parameter.
252    pub fn add_parameter(mut self, param: Parameter) -> Self {
253        self.parameters.push(param);
254        self
255    }
256
257    /// Get the fully qualified name.
258    pub fn qualified_name(&self) -> Cow<'_, str> {
259        match &self.schema {
260            Some(schema) => Cow::Owned(format!("{}.{}", schema, self.name)),
261            None => Cow::Borrowed(&self.name),
262        }
263    }
264
265    /// Check if any parameters are OUT or INOUT.
266    pub fn has_outputs(&self) -> bool {
267        self.parameters
268            .iter()
269            .any(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
270    }
271
272    /// Get input parameter values.
273    pub fn input_values(&self) -> Vec<FilterValue> {
274        self.parameters
275            .iter()
276            .filter(|p| matches!(p.mode, ParameterMode::In | ParameterMode::InOut))
277            .filter_map(|p| p.value.clone())
278            .collect()
279    }
280
281    /// Generate SQL for PostgreSQL.
282    pub fn to_postgres_sql(&self) -> (String, Vec<FilterValue>) {
283        let name = self.qualified_name();
284        let params = self.input_values();
285        let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("${}", i)).collect();
286
287        let sql = if self.is_function {
288            format!("SELECT {}({})", name, placeholders.join(", "))
289        } else {
290            format!("CALL {}({})", name, placeholders.join(", "))
291        };
292
293        (sql, params)
294    }
295
296    /// Generate SQL for MySQL.
297    pub fn to_mysql_sql(&self) -> (String, Vec<FilterValue>) {
298        let name = self.qualified_name();
299        let params = self.input_values();
300        let placeholders = vec!["?"; params.len()].join(", ");
301
302        let sql = if self.is_function {
303            format!("SELECT {}({})", name, placeholders)
304        } else {
305            format!("CALL {}({})", name, placeholders)
306        };
307
308        (sql, params)
309    }
310
311    /// Generate SQL for MSSQL.
312    pub fn to_mssql_sql(&self) -> (String, Vec<FilterValue>) {
313        let name = self.qualified_name();
314        let params = self.input_values();
315        let placeholders: Vec<String> = (1..=params.len()).map(|i| format!("@P{}", i)).collect();
316
317        if self.is_function {
318            (
319                format!("SELECT {}({})", name, placeholders.join(", ")),
320                params,
321            )
322        } else if self.has_outputs() {
323            // For procedures with OUT params, use EXEC with output variable declarations
324            let mut parts = vec![String::from("DECLARE ")];
325
326            // Declare output variables
327            let out_params: Vec<_> = self
328                .parameters
329                .iter()
330                .filter(|p| matches!(p.mode, ParameterMode::Out | ParameterMode::InOut))
331                .collect();
332
333            for (i, param) in out_params.iter().enumerate() {
334                if i > 0 {
335                    parts.push(String::from(", "));
336                }
337                let type_name = param.type_hint.as_deref().unwrap_or("SQL_VARIANT");
338                parts.push(format!("@{} {}", param.name, type_name));
339            }
340            parts.push(String::from("; "));
341
342            // Build EXEC statement
343            parts.push(format!("EXEC {} ", name));
344
345            let param_parts: Vec<String> = self
346                .parameters
347                .iter()
348                .enumerate()
349                .map(|(i, p)| match p.mode {
350                    ParameterMode::In => format!("@P{}", i + 1),
351                    ParameterMode::Out => format!("@{} OUTPUT", p.name),
352                    ParameterMode::InOut => format!("@P{} = @{} OUTPUT", i + 1, p.name),
353                })
354                .collect();
355
356            parts.push(param_parts.join(", "));
357            parts.push(String::from("; "));
358
359            // Select output values
360            let select_parts: Vec<String> = out_params
361                .iter()
362                .map(|p| format!("@{} AS {}", p.name, p.name))
363                .collect();
364            parts.push(format!("SELECT {}", select_parts.join(", ")));
365
366            (parts.join(""), params)
367        } else {
368            (format!("EXEC {} {}", name, placeholders.join(", ")), params)
369        }
370    }
371
372    /// Generate SQL for SQLite (only functions supported).
373    pub fn to_sqlite_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
374        if !self.is_function {
375            return Err(QueryError::unsupported(
376                "SQLite does not support stored procedures. Use Rust UDFs instead.",
377            ));
378        }
379
380        let name = self.qualified_name();
381        let params = self.input_values();
382        let placeholders = vec!["?"; params.len()].join(", ");
383
384        Ok((format!("SELECT {}({})", name, placeholders), params))
385    }
386
387    /// Generate SQL for the configured database type.
388    pub fn to_sql(&self) -> QueryResult<(String, Vec<FilterValue>)> {
389        match self.db_type {
390            DatabaseType::PostgreSQL => Ok(self.to_postgres_sql()),
391            DatabaseType::MySQL => Ok(self.to_mysql_sql()),
392            DatabaseType::SQLite => self.to_sqlite_sql(),
393            DatabaseType::MSSQL => Ok(self.to_mssql_sql()),
394        }
395    }
396}
397
398/// Operation for executing a procedure call.
399pub struct ProcedureCallOperation<E: QueryEngine> {
400    engine: E,
401    call: ProcedureCall,
402}
403
404impl<E: QueryEngine> ProcedureCallOperation<E> {
405    /// Create a new procedure call operation.
406    pub fn new(engine: E, call: ProcedureCall) -> Self {
407        Self { engine, call }
408    }
409
410    /// Execute the procedure and return the result.
411    pub async fn exec(self) -> QueryResult<ProcedureResult> {
412        let (sql, params) = self.call.to_sql()?;
413        let affected = self.engine.execute_raw(&sql, params).await?;
414
415        Ok(ProcedureResult {
416            outputs: HashMap::new(),
417            return_value: None,
418            rows_affected: Some(affected),
419        })
420    }
421
422    /// Execute the procedure and return typed results.
423    pub async fn exec_returning<T>(self) -> QueryResult<Vec<T>>
424    where
425        T: crate::traits::Model + crate::row::FromRow + Send + 'static,
426    {
427        let (sql, params) = self.call.to_sql()?;
428        self.engine.query_many(&sql, params).await
429    }
430
431    /// Execute a function and return a single value.
432    ///
433    /// Routes through [`QueryEngine::aggregate_query`] so the scalar
434    /// return lands in the first column of the first row as a
435    /// [`FilterValue`]. The caller's `T: TryFrom<FilterValue>` impl
436    /// handles the final type coercion — e.g., `T = i64` succeeds on
437    /// `FilterValue::Int`, errors on `FilterValue::String`.
438    pub async fn exec_scalar<T>(self) -> QueryResult<T>
439    where
440        T: TryFrom<FilterValue, Error = String> + Send + 'static,
441    {
442        let (sql, params) = self.call.to_sql()?;
443        let mut rows = self.engine.aggregate_query(&sql, params).await?;
444        let first = rows
445            .drain(..)
446            .next()
447            .ok_or_else(|| QueryError::not_found("scalar function returned no row".to_string()))?;
448        // Take any value from the map — scalar functions produce a
449        // single column, but the column name is driver-dependent.
450        let value = first.into_values().next().ok_or_else(|| {
451            QueryError::deserialization(
452                "scalar function returned a row with no columns".to_string(),
453            )
454        })?;
455        T::try_from(value).map_err(QueryError::deserialization)
456    }
457}
458
459/// Operation for executing a function call that returns a value.
460#[allow(dead_code)]
461pub struct FunctionCallOperation<E: QueryEngine, T> {
462    engine: E,
463    call: ProcedureCall,
464    _marker: PhantomData<T>,
465}
466
467impl<E: QueryEngine, T> FunctionCallOperation<E, T> {
468    /// Create a new function call operation.
469    pub fn new(engine: E, call: ProcedureCall) -> Self {
470        Self {
471            engine,
472            call,
473            _marker: PhantomData,
474        }
475    }
476}
477
478/// Extension trait for query engines to support procedure calls.
479pub trait ProcedureEngine: QueryEngine {
480    /// Call a stored procedure.
481    fn call(&self, name: impl Into<String>) -> ProcedureCall {
482        ProcedureCall::new(name)
483    }
484
485    /// Call a function.
486    fn function(&self, name: impl Into<String>) -> ProcedureCall {
487        ProcedureCall::function(name)
488    }
489
490    /// Execute a procedure call.
491    fn execute_procedure(&self, call: ProcedureCall) -> BoxFuture<'_, QueryResult<ProcedureResult>>
492    where
493        Self: Clone + 'static,
494    {
495        let engine = self.clone();
496        Box::pin(async move {
497            let op = ProcedureCallOperation::new(engine, call);
498            op.exec().await
499        })
500    }
501}
502
503// Implement ProcedureEngine for all QueryEngine implementations
504impl<T: QueryEngine + Clone + 'static> ProcedureEngine for T {}
505
506/// SQLite-specific UDF registration support.
507pub mod sqlite_udf {
508    #[allow(unused_imports)]
509    use super::*;
510
511    /// A Rust function that can be registered as a SQLite UDF.
512    pub trait SqliteFunction: Send + Sync + 'static {
513        /// The name of the function.
514        fn name(&self) -> &str;
515
516        /// The number of arguments (-1 for variadic).
517        fn num_args(&self) -> i32;
518
519        /// Whether the function is deterministic.
520        fn deterministic(&self) -> bool {
521            true
522        }
523    }
524
525    /// A scalar UDF definition.
526    #[derive(Debug, Clone)]
527    pub struct ScalarUdf {
528        /// Function name.
529        pub name: String,
530        /// Number of arguments.
531        pub num_args: i32,
532        /// Whether deterministic.
533        pub deterministic: bool,
534    }
535
536    impl ScalarUdf {
537        /// Create a new scalar UDF definition.
538        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
539            Self {
540                name: name.into(),
541                num_args,
542                deterministic: true,
543            }
544        }
545
546        /// Set whether the function is deterministic.
547        pub fn deterministic(mut self, deterministic: bool) -> Self {
548            self.deterministic = deterministic;
549            self
550        }
551    }
552
553    /// An aggregate UDF definition.
554    #[derive(Debug, Clone)]
555    pub struct AggregateUdf {
556        /// Function name.
557        pub name: String,
558        /// Number of arguments.
559        pub num_args: i32,
560    }
561
562    impl AggregateUdf {
563        /// Create a new aggregate UDF definition.
564        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
565            Self {
566                name: name.into(),
567                num_args,
568            }
569        }
570    }
571
572    /// A window UDF definition.
573    #[derive(Debug, Clone)]
574    pub struct WindowUdf {
575        /// Function name.
576        pub name: String,
577        /// Number of arguments.
578        pub num_args: i32,
579    }
580
581    impl WindowUdf {
582        /// Create a new window UDF definition.
583        pub fn new(name: impl Into<String>, num_args: i32) -> Self {
584            Self {
585                name: name.into(),
586                num_args,
587            }
588        }
589    }
590}
591
592/// MongoDB-specific function support.
593pub mod mongodb_func {
594    use super::*;
595
596    /// A MongoDB `$function` expression for custom JavaScript functions.
597    #[derive(Debug, Clone, Serialize, Deserialize)]
598    pub struct MongoFunction {
599        /// JavaScript function body.
600        pub body: String,
601        /// Function arguments (field references or values).
602        pub args: Vec<String>,
603        /// Language (always "js" for now).
604        pub lang: String,
605    }
606
607    impl MongoFunction {
608        /// Create a new MongoDB function.
609        pub fn new(body: impl Into<String>, args: Vec<impl Into<String>>) -> Self {
610            Self {
611                body: body.into(),
612                args: args.into_iter().map(Into::into).collect(),
613                lang: "js".to_string(),
614            }
615        }
616
617        /// Convert to a BSON document for use in aggregation.
618        #[cfg(feature = "mongodb")]
619        pub fn to_bson(&self) -> bson::Document {
620            use bson::doc;
621            doc! {
622                "$function": {
623                    "body": &self.body,
624                    "args": &self.args,
625                    "lang": &self.lang,
626                }
627            }
628        }
629    }
630
631    /// A MongoDB `$accumulator` expression for custom aggregation.
632    #[derive(Debug, Clone, Serialize, Deserialize)]
633    pub struct MongoAccumulator {
634        /// Initialize the accumulator state.
635        pub init: String,
636        /// Initialize arguments.
637        pub init_args: Vec<String>,
638        /// Accumulate function.
639        pub accumulate: String,
640        /// Accumulate arguments.
641        pub accumulate_args: Vec<String>,
642        /// Merge function.
643        pub merge: String,
644        /// Finalize function (optional).
645        pub finalize: Option<String>,
646        /// Language.
647        pub lang: String,
648    }
649
650    impl MongoAccumulator {
651        /// Create a new MongoDB accumulator.
652        pub fn new(
653            init: impl Into<String>,
654            accumulate: impl Into<String>,
655            merge: impl Into<String>,
656        ) -> Self {
657            Self {
658                init: init.into(),
659                init_args: Vec::new(),
660                accumulate: accumulate.into(),
661                accumulate_args: Vec::new(),
662                merge: merge.into(),
663                finalize: None,
664                lang: "js".to_string(),
665            }
666        }
667
668        /// Set init arguments.
669        pub fn with_init_args(mut self, args: Vec<impl Into<String>>) -> Self {
670            self.init_args = args.into_iter().map(Into::into).collect();
671            self
672        }
673
674        /// Set accumulate arguments.
675        pub fn with_accumulate_args(mut self, args: Vec<impl Into<String>>) -> Self {
676            self.accumulate_args = args.into_iter().map(Into::into).collect();
677            self
678        }
679
680        /// Set finalize function.
681        pub fn with_finalize(mut self, finalize: impl Into<String>) -> Self {
682            self.finalize = Some(finalize.into());
683            self
684        }
685
686        /// Convert to a BSON document for use in aggregation.
687        #[cfg(feature = "mongodb")]
688        pub fn to_bson(&self) -> bson::Document {
689            use bson::doc;
690            let mut doc = doc! {
691                "$accumulator": {
692                    "init": &self.init,
693                    "accumulate": &self.accumulate,
694                    "accumulateArgs": &self.accumulate_args,
695                    "merge": &self.merge,
696                    "lang": &self.lang,
697                }
698            };
699
700            if !self.init_args.is_empty() {
701                doc.get_document_mut("$accumulator")
702                    .unwrap()
703                    .insert("initArgs", &self.init_args);
704            }
705
706            if let Some(ref finalize) = self.finalize {
707                doc.get_document_mut("$accumulator")
708                    .unwrap()
709                    .insert("finalize", finalize);
710            }
711
712            doc
713        }
714    }
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720
721    #[test]
722    fn test_procedure_call_basic() {
723        let call = ProcedureCall::new("get_user")
724            .param("id", 42i32)
725            .param("active", true);
726
727        assert_eq!(call.name, "get_user");
728        assert_eq!(call.parameters.len(), 2);
729        assert!(!call.is_function);
730    }
731
732    #[test]
733    fn test_function_call() {
734        let call = ProcedureCall::function("calculate_tax")
735            .param("amount", 100.0f64)
736            .param("rate", 0.08f64);
737
738        assert_eq!(call.name, "calculate_tax");
739        assert!(call.is_function);
740    }
741
742    #[test]
743    fn test_postgres_sql_generation() {
744        let call = ProcedureCall::new("get_orders")
745            .param("user_id", 42i32)
746            .param("status", "pending".to_string());
747
748        let (sql, params) = call.to_postgres_sql();
749        assert_eq!(sql, "CALL get_orders($1, $2)");
750        assert_eq!(params.len(), 2);
751    }
752
753    #[test]
754    fn test_postgres_function_sql() {
755        let call = ProcedureCall::function("calculate_total").param("order_id", 123i32);
756
757        let (sql, params) = call.to_postgres_sql();
758        assert_eq!(sql, "SELECT calculate_total($1)");
759        assert_eq!(params.len(), 1);
760    }
761
762    #[test]
763    fn test_mysql_sql_generation() {
764        let call = ProcedureCall::new("get_orders")
765            .with_db_type(DatabaseType::MySQL)
766            .param("user_id", 42i32);
767
768        let (sql, params) = call.to_mysql_sql();
769        assert_eq!(sql, "CALL get_orders(?)");
770        assert_eq!(params.len(), 1);
771    }
772
773    #[test]
774    fn test_mssql_sql_generation() {
775        let call = ProcedureCall::new("GetOrders")
776            .schema("dbo")
777            .with_db_type(DatabaseType::MSSQL)
778            .param("UserId", 42i32);
779
780        let (sql, params) = call.to_mssql_sql();
781        assert!(sql.contains("EXEC dbo.GetOrders"));
782        assert_eq!(params.len(), 1);
783    }
784
785    #[test]
786    fn test_mssql_with_output_params() {
787        let call = ProcedureCall::new("CalculateTotals")
788            .with_db_type(DatabaseType::MSSQL)
789            .in_param("OrderId", 123i32)
790            .out_param_typed("TotalAmount", "DECIMAL(18,2)")
791            .out_param_typed("ItemCount", "INT");
792
793        let (sql, _params) = call.to_mssql_sql();
794        assert!(sql.contains("DECLARE"));
795        assert!(sql.contains("OUTPUT"));
796        assert!(sql.contains("SELECT"));
797    }
798
799    #[test]
800    fn test_sqlite_function() {
801        let call = ProcedureCall::function("custom_hash")
802            .with_db_type(DatabaseType::SQLite)
803            .param("input", "test".to_string());
804
805        let result = call.to_sqlite_sql();
806        assert!(result.is_ok());
807
808        let (sql, params) = result.unwrap();
809        assert_eq!(sql, "SELECT custom_hash(?)");
810        assert_eq!(params.len(), 1);
811    }
812
813    #[test]
814    fn test_sqlite_procedure_error() {
815        let call = ProcedureCall::new("some_procedure")
816            .with_db_type(DatabaseType::SQLite)
817            .param("id", 42i32);
818
819        let result = call.to_sqlite_sql();
820        assert!(result.is_err());
821    }
822
823    #[test]
824    fn test_qualified_name() {
825        let call = ProcedureCall::new("get_user").schema("public");
826        assert_eq!(call.qualified_name(), "public.get_user");
827
828        let call = ProcedureCall::new("get_user");
829        assert_eq!(call.qualified_name(), "get_user");
830    }
831
832    #[test]
833    fn test_parameter_modes() {
834        let call = ProcedureCall::new("calculate")
835            .in_param("input", 100i32)
836            .out_param("result")
837            .inout_param("running_total", 50i32);
838
839        assert_eq!(call.parameters.len(), 3);
840        assert_eq!(call.parameters[0].mode, ParameterMode::In);
841        assert_eq!(call.parameters[1].mode, ParameterMode::Out);
842        assert_eq!(call.parameters[2].mode, ParameterMode::InOut);
843        assert!(call.has_outputs());
844    }
845
846    #[test]
847    fn test_procedure_result() {
848        let mut result = ProcedureResult::default();
849        result
850            .outputs
851            .insert("total".to_string(), FilterValue::Int(100));
852        result.return_value = Some(FilterValue::Bool(true));
853
854        assert!(result.get("total").is_some());
855        assert!(result.get("nonexistent").is_none());
856        assert!(result.return_value().is_some());
857    }
858
859    #[test]
860    fn test_mongo_function() {
861        use mongodb_func::MongoFunction;
862
863        let func = MongoFunction::new(
864            "function(x, y) { return x + y; }",
865            vec!["$field1", "$field2"],
866        );
867
868        assert_eq!(func.lang, "js");
869        assert_eq!(func.args.len(), 2);
870    }
871
872    #[test]
873    fn test_mongo_accumulator() {
874        use mongodb_func::MongoAccumulator;
875
876        let acc = MongoAccumulator::new(
877            "function() { return { sum: 0, count: 0 }; }",
878            "function(state, value) { state.sum += value; state.count++; return state; }",
879            "function(s1, s2) { return { sum: s1.sum + s2.sum, count: s1.count + s2.count }; }",
880        )
881        .with_finalize("function(state) { return state.sum / state.count; }")
882        .with_accumulate_args(vec!["$value"]);
883
884        assert!(acc.finalize.is_some());
885        assert_eq!(acc.accumulate_args.len(), 1);
886    }
887
888    #[test]
889    fn test_sqlite_udf_definitions() {
890        use sqlite_udf::{AggregateUdf, ScalarUdf, WindowUdf};
891
892        let scalar = ScalarUdf::new("my_hash", 1).deterministic(true);
893        assert!(scalar.deterministic);
894
895        let aggregate = AggregateUdf::new("my_sum", 1);
896        assert_eq!(aggregate.num_args, 1);
897
898        let window = WindowUdf::new("my_rank", 0);
899        assert_eq!(window.num_args, 0);
900    }
901}