mssql-client 0.19.3

High-level async SQL Server client with type-state connection management
Documentation
//! Stored procedure builder for constructing and executing RPC calls.
//!
//! Provides a builder pattern for calling stored procedures with full
//! control over named parameters (both input and output).
//!
//! # Example
//!
//! ```rust,no_run
//! # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
//! // Simple positional call (input parameters only)
//! let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
//!
//! // Builder with named input/output parameters
//! let result = client.procedure("dbo.CalculateSum")?
//!     .input("@a", &10i32)
//!     .input("@b", &20i32)
//!     .output_int("@result")
//!     .execute().await?;
//!
//! let sum = result.get_output("@result").unwrap();
//! # let _ = sum;
//! # Ok(())
//! # }
//! ```
//!
//! ## How it works
//!
//! Both entry points issue a TDS RPC request (not a SQL batch): parameters are
//! sent as typed RPC parameters, never interpolated into SQL. The response
//! stream carries any result sets, `OUTPUT` parameter values, and the
//! procedure's `RETURN` value, surfaced on [`ProcedureResult`] as `result_sets`,
//! `output_params`, and `return_value`. The procedure name is validated as a SQL
//! identifier (per dotted part) before use.
//!
//! [`Client::call_procedure`] takes positional input parameters (named `@p1`,
//! `@p2`, ...); [`Client::procedure`] returns a [`ProcedureBuilder`] for named
//! input and `OUTPUT` parameters. Both work in the `Ready` and `InTransaction`
//! states, so procedures compose with [`Client::begin_transaction`].
//!
//! ## Output parameters
//!
//! Declare each `OUTPUT` parameter with the matching typed setter —
//! [`ProcedureBuilder::output_int`], `output_bigint`, `output_bit`,
//! `output_float`, `output_nvarchar` (length; `0` = MAX), `output_decimal`
//! (precision, scale). After
//! [`ProcedureBuilder::execute`], read values with
//! [`ProcedureResult::get_output`] and the return code with
//! [`ProcedureResult::get_return_value`].

use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};

use crate::client::Client;
use crate::error::{Error, Result};
use crate::state::ConnectionState;
use crate::stream::ProcedureResult;

/// Builder for constructing stored procedure calls with named parameters.
///
/// Created via [`Client::procedure()`]. Supports both input and output
/// parameters with type-safe output declarations.
///
/// # Example
///
/// ```rust,no_run
/// # use mssql_client::SqlValue;
/// # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
/// let result = client.procedure("dbo.CalculateSum")?
///     .input("@a", &10i32)
///     .input("@b", &20i32)
///     .output_int("@result")
///     .execute().await?;
///
/// // Access the output parameter
/// let output = result.get_output("@result").expect("output param present");
/// assert_eq!(output.value, SqlValue::Int(30));
/// # Ok(())
/// # }
/// ```
pub struct ProcedureBuilder<'a, S: ConnectionState> {
    client: &'a mut Client<S>,
    proc_name: String,
    params: Vec<RpcParam>,
    /// First parameter-conversion failure, surfaced by `execute()`. Kept as
    /// a field so `input()` stays chainable.
    deferred_error: Option<Error>,
}

impl<'a, S: ConnectionState> ProcedureBuilder<'a, S> {
    /// Create a new procedure builder.
    ///
    /// The procedure name must already be validated by the caller.
    pub(crate) fn new(client: &'a mut Client<S>, proc_name: &str) -> Self {
        Self {
            client,
            proc_name: proc_name.to_string(),
            params: Vec::new(),
            deferred_error: None,
        }
    }

