Skip to main content

tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31/// Well-known stored procedure IDs.
32///
33/// These are special procedure IDs that SQL Server recognizes
34/// without requiring the procedure name.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37#[non_exhaustive]
38pub enum ProcId {
39    /// sp_cursor (0x0001)
40    Cursor = 0x0001,
41    /// sp_cursoropen (0x0002)
42    CursorOpen = 0x0002,
43    /// sp_cursorprepare (0x0003)
44    CursorPrepare = 0x0003,
45    /// sp_cursorexecute (0x0004)
46    CursorExecute = 0x0004,
47    /// sp_cursorprepexec (0x0005)
48    CursorPrepExec = 0x0005,
49    /// sp_cursorunprepare (0x0006)
50    CursorUnprepare = 0x0006,
51    /// sp_cursorfetch (0x0007)
52    CursorFetch = 0x0007,
53    /// sp_cursoroption (0x0008)
54    CursorOption = 0x0008,
55    /// sp_cursorclose (0x0009)
56    CursorClose = 0x0009,
57    /// sp_executesql (0x000A) - Primary method for parameterized queries
58    ExecuteSql = 0x000A,
59    /// sp_prepare (0x000B)
60    Prepare = 0x000B,
61    /// sp_execute (0x000C)
62    Execute = 0x000C,
63    /// sp_prepexec (0x000D) - Prepare and execute in one call
64    PrepExec = 0x000D,
65    /// sp_prepexecrpc (0x000E)
66    PrepExecRpc = 0x000E,
67    /// sp_unprepare (0x000F)
68    Unprepare = 0x000F,
69}
70
71/// RPC option flags.
72#[derive(Debug, Clone, Copy, Default)]
73pub struct RpcOptionFlags {
74    /// Recompile the procedure.
75    pub with_recompile: bool,
76    /// No metadata in response.
77    pub no_metadata: bool,
78    /// Reuse metadata from previous call.
79    pub reuse_metadata: bool,
80}
81
82impl RpcOptionFlags {
83    /// Create new empty flags.
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Set with recompile flag.
89    #[must_use]
90    pub fn with_recompile(mut self, value: bool) -> Self {
91        self.with_recompile = value;
92        self
93    }
94
95    /// Encode to wire format (2 bytes).
96    pub fn encode(&self) -> u16 {
97        let mut flags = 0u16;
98        if self.with_recompile {
99            flags |= 0x0001;
100        }
101        if self.no_metadata {
102            flags |= 0x0002;
103        }
104        if self.reuse_metadata {
105            flags |= 0x0004;
106        }
107        flags
108    }
109}
110
111/// RPC parameter status flags.
112#[derive(Debug, Clone, Copy, Default)]
113pub struct ParamFlags {
114    /// Parameter is passed by reference (OUTPUT parameter).
115    pub by_ref: bool,
116    /// Parameter has a default value.
117    pub default: bool,
118    /// Parameter is encrypted (Always Encrypted).
119    pub encrypted: bool,
120}
121
122impl ParamFlags {
123    /// Create new empty flags.
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    /// Set as output parameter.
129    #[must_use]
130    pub fn output(mut self) -> Self {
131        self.by_ref = true;
132        self
133    }
134
135    /// Encode to wire format (1 byte).
136    pub fn encode(&self) -> u8 {
137        let mut flags = 0u8;
138        if self.by_ref {
139            flags |= 0x01;
140        }
141        if self.default {
142            flags |= 0x02;
143        }
144        if self.encrypted {
145            flags |= 0x08;
146        }
147        flags
148    }
149}
150
151/// TDS type information for RPC parameters.
152#[derive(Debug, Clone)]
153pub struct TypeInfo {
154    /// Type ID.
155    pub type_id: u8,
156    /// Maximum length for variable-length types.
157    pub max_length: Option<u16>,
158    /// Precision for numeric types.
159    pub precision: Option<u8>,
160    /// Scale for numeric types.
161    pub scale: Option<u8>,
162    /// Collation for string types.
163    pub collation: Option<[u8; 5]>,
164    /// TVP type name (e.g., "dbo.IntIdList") for Table-Valued Parameters.
165    pub tvp_type_name: Option<String>,
166}
167
168impl TypeInfo {
169    /// Create type info for INT.
170    pub fn int() -> Self {
171        Self {
172            type_id: 0x26, // INTNTYPE (variable-length int)
173            max_length: Some(4),
174            precision: None,
175            scale: None,
176            collation: None,
177            tvp_type_name: None,
178        }
179    }
180
181    /// Create type info for BIGINT.
182    pub fn bigint() -> Self {
183        Self {
184            type_id: 0x26, // INTNTYPE
185            max_length: Some(8),
186            precision: None,
187            scale: None,
188            collation: None,
189            tvp_type_name: None,
190        }
191    }
192
193    /// Create type info for SMALLINT.
194    pub fn smallint() -> Self {
195        Self {
196            type_id: 0x26, // INTNTYPE
197            max_length: Some(2),
198            precision: None,
199            scale: None,
200            collation: None,
201            tvp_type_name: None,
202        }
203    }
204
205    /// Create type info for TINYINT.
206    pub fn tinyint() -> Self {
207        Self {
208            type_id: 0x26, // INTNTYPE
209            max_length: Some(1),
210            precision: None,
211            scale: None,
212            collation: None,
213            tvp_type_name: None,
214        }
215    }
216
217    /// Create type info for BIT.
218    pub fn bit() -> Self {
219        Self {
220            type_id: 0x68, // BITNTYPE
221            max_length: Some(1),
222            precision: None,
223            scale: None,
224            collation: None,
225            tvp_type_name: None,
226        }
227    }
228
229    /// Create type info for FLOAT.
230    pub fn float() -> Self {
231        Self {
232            type_id: 0x6D, // FLTNTYPE
233            max_length: Some(8),
234            precision: None,
235            scale: None,
236            collation: None,
237            tvp_type_name: None,
238        }
239    }
240
241    /// Create type info for REAL.
242    pub fn real() -> Self {
243        Self {
244            type_id: 0x6D, // FLTNTYPE
245            max_length: Some(4),
246            precision: None,
247            scale: None,
248            collation: None,
249            tvp_type_name: None,
250        }
251    }
252
253    /// Create type info for NVARCHAR with max length.
254    pub fn nvarchar(max_len: u16) -> Self {
255        Self {
256            type_id: 0xE7,                 // NVARCHARTYPE
257            max_length: Some(max_len * 2), // UTF-16, so double the char count
258            precision: None,
259            scale: None,
260            // Default collation (Latin1_General_CI_AS equivalent)
261            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262            tvp_type_name: None,
263        }
264    }
265
266    /// Create type info for NVARCHAR(MAX).
267    pub fn nvarchar_max() -> Self {
268        Self {
269            type_id: 0xE7,            // NVARCHARTYPE
270            max_length: Some(0xFFFF), // MAX indicator
271            precision: None,
272            scale: None,
273            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
274            tvp_type_name: None,
275        }
276    }
277
278    /// Create type info for VARBINARY with max length.
279    pub fn varbinary(max_len: u16) -> Self {
280        Self {
281            type_id: 0xA5, // BIGVARBINTYPE
282            max_length: Some(max_len),
283            precision: None,
284            scale: None,
285            collation: None,
286            tvp_type_name: None,
287        }
288    }
289
290    /// Create type info for UNIQUEIDENTIFIER.
291    pub fn uniqueidentifier() -> Self {
292        Self {
293            type_id: 0x24, // GUIDTYPE
294            max_length: Some(16),
295            precision: None,
296            scale: None,
297            collation: None,
298            tvp_type_name: None,
299        }
300    }
301
302    /// Create type info for DATE.
303    pub fn date() -> Self {
304        Self {
305            type_id: 0x28, // DATETYPE
306            max_length: None,
307            precision: None,
308            scale: None,
309            collation: None,
310            tvp_type_name: None,
311        }
312    }
313
314    /// Create type info for DATETIME2.
315    pub fn datetime2(scale: u8) -> Self {
316        Self {
317            type_id: 0x2A, // DATETIME2TYPE
318            max_length: None,
319            precision: None,
320            scale: Some(scale),
321            collation: None,
322            tvp_type_name: None,
323        }
324    }
325
326    /// Create type info for DECIMAL.
327    pub fn decimal(precision: u8, scale: u8) -> Self {
328        Self {
329            type_id: 0x6C,        // DECIMALNTYPE
330            max_length: Some(17), // Max decimal size
331            precision: Some(precision),
332            scale: Some(scale),
333            collation: None,
334            tvp_type_name: None,
335        }
336    }
337
338    /// Create type info for a Table-Valued Parameter.
339    ///
340    /// # Arguments
341    /// * `type_name` - The fully qualified table type name (e.g., "dbo.IntIdList")
342    pub fn tvp(type_name: impl Into<String>) -> Self {
343        Self {
344            type_id: 0xF3, // TVP type
345            max_length: None,
346            precision: None,
347            scale: None,
348            collation: None,
349            tvp_type_name: Some(type_name.into()),
350        }
351    }
352
353    /// Encode type info to buffer.
354    pub fn encode(&self, buf: &mut BytesMut) {
355        // TVP (0xF3) has type_id embedded in the value data itself
356        // (written by TvpEncoder::encode_metadata), so don't write it here
357        if self.type_id != 0xF3 {
358            buf.put_u8(self.type_id);
359        }
360
361        // Variable-length types need max length
362        match self.type_id {
363            0x26 | 0x68 | 0x6D => {
364                // INTNTYPE, BITNTYPE, FLTNTYPE
365                if let Some(len) = self.max_length {
366                    buf.put_u8(len as u8);
367                }
368            }
369            0xE7 | 0xA5 | 0xEF => {
370                // NVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
371                if let Some(len) = self.max_length {
372                    buf.put_u16_le(len);
373                }
374                // Collation for string types
375                if let Some(collation) = self.collation {
376                    buf.put_slice(&collation);
377                }
378            }
379            0x24 => {
380                // GUIDTYPE
381                if let Some(len) = self.max_length {
382                    buf.put_u8(len as u8);
383                }
384            }
385            0x29..=0x2B => {
386                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
387                if let Some(scale) = self.scale {
388                    buf.put_u8(scale);
389                }
390            }
391            0x6C | 0x6A => {
392                // DECIMALNTYPE, NUMERICNTYPE
393                if let Some(len) = self.max_length {
394                    buf.put_u8(len as u8);
395                }
396                if let Some(precision) = self.precision {
397                    buf.put_u8(precision);
398                }
399                if let Some(scale) = self.scale {
400                    buf.put_u8(scale);
401                }
402            }
403            _ => {}
404        }
405    }
406}
407
408/// An RPC parameter.
409#[derive(Debug, Clone)]
410pub struct RpcParam {
411    /// Parameter name (can be empty for positional params).
412    pub name: String,
413    /// Status flags.
414    pub flags: ParamFlags,
415    /// Type information.
416    pub type_info: TypeInfo,
417    /// Parameter value (raw bytes).
418    pub value: Option<Bytes>,
419}
420
421impl RpcParam {
422    /// Create a new parameter with a value.
423    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
424        Self {
425            name: name.into(),
426            flags: ParamFlags::default(),
427            type_info,
428            value: Some(value),
429        }
430    }
431
432    /// Create a NULL parameter.
433    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
434        Self {
435            name: name.into(),
436            flags: ParamFlags::default(),
437            type_info,
438            value: None,
439        }
440    }
441
442    /// Create an INT parameter.
443    pub fn int(name: impl Into<String>, value: i32) -> Self {
444        let mut buf = BytesMut::with_capacity(4);
445        buf.put_i32_le(value);
446        Self::new(name, TypeInfo::int(), buf.freeze())
447    }
448
449    /// Create a BIGINT parameter.
450    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
451        let mut buf = BytesMut::with_capacity(8);
452        buf.put_i64_le(value);
453        Self::new(name, TypeInfo::bigint(), buf.freeze())
454    }
455
456    /// Create an NVARCHAR parameter.
457    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
458        let mut buf = BytesMut::new();
459        // Encode as UTF-16LE
460        for code_unit in value.encode_utf16() {
461            buf.put_u16_le(code_unit);
462        }
463        let char_len = value.chars().count();
464        let type_info = if char_len > 4000 {
465            TypeInfo::nvarchar_max()
466        } else {
467            TypeInfo::nvarchar(char_len.max(1) as u16)
468        };
469        Self::new(name, type_info, buf.freeze())
470    }
471
472    /// Mark as output parameter.
473    #[must_use]
474    pub fn as_output(mut self) -> Self {
475        self.flags = self.flags.output();
476        self
477    }
478
479    /// Encode the parameter to buffer.
480    pub fn encode(&self, buf: &mut BytesMut) {
481        // Parameter name (B_VARCHAR - length-prefixed)
482        let name_len = self.name.encode_utf16().count() as u8;
483        buf.put_u8(name_len);
484        if name_len > 0 {
485            for code_unit in self.name.encode_utf16() {
486                buf.put_u16_le(code_unit);
487            }
488        }
489
490        // Status flags
491        buf.put_u8(self.flags.encode());
492
493        // Type info
494        self.type_info.encode(buf);
495
496        // Value
497        if let Some(ref value) = self.value {
498            // Length prefix based on type
499            match self.type_info.type_id {
500                0x26 => {
501                    // INTNTYPE
502                    buf.put_u8(value.len() as u8);
503                    buf.put_slice(value);
504                }
505                0x68 | 0x6D => {
506                    // BITNTYPE, FLTNTYPE
507                    buf.put_u8(value.len() as u8);
508                    buf.put_slice(value);
509                }
510                0xE7 | 0xA5 => {
511                    // NVARCHARTYPE, BIGVARBINTYPE
512                    if self.type_info.max_length == Some(0xFFFF) {
513                        // MAX type - use PLP format
514                        // For simplicity, send as single chunk
515                        let total_len = value.len() as u64;
516                        buf.put_u64_le(total_len);
517                        buf.put_u32_le(value.len() as u32);
518                        buf.put_slice(value);
519                        buf.put_u32_le(0); // Terminator
520                    } else {
521                        buf.put_u16_le(value.len() as u16);
522                        buf.put_slice(value);
523                    }
524                }
525                0x24 => {
526                    // GUIDTYPE
527                    buf.put_u8(value.len() as u8);
528                    buf.put_slice(value);
529                }
530                0x28 => {
531                    // DATETYPE (fixed 3 bytes)
532                    buf.put_slice(value);
533                }
534                0x2A => {
535                    // DATETIME2TYPE
536                    buf.put_u8(value.len() as u8);
537                    buf.put_slice(value);
538                }
539                0x6C => {
540                    // DECIMALNTYPE
541                    buf.put_u8(value.len() as u8);
542                    buf.put_slice(value);
543                }
544                0xF3 => {
545                    // TVP (Table-Valued Parameter)
546                    // TVP values are self-delimiting: they contain complete metadata,
547                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
548                    buf.put_slice(value);
549                }
550                _ => {
551                    // Generic: assume length-prefixed
552                    buf.put_u8(value.len() as u8);
553                    buf.put_slice(value);
554                }
555            }
556        } else {
557            // NULL value
558            match self.type_info.type_id {
559                0xE7 | 0xA5 => {
560                    // Variable-length types use 0xFFFF for NULL
561                    if self.type_info.max_length == Some(0xFFFF) {
562                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
563                    } else {
564                        buf.put_u16_le(0xFFFF);
565                    }
566                }
567                _ => {
568                    buf.put_u8(0); // Zero-length for NULL
569                }
570            }
571        }
572    }
573}
574
575/// RPC request builder.
576#[derive(Debug, Clone)]
577pub struct RpcRequest {
578    /// Procedure name (if using named procedure).
579    proc_name: Option<String>,
580    /// Procedure ID (if using well-known procedure).
581    proc_id: Option<ProcId>,
582    /// Option flags.
583    options: RpcOptionFlags,
584    /// Parameters.
585    params: Vec<RpcParam>,
586}
587
588impl RpcRequest {
589    /// Create a new RPC request for a named procedure.
590    pub fn named(proc_name: impl Into<String>) -> Self {
591        Self {
592            proc_name: Some(proc_name.into()),
593            proc_id: None,
594            options: RpcOptionFlags::default(),
595            params: Vec::new(),
596        }
597    }
598
599    /// Create a new RPC request for a well-known procedure.
600    pub fn by_id(proc_id: ProcId) -> Self {
601        Self {
602            proc_name: None,
603            proc_id: Some(proc_id),
604            options: RpcOptionFlags::default(),
605            params: Vec::new(),
606        }
607    }
608
609    /// Create an sp_executesql request.
610    ///
611    /// This is the primary method for parameterized queries.
612    ///
613    /// # Example
614    ///
615    /// ```
616    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
617    ///
618    /// let rpc = RpcRequest::execute_sql(
619    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
620    ///     vec![
621    ///         RpcParam::int("@p1", 42),
622    ///         RpcParam::nvarchar("@p2", "Alice"),
623    ///     ],
624    /// );
625    /// ```
626    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
627        let mut request = Self::by_id(ProcId::ExecuteSql);
628
629        // First parameter: the SQL statement (NVARCHAR(MAX))
630        request.params.push(RpcParam::nvarchar("", sql));
631
632        // Second parameter: parameter declarations
633        if !params.is_empty() {
634            let declarations = Self::build_param_declarations(&params);
635            request.params.push(RpcParam::nvarchar("", &declarations));
636        }
637
638        // Add the actual parameters
639        request.params.extend(params);
640
641        request
642    }
643
644    /// Build parameter declaration string for sp_executesql.
645    fn build_param_declarations(params: &[RpcParam]) -> String {
646        params
647            .iter()
648            .map(|p| {
649                let name = if p.name.starts_with('@') {
650                    p.name.clone()
651                } else if p.name.is_empty() {
652                    // Generate positional name
653                    format!(
654                        "@p{}",
655                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
656                    )
657                } else {
658                    format!("@{}", p.name)
659                };
660
661                let type_name: String = match p.type_info.type_id {
662                    0x26 => match p.type_info.max_length {
663                        Some(1) => "tinyint".to_string(),
664                        Some(2) => "smallint".to_string(),
665                        Some(4) => "int".to_string(),
666                        Some(8) => "bigint".to_string(),
667                        _ => "int".to_string(),
668                    },
669                    0x68 => "bit".to_string(),
670                    0x6D => match p.type_info.max_length {
671                        Some(4) => "real".to_string(),
672                        _ => "float".to_string(),
673                    },
674                    0xE7 => {
675                        if p.type_info.max_length == Some(0xFFFF) {
676                            "nvarchar(max)".to_string()
677                        } else {
678                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
679                            format!("nvarchar({len})")
680                        }
681                    }
682                    0xA5 => {
683                        if p.type_info.max_length == Some(0xFFFF) {
684                            "varbinary(max)".to_string()
685                        } else {
686                            let len = p.type_info.max_length.unwrap_or(8000);
687                            format!("varbinary({len})")
688                        }
689                    }
690                    0x24 => "uniqueidentifier".to_string(),
691                    0x28 => "date".to_string(),
692                    0x2A => {
693                        let scale = p.type_info.scale.unwrap_or(7);
694                        format!("datetime2({scale})")
695                    }
696                    0x6C => {
697                        let precision = p.type_info.precision.unwrap_or(18);
698                        let scale = p.type_info.scale.unwrap_or(0);
699                        format!("decimal({precision}, {scale})")
700                    }
701                    0xF3 => {
702                        // TVP - Table-Valued Parameter
703                        // Must be declared with the table type name and READONLY
704                        if let Some(ref tvp_name) = p.type_info.tvp_type_name {
705                            format!("{tvp_name} READONLY")
706                        } else {
707                            // Fallback if type name is missing (shouldn't happen)
708                            "sql_variant".to_string()
709                        }
710                    }
711                    _ => "sql_variant".to_string(),
712                };
713
714                format!("{name} {type_name}")
715            })
716            .collect::<Vec<_>>()
717            .join(", ")
718    }
719
720    /// Create an sp_prepare request.
721    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
722        let mut request = Self::by_id(ProcId::Prepare);
723
724        // OUT: handle (INT)
725        request
726            .params
727            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
728
729        // Param declarations
730        let declarations = Self::build_param_declarations(params);
731        request
732            .params
733            .push(RpcParam::nvarchar("@params", &declarations));
734
735        // SQL statement
736        request.params.push(RpcParam::nvarchar("@stmt", sql));
737
738        // Options (1 = WITH RECOMPILE)
739        request.params.push(RpcParam::int("@options", 1));
740
741        request
742    }
743
744    /// Create an sp_execute request.
745    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
746        let mut request = Self::by_id(ProcId::Execute);
747
748        // Handle from sp_prepare
749        request.params.push(RpcParam::int("@handle", handle));
750
751        // Add parameters
752        request.params.extend(params);
753
754        request
755    }
756
757    /// Create an sp_unprepare request.
758    pub fn unprepare(handle: i32) -> Self {
759        let mut request = Self::by_id(ProcId::Unprepare);
760        request.params.push(RpcParam::int("@handle", handle));
761        request
762    }
763
764    /// Set option flags.
765    #[must_use]
766    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
767        self.options = options;
768        self
769    }
770
771    /// Add a parameter.
772    #[must_use]
773    pub fn param(mut self, param: RpcParam) -> Self {
774        self.params.push(param);
775        self
776    }
777
778    /// Encode the RPC request to bytes (auto-commit mode).
779    ///
780    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
781    #[must_use]
782    pub fn encode(&self) -> Bytes {
783        self.encode_with_transaction(0)
784    }
785
786    /// Encode the RPC request with a transaction descriptor.
787    ///
788    /// Per MS-TDS spec, when executing within an explicit transaction:
789    /// - The `transaction_descriptor` MUST be the value returned by the server
790    ///   in the BeginTransaction EnvChange token.
791    /// - For auto-commit mode (no explicit transaction), use 0.
792    ///
793    /// # Arguments
794    ///
795    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
796    ///   or 0 for auto-commit mode.
797    #[must_use]
798    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
799        let mut buf = BytesMut::with_capacity(256);
800
801        // ALL_HEADERS - TDS 7.2+ requires this section
802        // Total length placeholder (will be filled in)
803        let all_headers_start = buf.len();
804        buf.put_u32_le(0); // Total length placeholder
805
806        // Transaction descriptor header (required for RPC)
807        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
808        buf.put_u32_le(18); // Header length
809        buf.put_u16_le(0x0002); // Header type: transaction descriptor
810        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
811        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
812
813        // Fill in ALL_HEADERS total length
814        let all_headers_len = buf.len() - all_headers_start;
815        let len_bytes = (all_headers_len as u32).to_le_bytes();
816        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
817
818        // Procedure name or ID
819        if let Some(proc_id) = self.proc_id {
820            // Use PROCID format
821            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
822            buf.put_u16_le(proc_id as u16);
823        } else if let Some(ref proc_name) = self.proc_name {
824            // Use procedure name
825            let name_len = proc_name.encode_utf16().count() as u16;
826            buf.put_u16_le(name_len);
827            write_utf16_string(&mut buf, proc_name);
828        }
829
830        // Option flags
831        buf.put_u16_le(self.options.encode());
832
833        // Parameters
834        for param in &self.params {
835            param.encode(&mut buf);
836        }
837
838        buf.freeze()
839    }
840}
841
842#[cfg(test)]
843#[allow(clippy::unwrap_used)]
844mod tests {
845    use super::*;
846
847    #[test]
848    fn test_proc_id_values() {
849        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
850        assert_eq!(ProcId::Prepare as u16, 0x000B);
851        assert_eq!(ProcId::Execute as u16, 0x000C);
852        assert_eq!(ProcId::Unprepare as u16, 0x000F);
853    }
854
855    #[test]
856    fn test_option_flags_encode() {
857        let flags = RpcOptionFlags::new().with_recompile(true);
858        assert_eq!(flags.encode(), 0x0001);
859    }
860
861    #[test]
862    fn test_param_flags_encode() {
863        let flags = ParamFlags::new().output();
864        assert_eq!(flags.encode(), 0x01);
865    }
866
867    #[test]
868    fn test_int_param() {
869        let param = RpcParam::int("@p1", 42);
870        assert_eq!(param.name, "@p1");
871        assert_eq!(param.type_info.type_id, 0x26);
872        assert!(param.value.is_some());
873    }
874
875    #[test]
876    fn test_nvarchar_param() {
877        let param = RpcParam::nvarchar("@name", "Alice");
878        assert_eq!(param.name, "@name");
879        assert_eq!(param.type_info.type_id, 0xE7);
880        // UTF-16 encoded "Alice" = 10 bytes
881        assert_eq!(param.value.as_ref().unwrap().len(), 10);
882    }
883
884    #[test]
885    fn test_execute_sql_request() {
886        let rpc = RpcRequest::execute_sql(
887            "SELECT * FROM users WHERE id = @p1",
888            vec![RpcParam::int("@p1", 42)],
889        );
890
891        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
892        // SQL statement + param declarations + actual params
893        assert_eq!(rpc.params.len(), 3);
894    }
895
896    #[test]
897    fn test_param_declarations() {
898        let params = vec![
899            RpcParam::int("@p1", 42),
900            RpcParam::nvarchar("@name", "Alice"),
901        ];
902
903        let decls = RpcRequest::build_param_declarations(&params);
904        assert!(decls.contains("@p1 int"));
905        assert!(decls.contains("@name nvarchar"));
906    }
907
908    #[test]
909    fn test_rpc_encode_not_empty() {
910        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
911        let encoded = rpc.encode();
912        assert!(!encoded.is_empty());
913    }
914
915    #[test]
916    fn test_prepare_request() {
917        let rpc = RpcRequest::prepare(
918            "SELECT * FROM users WHERE id = @p1",
919            &[RpcParam::int("@p1", 0)],
920        );
921
922        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
923        // handle (output), params, stmt, options
924        assert_eq!(rpc.params.len(), 4);
925        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
926    }
927
928    #[test]
929    fn test_execute_request() {
930        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
931
932        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
933        assert_eq!(rpc.params.len(), 2); // handle + param
934    }
935
936    #[test]
937    fn test_unprepare_request() {
938        let rpc = RpcRequest::unprepare(123);
939
940        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
941        assert_eq!(rpc.params.len(), 1); // just the handle
942    }
943}