Skip to main content

tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30use crate::token::Collation;
31
32/// Well-known stored procedure IDs.
33///
34/// These are special procedure IDs that SQL Server recognizes
35/// without requiring the procedure name.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u16)]
38#[non_exhaustive]
39pub enum ProcId {
40    /// sp_cursor (0x0001)
41    Cursor = 0x0001,
42    /// sp_cursoropen (0x0002)
43    CursorOpen = 0x0002,
44    /// sp_cursorprepare (0x0003)
45    CursorPrepare = 0x0003,
46    /// sp_cursorexecute (0x0004)
47    CursorExecute = 0x0004,
48    /// sp_cursorprepexec (0x0005)
49    CursorPrepExec = 0x0005,
50    /// sp_cursorunprepare (0x0006)
51    CursorUnprepare = 0x0006,
52    /// sp_cursorfetch (0x0007)
53    CursorFetch = 0x0007,
54    /// sp_cursoroption (0x0008)
55    CursorOption = 0x0008,
56    /// sp_cursorclose (0x0009)
57    CursorClose = 0x0009,
58    /// sp_executesql (0x000A) - Primary method for parameterized queries
59    ExecuteSql = 0x000A,
60    /// sp_prepare (0x000B)
61    Prepare = 0x000B,
62    /// sp_execute (0x000C)
63    Execute = 0x000C,
64    /// sp_prepexec (0x000D) - Prepare and execute in one call
65    PrepExec = 0x000D,
66    /// sp_prepexecrpc (0x000E)
67    PrepExecRpc = 0x000E,
68    /// sp_unprepare (0x000F)
69    Unprepare = 0x000F,
70}
71
72/// RPC option flags.
73#[derive(Debug, Clone, Copy, Default)]
74pub struct RpcOptionFlags {
75    /// Recompile the procedure.
76    pub with_recompile: bool,
77    /// No metadata in response.
78    pub no_metadata: bool,
79    /// Reuse metadata from previous call.
80    pub reuse_metadata: bool,
81}
82
83impl RpcOptionFlags {
84    /// Create new empty flags.
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Set with recompile flag.
90    #[must_use]
91    pub fn with_recompile(mut self, value: bool) -> Self {
92        self.with_recompile = value;
93        self
94    }
95
96    /// Encode to wire format (2 bytes).
97    pub fn encode(&self) -> u16 {
98        let mut flags = 0u16;
99        if self.with_recompile {
100            flags |= 0x0001;
101        }
102        if self.no_metadata {
103            flags |= 0x0002;
104        }
105        if self.reuse_metadata {
106            flags |= 0x0004;
107        }
108        flags
109    }
110}
111
112/// RPC parameter status flags.
113#[derive(Debug, Clone, Copy, Default)]
114pub struct ParamFlags {
115    /// Parameter is passed by reference (OUTPUT parameter).
116    pub by_ref: bool,
117    /// Parameter has a default value.
118    pub default: bool,
119    /// Parameter is encrypted (Always Encrypted).
120    pub encrypted: bool,
121}
122
123impl ParamFlags {
124    /// Create new empty flags.
125    pub fn new() -> Self {
126        Self::default()
127    }
128
129    /// Set as output parameter.
130    #[must_use]
131    pub fn output(mut self) -> Self {
132        self.by_ref = true;
133        self
134    }
135
136    /// Encode to wire format (1 byte).
137    pub fn encode(&self) -> u8 {
138        let mut flags = 0u8;
139        if self.by_ref {
140            flags |= 0x01;
141        }
142        if self.default {
143            flags |= 0x02;
144        }
145        if self.encrypted {
146            flags |= 0x08;
147        }
148        flags
149    }
150}
151
152/// TDS type information for RPC parameters.
153#[derive(Debug, Clone)]
154pub struct TypeInfo {
155    /// Type ID.
156    pub type_id: u8,
157    /// Maximum length for variable-length types.
158    pub max_length: Option<u16>,
159    /// Precision for numeric types.
160    pub precision: Option<u8>,
161    /// Scale for numeric types.
162    pub scale: Option<u8>,
163    /// Collation for string types.
164    pub collation: Option<[u8; 5]>,
165    /// TVP type name (e.g., "dbo.IntIdList") for Table-Valued Parameters.
166    pub tvp_type_name: Option<String>,
167}
168
169impl TypeInfo {
170    /// Create type info for INT.
171    pub fn int() -> Self {
172        Self {
173            type_id: 0x26, // INTNTYPE (variable-length int)
174            max_length: Some(4),
175            precision: None,
176            scale: None,
177            collation: None,
178            tvp_type_name: None,
179        }
180    }
181
182    /// Create type info for BIGINT.
183    pub fn bigint() -> Self {
184        Self {
185            type_id: 0x26, // INTNTYPE
186            max_length: Some(8),
187            precision: None,
188            scale: None,
189            collation: None,
190            tvp_type_name: None,
191        }
192    }
193
194    /// Create type info for SMALLINT.
195    pub fn smallint() -> Self {
196        Self {
197            type_id: 0x26, // INTNTYPE
198            max_length: Some(2),
199            precision: None,
200            scale: None,
201            collation: None,
202            tvp_type_name: None,
203        }
204    }
205
206    /// Create type info for TINYINT.
207    pub fn tinyint() -> Self {
208        Self {
209            type_id: 0x26, // INTNTYPE
210            max_length: Some(1),
211            precision: None,
212            scale: None,
213            collation: None,
214            tvp_type_name: None,
215        }
216    }
217
218    /// Create type info for BIT.
219    pub fn bit() -> Self {
220        Self {
221            type_id: 0x68, // BITNTYPE
222            max_length: Some(1),
223            precision: None,
224            scale: None,
225            collation: None,
226            tvp_type_name: None,
227        }
228    }
229
230    /// Create type info for FLOAT.
231    pub fn float() -> Self {
232        Self {
233            type_id: 0x6D, // FLTNTYPE
234            max_length: Some(8),
235            precision: None,
236            scale: None,
237            collation: None,
238            tvp_type_name: None,
239        }
240    }
241
242    /// Create type info for REAL.
243    pub fn real() -> Self {
244        Self {
245            type_id: 0x6D, // FLTNTYPE
246            max_length: Some(4),
247            precision: None,
248            scale: None,
249            collation: None,
250            tvp_type_name: None,
251        }
252    }
253
254    /// Create type info for NVARCHAR with max length.
255    pub fn nvarchar(max_len: u16) -> Self {
256        Self {
257            type_id: 0xE7,                 // NVARCHARTYPE
258            max_length: Some(max_len * 2), // UTF-16, so double the char count
259            precision: None,
260            scale: None,
261            // Default collation (Latin1_General_CI_AS equivalent)
262            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
263            tvp_type_name: None,
264        }
265    }
266
267    /// Create type info for NVARCHAR(MAX).
268    pub fn nvarchar_max() -> Self {
269        Self {
270            type_id: 0xE7,            // NVARCHARTYPE
271            max_length: Some(0xFFFF), // MAX indicator
272            precision: None,
273            scale: None,
274            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
275            tvp_type_name: None,
276        }
277    }
278
279    /// Default collation bytes: Latin1_General_CI_AS (LCID 0x0409, sort ID 0x34).
280    const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
281
282    /// Create type info for VARCHAR with max length (in bytes).
283    pub fn varchar(max_len: u16) -> Self {
284        Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
285    }
286
287    /// Create type info for VARCHAR with max length and explicit collation.
288    pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
289        Self {
290            type_id: 0xA7, // BIGVARCHARTYPE
291            max_length: Some(max_len),
292            precision: None,
293            scale: None,
294            collation: Some(collation),
295            tvp_type_name: None,
296        }
297    }
298
299    /// Create type info for VARCHAR(MAX).
300    pub fn varchar_max() -> Self {
301        Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
302    }
303
304    /// Create type info for VARCHAR(MAX) with explicit collation.
305    pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
306        Self {
307            type_id: 0xA7,            // BIGVARCHARTYPE
308            max_length: Some(0xFFFF), // MAX indicator
309            precision: None,
310            scale: None,
311            collation: Some(collation),
312            tvp_type_name: None,
313        }
314    }
315
316    /// Create type info for VARBINARY with max length.
317    pub fn varbinary(max_len: u16) -> Self {
318        Self {
319            type_id: 0xA5, // BIGVARBINTYPE
320            max_length: Some(max_len),
321            precision: None,
322            scale: None,
323            collation: None,
324            tvp_type_name: None,
325        }
326    }
327
328    /// Create type info for VARBINARY(MAX).
329    pub fn varbinary_max() -> Self {
330        Self {
331            type_id: 0xA5,            // BIGVARBINTYPE
332            max_length: Some(0xFFFF), // MAX indicator β€” triggers PLP encoding
333            precision: None,
334            scale: None,
335            collation: None,
336            tvp_type_name: None,
337        }
338    }
339
340    /// Create type info for UNIQUEIDENTIFIER.
341    pub fn uniqueidentifier() -> Self {
342        Self {
343            type_id: 0x24, // GUIDTYPE
344            max_length: Some(16),
345            precision: None,
346            scale: None,
347            collation: None,
348            tvp_type_name: None,
349        }
350    }
351
352    /// Create type info for DATE.
353    pub fn date() -> Self {
354        Self {
355            type_id: 0x28, // DATETYPE
356            max_length: None,
357            precision: None,
358            scale: None,
359            collation: None,
360            tvp_type_name: None,
361        }
362    }
363
364    /// Create type info for TIME.
365    pub fn time(scale: u8) -> Self {
366        Self {
367            type_id: 0x29, // TIMETYPE
368            max_length: None,
369            precision: None,
370            scale: Some(scale),
371            collation: None,
372            tvp_type_name: None,
373        }
374    }
375
376    /// Create type info for DATETIME2.
377    pub fn datetime2(scale: u8) -> Self {
378        Self {
379            type_id: 0x2A, // DATETIME2TYPE
380            max_length: None,
381            precision: None,
382            scale: Some(scale),
383            collation: None,
384            tvp_type_name: None,
385        }
386    }
387
388    /// Create type info for DATETIMEOFFSET.
389    pub fn datetimeoffset(scale: u8) -> Self {
390        Self {
391            type_id: 0x2B, // DATETIMEOFFSETTYPE
392            max_length: None,
393            precision: None,
394            scale: Some(scale),
395            collation: None,
396            tvp_type_name: None,
397        }
398    }
399
400    /// Create type info for DECIMAL.
401    pub fn decimal(precision: u8, scale: u8) -> Self {
402        Self {
403            type_id: 0x6C,        // DECIMALNTYPE
404            max_length: Some(17), // Max decimal size
405            precision: Some(precision),
406            scale: Some(scale),
407            collation: None,
408            tvp_type_name: None,
409        }
410    }
411
412    /// Create type info for MONEY (8-byte scaled integer via MONEYN / 0x6E).
413    pub fn money() -> Self {
414        Self {
415            type_id: 0x6E, // MONEYNTYPE
416            max_length: Some(8),
417            precision: None,
418            scale: None,
419            collation: None,
420            tvp_type_name: None,
421        }
422    }
423
424    /// Create type info for SMALLMONEY (4-byte scaled integer via MONEYN / 0x6E).
425    pub fn smallmoney() -> Self {
426        Self {
427            type_id: 0x6E, // MONEYNTYPE
428            max_length: Some(4),
429            precision: None,
430            scale: None,
431            collation: None,
432            tvp_type_name: None,
433        }
434    }
435
436    /// Create type info for SMALLDATETIME (4-byte days+minutes via DATETIMEN / 0x6F).
437    pub fn smalldatetime() -> Self {
438        Self {
439            type_id: 0x6F, // DATETIMENTYPE
440            max_length: Some(4),
441            precision: None,
442            scale: None,
443            collation: None,
444            tvp_type_name: None,
445        }
446    }
447
448    /// Create type info for a Table-Valued Parameter.
449    ///
450    /// # Arguments
451    /// * `type_name` - The fully qualified table type name (e.g., "dbo.IntIdList")
452    pub fn tvp(type_name: impl Into<String>) -> Self {
453        Self {
454            type_id: 0xF3, // TVP type
455            max_length: None,
456            precision: None,
457            scale: None,
458            collation: None,
459            tvp_type_name: Some(type_name.into()),
460        }
461    }
462
463    /// Encode type info to buffer.
464    pub fn encode(&self, buf: &mut BytesMut) {
465        // TVP (0xF3) has type_id embedded in the value data itself
466        // (written by TvpEncoder::encode_metadata), so don't write it here
467        if self.type_id != 0xF3 {
468            buf.put_u8(self.type_id);
469        }
470
471        // Variable-length types need max length
472        match self.type_id {
473            0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
474                // INTNTYPE, BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
475                if let Some(len) = self.max_length {
476                    buf.put_u8(len as u8);
477                }
478            }
479            0xE7 | 0xA7 | 0xA5 | 0xEF => {
480                // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
481                if let Some(len) = self.max_length {
482                    buf.put_u16_le(len);
483                }
484                // Collation for string types
485                if let Some(collation) = self.collation {
486                    buf.put_slice(&collation);
487                }
488            }
489            0x24 => {
490                // GUIDTYPE
491                if let Some(len) = self.max_length {
492                    buf.put_u8(len as u8);
493                }
494            }
495            0x29..=0x2B => {
496                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
497                if let Some(scale) = self.scale {
498                    buf.put_u8(scale);
499                }
500            }
501            0x6C | 0x6A => {
502                // DECIMALNTYPE, NUMERICNTYPE
503                if let Some(len) = self.max_length {
504                    buf.put_u8(len as u8);
505                }
506                if let Some(precision) = self.precision {
507                    buf.put_u8(precision);
508                }
509                if let Some(scale) = self.scale {
510                    buf.put_u8(scale);
511                }
512            }
513            _ => {}
514        }
515    }
516}
517
518/// An RPC parameter.
519#[derive(Debug, Clone)]
520pub struct RpcParam {
521    /// Parameter name (can be empty for positional params).
522    pub name: String,
523    /// Status flags.
524    pub flags: ParamFlags,
525    /// Type information.
526    pub type_info: TypeInfo,
527    /// Parameter value (raw bytes).
528    pub value: Option<Bytes>,
529}
530
531impl RpcParam {
532    /// Create a new parameter with a value.
533    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
534        Self {
535            name: name.into(),
536            flags: ParamFlags::default(),
537            type_info,
538            value: Some(value),
539        }
540    }
541
542    /// Create a NULL parameter.
543    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
544        Self {
545            name: name.into(),
546            flags: ParamFlags::default(),
547            type_info,
548            value: None,
549        }
550    }
551
552    /// Create an INT parameter.
553    pub fn int(name: impl Into<String>, value: i32) -> Self {
554        let mut buf = BytesMut::with_capacity(4);
555        buf.put_i32_le(value);
556        Self::new(name, TypeInfo::int(), buf.freeze())
557    }
558
559    /// Create a BIGINT parameter.
560    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
561        let mut buf = BytesMut::with_capacity(8);
562        buf.put_i64_le(value);
563        Self::new(name, TypeInfo::bigint(), buf.freeze())
564    }
565
566    /// Create an NVARCHAR parameter.
567    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
568        let mut buf = BytesMut::new();
569        let mut code_units: usize = 0;
570        for code_unit in value.encode_utf16() {
571            buf.put_u16_le(code_unit);
572            code_units += 1;
573        }
574        // NVARCHAR length is measured in UTF-16 code units, not Rust chars β€”
575        // supplementary characters (emoji, CJK extension B) encode to a surrogate
576        // pair (2 code units) but count as 1 `char`. Using chars().count() here
577        // under-reports the buffer length and the server rejects the RPC with
578        // "Data type 0xE7 has an invalid data length or metadata length."
579        let type_info = if code_units > 4000 {
580            TypeInfo::nvarchar_max()
581        } else {
582            TypeInfo::nvarchar(code_units.max(1) as u16)
583        };
584        Self::new(name, type_info, buf.freeze())
585    }
586
587    /// Create a VARCHAR parameter.
588    ///
589    /// Encodes the string as single-byte characters using Windows-1252 encoding
590    /// (when the `encoding` feature is enabled) or Latin-1 fallback. Characters
591    /// not representable in the target encoding are replaced with `?`.
592    ///
593    /// Use this instead of [`nvarchar`](Self::nvarchar) when
594    /// `SendStringParametersAsUnicode=false` to allow SQL Server to use
595    /// index seeks on VARCHAR columns.
596    pub fn varchar(name: impl Into<String>, value: &str) -> Self {
597        let encoded = Self::encode_varchar_bytes(value);
598        let byte_len = encoded.len();
599        let type_info = if byte_len > 8000 {
600            TypeInfo::varchar_max()
601        } else {
602            TypeInfo::varchar(byte_len.max(1) as u16)
603        };
604        Self::new(name, type_info, Bytes::from(encoded))
605    }
606
607    /// Encode a string as single-byte VARCHAR data using the default
608    /// Windows-1252 encoding (or Latin-1 fallback without the `encoding` feature).
609    fn encode_varchar_bytes(value: &str) -> Vec<u8> {
610        #[cfg(feature = "encoding")]
611        {
612            let (encoded, _, _) = encoding_rs::WINDOWS_1252.encode(value);
613            encoded.into_owned()
614        }
615        #[cfg(not(feature = "encoding"))]
616        {
617            // Latin-1 fallback: chars ≀ 0xFF pass through, others become '?'
618            value
619                .chars()
620                .map(|ch| if (ch as u32) <= 0xFF { ch as u8 } else { b'?' })
621                .collect()
622        }
623    }
624
625    /// Create a VARCHAR parameter using the server's collation for encoding.
626    ///
627    /// Uses the collation's character encoding instead of the default Windows-1252.
628    /// For UTF-8 collations (SQL Server 2019+), the string bytes are used directly.
629    pub fn varchar_with_collation(
630        name: impl Into<String>,
631        value: &str,
632        collation: &Collation,
633    ) -> Self {
634        let collation_bytes = collation.to_bytes();
635        let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
636        let byte_len = encoded.len();
637        let type_info = if byte_len > 8000 {
638            TypeInfo::varchar_max_with_collation(collation_bytes)
639        } else {
640            TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
641        };
642        Self::new(name, type_info, Bytes::from(encoded))
643    }
644
645    /// Encode a string using the collation's character encoding.
646    fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
647        crate::collation::encode_str_for_collation(value, Some(collation))
648    }
649
650    /// Mark as output parameter.
651    #[must_use]
652    pub fn as_output(mut self) -> Self {
653        self.flags = self.flags.output();
654        self
655    }
656
657    /// Encode the parameter to buffer.
658    pub fn encode(&self, buf: &mut BytesMut) {
659        // Parameter name (B_VARCHAR - length-prefixed)
660        let name_len = self.name.encode_utf16().count() as u8;
661        buf.put_u8(name_len);
662        if name_len > 0 {
663            for code_unit in self.name.encode_utf16() {
664                buf.put_u16_le(code_unit);
665            }
666        }
667
668        // Status flags
669        buf.put_u8(self.flags.encode());
670
671        // Type info
672        self.type_info.encode(buf);
673
674        // Value
675        if let Some(ref value) = self.value {
676            // Length prefix based on type
677            match self.type_info.type_id {
678                0x26 => {
679                    // INTNTYPE
680                    buf.put_u8(value.len() as u8);
681                    buf.put_slice(value);
682                }
683                0x68 | 0x6D | 0x6E | 0x6F => {
684                    // BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
685                    buf.put_u8(value.len() as u8);
686                    buf.put_slice(value);
687                }
688                0xE7 | 0xA7 | 0xA5 => {
689                    // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE
690                    if self.type_info.max_length == Some(0xFFFF) {
691                        // MAX type - use PLP format
692                        // For simplicity, send as single chunk
693                        let total_len = value.len() as u64;
694                        buf.put_u64_le(total_len);
695                        buf.put_u32_le(value.len() as u32);
696                        buf.put_slice(value);
697                        buf.put_u32_le(0); // Terminator
698                    } else {
699                        buf.put_u16_le(value.len() as u16);
700                        buf.put_slice(value);
701                    }
702                }
703                0x24 => {
704                    // GUIDTYPE
705                    buf.put_u8(value.len() as u8);
706                    buf.put_slice(value);
707                }
708                0x28..=0x2B => {
709                    // DATE, TIME, DATETIME2, DATETIMEOFFSET
710                    buf.put_u8(value.len() as u8);
711                    buf.put_slice(value);
712                }
713                0x6C => {
714                    // DECIMALNTYPE
715                    buf.put_u8(value.len() as u8);
716                    buf.put_slice(value);
717                }
718                0xF3 => {
719                    // TVP (Table-Valued Parameter)
720                    // TVP values are self-delimiting: they contain complete metadata,
721                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
722                    buf.put_slice(value);
723                }
724                _ => {
725                    // Generic: assume length-prefixed
726                    buf.put_u8(value.len() as u8);
727                    buf.put_slice(value);
728                }
729            }
730        } else {
731            // NULL value
732            match self.type_info.type_id {
733                0xE7 | 0xA7 | 0xA5 => {
734                    // Variable-length types use 0xFFFF for NULL
735                    if self.type_info.max_length == Some(0xFFFF) {
736                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
737                    } else {
738                        buf.put_u16_le(0xFFFF);
739                    }
740                }
741                _ => {
742                    buf.put_u8(0); // Zero-length for NULL
743                }
744            }
745        }
746    }
747}
748
749/// RPC request builder.
750#[derive(Debug, Clone)]
751pub struct RpcRequest {
752    /// Procedure name (if using named procedure).
753    proc_name: Option<String>,
754    /// Procedure ID (if using well-known procedure).
755    proc_id: Option<ProcId>,
756    /// Option flags.
757    options: RpcOptionFlags,
758    /// Parameters.
759    params: Vec<RpcParam>,
760}
761
762impl RpcRequest {
763    /// Create a new RPC request for a named procedure.
764    pub fn named(proc_name: impl Into<String>) -> Self {
765        Self {
766            proc_name: Some(proc_name.into()),
767            proc_id: None,
768            options: RpcOptionFlags::default(),
769            params: Vec::new(),
770        }
771    }
772
773    /// Create a new RPC request for a well-known procedure.
774    pub fn by_id(proc_id: ProcId) -> Self {
775        Self {
776            proc_name: None,
777            proc_id: Some(proc_id),
778            options: RpcOptionFlags::default(),
779            params: Vec::new(),
780        }
781    }
782
783    /// Create an sp_executesql request.
784    ///
785    /// This is the primary method for parameterized queries.
786    ///
787    /// # Example
788    ///
789    /// ```
790    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
791    ///
792    /// let rpc = RpcRequest::execute_sql(
793    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
794    ///     vec![
795    ///         RpcParam::int("@p1", 42),
796    ///         RpcParam::nvarchar("@p2", "Alice"),
797    ///     ],
798    /// );
799    /// ```
800    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
801        let mut request = Self::by_id(ProcId::ExecuteSql);
802
803        // First parameter: the SQL statement (NVARCHAR(MAX))
804        request.params.push(RpcParam::nvarchar("", sql));
805
806        // Second parameter: parameter declarations
807        if !params.is_empty() {
808            let declarations = Self::build_param_declarations(&params);
809            request.params.push(RpcParam::nvarchar("", &declarations));
810        }
811
812        // Add the actual parameters
813        request.params.extend(params);
814
815        request
816    }
817
818    /// Build parameter declaration string for sp_executesql.
819    fn build_param_declarations(params: &[RpcParam]) -> String {
820        params
821            .iter()
822            .map(|p| {
823                let name = if p.name.starts_with('@') {
824                    p.name.clone()
825                } else if p.name.is_empty() {
826                    // Generate positional name
827                    format!(
828                        "@p{}",
829                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
830                    )
831                } else {
832                    format!("@{}", p.name)
833                };
834
835                let type_name: String = match p.type_info.type_id {
836                    0x26 => match p.type_info.max_length {
837                        Some(1) => "tinyint".to_string(),
838                        Some(2) => "smallint".to_string(),
839                        Some(4) => "int".to_string(),
840                        Some(8) => "bigint".to_string(),
841                        _ => "int".to_string(),
842                    },
843                    0x68 => "bit".to_string(),
844                    0x6D => match p.type_info.max_length {
845                        Some(4) => "real".to_string(),
846                        _ => "float".to_string(),
847                    },
848                    0xE7 => {
849                        if p.type_info.max_length == Some(0xFFFF) {
850                            "nvarchar(max)".to_string()
851                        } else {
852                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
853                            format!("nvarchar({len})")
854                        }
855                    }
856                    0xA7 => {
857                        if p.type_info.max_length == Some(0xFFFF) {
858                            "varchar(max)".to_string()
859                        } else {
860                            let len = p.type_info.max_length.unwrap_or(8000);
861                            format!("varchar({len})")
862                        }
863                    }
864                    0xA5 => {
865                        if p.type_info.max_length == Some(0xFFFF) {
866                            "varbinary(max)".to_string()
867                        } else {
868                            let len = p.type_info.max_length.unwrap_or(8000);
869                            format!("varbinary({len})")
870                        }
871                    }
872                    0x24 => "uniqueidentifier".to_string(),
873                    0x28 => "date".to_string(),
874                    0x29 => {
875                        let scale = p.type_info.scale.unwrap_or(7);
876                        format!("time({scale})")
877                    }
878                    0x2A => {
879                        let scale = p.type_info.scale.unwrap_or(7);
880                        format!("datetime2({scale})")
881                    }
882                    0x2B => {
883                        let scale = p.type_info.scale.unwrap_or(7);
884                        format!("datetimeoffset({scale})")
885                    }
886                    0x6C => {
887                        let precision = p.type_info.precision.unwrap_or(18);
888                        let scale = p.type_info.scale.unwrap_or(0);
889                        format!("decimal({precision}, {scale})")
890                    }
891                    0x6E => match p.type_info.max_length {
892                        Some(4) => "smallmoney".to_string(),
893                        _ => "money".to_string(),
894                    },
895                    0x6F => match p.type_info.max_length {
896                        Some(4) => "smalldatetime".to_string(),
897                        _ => "datetime".to_string(),
898                    },
899                    0xF3 => {
900                        // TVP - Table-Valued Parameter
901                        // Must be declared with the table type name and READONLY
902                        if let Some(ref tvp_name) = p.type_info.tvp_type_name {
903                            format!("{tvp_name} READONLY")
904                        } else {
905                            // Fallback if type name is missing (shouldn't happen)
906                            "sql_variant".to_string()
907                        }
908                    }
909                    _ => "sql_variant".to_string(),
910                };
911
912                format!("{name} {type_name}")
913            })
914            .collect::<Vec<_>>()
915            .join(", ")
916    }
917
918    /// Create an sp_prepare request.
919    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
920        let mut request = Self::by_id(ProcId::Prepare);
921
922        // OUT: handle (INT)
923        request
924            .params
925            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
926
927        // Param declarations
928        let declarations = Self::build_param_declarations(params);
929        request
930            .params
931            .push(RpcParam::nvarchar("@params", &declarations));
932
933        // SQL statement
934        request.params.push(RpcParam::nvarchar("@stmt", sql));
935
936        // Options (1 = WITH RECOMPILE)
937        request.params.push(RpcParam::int("@options", 1));
938
939        request
940    }
941
942    /// Create an sp_execute request.
943    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
944        let mut request = Self::by_id(ProcId::Execute);
945
946        // Handle from sp_prepare
947        request.params.push(RpcParam::int("@handle", handle));
948
949        // Add parameters
950        request.params.extend(params);
951
952        request
953    }
954
955    /// Create an sp_unprepare request.
956    pub fn unprepare(handle: i32) -> Self {
957        let mut request = Self::by_id(ProcId::Unprepare);
958        request.params.push(RpcParam::int("@handle", handle));
959        request
960    }
961
962    /// Set option flags.
963    #[must_use]
964    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
965        self.options = options;
966        self
967    }
968
969    /// Add a parameter.
970    #[must_use]
971    pub fn param(mut self, param: RpcParam) -> Self {
972        self.params.push(param);
973        self
974    }
975
976    /// Encode the RPC request to bytes (auto-commit mode).
977    ///
978    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
979    #[must_use]
980    pub fn encode(&self) -> Bytes {
981        self.encode_with_transaction(0)
982    }
983
984    /// Encode the RPC request with a transaction descriptor.
985    ///
986    /// Per MS-TDS spec, when executing within an explicit transaction:
987    /// - The `transaction_descriptor` MUST be the value returned by the server
988    ///   in the BeginTransaction EnvChange token.
989    /// - For auto-commit mode (no explicit transaction), use 0.
990    ///
991    /// # Arguments
992    ///
993    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
994    ///   or 0 for auto-commit mode.
995    #[must_use]
996    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
997        let mut buf = BytesMut::with_capacity(256);
998
999        // ALL_HEADERS - TDS 7.2+ requires this section
1000        // Total length placeholder (will be filled in)
1001        let all_headers_start = buf.len();
1002        buf.put_u32_le(0); // Total length placeholder
1003
1004        // Transaction descriptor header (required for RPC)
1005        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
1006        buf.put_u32_le(18); // Header length
1007        buf.put_u16_le(0x0002); // Header type: transaction descriptor
1008        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
1009        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
1010
1011        // Fill in ALL_HEADERS total length
1012        let all_headers_len = buf.len() - all_headers_start;
1013        let len_bytes = (all_headers_len as u32).to_le_bytes();
1014        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1015
1016        // Procedure name or ID
1017        if let Some(proc_id) = self.proc_id {
1018            // Use PROCID format
1019            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
1020            buf.put_u16_le(proc_id as u16);
1021        } else if let Some(ref proc_name) = self.proc_name {
1022            // Use procedure name
1023            let name_len = proc_name.encode_utf16().count() as u16;
1024            buf.put_u16_le(name_len);
1025            write_utf16_string(&mut buf, proc_name);
1026        }
1027
1028        // Option flags
1029        buf.put_u16_le(self.options.encode());
1030
1031        // Parameters
1032        for param in &self.params {
1033            param.encode(&mut buf);
1034        }
1035
1036        buf.freeze()
1037    }
1038}
1039
1040#[cfg(test)]
1041#[allow(clippy::unwrap_used)]
1042mod tests {
1043    use super::*;
1044
1045    #[test]
1046    fn test_proc_id_values() {
1047        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1048        assert_eq!(ProcId::Prepare as u16, 0x000B);
1049        assert_eq!(ProcId::Execute as u16, 0x000C);
1050        assert_eq!(ProcId::Unprepare as u16, 0x000F);
1051    }
1052
1053    #[test]
1054    fn test_option_flags_encode() {
1055        let flags = RpcOptionFlags::new().with_recompile(true);
1056        assert_eq!(flags.encode(), 0x0001);
1057    }
1058
1059    #[test]
1060    fn test_param_flags_encode() {
1061        let flags = ParamFlags::new().output();
1062        assert_eq!(flags.encode(), 0x01);
1063    }
1064
1065    #[test]
1066    fn test_int_param() {
1067        let param = RpcParam::int("@p1", 42);
1068        assert_eq!(param.name, "@p1");
1069        assert_eq!(param.type_info.type_id, 0x26);
1070        assert!(param.value.is_some());
1071    }
1072
1073    #[test]
1074    fn test_nvarchar_param() {
1075        let param = RpcParam::nvarchar("@name", "Alice");
1076        assert_eq!(param.name, "@name");
1077        assert_eq!(param.type_info.type_id, 0xE7);
1078        // UTF-16 encoded "Alice" = 10 bytes
1079        assert_eq!(param.value.as_ref().unwrap().len(), 10);
1080    }
1081
1082    #[test]
1083    fn test_nvarchar_param_surrogate_pair_length() {
1084        // 🌍 is a supplementary character β€” 1 Rust char but 2 UTF-16 code units
1085        // (4 bytes). TypeInfo.max_length is stored doubled internally, so
1086        // the metadata must declare 2 code units for the buffer to match.
1087        let param = RpcParam::nvarchar("@p", "🌍");
1088        assert_eq!(param.value.as_ref().unwrap().len(), 4);
1089        // TypeInfo::nvarchar(n) stores max_length as n*2 bytes.
1090        assert_eq!(param.type_info.max_length, Some(4));
1091
1092        let param = RpcParam::nvarchar("@p", "Hello δΈ–η•Œ 🌍");
1093        // "Hello δΈ–η•Œ " = 9 BMP code units + 🌍 = 2 surrogate units β†’ 11 code units, 22 bytes
1094        assert_eq!(param.value.as_ref().unwrap().len(), 22);
1095        assert_eq!(param.type_info.max_length, Some(22));
1096    }
1097
1098    #[test]
1099    fn test_execute_sql_request() {
1100        let rpc = RpcRequest::execute_sql(
1101            "SELECT * FROM users WHERE id = @p1",
1102            vec![RpcParam::int("@p1", 42)],
1103        );
1104
1105        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1106        // SQL statement + param declarations + actual params
1107        assert_eq!(rpc.params.len(), 3);
1108    }
1109
1110    #[test]
1111    fn test_param_declarations() {
1112        let params = vec![
1113            RpcParam::int("@p1", 42),
1114            RpcParam::nvarchar("@name", "Alice"),
1115        ];
1116
1117        let decls = RpcRequest::build_param_declarations(&params);
1118        assert!(decls.contains("@p1 int"));
1119        assert!(decls.contains("@name nvarchar"));
1120    }
1121
1122    #[test]
1123    fn test_rpc_encode_not_empty() {
1124        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1125        let encoded = rpc.encode();
1126        assert!(!encoded.is_empty());
1127    }
1128
1129    #[test]
1130    fn test_prepare_request() {
1131        let rpc = RpcRequest::prepare(
1132            "SELECT * FROM users WHERE id = @p1",
1133            &[RpcParam::int("@p1", 0)],
1134        );
1135
1136        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1137        // handle (output), params, stmt, options
1138        assert_eq!(rpc.params.len(), 4);
1139        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
1140    }
1141
1142    #[test]
1143    fn test_execute_request() {
1144        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1145
1146        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1147        assert_eq!(rpc.params.len(), 2); // handle + param
1148    }
1149
1150    #[test]
1151    fn test_unprepare_request() {
1152        let rpc = RpcRequest::unprepare(123);
1153
1154        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1155        assert_eq!(rpc.params.len(), 1); // just the handle
1156    }
1157
1158    #[test]
1159    fn test_varchar_param() {
1160        let param = RpcParam::varchar("@name", "Alice");
1161        assert_eq!(param.name, "@name");
1162        assert_eq!(param.type_info.type_id, 0xA7);
1163        // Single-byte encoded "Alice" = 5 bytes
1164        assert_eq!(param.value.as_ref().unwrap().len(), 5);
1165        assert_eq!(&param.value.as_ref().unwrap()[..], b"Alice");
1166    }
1167
1168    #[test]
1169    fn test_varchar_param_max() {
1170        // String > 8000 bytes should use VARCHAR(MAX)
1171        let long_str = "a".repeat(9000);
1172        let param = RpcParam::varchar("@big", &long_str);
1173        assert_eq!(param.type_info.type_id, 0xA7);
1174        assert_eq!(param.type_info.max_length, Some(0xFFFF));
1175        assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1176    }
1177
1178    #[test]
1179    fn test_varchar_param_declarations() {
1180        let params = vec![
1181            RpcParam::int("@p1", 42),
1182            RpcParam::varchar("@name", "Alice"),
1183        ];
1184
1185        let decls = RpcRequest::build_param_declarations(&params);
1186        assert!(decls.contains("@p1 int"));
1187        assert!(decls.contains("@name varchar(5)"));
1188    }
1189
1190    #[test]
1191    fn test_varchar_type_info_has_collation() {
1192        let ti = TypeInfo::varchar(100);
1193        assert_eq!(ti.type_id, 0xA7);
1194        assert_eq!(ti.max_length, Some(100));
1195        assert!(ti.collation.is_some());
1196    }
1197
1198    #[test]
1199    fn test_varchar_encode_round_trip() {
1200        // Verify the encoded param can be serialized without panics
1201        let param = RpcParam::varchar("@val", "test value");
1202        let mut buf = bytes::BytesMut::new();
1203        param.encode(&mut buf);
1204        assert!(!buf.is_empty());
1205    }
1206
1207    #[test]
1208    fn test_collation_round_trip() {
1209        let collation = Collation {
1210            lcid: 0x00D0_0409,
1211            sort_id: 0x34,
1212        };
1213        let bytes = collation.to_bytes();
1214        assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1215
1216        let restored = Collation::from_bytes(&bytes);
1217        assert_eq!(restored.lcid, collation.lcid);
1218        assert_eq!(restored.sort_id, collation.sort_id);
1219    }
1220
1221    #[test]
1222    fn test_varchar_with_collation_uses_custom_collation_bytes() {
1223        // Chinese_PRC_CI_AS collation (LCID 0x0804)
1224        let collation = Collation {
1225            lcid: 0x0804,
1226            sort_id: 0,
1227        };
1228        let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1229        assert_eq!(param.type_info.type_id, 0xA7);
1230        // Collation bytes should match the custom collation, not default Latin1
1231        assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1232    }
1233
1234    #[test]
1235    fn test_money_type_info() {
1236        let ti = TypeInfo::money();
1237        assert_eq!(ti.type_id, 0x6E);
1238        assert_eq!(ti.max_length, Some(8));
1239    }
1240
1241    #[test]
1242    fn test_smallmoney_type_info() {
1243        let ti = TypeInfo::smallmoney();
1244        assert_eq!(ti.type_id, 0x6E);
1245        assert_eq!(ti.max_length, Some(4));
1246    }
1247
1248    #[test]
1249    fn test_smalldatetime_type_info() {
1250        let ti = TypeInfo::smalldatetime();
1251        assert_eq!(ti.type_id, 0x6F);
1252        assert_eq!(ti.max_length, Some(4));
1253    }
1254
1255    #[test]
1256    fn test_money_param_declarations() {
1257        let decls = RpcRequest::build_param_declarations(&[
1258            RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1259            RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1260            RpcParam::new(
1261                "@sdt",
1262                TypeInfo::smalldatetime(),
1263                Bytes::from_static(&[0u8; 4]),
1264            ),
1265        ]);
1266        assert!(decls.contains("@m money"), "got: {decls}");
1267        assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1268        assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1269    }
1270
1271    #[test]
1272    fn test_money_typeinfo_encodes_max_length_byte() {
1273        let mut buf = bytes::BytesMut::new();
1274        TypeInfo::money().encode(&mut buf);
1275        // type_id 0x6E + max_length byte 0x08
1276        assert_eq!(&buf[..], &[0x6E, 0x08]);
1277
1278        let mut buf = bytes::BytesMut::new();
1279        TypeInfo::smallmoney().encode(&mut buf);
1280        assert_eq!(&buf[..], &[0x6E, 0x04]);
1281
1282        let mut buf = bytes::BytesMut::new();
1283        TypeInfo::smalldatetime().encode(&mut buf);
1284        assert_eq!(&buf[..], &[0x6F, 0x04]);
1285    }
1286
1287    #[test]
1288    fn test_varchar_with_collation_default_vs_custom_differ() {
1289        let default_param = RpcParam::varchar("@val", "test");
1290        let custom_collation = Collation {
1291            lcid: 0x0419, // Russian
1292            sort_id: 0,
1293        };
1294        let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1295        // The collation bytes should differ
1296        assert_ne!(
1297            default_param.type_info.collation,
1298            custom_param.type_info.collation
1299        );
1300    }
1301}