    /// Add a named input parameter.
    ///
    /// The name should include the `@` prefix (e.g., `"@id"`).
    /// The value is converted using the same logic as query parameters.
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
    /// client.procedure("dbo.UpdateUser")?
    ///     .input("@id", &42i32)
    ///     .input("@name", &"Alice")
    ///     .execute().await?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn input(&mut self, name: &str, value: &(dyn crate::ToSql + Sync)) -> &mut Self {
        // Use the shared conversion logic from params.rs.
        // If conversion fails, the error is deferred to execute() — the call
        // must NOT proceed with a substituted value: the server cannot tell
        // a placeholder apart from an intentional NULL, so the procedure
        // would silently run with wrong data (issue #157).
        match Client::<S>::convert_single_param(
            name,
            value,
            self.client.send_unicode(),
            self.client.server_collation(),
        ) {
            Ok(param) => self.params.push(param),
            Err(e) => {
                tracing::warn!(name = name, error = %e, "failed to convert input parameter");
                if self.deferred_error.is_none() {
                    self.deferred_error = Some(e);
                }
            }
        }
        self
    }

    /// Add a named output parameter of type INT.
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
    /// client.procedure("dbo.GetCount")?
    ///     .output_int("@count")
    ///     .execute().await?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn output_int(&mut self, name: &str) -> &mut Self {
        self.params
            .push(RpcParam::null(name, RpcTypeInfo::int()).as_output());
        self
    }

    /// Add a named output parameter of type BIGINT.
    pub fn output_bigint(&mut self, name: &str) -> &mut Self {
        self.params
            .push(RpcParam::null(name, RpcTypeInfo::bigint()).as_output());
        self
    }

    /// Add a named output parameter of type NVARCHAR with the given max length.
    ///
    /// Use `max_len = 0` for NVARCHAR(MAX).
    pub fn output_nvarchar(&mut self, name: &str, max_len: u16) -> &mut Self {
        let type_info = if max_len == 0 {
            RpcTypeInfo::nvarchar_max()
        } else {
            RpcTypeInfo::nvarchar(max_len)
        };
        self.params
            .push(RpcParam::null(name, type_info).as_output());
        self
    }

    /// Add a named output parameter of type BIT.
    pub fn output_bit(&mut self, name: &str) -> &mut Self {
        self.params
            .push(RpcParam::null(name, RpcTypeInfo::bit()).as_output());
        self
    }

    /// Add a named output parameter of type FLOAT (64-bit).
    pub fn output_float(&mut self, name: &str) -> &mut Self {
        self.params
            .push(RpcParam::null(name, RpcTypeInfo::float()).as_output());
        self
    }

    /// Add a named output parameter of type DECIMAL with given precision and scale.
    pub fn output_decimal(&mut self, name: &str, precision: u8, scale: u8) -> &mut Self {
        self.params
            .push(RpcParam::null(name, RpcTypeInfo::decimal(precision, scale)).as_output());
        self
    }

    /// Execute the stored procedure and return the result.
    ///
    /// Sends an RPC request to SQL Server with the accumulated parameters
    /// and reads the complete response including result sets, output
    /// parameters, and the procedure return value.
    ///
    /// # Errors
    ///
    /// Returns the first parameter-conversion error from [`input`](Self::input)
    /// before anything is sent — the procedure is never called with a
    /// substituted value.
    pub async fn execute(&mut self) -> Result<ProcedureResult> {
        if let Some(e) = self.deferred_error.take() {
            return Err(e);
        }
        let mut rpc = RpcRequest::named(&self.proc_name);
        for param in self.params.drain(..) {
            rpc = rpc.param(param);
        }

        #[cfg(feature = "otel")]
        let instrumentation = self.client.instrumentation().clone();
        #[cfg(feature = "otel")]
        let mut span = instrumentation.procedure_span(&self.proc_name);
        #[cfg(feature = "otel")]
        let timer = crate::instrumentation::OperationTimer::start("EXECUTE");

        let deadline = self.client.command_deadline();
        let canceller = self.client.connection_cancel_handle();
        let result = crate::client::run_with_deadline(
            async {
                self.client.send_rpc(&rpc).await?;
                self.client.read_procedure_result().await
            },
            deadline,
            canceller,
        )
        .await;

        #[cfg(feature = "otel")]
        match &result {
            Ok(r) => crate::instrumentation::InstrumentationContext::record_success(
                &mut span,
                Some(r.rows_affected),
            ),
            Err(e) => crate::instrumentation::InstrumentationContext::record_error(&mut span, e),
        }
        #[cfg(feature = "otel")]
        timer.finish(instrumentation.metrics(), result.is_ok());
        #[cfg(feature = "otel")]
        drop(span);

        result
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use tds_protocol::rpc::TypeInfo as RpcTypeInfo;

    #[test]
    fn test_output_int_has_by_ref_flag() {
        use tds_protocol::rpc::RpcParam;

        let param = RpcParam::null("@result", RpcTypeInfo::int()).as_output();
        assert!(param.flags.by_ref);
        assert!(param.value.is_none());
        assert_eq!(param.name, "@result");
    }

    #[test]
    fn test_output_nvarchar_max() {
        use tds_protocol::rpc::RpcParam;

        let param = RpcParam::null("@msg", RpcTypeInfo::nvarchar_max()).as_output();
        assert!(param.flags.by_ref);
        assert_eq!(param.type_info.max_length, Some(0xFFFF));
    }

    #[test]
    fn test_output_decimal_precision_scale() {
        use tds_protocol::rpc::RpcParam;

        let param = RpcParam::null("@total", RpcTypeInfo::decimal(18, 2)).as_output();
        assert!(param.flags.by_ref);
        assert_eq!(param.type_info.precision, Some(18));
        assert_eq!(param.type_info.scale, Some(2));
    }
}