tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31/// Well-known stored procedure IDs.
32///
33/// These are special procedure IDs that SQL Server recognizes
34/// without requiring the procedure name.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37pub enum ProcId {
38    /// sp_cursor (0x0001)
39    Cursor = 0x0001,
40    /// sp_cursoropen (0x0002)
41    CursorOpen = 0x0002,
42    /// sp_cursorprepare (0x0003)
43    CursorPrepare = 0x0003,
44    /// sp_cursorexecute (0x0004)
45    CursorExecute = 0x0004,
46    /// sp_cursorprepexec (0x0005)
47    CursorPrepExec = 0x0005,
48    /// sp_cursorunprepare (0x0006)
49    CursorUnprepare = 0x0006,
50    /// sp_cursorfetch (0x0007)
51    CursorFetch = 0x0007,
52    /// sp_cursoroption (0x0008)
53    CursorOption = 0x0008,
54    /// sp_cursorclose (0x0009)
55    CursorClose = 0x0009,
56    /// sp_executesql (0x000A) - Primary method for parameterized queries
57    ExecuteSql = 0x000A,
58    /// sp_prepare (0x000B)
59    Prepare = 0x000B,
60    /// sp_execute (0x000C)
61    Execute = 0x000C,
62    /// sp_prepexec (0x000D) - Prepare and execute in one call
63    PrepExec = 0x000D,
64    /// sp_prepexecrpc (0x000E)
65    PrepExecRpc = 0x000E,
66    /// sp_unprepare (0x000F)
67    Unprepare = 0x000F,
68}
69
70/// RPC option flags.
71#[derive(Debug, Clone, Copy, Default)]
72pub struct RpcOptionFlags {
73    /// Recompile the procedure.
74    pub with_recompile: bool,
75    /// No metadata in response.
76    pub no_metadata: bool,
77    /// Reuse metadata from previous call.
78    pub reuse_metadata: bool,
79}
80
81impl RpcOptionFlags {
82    /// Create new empty flags.
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Set with recompile flag.
88    #[must_use]
89    pub fn with_recompile(mut self, value: bool) -> Self {
90        self.with_recompile = value;
91        self
92    }
93
94    /// Encode to wire format (2 bytes).
95    pub fn encode(&self) -> u16 {
96        let mut flags = 0u16;
97        if self.with_recompile {
98            flags |= 0x0001;
99        }
100        if self.no_metadata {
101            flags |= 0x0002;
102        }
103        if self.reuse_metadata {
104            flags |= 0x0004;
105        }
106        flags
107    }
108}
109
110/// RPC parameter status flags.
111#[derive(Debug, Clone, Copy, Default)]
112pub struct ParamFlags {
113    /// Parameter is passed by reference (OUTPUT parameter).
114    pub by_ref: bool,
115    /// Parameter has a default value.
116    pub default: bool,
117    /// Parameter is encrypted (Always Encrypted).
118    pub encrypted: bool,
119}
120
121impl ParamFlags {
122    /// Create new empty flags.
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    /// Set as output parameter.
128    #[must_use]
129    pub fn output(mut self) -> Self {
130        self.by_ref = true;
131        self
132    }
133
134    /// Encode to wire format (1 byte).
135    pub fn encode(&self) -> u8 {
136        let mut flags = 0u8;
137        if self.by_ref {
138            flags |= 0x01;
139        }
140        if self.default {
141            flags |= 0x02;
142        }
143        if self.encrypted {
144            flags |= 0x08;
145        }
146        flags
147    }
148}
149
150/// TDS type information for RPC parameters.
151#[derive(Debug, Clone)]
152pub struct TypeInfo {
153    /// Type ID.
154    pub type_id: u8,
155    /// Maximum length for variable-length types.
156    pub max_length: Option<u16>,
157    /// Precision for numeric types.
158    pub precision: Option<u8>,
159    /// Scale for numeric types.
160    pub scale: Option<u8>,
161    /// Collation for string types.
162    pub collation: Option<[u8; 5]>,
163}
164
165impl TypeInfo {
166    /// Create type info for INT.
167    pub fn int() -> Self {
168        Self {
169            type_id: 0x26, // INTNTYPE (variable-length int)
170            max_length: Some(4),
171            precision: None,
172            scale: None,
173            collation: None,
174        }
175    }
176
177    /// Create type info for BIGINT.
178    pub fn bigint() -> Self {
179        Self {
180            type_id: 0x26, // INTNTYPE
181            max_length: Some(8),
182            precision: None,
183            scale: None,
184            collation: None,
185        }
186    }
187
188    /// Create type info for SMALLINT.
189    pub fn smallint() -> Self {
190        Self {
191            type_id: 0x26, // INTNTYPE
192            max_length: Some(2),
193            precision: None,
194            scale: None,
195            collation: None,
196        }
197    }
198
199    /// Create type info for TINYINT.
200    pub fn tinyint() -> Self {
201        Self {
202            type_id: 0x26, // INTNTYPE
203            max_length: Some(1),
204            precision: None,
205            scale: None,
206            collation: None,
207        }
208    }
209
210    /// Create type info for BIT.
211    pub fn bit() -> Self {
212        Self {
213            type_id: 0x68, // BITNTYPE
214            max_length: Some(1),
215            precision: None,
216            scale: None,
217            collation: None,
218        }
219    }
220
221    /// Create type info for FLOAT.
222    pub fn float() -> Self {
223        Self {
224            type_id: 0x6D, // FLTNTYPE
225            max_length: Some(8),
226            precision: None,
227            scale: None,
228            collation: None,
229        }
230    }
231
232    /// Create type info for REAL.
233    pub fn real() -> Self {
234        Self {
235            type_id: 0x6D, // FLTNTYPE
236            max_length: Some(4),
237            precision: None,
238            scale: None,
239            collation: None,
240        }
241    }
242
243    /// Create type info for NVARCHAR with max length.
244    pub fn nvarchar(max_len: u16) -> Self {
245        Self {
246            type_id: 0xE7,                 // NVARCHARTYPE
247            max_length: Some(max_len * 2), // UTF-16, so double the char count
248            precision: None,
249            scale: None,
250            // Default collation (Latin1_General_CI_AS equivalent)
251            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
252        }
253    }
254
255    /// Create type info for NVARCHAR(MAX).
256    pub fn nvarchar_max() -> Self {
257        Self {
258            type_id: 0xE7,            // NVARCHARTYPE
259            max_length: Some(0xFFFF), // MAX indicator
260            precision: None,
261            scale: None,
262            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
263        }
264    }
265
266    /// Create type info for VARBINARY with max length.
267    pub fn varbinary(max_len: u16) -> Self {
268        Self {
269            type_id: 0xA5, // BIGVARBINTYPE
270            max_length: Some(max_len),
271            precision: None,
272            scale: None,
273            collation: None,
274        }
275    }
276
277    /// Create type info for UNIQUEIDENTIFIER.
278    pub fn uniqueidentifier() -> Self {
279        Self {
280            type_id: 0x24, // GUIDTYPE
281            max_length: Some(16),
282            precision: None,
283            scale: None,
284            collation: None,
285        }
286    }
287
288    /// Create type info for DATE.
289    pub fn date() -> Self {
290        Self {
291            type_id: 0x28, // DATETYPE
292            max_length: None,
293            precision: None,
294            scale: None,
295            collation: None,
296        }
297    }
298
299    /// Create type info for DATETIME2.
300    pub fn datetime2(scale: u8) -> Self {
301        Self {
302            type_id: 0x2A, // DATETIME2TYPE
303            max_length: None,
304            precision: None,
305            scale: Some(scale),
306            collation: None,
307        }
308    }
309
310    /// Create type info for DECIMAL.
311    pub fn decimal(precision: u8, scale: u8) -> Self {
312        Self {
313            type_id: 0x6C,        // DECIMALNTYPE
314            max_length: Some(17), // Max decimal size
315            precision: Some(precision),
316            scale: Some(scale),
317            collation: None,
318        }
319    }
320
321    /// Encode type info to buffer.
322    pub fn encode(&self, buf: &mut BytesMut) {
323        // TVP (0xF3) has type_id embedded in the value data itself
324        // (written by TvpEncoder::encode_metadata), so don't write it here
325        if self.type_id != 0xF3 {
326            buf.put_u8(self.type_id);
327        }
328
329        // Variable-length types need max length
330        match self.type_id {
331            0x26 | 0x68 | 0x6D => {
332                // INTNTYPE, BITNTYPE, FLTNTYPE
333                if let Some(len) = self.max_length {
334                    buf.put_u8(len as u8);
335                }
336            }
337            0xE7 | 0xA5 | 0xEF => {
338                // NVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
339                if let Some(len) = self.max_length {
340                    buf.put_u16_le(len);
341                }
342                // Collation for string types
343                if let Some(collation) = self.collation {
344                    buf.put_slice(&collation);
345                }
346            }
347            0x24 => {
348                // GUIDTYPE
349                if let Some(len) = self.max_length {
350                    buf.put_u8(len as u8);
351                }
352            }
353            0x29..=0x2B => {
354                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
355                if let Some(scale) = self.scale {
356                    buf.put_u8(scale);
357                }
358            }
359            0x6C | 0x6A => {
360                // DECIMALNTYPE, NUMERICNTYPE
361                if let Some(len) = self.max_length {
362                    buf.put_u8(len as u8);
363                }
364                if let Some(precision) = self.precision {
365                    buf.put_u8(precision);
366                }
367                if let Some(scale) = self.scale {
368                    buf.put_u8(scale);
369                }
370            }
371            _ => {}
372        }
373    }
374}
375
376/// An RPC parameter.
377#[derive(Debug, Clone)]
378pub struct RpcParam {
379    /// Parameter name (can be empty for positional params).
380    pub name: String,
381    /// Status flags.
382    pub flags: ParamFlags,
383    /// Type information.
384    pub type_info: TypeInfo,
385    /// Parameter value (raw bytes).
386    pub value: Option<Bytes>,
387}
388
389impl RpcParam {
390    /// Create a new parameter with a value.
391    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
392        Self {
393            name: name.into(),
394            flags: ParamFlags::default(),
395            type_info,
396            value: Some(value),
397        }
398    }
399
400    /// Create a NULL parameter.
401    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
402        Self {
403            name: name.into(),
404            flags: ParamFlags::default(),
405            type_info,
406            value: None,
407        }
408    }
409
410    /// Create an INT parameter.
411    pub fn int(name: impl Into<String>, value: i32) -> Self {
412        let mut buf = BytesMut::with_capacity(4);
413        buf.put_i32_le(value);
414        Self::new(name, TypeInfo::int(), buf.freeze())
415    }
416
417    /// Create a BIGINT parameter.
418    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
419        let mut buf = BytesMut::with_capacity(8);
420        buf.put_i64_le(value);
421        Self::new(name, TypeInfo::bigint(), buf.freeze())
422    }
423
424    /// Create an NVARCHAR parameter.
425    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
426        let mut buf = BytesMut::new();
427        // Encode as UTF-16LE
428        for code_unit in value.encode_utf16() {
429            buf.put_u16_le(code_unit);
430        }
431        let char_len = value.chars().count();
432        let type_info = if char_len > 4000 {
433            TypeInfo::nvarchar_max()
434        } else {
435            TypeInfo::nvarchar(char_len.max(1) as u16)
436        };
437        Self::new(name, type_info, buf.freeze())
438    }
439
440    /// Mark as output parameter.
441    #[must_use]
442    pub fn as_output(mut self) -> Self {
443        self.flags = self.flags.output();
444        self
445    }
446
447    /// Encode the parameter to buffer.
448    pub fn encode(&self, buf: &mut BytesMut) {
449        // Parameter name (B_VARCHAR - length-prefixed)
450        let name_len = self.name.encode_utf16().count() as u8;
451        buf.put_u8(name_len);
452        if name_len > 0 {
453            for code_unit in self.name.encode_utf16() {
454                buf.put_u16_le(code_unit);
455            }
456        }
457
458        // Status flags
459        buf.put_u8(self.flags.encode());
460
461        // Type info
462        self.type_info.encode(buf);
463
464        // Value
465        if let Some(ref value) = self.value {
466            // Length prefix based on type
467            match self.type_info.type_id {
468                0x26 => {
469                    // INTNTYPE
470                    buf.put_u8(value.len() as u8);
471                    buf.put_slice(value);
472                }
473                0x68 | 0x6D => {
474                    // BITNTYPE, FLTNTYPE
475                    buf.put_u8(value.len() as u8);
476                    buf.put_slice(value);
477                }
478                0xE7 | 0xA5 => {
479                    // NVARCHARTYPE, BIGVARBINTYPE
480                    if self.type_info.max_length == Some(0xFFFF) {
481                        // MAX type - use PLP format
482                        // For simplicity, send as single chunk
483                        let total_len = value.len() as u64;
484                        buf.put_u64_le(total_len);
485                        buf.put_u32_le(value.len() as u32);
486                        buf.put_slice(value);
487                        buf.put_u32_le(0); // Terminator
488                    } else {
489                        buf.put_u16_le(value.len() as u16);
490                        buf.put_slice(value);
491                    }
492                }
493                0x24 => {
494                    // GUIDTYPE
495                    buf.put_u8(value.len() as u8);
496                    buf.put_slice(value);
497                }
498                0x28 => {
499                    // DATETYPE (fixed 3 bytes)
500                    buf.put_slice(value);
501                }
502                0x2A => {
503                    // DATETIME2TYPE
504                    buf.put_u8(value.len() as u8);
505                    buf.put_slice(value);
506                }
507                0x6C => {
508                    // DECIMALNTYPE
509                    buf.put_u8(value.len() as u8);
510                    buf.put_slice(value);
511                }
512                0xF3 => {
513                    // TVP (Table-Valued Parameter)
514                    // TVP values are self-delimiting: they contain complete metadata,
515                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
516                    buf.put_slice(value);
517                }
518                _ => {
519                    // Generic: assume length-prefixed
520                    buf.put_u8(value.len() as u8);
521                    buf.put_slice(value);
522                }
523            }
524        } else {
525            // NULL value
526            match self.type_info.type_id {
527                0xE7 | 0xA5 => {
528                    // Variable-length types use 0xFFFF for NULL
529                    if self.type_info.max_length == Some(0xFFFF) {
530                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
531                    } else {
532                        buf.put_u16_le(0xFFFF);
533                    }
534                }
535                _ => {
536                    buf.put_u8(0); // Zero-length for NULL
537                }
538            }
539        }
540    }
541}
542
543/// RPC request builder.
544#[derive(Debug, Clone)]
545pub struct RpcRequest {
546    /// Procedure name (if using named procedure).
547    proc_name: Option<String>,
548    /// Procedure ID (if using well-known procedure).
549    proc_id: Option<ProcId>,
550    /// Option flags.
551    options: RpcOptionFlags,
552    /// Parameters.
553    params: Vec<RpcParam>,
554}
555
556impl RpcRequest {
557    /// Create a new RPC request for a named procedure.
558    pub fn named(proc_name: impl Into<String>) -> Self {
559        Self {
560            proc_name: Some(proc_name.into()),
561            proc_id: None,
562            options: RpcOptionFlags::default(),
563            params: Vec::new(),
564        }
565    }
566
567    /// Create a new RPC request for a well-known procedure.
568    pub fn by_id(proc_id: ProcId) -> Self {
569        Self {
570            proc_name: None,
571            proc_id: Some(proc_id),
572            options: RpcOptionFlags::default(),
573            params: Vec::new(),
574        }
575    }
576
577    /// Create an sp_executesql request.
578    ///
579    /// This is the primary method for parameterized queries.
580    ///
581    /// # Example
582    ///
583    /// ```
584    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
585    ///
586    /// let rpc = RpcRequest::execute_sql(
587    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
588    ///     vec![
589    ///         RpcParam::int("@p1", 42),
590    ///         RpcParam::nvarchar("@p2", "Alice"),
591    ///     ],
592    /// );
593    /// ```
594    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
595        let mut request = Self::by_id(ProcId::ExecuteSql);
596
597        // First parameter: the SQL statement (NVARCHAR(MAX))
598        request.params.push(RpcParam::nvarchar("", sql));
599
600        // Second parameter: parameter declarations
601        if !params.is_empty() {
602            let declarations = Self::build_param_declarations(&params);
603            request.params.push(RpcParam::nvarchar("", &declarations));
604        }
605
606        // Add the actual parameters
607        request.params.extend(params);
608
609        request
610    }
611
612    /// Build parameter declaration string for sp_executesql.
613    fn build_param_declarations(params: &[RpcParam]) -> String {
614        params
615            .iter()
616            .map(|p| {
617                let name = if p.name.starts_with('@') {
618                    p.name.clone()
619                } else if p.name.is_empty() {
620                    // Generate positional name
621                    format!(
622                        "@p{}",
623                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
624                    )
625                } else {
626                    format!("@{}", p.name)
627                };
628
629                let type_name: String = match p.type_info.type_id {
630                    0x26 => match p.type_info.max_length {
631                        Some(1) => "tinyint".to_string(),
632                        Some(2) => "smallint".to_string(),
633                        Some(4) => "int".to_string(),
634                        Some(8) => "bigint".to_string(),
635                        _ => "int".to_string(),
636                    },
637                    0x68 => "bit".to_string(),
638                    0x6D => match p.type_info.max_length {
639                        Some(4) => "real".to_string(),
640                        _ => "float".to_string(),
641                    },
642                    0xE7 => {
643                        if p.type_info.max_length == Some(0xFFFF) {
644                            "nvarchar(max)".to_string()
645                        } else {
646                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
647                            format!("nvarchar({})", len)
648                        }
649                    }
650                    0xA5 => {
651                        if p.type_info.max_length == Some(0xFFFF) {
652                            "varbinary(max)".to_string()
653                        } else {
654                            let len = p.type_info.max_length.unwrap_or(8000);
655                            format!("varbinary({})", len)
656                        }
657                    }
658                    0x24 => "uniqueidentifier".to_string(),
659                    0x28 => "date".to_string(),
660                    0x2A => {
661                        let scale = p.type_info.scale.unwrap_or(7);
662                        format!("datetime2({})", scale)
663                    }
664                    0x6C => {
665                        let precision = p.type_info.precision.unwrap_or(18);
666                        let scale = p.type_info.scale.unwrap_or(0);
667                        format!("decimal({}, {})", precision, scale)
668                    }
669                    _ => "sql_variant".to_string(),
670                };
671
672                format!("{} {}", name, type_name)
673            })
674            .collect::<Vec<_>>()
675            .join(", ")
676    }
677
678    /// Create an sp_prepare request.
679    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
680        let mut request = Self::by_id(ProcId::Prepare);
681
682        // OUT: handle (INT)
683        request
684            .params
685            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
686
687        // Param declarations
688        let declarations = Self::build_param_declarations(params);
689        request
690            .params
691            .push(RpcParam::nvarchar("@params", &declarations));
692
693        // SQL statement
694        request.params.push(RpcParam::nvarchar("@stmt", sql));
695
696        // Options (1 = WITH RECOMPILE)
697        request.params.push(RpcParam::int("@options", 1));
698
699        request
700    }
701
702    /// Create an sp_execute request.
703    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
704        let mut request = Self::by_id(ProcId::Execute);
705
706        // Handle from sp_prepare
707        request.params.push(RpcParam::int("@handle", handle));
708
709        // Add parameters
710        request.params.extend(params);
711
712        request
713    }
714
715    /// Create an sp_unprepare request.
716    pub fn unprepare(handle: i32) -> Self {
717        let mut request = Self::by_id(ProcId::Unprepare);
718        request.params.push(RpcParam::int("@handle", handle));
719        request
720    }
721
722    /// Set option flags.
723    #[must_use]
724    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
725        self.options = options;
726        self
727    }
728
729    /// Add a parameter.
730    #[must_use]
731    pub fn param(mut self, param: RpcParam) -> Self {
732        self.params.push(param);
733        self
734    }
735
736    /// Encode the RPC request to bytes (auto-commit mode).
737    ///
738    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
739    #[must_use]
740    pub fn encode(&self) -> Bytes {
741        self.encode_with_transaction(0)
742    }
743
744    /// Encode the RPC request with a transaction descriptor.
745    ///
746    /// Per MS-TDS spec, when executing within an explicit transaction:
747    /// - The `transaction_descriptor` MUST be the value returned by the server
748    ///   in the BeginTransaction EnvChange token.
749    /// - For auto-commit mode (no explicit transaction), use 0.
750    ///
751    /// # Arguments
752    ///
753    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
754    ///   or 0 for auto-commit mode.
755    #[must_use]
756    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
757        let mut buf = BytesMut::with_capacity(256);
758
759        // ALL_HEADERS - TDS 7.2+ requires this section
760        // Total length placeholder (will be filled in)
761        let all_headers_start = buf.len();
762        buf.put_u32_le(0); // Total length placeholder
763
764        // Transaction descriptor header (required for RPC)
765        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
766        buf.put_u32_le(18); // Header length
767        buf.put_u16_le(0x0002); // Header type: transaction descriptor
768        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
769        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
770
771        // Fill in ALL_HEADERS total length
772        let all_headers_len = buf.len() - all_headers_start;
773        let len_bytes = (all_headers_len as u32).to_le_bytes();
774        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
775
776        // Procedure name or ID
777        if let Some(proc_id) = self.proc_id {
778            // Use PROCID format
779            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
780            buf.put_u16_le(proc_id as u16);
781        } else if let Some(ref proc_name) = self.proc_name {
782            // Use procedure name
783            let name_len = proc_name.encode_utf16().count() as u16;
784            buf.put_u16_le(name_len);
785            write_utf16_string(&mut buf, proc_name);
786        }
787
788        // Option flags
789        buf.put_u16_le(self.options.encode());
790
791        // Parameters
792        for param in &self.params {
793            param.encode(&mut buf);
794        }
795
796        buf.freeze()
797    }
798}
799
800#[cfg(test)]
801#[allow(clippy::unwrap_used)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_proc_id_values() {
807        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
808        assert_eq!(ProcId::Prepare as u16, 0x000B);
809        assert_eq!(ProcId::Execute as u16, 0x000C);
810        assert_eq!(ProcId::Unprepare as u16, 0x000F);
811    }
812
813    #[test]
814    fn test_option_flags_encode() {
815        let flags = RpcOptionFlags::new().with_recompile(true);
816        assert_eq!(flags.encode(), 0x0001);
817    }
818
819    #[test]
820    fn test_param_flags_encode() {
821        let flags = ParamFlags::new().output();
822        assert_eq!(flags.encode(), 0x01);
823    }
824
825    #[test]
826    fn test_int_param() {
827        let param = RpcParam::int("@p1", 42);
828        assert_eq!(param.name, "@p1");
829        assert_eq!(param.type_info.type_id, 0x26);
830        assert!(param.value.is_some());
831    }
832
833    #[test]
834    fn test_nvarchar_param() {
835        let param = RpcParam::nvarchar("@name", "Alice");
836        assert_eq!(param.name, "@name");
837        assert_eq!(param.type_info.type_id, 0xE7);
838        // UTF-16 encoded "Alice" = 10 bytes
839        assert_eq!(param.value.as_ref().unwrap().len(), 10);
840    }
841
842    #[test]
843    fn test_execute_sql_request() {
844        let rpc = RpcRequest::execute_sql(
845            "SELECT * FROM users WHERE id = @p1",
846            vec![RpcParam::int("@p1", 42)],
847        );
848
849        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
850        // SQL statement + param declarations + actual params
851        assert_eq!(rpc.params.len(), 3);
852    }
853
854    #[test]
855    fn test_param_declarations() {
856        let params = vec![
857            RpcParam::int("@p1", 42),
858            RpcParam::nvarchar("@name", "Alice"),
859        ];
860
861        let decls = RpcRequest::build_param_declarations(&params);
862        assert!(decls.contains("@p1 int"));
863        assert!(decls.contains("@name nvarchar"));
864    }
865
866    #[test]
867    fn test_rpc_encode_not_empty() {
868        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
869        let encoded = rpc.encode();
870        assert!(!encoded.is_empty());
871    }
872
873    #[test]
874    fn test_prepare_request() {
875        let rpc = RpcRequest::prepare(
876            "SELECT * FROM users WHERE id = @p1",
877            &[RpcParam::int("@p1", 0)],
878        );
879
880        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
881        // handle (output), params, stmt, options
882        assert_eq!(rpc.params.len(), 4);
883        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
884    }
885
886    #[test]
887    fn test_execute_request() {
888        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
889
890        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
891        assert_eq!(rpc.params.len(), 2); // handle + param
892    }
893
894    #[test]
895    fn test_unprepare_request() {
896        let rpc = RpcRequest::unprepare(123);
897
898        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
899        assert_eq!(rpc.params.len(), 1); // just the handle
900    }
901}