Skip to main content

tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30use crate::token::Collation;
31
32/// Well-known stored procedure IDs.
33///
34/// These are special procedure IDs that SQL Server recognizes
35/// without requiring the procedure name.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u16)]
38#[non_exhaustive]
39pub enum ProcId {
40    /// sp_cursor (0x0001)
41    Cursor = 0x0001,
42    /// sp_cursoropen (0x0002)
43    CursorOpen = 0x0002,
44    /// sp_cursorprepare (0x0003)
45    CursorPrepare = 0x0003,
46    /// sp_cursorexecute (0x0004)
47    CursorExecute = 0x0004,
48    /// sp_cursorprepexec (0x0005)
49    CursorPrepExec = 0x0005,
50    /// sp_cursorunprepare (0x0006)
51    CursorUnprepare = 0x0006,
52    /// sp_cursorfetch (0x0007)
53    CursorFetch = 0x0007,
54    /// sp_cursoroption (0x0008)
55    CursorOption = 0x0008,
56    /// sp_cursorclose (0x0009)
57    CursorClose = 0x0009,
58    /// sp_executesql (0x000A) - Primary method for parameterized queries
59    ExecuteSql = 0x000A,
60    /// sp_prepare (0x000B)
61    Prepare = 0x000B,
62    /// sp_execute (0x000C)
63    Execute = 0x000C,
64    /// sp_prepexec (0x000D) - Prepare and execute in one call
65    PrepExec = 0x000D,
66    /// sp_prepexecrpc (0x000E)
67    PrepExecRpc = 0x000E,
68    /// sp_unprepare (0x000F)
69    Unprepare = 0x000F,
70}
71
72/// RPC option flags.
73#[derive(Debug, Clone, Copy, Default)]
74pub struct RpcOptionFlags {
75    /// Recompile the procedure.
76    pub with_recompile: bool,
77    /// No metadata in response.
78    pub no_metadata: bool,
79    /// Reuse metadata from previous call.
80    pub reuse_metadata: bool,
81}
82
83impl RpcOptionFlags {
84    /// Create new empty flags.
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Set with recompile flag.
90    #[must_use]
91    pub fn with_recompile(mut self, value: bool) -> Self {
92        self.with_recompile = value;
93        self
94    }
95
96    /// Encode to wire format (2 bytes).
97    pub fn encode(&self) -> u16 {
98        let mut flags = 0u16;
99        if self.with_recompile {
100            flags |= 0x0001;
101        }
102        if self.no_metadata {
103            flags |= 0x0002;
104        }
105        if self.reuse_metadata {
106            flags |= 0x0004;
107        }
108        flags
109    }
110}
111
112/// RPC parameter status flags.
113#[derive(Debug, Clone, Copy, Default)]
114pub struct ParamFlags {
115    /// Parameter is passed by reference (OUTPUT parameter).
116    pub by_ref: bool,
117    /// Parameter has a default value.
118    pub default: bool,
119    /// Parameter is encrypted (Always Encrypted).
120    pub encrypted: bool,
121}
122
123impl ParamFlags {
124    /// Create new empty flags.
125    pub fn new() -> Self {
126        Self::default()
127    }
128
129    /// Set as output parameter.
130    #[must_use]
131    pub fn output(mut self) -> Self {
132        self.by_ref = true;
133        self
134    }
135
136    /// Encode to wire format (1 byte).
137    pub fn encode(&self) -> u8 {
138        let mut flags = 0u8;
139        if self.by_ref {
140            flags |= 0x01;
141        }
142        if self.default {
143            flags |= 0x02;
144        }
145        if self.encrypted {
146            flags |= 0x08;
147        }
148        flags
149    }
150}
151
152/// TDS type information for RPC parameters.
153#[derive(Debug, Clone)]
154pub struct TypeInfo {
155    /// Type ID.
156    pub type_id: u8,
157    /// Maximum length for variable-length types.
158    pub max_length: Option<u16>,
159    /// Precision for numeric types.
160    pub precision: Option<u8>,
161    /// Scale for numeric types.
162    pub scale: Option<u8>,
163    /// Collation for string types.
164    pub collation: Option<[u8; 5]>,
165    /// TVP type name (e.g., "dbo.IntIdList") for Table-Valued Parameters.
166    pub tvp_type_name: Option<String>,
167}
168
169impl TypeInfo {
170    /// Create type info for INT.
171    pub fn int() -> Self {
172        Self {
173            type_id: 0x26, // INTNTYPE (variable-length int)
174            max_length: Some(4),
175            precision: None,
176            scale: None,
177            collation: None,
178            tvp_type_name: None,
179        }
180    }
181
182    /// Create type info for BIGINT.
183    pub fn bigint() -> Self {
184        Self {
185            type_id: 0x26, // INTNTYPE
186            max_length: Some(8),
187            precision: None,
188            scale: None,
189            collation: None,
190            tvp_type_name: None,
191        }
192    }
193
194    /// Create type info for SMALLINT.
195    pub fn smallint() -> Self {
196        Self {
197            type_id: 0x26, // INTNTYPE
198            max_length: Some(2),
199            precision: None,
200            scale: None,
201            collation: None,
202            tvp_type_name: None,
203        }
204    }
205
206    /// Create type info for TINYINT.
207    pub fn tinyint() -> Self {
208        Self {
209            type_id: 0x26, // INTNTYPE
210            max_length: Some(1),
211            precision: None,
212            scale: None,
213            collation: None,
214            tvp_type_name: None,
215        }
216    }
217
218    /// Create type info for BIT.
219    pub fn bit() -> Self {
220        Self {
221            type_id: 0x68, // BITNTYPE
222            max_length: Some(1),
223            precision: None,
224            scale: None,
225            collation: None,
226            tvp_type_name: None,
227        }
228    }
229
230    /// Create type info for FLOAT.
231    pub fn float() -> Self {
232        Self {
233            type_id: 0x6D, // FLTNTYPE
234            max_length: Some(8),
235            precision: None,
236            scale: None,
237            collation: None,
238            tvp_type_name: None,
239        }
240    }
241
242    /// Create type info for REAL.
243    pub fn real() -> Self {
244        Self {
245            type_id: 0x6D, // FLTNTYPE
246            max_length: Some(4),
247            precision: None,
248            scale: None,
249            collation: None,
250            tvp_type_name: None,
251        }
252    }
253
254    /// Create type info for NVARCHAR with max length.
255    pub fn nvarchar(max_len: u16) -> Self {
256        Self {
257            type_id: 0xE7,                 // NVARCHARTYPE
258            max_length: Some(max_len * 2), // UTF-16, so double the char count
259            precision: None,
260            scale: None,
261            // Default collation (Latin1_General_CI_AS equivalent)
262            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
263            tvp_type_name: None,
264        }
265    }
266
267    /// Create type info for NVARCHAR(MAX).
268    pub fn nvarchar_max() -> Self {
269        Self {
270            type_id: 0xE7,            // NVARCHARTYPE
271            max_length: Some(0xFFFF), // MAX indicator
272            precision: None,
273            scale: None,
274            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
275            tvp_type_name: None,
276        }
277    }
278
279    /// Default collation bytes: Latin1_General_CI_AS (LCID 0x0409, sort ID 0x34).
280    const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
281
282    /// Create type info for VARCHAR with max length (in bytes).
283    pub fn varchar(max_len: u16) -> Self {
284        Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
285    }
286
287    /// Create type info for VARCHAR with max length and explicit collation.
288    pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
289        Self {
290            type_id: 0xA7, // BIGVARCHARTYPE
291            max_length: Some(max_len),
292            precision: None,
293            scale: None,
294            collation: Some(collation),
295            tvp_type_name: None,
296        }
297    }
298
299    /// Create type info for VARCHAR(MAX).
300    pub fn varchar_max() -> Self {
301        Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
302    }
303
304    /// Create type info for VARCHAR(MAX) with explicit collation.
305    pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
306        Self {
307            type_id: 0xA7,            // BIGVARCHARTYPE
308            max_length: Some(0xFFFF), // MAX indicator
309            precision: None,
310            scale: None,
311            collation: Some(collation),
312            tvp_type_name: None,
313        }
314    }
315
316    /// Create type info for VARBINARY with max length.
317    pub fn varbinary(max_len: u16) -> Self {
318        Self {
319            type_id: 0xA5, // BIGVARBINTYPE
320            max_length: Some(max_len),
321            precision: None,
322            scale: None,
323            collation: None,
324            tvp_type_name: None,
325        }
326    }
327
328    /// Create type info for VARBINARY(MAX).
329    pub fn varbinary_max() -> Self {
330        Self {
331            type_id: 0xA5,            // BIGVARBINTYPE
332            max_length: Some(0xFFFF), // MAX indicator β€” triggers PLP encoding
333            precision: None,
334            scale: None,
335            collation: None,
336            tvp_type_name: None,
337        }
338    }
339
340    /// Create type info for UNIQUEIDENTIFIER.
341    pub fn uniqueidentifier() -> Self {
342        Self {
343            type_id: 0x24, // GUIDTYPE
344            max_length: Some(16),
345            precision: None,
346            scale: None,
347            collation: None,
348            tvp_type_name: None,
349        }
350    }
351
352    /// Create type info for DATE.
353    pub fn date() -> Self {
354        Self {
355            type_id: 0x28, // DATETYPE
356            max_length: None,
357            precision: None,
358            scale: None,
359            collation: None,
360            tvp_type_name: None,
361        }
362    }
363
364    /// Create type info for TIME.
365    pub fn time(scale: u8) -> Self {
366        Self {
367            type_id: 0x29, // TIMETYPE
368            max_length: None,
369            precision: None,
370            scale: Some(scale),
371            collation: None,
372            tvp_type_name: None,
373        }
374    }
375
376    /// Create type info for DATETIME2.
377    pub fn datetime2(scale: u8) -> Self {
378        Self {
379            type_id: 0x2A, // DATETIME2TYPE
380            max_length: None,
381            precision: None,
382            scale: Some(scale),
383            collation: None,
384            tvp_type_name: None,
385        }
386    }
387
388    /// Create type info for DATETIMEOFFSET.
389    pub fn datetimeoffset(scale: u8) -> Self {
390        Self {
391            type_id: 0x2B, // DATETIMEOFFSETTYPE
392            max_length: None,
393            precision: None,
394            scale: Some(scale),
395            collation: None,
396            tvp_type_name: None,
397        }
398    }
399
400    /// Create type info for DECIMAL.
401    pub fn decimal(precision: u8, scale: u8) -> Self {
402        Self {
403            type_id: 0x6C,        // DECIMALNTYPE
404            max_length: Some(17), // Max decimal size
405            precision: Some(precision),
406            scale: Some(scale),
407            collation: None,
408            tvp_type_name: None,
409        }
410    }
411
412    /// Create type info for MONEY (8-byte scaled integer via MONEYN / 0x6E).
413    pub fn money() -> Self {
414        Self {
415            type_id: 0x6E, // MONEYNTYPE
416            max_length: Some(8),
417            precision: None,
418            scale: None,
419            collation: None,
420            tvp_type_name: None,
421        }
422    }
423
424    /// Create type info for SMALLMONEY (4-byte scaled integer via MONEYN / 0x6E).
425    pub fn smallmoney() -> Self {
426        Self {
427            type_id: 0x6E, // MONEYNTYPE
428            max_length: Some(4),
429            precision: None,
430            scale: None,
431            collation: None,
432            tvp_type_name: None,
433        }
434    }
435
436    /// Create type info for SMALLDATETIME (4-byte days+minutes via DATETIMEN / 0x6F).
437    pub fn smalldatetime() -> Self {
438        Self {
439            type_id: 0x6F, // DATETIMENTYPE
440            max_length: Some(4),
441            precision: None,
442            scale: None,
443            collation: None,
444            tvp_type_name: None,
445        }
446    }
447
448    /// Create type info for a Table-Valued Parameter.
449    ///
450    /// # Arguments
451    /// * `type_name` - The fully qualified table type name (e.g., "dbo.IntIdList")
452    pub fn tvp(type_name: impl Into<String>) -> Self {
453        Self {
454            type_id: 0xF3, // TVP type
455            max_length: None,
456            precision: None,
457            scale: None,
458            collation: None,
459            tvp_type_name: Some(type_name.into()),
460        }
461    }
462
463    /// Encode type info to buffer.
464    pub fn encode(&self, buf: &mut BytesMut) {
465        // TVP (0xF3) has type_id embedded in the value data itself
466        // (written by TvpEncoder::encode_metadata), so don't write it here
467        if self.type_id != 0xF3 {
468            buf.put_u8(self.type_id);
469        }
470
471        // Variable-length types need max length
472        match self.type_id {
473            0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
474                // INTNTYPE, BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
475                if let Some(len) = self.max_length {
476                    buf.put_u8(len as u8);
477                }
478            }
479            0xE7 | 0xA7 | 0xA5 | 0xEF => {
480                // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
481                if let Some(len) = self.max_length {
482                    buf.put_u16_le(len);
483                }
484                // Collation for string types
485                if let Some(collation) = self.collation {
486                    buf.put_slice(&collation);
487                }
488            }
489            0x24 => {
490                // GUIDTYPE
491                if let Some(len) = self.max_length {
492                    buf.put_u8(len as u8);
493                }
494            }
495            0x29..=0x2B => {
496                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
497                if let Some(scale) = self.scale {
498                    buf.put_u8(scale);
499                }
500            }
501            0x6C | 0x6A => {
502                // DECIMALNTYPE, NUMERICNTYPE
503                if let Some(len) = self.max_length {
504                    buf.put_u8(len as u8);
505                }
506                if let Some(precision) = self.precision {
507                    buf.put_u8(precision);
508                }
509                if let Some(scale) = self.scale {
510                    buf.put_u8(scale);
511                }
512            }
513            _ => {}
514        }
515    }
516}
517
518/// An RPC parameter.
519#[derive(Debug, Clone)]
520pub struct RpcParam {
521    /// Parameter name (can be empty for positional params).
522    pub name: String,
523    /// Status flags.
524    pub flags: ParamFlags,
525    /// Type information.
526    pub type_info: TypeInfo,
527    /// Parameter value (raw bytes).
528    pub value: Option<Bytes>,
529}
530
531impl RpcParam {
532    /// Create a new parameter with a value.
533    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
534        Self {
535            name: name.into(),
536            flags: ParamFlags::default(),
537            type_info,
538            value: Some(value),
539        }
540    }
541
542    /// Create a NULL parameter.
543    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
544        Self {
545            name: name.into(),
546            flags: ParamFlags::default(),
547            type_info,
548            value: None,
549        }
550    }
551
552    /// Create an INT parameter.
553    pub fn int(name: impl Into<String>, value: i32) -> Self {
554        let mut buf = BytesMut::with_capacity(4);
555        buf.put_i32_le(value);
556        Self::new(name, TypeInfo::int(), buf.freeze())
557    }
558
559    /// Create a BIGINT parameter.
560    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
561        let mut buf = BytesMut::with_capacity(8);
562        buf.put_i64_le(value);
563        Self::new(name, TypeInfo::bigint(), buf.freeze())
564    }
565
566    /// Create an NVARCHAR parameter.
567    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
568        let mut buf = BytesMut::new();
569        let mut code_units: usize = 0;
570        for code_unit in value.encode_utf16() {
571            buf.put_u16_le(code_unit);
572            code_units += 1;
573        }
574        // NVARCHAR length is measured in UTF-16 code units, not Rust chars β€”
575        // supplementary characters (emoji, CJK extension B) encode to a surrogate
576        // pair (2 code units) but count as 1 `char`. Using chars().count() here
577        // under-reports the buffer length and the server rejects the RPC with
578        // "Data type 0xE7 has an invalid data length or metadata length."
579        let type_info = if code_units > 4000 {
580            TypeInfo::nvarchar_max()
581        } else {
582            TypeInfo::nvarchar(code_units.max(1) as u16)
583        };
584        Self::new(name, type_info, buf.freeze())
585    }
586
587    /// Create a VARCHAR parameter.
588    ///
589    /// Encodes the string as single-byte characters using Windows-1252 encoding
590    /// (when the `encoding` feature is enabled) or Latin-1 fallback. Characters
591    /// not representable in the target encoding are replaced with `?`.
592    ///
593    /// Use this instead of [`nvarchar`](Self::nvarchar) when
594    /// `SendStringParametersAsUnicode=false` to allow SQL Server to use
595    /// index seeks on VARCHAR columns.
596    pub fn varchar(name: impl Into<String>, value: &str) -> Self {
597        let encoded = Self::encode_varchar_bytes(value);
598        let byte_len = encoded.len();
599        let type_info = if byte_len > 8000 {
600            TypeInfo::varchar_max()
601        } else {
602            TypeInfo::varchar(byte_len.max(1) as u16)
603        };
604        Self::new(name, type_info, Bytes::from(encoded))
605    }
606
607    /// Encode a string as single-byte VARCHAR data using the default
608    /// Windows-1252 encoding (or Latin-1 fallback without the `encoding` feature).
609    fn encode_varchar_bytes(value: &str) -> Vec<u8> {
610        crate::collation::encode_str_for_collation(value, None)
611    }
612
613    /// Create a VARCHAR parameter using the server's collation for encoding.
614    ///
615    /// Uses the collation's character encoding instead of the default Windows-1252.
616    /// For UTF-8 collations (SQL Server 2019+), the string bytes are used directly.
617    pub fn varchar_with_collation(
618        name: impl Into<String>,
619        value: &str,
620        collation: &Collation,
621    ) -> Self {
622        let collation_bytes = collation.to_bytes();
623        let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
624        let byte_len = encoded.len();
625        let type_info = if byte_len > 8000 {
626            TypeInfo::varchar_max_with_collation(collation_bytes)
627        } else {
628            TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
629        };
630        Self::new(name, type_info, Bytes::from(encoded))
631    }
632
633    /// Encode a string using the collation's character encoding.
634    fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
635        crate::collation::encode_str_for_collation(value, Some(collation))
636    }
637
638    /// Mark as output parameter.
639    #[must_use]
640    pub fn as_output(mut self) -> Self {
641        self.flags = self.flags.output();
642        self
643    }
644
645    /// Encode the parameter to buffer.
646    pub fn encode(&self, buf: &mut BytesMut) {
647        // Parameter name (B_VARCHAR - length-prefixed)
648        let name_len = self.name.encode_utf16().count() as u8;
649        buf.put_u8(name_len);
650        if name_len > 0 {
651            for code_unit in self.name.encode_utf16() {
652                buf.put_u16_le(code_unit);
653            }
654        }
655
656        // Status flags
657        buf.put_u8(self.flags.encode());
658
659        // Type info
660        self.type_info.encode(buf);
661
662        // Value
663        if let Some(ref value) = self.value {
664            // Length prefix based on type
665            match self.type_info.type_id {
666                0x26 => {
667                    // INTNTYPE
668                    buf.put_u8(value.len() as u8);
669                    buf.put_slice(value);
670                }
671                0x68 | 0x6D | 0x6E | 0x6F => {
672                    // BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
673                    buf.put_u8(value.len() as u8);
674                    buf.put_slice(value);
675                }
676                0xE7 | 0xA7 | 0xA5 => {
677                    // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE
678                    if self.type_info.max_length == Some(0xFFFF) {
679                        // MAX type - use PLP format
680                        // For simplicity, send as single chunk
681                        let total_len = value.len() as u64;
682                        buf.put_u64_le(total_len);
683                        buf.put_u32_le(value.len() as u32);
684                        buf.put_slice(value);
685                        buf.put_u32_le(0); // Terminator
686                    } else {
687                        buf.put_u16_le(value.len() as u16);
688                        buf.put_slice(value);
689                    }
690                }
691                0x24 => {
692                    // GUIDTYPE
693                    buf.put_u8(value.len() as u8);
694                    buf.put_slice(value);
695                }
696                0x28..=0x2B => {
697                    // DATE, TIME, DATETIME2, DATETIMEOFFSET
698                    buf.put_u8(value.len() as u8);
699                    buf.put_slice(value);
700                }
701                0x6C => {
702                    // DECIMALNTYPE
703                    buf.put_u8(value.len() as u8);
704                    buf.put_slice(value);
705                }
706                0xF3 => {
707                    // TVP (Table-Valued Parameter)
708                    // TVP values are self-delimiting: they contain complete metadata,
709                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
710                    buf.put_slice(value);
711                }
712                _ => {
713                    // Generic: assume length-prefixed
714                    buf.put_u8(value.len() as u8);
715                    buf.put_slice(value);
716                }
717            }
718        } else {
719            // NULL value
720            match self.type_info.type_id {
721                0xE7 | 0xA7 | 0xA5 => {
722                    // Variable-length types use 0xFFFF for NULL
723                    if self.type_info.max_length == Some(0xFFFF) {
724                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
725                    } else {
726                        buf.put_u16_le(0xFFFF);
727                    }
728                }
729                _ => {
730                    buf.put_u8(0); // Zero-length for NULL
731                }
732            }
733        }
734    }
735}
736
737/// RPC request builder.
738#[derive(Debug, Clone)]
739pub struct RpcRequest {
740    /// Procedure name (if using named procedure).
741    proc_name: Option<String>,
742    /// Procedure ID (if using well-known procedure).
743    proc_id: Option<ProcId>,
744    /// Option flags.
745    options: RpcOptionFlags,
746    /// Parameters.
747    params: Vec<RpcParam>,
748}
749
750impl RpcRequest {
751    /// Create a new RPC request for a named procedure.
752    pub fn named(proc_name: impl Into<String>) -> Self {
753        Self {
754            proc_name: Some(proc_name.into()),
755            proc_id: None,
756            options: RpcOptionFlags::default(),
757            params: Vec::new(),
758        }
759    }
760
761    /// Create a new RPC request for a well-known procedure.
762    pub fn by_id(proc_id: ProcId) -> Self {
763        Self {
764            proc_name: None,
765            proc_id: Some(proc_id),
766            options: RpcOptionFlags::default(),
767            params: Vec::new(),
768        }
769    }
770
771    /// Create an sp_executesql request.
772    ///
773    /// This is the primary method for parameterized queries.
774    ///
775    /// # Example
776    ///
777    /// ```
778    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
779    ///
780    /// let rpc = RpcRequest::execute_sql(
781    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
782    ///     vec![
783    ///         RpcParam::int("@p1", 42),
784    ///         RpcParam::nvarchar("@p2", "Alice"),
785    ///     ],
786    /// );
787    /// ```
788    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
789        let mut request = Self::by_id(ProcId::ExecuteSql);
790
791        // First parameter: the SQL statement (NVARCHAR(MAX))
792        request.params.push(RpcParam::nvarchar("", sql));
793
794        // Second parameter: parameter declarations
795        if !params.is_empty() {
796            let declarations = Self::build_param_declarations(&params);
797            request.params.push(RpcParam::nvarchar("", &declarations));
798        }
799
800        // Add the actual parameters
801        request.params.extend(params);
802
803        request
804    }
805
806    /// Build parameter declaration string for sp_executesql.
807    fn build_param_declarations(params: &[RpcParam]) -> String {
808        params
809            .iter()
810            .map(|p| {
811                let name = if p.name.starts_with('@') {
812                    p.name.clone()
813                } else if p.name.is_empty() {
814                    // Generate positional name
815                    format!(
816                        "@p{}",
817                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
818                    )
819                } else {
820                    format!("@{}", p.name)
821                };
822
823                let type_name: String = match p.type_info.type_id {
824                    0x26 => match p.type_info.max_length {
825                        Some(1) => "tinyint".to_string(),
826                        Some(2) => "smallint".to_string(),
827                        Some(4) => "int".to_string(),
828                        Some(8) => "bigint".to_string(),
829                        _ => "int".to_string(),
830                    },
831                    0x68 => "bit".to_string(),
832                    0x6D => match p.type_info.max_length {
833                        Some(4) => "real".to_string(),
834                        _ => "float".to_string(),
835                    },
836                    0xE7 => {
837                        if p.type_info.max_length == Some(0xFFFF) {
838                            "nvarchar(max)".to_string()
839                        } else {
840                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
841                            format!("nvarchar({len})")
842                        }
843                    }
844                    0xA7 => {
845                        if p.type_info.max_length == Some(0xFFFF) {
846                            "varchar(max)".to_string()
847                        } else {
848                            let len = p.type_info.max_length.unwrap_or(8000);
849                            format!("varchar({len})")
850                        }
851                    }
852                    0xA5 => {
853                        if p.type_info.max_length == Some(0xFFFF) {
854                            "varbinary(max)".to_string()
855                        } else {
856                            let len = p.type_info.max_length.unwrap_or(8000);
857                            format!("varbinary({len})")
858                        }
859                    }
860                    0x24 => "uniqueidentifier".to_string(),
861                    0x28 => "date".to_string(),
862                    0x29 => {
863                        let scale = p.type_info.scale.unwrap_or(7);
864                        format!("time({scale})")
865                    }
866                    0x2A => {
867                        let scale = p.type_info.scale.unwrap_or(7);
868                        format!("datetime2({scale})")
869                    }
870                    0x2B => {
871                        let scale = p.type_info.scale.unwrap_or(7);
872                        format!("datetimeoffset({scale})")
873                    }
874                    0x6C => {
875                        let precision = p.type_info.precision.unwrap_or(18);
876                        let scale = p.type_info.scale.unwrap_or(0);
877                        format!("decimal({precision}, {scale})")
878                    }
879                    0x6E => match p.type_info.max_length {
880                        Some(4) => "smallmoney".to_string(),
881                        _ => "money".to_string(),
882                    },
883                    0x6F => match p.type_info.max_length {
884                        Some(4) => "smalldatetime".to_string(),
885                        _ => "datetime".to_string(),
886                    },
887                    0xF3 => {
888                        // TVP - Table-Valued Parameter
889                        // Must be declared with the table type name and READONLY
890                        if let Some(ref tvp_name) = p.type_info.tvp_type_name {
891                            format!("{tvp_name} READONLY")
892                        } else {
893                            // Fallback if type name is missing (shouldn't happen)
894                            "sql_variant".to_string()
895                        }
896                    }
897                    _ => "sql_variant".to_string(),
898                };
899
900                format!("{name} {type_name}")
901            })
902            .collect::<Vec<_>>()
903            .join(", ")
904    }
905
906    /// Create an sp_prepare request.
907    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
908        let mut request = Self::by_id(ProcId::Prepare);
909
910        // OUT: handle (INT)
911        request
912            .params
913            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
914
915        // Param declarations
916        let declarations = Self::build_param_declarations(params);
917        request
918            .params
919            .push(RpcParam::nvarchar("@params", &declarations));
920
921        // SQL statement
922        request.params.push(RpcParam::nvarchar("@stmt", sql));
923
924        // Options (1 = WITH RECOMPILE)
925        request.params.push(RpcParam::int("@options", 1));
926
927        request
928    }
929
930    /// Create an sp_execute request.
931    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
932        let mut request = Self::by_id(ProcId::Execute);
933
934        // Handle from sp_prepare
935        request.params.push(RpcParam::int("@handle", handle));
936
937        // Add parameters
938        request.params.extend(params);
939
940        request
941    }
942
943    /// Create an sp_unprepare request.
944    pub fn unprepare(handle: i32) -> Self {
945        let mut request = Self::by_id(ProcId::Unprepare);
946        request.params.push(RpcParam::int("@handle", handle));
947        request
948    }
949
950    /// Set option flags.
951    #[must_use]
952    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
953        self.options = options;
954        self
955    }
956
957    /// Add a parameter.
958    #[must_use]
959    pub fn param(mut self, param: RpcParam) -> Self {
960        self.params.push(param);
961        self
962    }
963
964    /// Encode the RPC request to bytes (auto-commit mode).
965    ///
966    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
967    #[must_use]
968    pub fn encode(&self) -> Bytes {
969        self.encode_with_transaction(0)
970    }
971
972    /// Encode the RPC request with a transaction descriptor.
973    ///
974    /// Per MS-TDS spec, when executing within an explicit transaction:
975    /// - The `transaction_descriptor` MUST be the value returned by the server
976    ///   in the BeginTransaction EnvChange token.
977    /// - For auto-commit mode (no explicit transaction), use 0.
978    ///
979    /// # Arguments
980    ///
981    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
982    ///   or 0 for auto-commit mode.
983    #[must_use]
984    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
985        let mut buf = BytesMut::with_capacity(256);
986
987        // ALL_HEADERS - TDS 7.2+ requires this section
988        // Total length placeholder (will be filled in)
989        let all_headers_start = buf.len();
990        buf.put_u32_le(0); // Total length placeholder
991
992        // Transaction descriptor header (required for RPC)
993        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
994        buf.put_u32_le(18); // Header length
995        buf.put_u16_le(0x0002); // Header type: transaction descriptor
996        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
997        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
998
999        // Fill in ALL_HEADERS total length
1000        let all_headers_len = buf.len() - all_headers_start;
1001        let len_bytes = (all_headers_len as u32).to_le_bytes();
1002        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1003
1004        // Procedure name or ID
1005        if let Some(proc_id) = self.proc_id {
1006            // Use PROCID format
1007            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
1008            buf.put_u16_le(proc_id as u16);
1009        } else if let Some(ref proc_name) = self.proc_name {
1010            // Use procedure name
1011            let name_len = proc_name.encode_utf16().count() as u16;
1012            buf.put_u16_le(name_len);
1013            write_utf16_string(&mut buf, proc_name);
1014        }
1015
1016        // Option flags
1017        buf.put_u16_le(self.options.encode());
1018
1019        // Parameters
1020        for param in &self.params {
1021            param.encode(&mut buf);
1022        }
1023
1024        buf.freeze()
1025    }
1026}
1027
1028#[cfg(test)]
1029#[allow(clippy::unwrap_used)]
1030mod tests {
1031    use super::*;
1032
1033    #[test]
1034    fn test_proc_id_values() {
1035        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1036        assert_eq!(ProcId::Prepare as u16, 0x000B);
1037        assert_eq!(ProcId::Execute as u16, 0x000C);
1038        assert_eq!(ProcId::Unprepare as u16, 0x000F);
1039    }
1040
1041    #[test]
1042    fn test_option_flags_encode() {
1043        let flags = RpcOptionFlags::new().with_recompile(true);
1044        assert_eq!(flags.encode(), 0x0001);
1045    }
1046
1047    #[test]
1048    fn test_param_flags_encode() {
1049        let flags = ParamFlags::new().output();
1050        assert_eq!(flags.encode(), 0x01);
1051    }
1052
1053    #[test]
1054    fn test_int_param() {
1055        let param = RpcParam::int("@p1", 42);
1056        assert_eq!(param.name, "@p1");
1057        assert_eq!(param.type_info.type_id, 0x26);
1058        assert!(param.value.is_some());
1059    }
1060
1061    #[test]
1062    fn test_nvarchar_param() {
1063        let param = RpcParam::nvarchar("@name", "Alice");
1064        assert_eq!(param.name, "@name");
1065        assert_eq!(param.type_info.type_id, 0xE7);
1066        // UTF-16 encoded "Alice" = 10 bytes
1067        assert_eq!(param.value.as_ref().unwrap().len(), 10);
1068    }
1069
1070    #[test]
1071    fn test_nvarchar_param_surrogate_pair_length() {
1072        // 🌍 is a supplementary character β€” 1 Rust char but 2 UTF-16 code units
1073        // (4 bytes). TypeInfo.max_length is stored doubled internally, so
1074        // the metadata must declare 2 code units for the buffer to match.
1075        let param = RpcParam::nvarchar("@p", "🌍");
1076        assert_eq!(param.value.as_ref().unwrap().len(), 4);
1077        // TypeInfo::nvarchar(n) stores max_length as n*2 bytes.
1078        assert_eq!(param.type_info.max_length, Some(4));
1079
1080        let param = RpcParam::nvarchar("@p", "Hello δΈ–η•Œ 🌍");
1081        // "Hello δΈ–η•Œ " = 9 BMP code units + 🌍 = 2 surrogate units β†’ 11 code units, 22 bytes
1082        assert_eq!(param.value.as_ref().unwrap().len(), 22);
1083        assert_eq!(param.type_info.max_length, Some(22));
1084    }
1085
1086    #[test]
1087    fn test_execute_sql_request() {
1088        let rpc = RpcRequest::execute_sql(
1089            "SELECT * FROM users WHERE id = @p1",
1090            vec![RpcParam::int("@p1", 42)],
1091        );
1092
1093        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1094        // SQL statement + param declarations + actual params
1095        assert_eq!(rpc.params.len(), 3);
1096    }
1097
1098    #[test]
1099    fn test_param_declarations() {
1100        let params = vec![
1101            RpcParam::int("@p1", 42),
1102            RpcParam::nvarchar("@name", "Alice"),
1103        ];
1104
1105        let decls = RpcRequest::build_param_declarations(&params);
1106        assert!(decls.contains("@p1 int"));
1107        assert!(decls.contains("@name nvarchar"));
1108    }
1109
1110    #[test]
1111    fn test_rpc_encode_not_empty() {
1112        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1113        let encoded = rpc.encode();
1114        assert!(!encoded.is_empty());
1115    }
1116
1117    #[test]
1118    fn test_prepare_request() {
1119        let rpc = RpcRequest::prepare(
1120            "SELECT * FROM users WHERE id = @p1",
1121            &[RpcParam::int("@p1", 0)],
1122        );
1123
1124        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1125        // handle (output), params, stmt, options
1126        assert_eq!(rpc.params.len(), 4);
1127        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
1128    }
1129
1130    #[test]
1131    fn test_execute_request() {
1132        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1133
1134        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1135        assert_eq!(rpc.params.len(), 2); // handle + param
1136    }
1137
1138    #[test]
1139    fn test_unprepare_request() {
1140        let rpc = RpcRequest::unprepare(123);
1141
1142        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1143        assert_eq!(rpc.params.len(), 1); // just the handle
1144    }
1145
1146    #[test]
1147    fn test_varchar_param() {
1148        let param = RpcParam::varchar("@name", "Alice");
1149        assert_eq!(param.name, "@name");
1150        assert_eq!(param.type_info.type_id, 0xA7);
1151        // Single-byte encoded "Alice" = 5 bytes
1152        assert_eq!(param.value.as_ref().unwrap().len(), 5);
1153        assert_eq!(&param.value.as_ref().unwrap()[..], b"Alice");
1154    }
1155
1156    #[test]
1157    fn test_varchar_param_max() {
1158        // String > 8000 bytes should use VARCHAR(MAX)
1159        let long_str = "a".repeat(9000);
1160        let param = RpcParam::varchar("@big", &long_str);
1161        assert_eq!(param.type_info.type_id, 0xA7);
1162        assert_eq!(param.type_info.max_length, Some(0xFFFF));
1163        assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1164    }
1165
1166    #[test]
1167    fn test_varchar_param_declarations() {
1168        let params = vec![
1169            RpcParam::int("@p1", 42),
1170            RpcParam::varchar("@name", "Alice"),
1171        ];
1172
1173        let decls = RpcRequest::build_param_declarations(&params);
1174        assert!(decls.contains("@p1 int"));
1175        assert!(decls.contains("@name varchar(5)"));
1176    }
1177
1178    #[test]
1179    fn test_varchar_type_info_has_collation() {
1180        let ti = TypeInfo::varchar(100);
1181        assert_eq!(ti.type_id, 0xA7);
1182        assert_eq!(ti.max_length, Some(100));
1183        assert!(ti.collation.is_some());
1184    }
1185
1186    #[test]
1187    fn test_varchar_encode_round_trip() {
1188        // Verify the encoded param can be serialized without panics
1189        let param = RpcParam::varchar("@val", "test value");
1190        let mut buf = bytes::BytesMut::new();
1191        param.encode(&mut buf);
1192        assert!(!buf.is_empty());
1193    }
1194
1195    #[test]
1196    fn test_collation_round_trip() {
1197        let collation = Collation {
1198            lcid: 0x00D0_0409,
1199            sort_id: 0x34,
1200        };
1201        let bytes = collation.to_bytes();
1202        assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1203
1204        let restored = Collation::from_bytes(&bytes);
1205        assert_eq!(restored.lcid, collation.lcid);
1206        assert_eq!(restored.sort_id, collation.sort_id);
1207    }
1208
1209    #[test]
1210    fn test_varchar_with_collation_uses_custom_collation_bytes() {
1211        // Chinese_PRC_CI_AS collation (LCID 0x0804)
1212        let collation = Collation {
1213            lcid: 0x0804,
1214            sort_id: 0,
1215        };
1216        let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1217        assert_eq!(param.type_info.type_id, 0xA7);
1218        // Collation bytes should match the custom collation, not default Latin1
1219        assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1220    }
1221
1222    #[test]
1223    fn test_money_type_info() {
1224        let ti = TypeInfo::money();
1225        assert_eq!(ti.type_id, 0x6E);
1226        assert_eq!(ti.max_length, Some(8));
1227    }
1228
1229    #[test]
1230    fn test_smallmoney_type_info() {
1231        let ti = TypeInfo::smallmoney();
1232        assert_eq!(ti.type_id, 0x6E);
1233        assert_eq!(ti.max_length, Some(4));
1234    }
1235
1236    #[test]
1237    fn test_smalldatetime_type_info() {
1238        let ti = TypeInfo::smalldatetime();
1239        assert_eq!(ti.type_id, 0x6F);
1240        assert_eq!(ti.max_length, Some(4));
1241    }
1242
1243    #[test]
1244    fn test_money_param_declarations() {
1245        let decls = RpcRequest::build_param_declarations(&[
1246            RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1247            RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1248            RpcParam::new(
1249                "@sdt",
1250                TypeInfo::smalldatetime(),
1251                Bytes::from_static(&[0u8; 4]),
1252            ),
1253        ]);
1254        assert!(decls.contains("@m money"), "got: {decls}");
1255        assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1256        assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1257    }
1258
1259    #[test]
1260    fn test_money_typeinfo_encodes_max_length_byte() {
1261        let mut buf = bytes::BytesMut::new();
1262        TypeInfo::money().encode(&mut buf);
1263        // type_id 0x6E + max_length byte 0x08
1264        assert_eq!(&buf[..], &[0x6E, 0x08]);
1265
1266        let mut buf = bytes::BytesMut::new();
1267        TypeInfo::smallmoney().encode(&mut buf);
1268        assert_eq!(&buf[..], &[0x6E, 0x04]);
1269
1270        let mut buf = bytes::BytesMut::new();
1271        TypeInfo::smalldatetime().encode(&mut buf);
1272        assert_eq!(&buf[..], &[0x6F, 0x04]);
1273    }
1274
1275    #[test]
1276    fn test_varchar_with_collation_default_vs_custom_differ() {
1277        let default_param = RpcParam::varchar("@val", "test");
1278        let custom_collation = Collation {
1279            lcid: 0x0419, // Russian
1280            sort_id: 0,
1281        };
1282        let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1283        // The collation bytes should differ
1284        assert_ne!(
1285            default_param.type_info.collation,
1286            custom_param.type_info.collation
1287        );
1288    }
1289}