Skip to main content

mssql_client/
procedure.rs

1//! Stored procedure builder for constructing and executing RPC calls.
2//!
3//! Provides a builder pattern for calling stored procedures with full
4//! control over named parameters (both input and output).
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! // Simple positional call (input parameters only)
10//! let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
11//!
12//! // Builder with named input/output parameters
13//! let result = client.procedure("dbo.CalculateSum")?
14//!     .input("@a", &10i32)
15//!     .input("@b", &20i32)
16//!     .output_int("@result")
17//!     .execute().await?;
18//!
19//! let sum = result.get_output("@result").unwrap();
20//! ```
21
22use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
23
24use crate::client::Client;
25use crate::error::Result;
26use crate::state::ConnectionState;
27use crate::stream::ProcedureResult;
28
29/// Builder for constructing stored procedure calls with named parameters.
30///
31/// Created via [`Client::procedure()`]. Supports both input and output
32/// parameters with type-safe output declarations.
33///
34/// # Example
35///
36/// ```rust,ignore
37/// let result = client.procedure("dbo.CalculateSum")?
38///     .input("@a", &10i32)
39///     .input("@b", &20i32)
40///     .output_int("@result")
41///     .execute().await?;
42///
43/// // Access the output parameter
44/// let output = result.get_output("@result").expect("output param present");
45/// assert_eq!(output.value, SqlValue::Int(30));
46/// ```
47pub struct ProcedureBuilder<'a, S: ConnectionState> {
48    client: &'a mut Client<S>,
49    proc_name: String,
50    params: Vec<RpcParam>,
51}
52
53impl<'a, S: ConnectionState> ProcedureBuilder<'a, S> {
54    /// Create a new procedure builder.
55    ///
56    /// The procedure name must already be validated by the caller.
57    pub(crate) fn new(client: &'a mut Client<S>, proc_name: &str) -> Self {
58        Self {
59            client,
60            proc_name: proc_name.to_string(),
61            params: Vec::new(),
62        }
63    }
64
65    /// Add a named input parameter.
66    ///
67    /// The name should include the `@` prefix (e.g., `"@id"`).
68    /// The value is converted using the same logic as query parameters.
69    ///
70    /// # Example
71    ///
72    /// ```rust,ignore
73    /// client.procedure("dbo.UpdateUser")?
74    ///     .input("@id", &42i32)
75    ///     .input("@name", &"Alice")
76    ///     .execute().await?;
77    /// ```
78    pub fn input(&mut self, name: &str, value: &(dyn crate::ToSql + Sync)) -> &mut Self {
79        // Use the shared conversion logic from params.rs.
80        // If conversion fails, the error is deferred to execute().
81        match Client::<S>::convert_single_param(
82            name,
83            value,
84            self.client.send_unicode(),
85            self.client.server_collation(),
86        ) {
87            Ok(param) => self.params.push(param),
88            Err(e) => {
89                tracing::warn!(name = name, error = %e, "failed to convert input parameter");
90                // Store a null placeholder so parameter ordering is preserved.
91                // The error will surface if the server rejects the call.
92                self.params
93                    .push(RpcParam::null(name, RpcTypeInfo::nvarchar(1)));
94            }
95        }
96        self
97    }
98
99    /// Add a named output parameter of type INT.
100    ///
101    /// # Example
102    ///
103    /// ```rust,ignore
104    /// client.procedure("dbo.GetCount")?
105    ///     .output_int("@count")
106    ///     .execute().await?;
107    /// ```
108    pub fn output_int(&mut self, name: &str) -> &mut Self {
109        self.params
110            .push(RpcParam::null(name, RpcTypeInfo::int()).as_output());
111        self
112    }
113
114    /// Add a named output parameter of type BIGINT.
115    pub fn output_bigint(&mut self, name: &str) -> &mut Self {
116        self.params
117            .push(RpcParam::null(name, RpcTypeInfo::bigint()).as_output());
118        self
119    }
120
121    /// Add a named output parameter of type NVARCHAR with the given max length.
122    ///
123    /// Use `max_len = 0` for NVARCHAR(MAX).
124    pub fn output_nvarchar(&mut self, name: &str, max_len: u16) -> &mut Self {
125        let type_info = if max_len == 0 {
126            RpcTypeInfo::nvarchar_max()
127        } else {
128            RpcTypeInfo::nvarchar(max_len)
129        };
130        self.params
131            .push(RpcParam::null(name, type_info).as_output());
132        self
133    }
134
135    /// Add a named output parameter of type BIT.
136    pub fn output_bit(&mut self, name: &str) -> &mut Self {
137        self.params
138            .push(RpcParam::null(name, RpcTypeInfo::bit()).as_output());
139        self
140    }
141
142    /// Add a named output parameter of type FLOAT (64-bit).
143    pub fn output_float(&mut self, name: &str) -> &mut Self {
144        self.params
145            .push(RpcParam::null(name, RpcTypeInfo::float()).as_output());
146        self
147    }
148
149    /// Add a named output parameter of type DECIMAL with given precision and scale.
150    pub fn output_decimal(&mut self, name: &str, precision: u8, scale: u8) -> &mut Self {
151        self.params
152            .push(RpcParam::null(name, RpcTypeInfo::decimal(precision, scale)).as_output());
153        self
154    }
155
156    /// Add a named output parameter with a raw `TypeInfo` for uncommon types.
157    ///
158    /// This is an escape hatch for types not covered by the typed output methods.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// use tds_protocol::rpc::TypeInfo;
164    ///
165    /// client.procedure("dbo.GetGuid")?
166    ///     .output_raw("@id", TypeInfo::uniqueidentifier())
167    ///     .execute().await?;
168    /// ```
169    pub fn output_raw(&mut self, name: &str, type_info: RpcTypeInfo) -> &mut Self {
170        self.params
171            .push(RpcParam::null(name, type_info).as_output());
172        self
173    }
174
175    /// Execute the stored procedure and return the result.
176    ///
177    /// Sends an RPC request to SQL Server with the accumulated parameters
178    /// and reads the complete response including result sets, output
179    /// parameters, and the procedure return value.
180    pub async fn execute(&mut self) -> Result<ProcedureResult> {
181        let mut rpc = RpcRequest::named(&self.proc_name);
182        for param in self.params.drain(..) {
183            rpc = rpc.param(param);
184        }
185
186        self.client.send_rpc(&rpc).await?;
187        self.client.read_procedure_result().await
188    }
189}
190
191#[cfg(test)]
192#[allow(clippy::unwrap_used)]
193mod tests {
194    use tds_protocol::rpc::TypeInfo as RpcTypeInfo;
195
196    #[test]
197    fn test_output_int_has_by_ref_flag() {
198        use tds_protocol::rpc::RpcParam;
199
200        let param = RpcParam::null("@result", RpcTypeInfo::int()).as_output();
201        assert!(param.flags.by_ref);
202        assert!(param.value.is_none());
203        assert_eq!(param.name, "@result");
204    }
205
206    #[test]
207    fn test_output_nvarchar_max() {
208        use tds_protocol::rpc::RpcParam;
209
210        let param = RpcParam::null("@msg", RpcTypeInfo::nvarchar_max()).as_output();
211        assert!(param.flags.by_ref);
212        assert_eq!(param.type_info.max_length, Some(0xFFFF));
213    }
214
215    #[test]
216    fn test_output_decimal_precision_scale() {
217        use tds_protocol::rpc::RpcParam;
218
219        let param = RpcParam::null("@total", RpcTypeInfo::decimal(18, 2)).as_output();
220        assert!(param.flags.by_ref);
221        assert_eq!(param.type_info.precision, Some(18));
222        assert_eq!(param.type_info.scale, Some(2));
223    }
224}