tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31/// Well-known stored procedure IDs.
32///
33/// These are special procedure IDs that SQL Server recognizes
34/// without requiring the procedure name.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37pub enum ProcId {
38    /// sp_cursor (0x0001)
39    Cursor = 0x0001,
40    /// sp_cursoropen (0x0002)
41    CursorOpen = 0x0002,
42    /// sp_cursorprepare (0x0003)
43    CursorPrepare = 0x0003,
44    /// sp_cursorexecute (0x0004)
45    CursorExecute = 0x0004,
46    /// sp_cursorprepexec (0x0005)
47    CursorPrepExec = 0x0005,
48    /// sp_cursorunprepare (0x0006)
49    CursorUnprepare = 0x0006,
50    /// sp_cursorfetch (0x0007)
51    CursorFetch = 0x0007,
52    /// sp_cursoroption (0x0008)
53    CursorOption = 0x0008,
54    /// sp_cursorclose (0x0009)
55    CursorClose = 0x0009,
56    /// sp_executesql (0x000A) - Primary method for parameterized queries
57    ExecuteSql = 0x000A,
58    /// sp_prepare (0x000B)
59    Prepare = 0x000B,
60    /// sp_execute (0x000C)
61    Execute = 0x000C,
62    /// sp_prepexec (0x000D) - Prepare and execute in one call
63    PrepExec = 0x000D,
64    /// sp_prepexecrpc (0x000E)
65    PrepExecRpc = 0x000E,
66    /// sp_unprepare (0x000F)
67    Unprepare = 0x000F,
68}
69
70/// RPC option flags.
71#[derive(Debug, Clone, Copy, Default)]
72pub struct RpcOptionFlags {
73    /// Recompile the procedure.
74    pub with_recompile: bool,
75    /// No metadata in response.
76    pub no_metadata: bool,
77    /// Reuse metadata from previous call.
78    pub reuse_metadata: bool,
79}
80
81impl RpcOptionFlags {
82    /// Create new empty flags.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Set with recompile flag.
88    #[must_use]
89    pub fn with_recompile(mut self, value: bool) -> Self {
90        self.with_recompile = value;
91        self
92    }
93
94    /// Encode to wire format (2 bytes).
95    pub fn encode(&self) -> u16 {
96        let mut flags = 0u16;
97        if self.with_recompile {
98            flags |= 0x0001;
99        }
100        if self.no_metadata {
101            flags |= 0x0002;
102        }
103        if self.reuse_metadata {
104            flags |= 0x0004;
105        }
106        flags
107    }
108}
109
110/// RPC parameter status flags.
111#[derive(Debug, Clone, Copy, Default)]
112pub struct ParamFlags {
113    /// Parameter is passed by reference (OUTPUT parameter).
114    pub by_ref: bool,
115    /// Parameter has a default value.
116    pub default: bool,
117    /// Parameter is encrypted (Always Encrypted).
118    pub encrypted: bool,
119}
120
121impl ParamFlags {
122    /// Create new empty flags.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Set as output parameter.
128    #[must_use]
129    pub fn output(mut self) -> Self {
130        self.by_ref = true;
131        self
132    }
133
134    /// Encode to wire format (1 byte).
135    pub fn encode(&self) -> u8 {
136        let mut flags = 0u8;
137        if self.by_ref {
138            flags |= 0x01;
139        }
140        if self.default {
141            flags |= 0x02;
142        }
143        if self.encrypted {
144            flags |= 0x08;
145        }
146        flags
147    }
148}
149
150/// TDS type information for RPC parameters.
151#[derive(Debug, Clone)]
152pub struct TypeInfo {
153    /// Type ID.
154    pub type_id: u8,
155    /// Maximum length for variable-length types.
156    pub max_length: Option<u16>,
157    /// Precision for numeric types.
158    pub precision: Option<u8>,
159    /// Scale for numeric types.
160    pub scale: Option<u8>,
161    /// Collation for string types.
162    pub collation: Option<[u8; 5]>,
163    /// TVP type name (e.g., "dbo.IntIdList") for Table-Valued Parameters.
164    pub tvp_type_name: Option<String>,
165}
166
167impl TypeInfo {
168    /// Create type info for INT.
169    pub fn int() -> Self {
170        Self {
171            type_id: 0x26, // INTNTYPE (variable-length int)
172            max_length: Some(4),
173            precision: None,
174            scale: None,
175            collation: None,
176            tvp_type_name: None,
177        }
178    }
179
180    /// Create type info for BIGINT.
181    pub fn bigint() -> Self {
182        Self {
183            type_id: 0x26, // INTNTYPE
184            max_length: Some(8),
185            precision: None,
186            scale: None,
187            collation: None,
188            tvp_type_name: None,
189        }
190    }
191
192    /// Create type info for SMALLINT.
193    pub fn smallint() -> Self {
194        Self {
195            type_id: 0x26, // INTNTYPE
196            max_length: Some(2),
197            precision: None,
198            scale: None,
199            collation: None,
200            tvp_type_name: None,
201        }
202    }
203
204    /// Create type info for TINYINT.
205    pub fn tinyint() -> Self {
206        Self {
207            type_id: 0x26, // INTNTYPE
208            max_length: Some(1),
209            precision: None,
210            scale: None,
211            collation: None,
212            tvp_type_name: None,
213        }
214    }
215
216    /// Create type info for BIT.
217    pub fn bit() -> Self {
218        Self {
219            type_id: 0x68, // BITNTYPE
220            max_length: Some(1),
221            precision: None,
222            scale: None,
223            collation: None,
224            tvp_type_name: None,
225        }
226    }
227
228    /// Create type info for FLOAT.
229    pub fn float() -> Self {
230        Self {
231            type_id: 0x6D, // FLTNTYPE
232            max_length: Some(8),
233            precision: None,
234            scale: None,
235            collation: None,
236            tvp_type_name: None,
237        }
238    }
239
240    /// Create type info for REAL.
241    pub fn real() -> Self {
242        Self {
243            type_id: 0x6D, // FLTNTYPE
244            max_length: Some(4),
245            precision: None,
246            scale: None,
247            collation: None,
248            tvp_type_name: None,
249        }
250    }
251
252    /// Create type info for NVARCHAR with max length.
253    pub fn nvarchar(max_len: u16) -> Self {
254        Self {
255            type_id: 0xE7,                 // NVARCHARTYPE
256            max_length: Some(max_len * 2), // UTF-16, so double the char count
257            precision: None,
258            scale: None,
259            // Default collation (Latin1_General_CI_AS equivalent)
260            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
261            tvp_type_name: None,
262        }
263    }
264
265    /// Create type info for NVARCHAR(MAX).
266    pub fn nvarchar_max() -> Self {
267        Self {
268            type_id: 0xE7,            // NVARCHARTYPE
269            max_length: Some(0xFFFF), // MAX indicator
270            precision: None,
271            scale: None,
272            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
273            tvp_type_name: None,
274        }
275    }
276
277    /// Create type info for VARBINARY with max length.
278    pub fn varbinary(max_len: u16) -> Self {
279        Self {
280            type_id: 0xA5, // BIGVARBINTYPE
281            max_length: Some(max_len),
282            precision: None,
283            scale: None,
284            collation: None,
285            tvp_type_name: None,
286        }
287    }
288
289    /// Create type info for UNIQUEIDENTIFIER.
290    pub fn uniqueidentifier() -> Self {
291        Self {
292            type_id: 0x24, // GUIDTYPE
293            max_length: Some(16),
294            precision: None,
295            scale: None,
296            collation: None,
297            tvp_type_name: None,
298        }
299    }
300
301    /// Create type info for DATE.
302    pub fn date() -> Self {
303        Self {
304            type_id: 0x28, // DATETYPE
305            max_length: None,
306            precision: None,
307            scale: None,
308            collation: None,
309            tvp_type_name: None,
310        }
311    }
312
313    /// Create type info for DATETIME2.
314    pub fn datetime2(scale: u8) -> Self {
315        Self {
316            type_id: 0x2A, // DATETIME2TYPE
317            max_length: None,
318            precision: None,
319            scale: Some(scale),
320            collation: None,
321            tvp_type_name: None,
322        }
323    }
324
325    /// Create type info for DECIMAL.
326    pub fn decimal(precision: u8, scale: u8) -> Self {
327        Self {
328            type_id: 0x6C,        // DECIMALNTYPE
329            max_length: Some(17), // Max decimal size
330            precision: Some(precision),
331            scale: Some(scale),
332            collation: None,
333            tvp_type_name: None,
334        }
335    }
336
337    /// Create type info for a Table-Valued Parameter.
338    ///
339    /// # Arguments
340    /// * `type_name` - The fully qualified table type name (e.g., "dbo.IntIdList")
341    pub fn tvp(type_name: impl Into<String>) -> Self {
342        Self {
343            type_id: 0xF3, // TVP type
344            max_length: None,
345            precision: None,
346            scale: None,
347            collation: None,
348            tvp_type_name: Some(type_name.into()),
349        }
350    }
351
352    /// Encode type info to buffer.
353    pub fn encode(&self, buf: &mut BytesMut) {
354        // TVP (0xF3) has type_id embedded in the value data itself
355        // (written by TvpEncoder::encode_metadata), so don't write it here
356        if self.type_id != 0xF3 {
357            buf.put_u8(self.type_id);
358        }
359
360        // Variable-length types need max length
361        match self.type_id {
362            0x26 | 0x68 | 0x6D => {
363                // INTNTYPE, BITNTYPE, FLTNTYPE
364                if let Some(len) = self.max_length {
365                    buf.put_u8(len as u8);
366                }
367            }
368            0xE7 | 0xA5 | 0xEF => {
369                // NVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
370                if let Some(len) = self.max_length {
371                    buf.put_u16_le(len);
372                }
373                // Collation for string types
374                if let Some(collation) = self.collation {
375                    buf.put_slice(&collation);
376                }
377            }
378            0x24 => {
379                // GUIDTYPE
380                if let Some(len) = self.max_length {
381                    buf.put_u8(len as u8);
382                }
383            }
384            0x29..=0x2B => {
385                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
386                if let Some(scale) = self.scale {
387                    buf.put_u8(scale);
388                }
389            }
390            0x6C | 0x6A => {
391                // DECIMALNTYPE, NUMERICNTYPE
392                if let Some(len) = self.max_length {
393                    buf.put_u8(len as u8);
394                }
395                if let Some(precision) = self.precision {
396                    buf.put_u8(precision);
397                }
398                if let Some(scale) = self.scale {
399                    buf.put_u8(scale);
400                }
401            }
402            _ => {}
403        }
404    }
405}
406
407/// An RPC parameter.
408#[derive(Debug, Clone)]
409pub struct RpcParam {
410    /// Parameter name (can be empty for positional params).
411    pub name: String,
412    /// Status flags.
413    pub flags: ParamFlags,
414    /// Type information.
415    pub type_info: TypeInfo,
416    /// Parameter value (raw bytes).
417    pub value: Option<Bytes>,
418}
419
420impl RpcParam {
421    /// Create a new parameter with a value.
422    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
423        Self {
424            name: name.into(),
425            flags: ParamFlags::default(),
426            type_info,
427            value: Some(value),
428        }
429    }
430
431    /// Create a NULL parameter.
432    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
433        Self {
434            name: name.into(),
435            flags: ParamFlags::default(),
436            type_info,
437            value: None,
438        }
439    }
440
441    /// Create an INT parameter.
442    pub fn int(name: impl Into<String>, value: i32) -> Self {
443        let mut buf = BytesMut::with_capacity(4);
444        buf.put_i32_le(value);
445        Self::new(name, TypeInfo::int(), buf.freeze())
446    }
447
448    /// Create a BIGINT parameter.
449    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
450        let mut buf = BytesMut::with_capacity(8);
451        buf.put_i64_le(value);
452        Self::new(name, TypeInfo::bigint(), buf.freeze())
453    }
454
455    /// Create an NVARCHAR parameter.
456    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
457        let mut buf = BytesMut::new();
458        // Encode as UTF-16LE
459        for code_unit in value.encode_utf16() {
460            buf.put_u16_le(code_unit);
461        }
462        let char_len = value.chars().count();
463        let type_info = if char_len > 4000 {
464            TypeInfo::nvarchar_max()
465        } else {
466            TypeInfo::nvarchar(char_len.max(1) as u16)
467        };
468        Self::new(name, type_info, buf.freeze())
469    }
470
471    /// Mark as output parameter.
472    #[must_use]
473    pub fn as_output(mut self) -> Self {
474        self.flags = self.flags.output();
475        self
476    }
477
478    /// Encode the parameter to buffer.
479    pub fn encode(&self, buf: &mut BytesMut) {
480        // Parameter name (B_VARCHAR - length-prefixed)
481        let name_len = self.name.encode_utf16().count() as u8;
482        buf.put_u8(name_len);
483        if name_len > 0 {
484            for code_unit in self.name.encode_utf16() {
485                buf.put_u16_le(code_unit);
486            }
487        }
488
489        // Status flags
490        buf.put_u8(self.flags.encode());
491
492        // Type info
493        self.type_info.encode(buf);
494
495        // Value
496        if let Some(ref value) = self.value {
497            // Length prefix based on type
498            match self.type_info.type_id {
499                0x26 => {
500                    // INTNTYPE
501                    buf.put_u8(value.len() as u8);
502                    buf.put_slice(value);
503                }
504                0x68 | 0x6D => {
505                    // BITNTYPE, FLTNTYPE
506                    buf.put_u8(value.len() as u8);
507                    buf.put_slice(value);
508                }
509                0xE7 | 0xA5 => {
510                    // NVARCHARTYPE, BIGVARBINTYPE
511                    if self.type_info.max_length == Some(0xFFFF) {
512                        // MAX type - use PLP format
513                        // For simplicity, send as single chunk
514                        let total_len = value.len() as u64;
515                        buf.put_u64_le(total_len);
516                        buf.put_u32_le(value.len() as u32);
517                        buf.put_slice(value);
518                        buf.put_u32_le(0); // Terminator
519                    } else {
520                        buf.put_u16_le(value.len() as u16);
521                        buf.put_slice(value);
522                    }
523                }
524                0x24 => {
525                    // GUIDTYPE
526                    buf.put_u8(value.len() as u8);
527                    buf.put_slice(value);
528                }
529                0x28 => {
530                    // DATETYPE (fixed 3 bytes)
531                    buf.put_slice(value);
532                }
533                0x2A => {
534                    // DATETIME2TYPE
535                    buf.put_u8(value.len() as u8);
536                    buf.put_slice(value);
537                }
538                0x6C => {
539                    // DECIMALNTYPE
540                    buf.put_u8(value.len() as u8);
541                    buf.put_slice(value);
542                }
543                0xF3 => {
544                    // TVP (Table-Valued Parameter)
545                    // TVP values are self-delimiting: they contain complete metadata,
546                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
547                    buf.put_slice(value);
548                }
549                _ => {
550                    // Generic: assume length-prefixed
551                    buf.put_u8(value.len() as u8);
552                    buf.put_slice(value);
553                }
554            }
555        } else {
556            // NULL value
557            match self.type_info.type_id {
558                0xE7 | 0xA5 => {
559                    // Variable-length types use 0xFFFF for NULL
560                    if self.type_info.max_length == Some(0xFFFF) {
561                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
562                    } else {
563                        buf.put_u16_le(0xFFFF);
564                    }
565                }
566                _ => {
567                    buf.put_u8(0); // Zero-length for NULL
568                }
569            }
570        }
571    }
572}
573
574/// RPC request builder.
575#[derive(Debug, Clone)]
576pub struct RpcRequest {
577    /// Procedure name (if using named procedure).
578    proc_name: Option<String>,
579    /// Procedure ID (if using well-known procedure).
580    proc_id: Option<ProcId>,
581    /// Option flags.
582    options: RpcOptionFlags,
583    /// Parameters.
584    params: Vec<RpcParam>,
585}
586
587impl RpcRequest {
588    /// Create a new RPC request for a named procedure.
589    pub fn named(proc_name: impl Into<String>) -> Self {
590        Self {
591            proc_name: Some(proc_name.into()),
592            proc_id: None,
593            options: RpcOptionFlags::default(),
594            params: Vec::new(),
595        }
596    }
597
598    /// Create a new RPC request for a well-known procedure.
599    pub fn by_id(proc_id: ProcId) -> Self {
600        Self {
601            proc_name: None,
602            proc_id: Some(proc_id),
603            options: RpcOptionFlags::default(),
604            params: Vec::new(),
605        }
606    }
607
608    /// Create an sp_executesql request.
609    ///
610    /// This is the primary method for parameterized queries.
611    ///
612    /// # Example
613    ///
614    /// ```
615    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
616    ///
617    /// let rpc = RpcRequest::execute_sql(
618    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
619    ///     vec![
620    ///         RpcParam::int("@p1", 42),
621    ///         RpcParam::nvarchar("@p2", "Alice"),
622    ///     ],
623    /// );
624    /// ```
625    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
626        let mut request = Self::by_id(ProcId::ExecuteSql);
627
628        // First parameter: the SQL statement (NVARCHAR(MAX))
629        request.params.push(RpcParam::nvarchar("", sql));
630
631        // Second parameter: parameter declarations
632        if !params.is_empty() {
633            let declarations = Self::build_param_declarations(&params);
634            request.params.push(RpcParam::nvarchar("", &declarations));
635        }
636
637        // Add the actual parameters
638        request.params.extend(params);
639
640        request
641    }
642
643    /// Build parameter declaration string for sp_executesql.
644    fn build_param_declarations(params: &[RpcParam]) -> String {
645        params
646            .iter()
647            .map(|p| {
648                let name = if p.name.starts_with('@') {
649                    p.name.clone()
650                } else if p.name.is_empty() {
651                    // Generate positional name
652                    format!(
653                        "@p{}",
654                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
655                    )
656                } else {
657                    format!("@{}", p.name)
658                };
659
660                let type_name: String = match p.type_info.type_id {
661                    0x26 => match p.type_info.max_length {
662                        Some(1) => "tinyint".to_string(),
663                        Some(2) => "smallint".to_string(),
664                        Some(4) => "int".to_string(),
665                        Some(8) => "bigint".to_string(),
666                        _ => "int".to_string(),
667                    },
668                    0x68 => "bit".to_string(),
669                    0x6D => match p.type_info.max_length {
670                        Some(4) => "real".to_string(),
671                        _ => "float".to_string(),
672                    },
673                    0xE7 => {
674                        if p.type_info.max_length == Some(0xFFFF) {
675                            "nvarchar(max)".to_string()
676                        } else {
677                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
678                            format!("nvarchar({})", len)
679                        }
680                    }
681                    0xA5 => {
682                        if p.type_info.max_length == Some(0xFFFF) {
683                            "varbinary(max)".to_string()
684                        } else {
685                            let len = p.type_info.max_length.unwrap_or(8000);
686                            format!("varbinary({})", len)
687                        }
688                    }
689                    0x24 => "uniqueidentifier".to_string(),
690                    0x28 => "date".to_string(),
691                    0x2A => {
692                        let scale = p.type_info.scale.unwrap_or(7);
693                        format!("datetime2({})", scale)
694                    }
695                    0x6C => {
696                        let precision = p.type_info.precision.unwrap_or(18);
697                        let scale = p.type_info.scale.unwrap_or(0);
698                        format!("decimal({}, {})", precision, scale)
699                    }
700                    0xF3 => {
701                        // TVP - Table-Valued Parameter
702                        // Must be declared with the table type name and READONLY
703                        if let Some(ref tvp_name) = p.type_info.tvp_type_name {
704                            format!("{} READONLY", tvp_name)
705                        } else {
706                            // Fallback if type name is missing (shouldn't happen)
707                            "sql_variant".to_string()
708                        }
709                    }
710                    _ => "sql_variant".to_string(),
711                };
712
713                format!("{} {}", name, type_name)
714            })
715            .collect::<Vec<_>>()
716            .join(", ")
717    }
718
719    /// Create an sp_prepare request.
720    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
721        let mut request = Self::by_id(ProcId::Prepare);
722
723        // OUT: handle (INT)
724        request
725            .params
726            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
727
728        // Param declarations
729        let declarations = Self::build_param_declarations(params);
730        request
731            .params
732            .push(RpcParam::nvarchar("@params", &declarations));
733
734        // SQL statement
735        request.params.push(RpcParam::nvarchar("@stmt", sql));
736
737        // Options (1 = WITH RECOMPILE)
738        request.params.push(RpcParam::int("@options", 1));
739
740        request
741    }
742
743    /// Create an sp_execute request.
744    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
745        let mut request = Self::by_id(ProcId::Execute);
746
747        // Handle from sp_prepare
748        request.params.push(RpcParam::int("@handle", handle));
749
750        // Add parameters
751        request.params.extend(params);
752
753        request
754    }
755
756    /// Create an sp_unprepare request.
757    pub fn unprepare(handle: i32) -> Self {
758        let mut request = Self::by_id(ProcId::Unprepare);
759        request.params.push(RpcParam::int("@handle", handle));
760        request
761    }
762
763    /// Set option flags.
764    #[must_use]
765    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
766        self.options = options;
767        self
768    }
769
770    /// Add a parameter.
771    #[must_use]
772    pub fn param(mut self, param: RpcParam) -> Self {
773        self.params.push(param);
774        self
775    }
776
777    /// Encode the RPC request to bytes (auto-commit mode).
778    ///
779    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
780    #[must_use]
781    pub fn encode(&self) -> Bytes {
782        self.encode_with_transaction(0)
783    }
784
785    /// Encode the RPC request with a transaction descriptor.
786    ///
787    /// Per MS-TDS spec, when executing within an explicit transaction:
788    /// - The `transaction_descriptor` MUST be the value returned by the server
789    ///   in the BeginTransaction EnvChange token.
790    /// - For auto-commit mode (no explicit transaction), use 0.
791    ///
792    /// # Arguments
793    ///
794    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
795    ///   or 0 for auto-commit mode.
796    #[must_use]
797    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
798        let mut buf = BytesMut::with_capacity(256);
799
800        // ALL_HEADERS - TDS 7.2+ requires this section
801        // Total length placeholder (will be filled in)
802        let all_headers_start = buf.len();
803        buf.put_u32_le(0); // Total length placeholder
804
805        // Transaction descriptor header (required for RPC)
806        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
807        buf.put_u32_le(18); // Header length
808        buf.put_u16_le(0x0002); // Header type: transaction descriptor
809        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
810        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
811
812        // Fill in ALL_HEADERS total length
813        let all_headers_len = buf.len() - all_headers_start;
814        let len_bytes = (all_headers_len as u32).to_le_bytes();
815        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
816
817        // Procedure name or ID
818        if let Some(proc_id) = self.proc_id {
819            // Use PROCID format
820            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
821            buf.put_u16_le(proc_id as u16);
822        } else if let Some(ref proc_name) = self.proc_name {
823            // Use procedure name
824            let name_len = proc_name.encode_utf16().count() as u16;
825            buf.put_u16_le(name_len);
826            write_utf16_string(&mut buf, proc_name);
827        }
828
829        // Option flags
830        buf.put_u16_le(self.options.encode());
831
832        // Parameters
833        for param in &self.params {
834            param.encode(&mut buf);
835        }
836
837        buf.freeze()
838    }
839}
840
841#[cfg(test)]
842#[allow(clippy::unwrap_used)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_proc_id_values() {
848        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
849        assert_eq!(ProcId::Prepare as u16, 0x000B);
850        assert_eq!(ProcId::Execute as u16, 0x000C);
851        assert_eq!(ProcId::Unprepare as u16, 0x000F);
852    }
853
854    #[test]
855    fn test_option_flags_encode() {
856        let flags = RpcOptionFlags::new().with_recompile(true);
857        assert_eq!(flags.encode(), 0x0001);
858    }
859
860    #[test]
861    fn test_param_flags_encode() {
862        let flags = ParamFlags::new().output();
863        assert_eq!(flags.encode(), 0x01);
864    }
865
866    #[test]
867    fn test_int_param() {
868        let param = RpcParam::int("@p1", 42);
869        assert_eq!(param.name, "@p1");
870        assert_eq!(param.type_info.type_id, 0x26);
871        assert!(param.value.is_some());
872    }
873
874    #[test]
875    fn test_nvarchar_param() {
876        let param = RpcParam::nvarchar("@name", "Alice");
877        assert_eq!(param.name, "@name");
878        assert_eq!(param.type_info.type_id, 0xE7);
879        // UTF-16 encoded "Alice" = 10 bytes
880        assert_eq!(param.value.as_ref().unwrap().len(), 10);
881    }
882
883    #[test]
884    fn test_execute_sql_request() {
885        let rpc = RpcRequest::execute_sql(
886            "SELECT * FROM users WHERE id = @p1",
887            vec![RpcParam::int("@p1", 42)],
888        );
889
890        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
891        // SQL statement + param declarations + actual params
892        assert_eq!(rpc.params.len(), 3);
893    }
894
895    #[test]
896    fn test_param_declarations() {
897        let params = vec![
898            RpcParam::int("@p1", 42),
899            RpcParam::nvarchar("@name", "Alice"),
900        ];
901
902        let decls = RpcRequest::build_param_declarations(&params);
903        assert!(decls.contains("@p1 int"));
904        assert!(decls.contains("@name nvarchar"));
905    }
906
907    #[test]
908    fn test_rpc_encode_not_empty() {
909        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
910        let encoded = rpc.encode();
911        assert!(!encoded.is_empty());
912    }
913
914    #[test]
915    fn test_prepare_request() {
916        let rpc = RpcRequest::prepare(
917            "SELECT * FROM users WHERE id = @p1",
918            &[RpcParam::int("@p1", 0)],
919        );
920
921        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
922        // handle (output), params, stmt, options
923        assert_eq!(rpc.params.len(), 4);
924        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
925    }
926
927    #[test]
928    fn test_execute_request() {
929        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
930
931        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
932        assert_eq!(rpc.params.len(), 2); // handle + param
933    }
934
935    #[test]
936    fn test_unprepare_request() {
937        let rpc = RpcRequest::unprepare(123);
938
939        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
940        assert_eq!(rpc.params.len(), 1); // just the handle
941    }
942}