tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29
30/// Well-known stored procedure IDs.
31///
32/// These are special procedure IDs that SQL Server recognizes
33/// without requiring the procedure name.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u16)]
36pub enum ProcId {
37    /// sp_cursor (0x0001)
38    Cursor = 0x0001,
39    /// sp_cursoropen (0x0002)
40    CursorOpen = 0x0002,
41    /// sp_cursorprepare (0x0003)
42    CursorPrepare = 0x0003,
43    /// sp_cursorexecute (0x0004)
44    CursorExecute = 0x0004,
45    /// sp_cursorprepexec (0x0005)
46    CursorPrepExec = 0x0005,
47    /// sp_cursorunprepare (0x0006)
48    CursorUnprepare = 0x0006,
49    /// sp_cursorfetch (0x0007)
50    CursorFetch = 0x0007,
51    /// sp_cursoroption (0x0008)
52    CursorOption = 0x0008,
53    /// sp_cursorclose (0x0009)
54    CursorClose = 0x0009,
55    /// sp_executesql (0x000A) - Primary method for parameterized queries
56    ExecuteSql = 0x000A,
57    /// sp_prepare (0x000B)
58    Prepare = 0x000B,
59    /// sp_execute (0x000C)
60    Execute = 0x000C,
61    /// sp_prepexec (0x000D) - Prepare and execute in one call
62    PrepExec = 0x000D,
63    /// sp_prepexecrpc (0x000E)
64    PrepExecRpc = 0x000E,
65    /// sp_unprepare (0x000F)
66    Unprepare = 0x000F,
67}
68
69/// RPC option flags.
70#[derive(Debug, Clone, Copy, Default)]
71pub struct RpcOptionFlags {
72    /// Recompile the procedure.
73    pub with_recompile: bool,
74    /// No metadata in response.
75    pub no_metadata: bool,
76    /// Reuse metadata from previous call.
77    pub reuse_metadata: bool,
78}
79
80impl RpcOptionFlags {
81    /// Create new empty flags.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Set with recompile flag.
87    #[must_use]
88    pub fn with_recompile(mut self, value: bool) -> Self {
89        self.with_recompile = value;
90        self
91    }
92
93    /// Encode to wire format (2 bytes).
94    pub fn encode(&self) -> u16 {
95        let mut flags = 0u16;
96        if self.with_recompile {
97            flags |= 0x0001;
98        }
99        if self.no_metadata {
100            flags |= 0x0002;
101        }
102        if self.reuse_metadata {
103            flags |= 0x0004;
104        }
105        flags
106    }
107}
108
109/// RPC parameter status flags.
110#[derive(Debug, Clone, Copy, Default)]
111pub struct ParamFlags {
112    /// Parameter is passed by reference (OUTPUT parameter).
113    pub by_ref: bool,
114    /// Parameter has a default value.
115    pub default: bool,
116    /// Parameter is encrypted (Always Encrypted).
117    pub encrypted: bool,
118}
119
120impl ParamFlags {
121    /// Create new empty flags.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Set as output parameter.
127    #[must_use]
128    pub fn output(mut self) -> Self {
129        self.by_ref = true;
130        self
131    }
132
133    /// Encode to wire format (1 byte).
134    pub fn encode(&self) -> u8 {
135        let mut flags = 0u8;
136        if self.by_ref {
137            flags |= 0x01;
138        }
139        if self.default {
140            flags |= 0x02;
141        }
142        if self.encrypted {
143            flags |= 0x08;
144        }
145        flags
146    }
147}
148
149/// TDS type information for RPC parameters.
150#[derive(Debug, Clone)]
151pub struct TypeInfo {
152    /// Type ID.
153    pub type_id: u8,
154    /// Maximum length for variable-length types.
155    pub max_length: Option<u16>,
156    /// Precision for numeric types.
157    pub precision: Option<u8>,
158    /// Scale for numeric types.
159    pub scale: Option<u8>,
160    /// Collation for string types.
161    pub collation: Option<[u8; 5]>,
162}
163
164impl TypeInfo {
165    /// Create type info for INT.
166    pub fn int() -> Self {
167        Self {
168            type_id: 0x26, // INTNTYPE (variable-length int)
169            max_length: Some(4),
170            precision: None,
171            scale: None,
172            collation: None,
173        }
174    }
175
176    /// Create type info for BIGINT.
177    pub fn bigint() -> Self {
178        Self {
179            type_id: 0x26, // INTNTYPE
180            max_length: Some(8),
181            precision: None,
182            scale: None,
183            collation: None,
184        }
185    }
186
187    /// Create type info for SMALLINT.
188    pub fn smallint() -> Self {
189        Self {
190            type_id: 0x26, // INTNTYPE
191            max_length: Some(2),
192            precision: None,
193            scale: None,
194            collation: None,
195        }
196    }
197
198    /// Create type info for TINYINT.
199    pub fn tinyint() -> Self {
200        Self {
201            type_id: 0x26, // INTNTYPE
202            max_length: Some(1),
203            precision: None,
204            scale: None,
205            collation: None,
206        }
207    }
208
209    /// Create type info for BIT.
210    pub fn bit() -> Self {
211        Self {
212            type_id: 0x68, // BITNTYPE
213            max_length: Some(1),
214            precision: None,
215            scale: None,
216            collation: None,
217        }
218    }
219
220    /// Create type info for FLOAT.
221    pub fn float() -> Self {
222        Self {
223            type_id: 0x6D, // FLTNTYPE
224            max_length: Some(8),
225            precision: None,
226            scale: None,
227            collation: None,
228        }
229    }
230
231    /// Create type info for REAL.
232    pub fn real() -> Self {
233        Self {
234            type_id: 0x6D, // FLTNTYPE
235            max_length: Some(4),
236            precision: None,
237            scale: None,
238            collation: None,
239        }
240    }
241
242    /// Create type info for NVARCHAR with max length.
243    pub fn nvarchar(max_len: u16) -> Self {
244        Self {
245            type_id: 0xE7,                 // NVARCHARTYPE
246            max_length: Some(max_len * 2), // UTF-16, so double the char count
247            precision: None,
248            scale: None,
249            // Default collation (Latin1_General_CI_AS equivalent)
250            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
251        }
252    }
253
254    /// Create type info for NVARCHAR(MAX).
255    pub fn nvarchar_max() -> Self {
256        Self {
257            type_id: 0xE7,            // NVARCHARTYPE
258            max_length: Some(0xFFFF), // MAX indicator
259            precision: None,
260            scale: None,
261            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262        }
263    }
264
265    /// Create type info for VARBINARY with max length.
266    pub fn varbinary(max_len: u16) -> Self {
267        Self {
268            type_id: 0xA5, // BIGVARBINTYPE
269            max_length: Some(max_len),
270            precision: None,
271            scale: None,
272            collation: None,
273        }
274    }
275
276    /// Create type info for UNIQUEIDENTIFIER.
277    pub fn uniqueidentifier() -> Self {
278        Self {
279            type_id: 0x24, // GUIDTYPE
280            max_length: Some(16),
281            precision: None,
282            scale: None,
283            collation: None,
284        }
285    }
286
287    /// Create type info for DATE.
288    pub fn date() -> Self {
289        Self {
290            type_id: 0x28, // DATETYPE
291            max_length: None,
292            precision: None,
293            scale: None,
294            collation: None,
295        }
296    }
297
298    /// Create type info for DATETIME2.
299    pub fn datetime2(scale: u8) -> Self {
300        Self {
301            type_id: 0x2A, // DATETIME2TYPE
302            max_length: None,
303            precision: None,
304            scale: Some(scale),
305            collation: None,
306        }
307    }
308
309    /// Create type info for DECIMAL.
310    pub fn decimal(precision: u8, scale: u8) -> Self {
311        Self {
312            type_id: 0x6C,        // DECIMALNTYPE
313            max_length: Some(17), // Max decimal size
314            precision: Some(precision),
315            scale: Some(scale),
316            collation: None,
317        }
318    }
319
320    /// Encode type info to buffer.
321    pub fn encode(&self, buf: &mut BytesMut) {
322        // TVP (0xF3) has type_id embedded in the value data itself
323        // (written by TvpEncoder::encode_metadata), so don't write it here
324        if self.type_id != 0xF3 {
325            buf.put_u8(self.type_id);
326        }
327
328        // Variable-length types need max length
329        match self.type_id {
330            0x26 | 0x68 | 0x6D => {
331                // INTNTYPE, BITNTYPE, FLTNTYPE
332                if let Some(len) = self.max_length {
333                    buf.put_u8(len as u8);
334                }
335            }
336            0xE7 | 0xA5 | 0xEF => {
337                // NVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
338                if let Some(len) = self.max_length {
339                    buf.put_u16_le(len);
340                }
341                // Collation for string types
342                if let Some(collation) = self.collation {
343                    buf.put_slice(&collation);
344                }
345            }
346            0x24 => {
347                // GUIDTYPE
348                if let Some(len) = self.max_length {
349                    buf.put_u8(len as u8);
350                }
351            }
352            0x29..=0x2B => {
353                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
354                if let Some(scale) = self.scale {
355                    buf.put_u8(scale);
356                }
357            }
358            0x6C | 0x6A => {
359                // DECIMALNTYPE, NUMERICNTYPE
360                if let Some(len) = self.max_length {
361                    buf.put_u8(len as u8);
362                }
363                if let Some(precision) = self.precision {
364                    buf.put_u8(precision);
365                }
366                if let Some(scale) = self.scale {
367                    buf.put_u8(scale);
368                }
369            }
370            _ => {}
371        }
372    }
373}
374
375/// An RPC parameter.
376#[derive(Debug, Clone)]
377pub struct RpcParam {
378    /// Parameter name (can be empty for positional params).
379    pub name: String,
380    /// Status flags.
381    pub flags: ParamFlags,
382    /// Type information.
383    pub type_info: TypeInfo,
384    /// Parameter value (raw bytes).
385    pub value: Option<Bytes>,
386}
387
388impl RpcParam {
389    /// Create a new parameter with a value.
390    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
391        Self {
392            name: name.into(),
393            flags: ParamFlags::default(),
394            type_info,
395            value: Some(value),
396        }
397    }
398
399    /// Create a NULL parameter.
400    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
401        Self {
402            name: name.into(),
403            flags: ParamFlags::default(),
404            type_info,
405            value: None,
406        }
407    }
408
409    /// Create an INT parameter.
410    pub fn int(name: impl Into<String>, value: i32) -> Self {
411        let mut buf = BytesMut::with_capacity(4);
412        buf.put_i32_le(value);
413        Self::new(name, TypeInfo::int(), buf.freeze())
414    }
415
416    /// Create a BIGINT parameter.
417    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
418        let mut buf = BytesMut::with_capacity(8);
419        buf.put_i64_le(value);
420        Self::new(name, TypeInfo::bigint(), buf.freeze())
421    }
422
423    /// Create an NVARCHAR parameter.
424    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
425        let mut buf = BytesMut::new();
426        // Encode as UTF-16LE
427        for code_unit in value.encode_utf16() {
428            buf.put_u16_le(code_unit);
429        }
430        let char_len = value.chars().count();
431        let type_info = if char_len > 4000 {
432            TypeInfo::nvarchar_max()
433        } else {
434            TypeInfo::nvarchar(char_len.max(1) as u16)
435        };
436        Self::new(name, type_info, buf.freeze())
437    }
438
439    /// Mark as output parameter.
440    #[must_use]
441    pub fn as_output(mut self) -> Self {
442        self.flags = self.flags.output();
443        self
444    }
445
446    /// Encode the parameter to buffer.
447    pub fn encode(&self, buf: &mut BytesMut) {
448        // Parameter name (B_VARCHAR - length-prefixed)
449        let name_len = self.name.encode_utf16().count() as u8;
450        buf.put_u8(name_len);
451        if name_len > 0 {
452            for code_unit in self.name.encode_utf16() {
453                buf.put_u16_le(code_unit);
454            }
455        }
456
457        // Status flags
458        buf.put_u8(self.flags.encode());
459
460        // Type info
461        self.type_info.encode(buf);
462
463        // Value
464        if let Some(ref value) = self.value {
465            // Length prefix based on type
466            match self.type_info.type_id {
467                0x26 => {
468                    // INTNTYPE
469                    buf.put_u8(value.len() as u8);
470                    buf.put_slice(value);
471                }
472                0x68 | 0x6D => {
473                    // BITNTYPE, FLTNTYPE
474                    buf.put_u8(value.len() as u8);
475                    buf.put_slice(value);
476                }
477                0xE7 | 0xA5 => {
478                    // NVARCHARTYPE, BIGVARBINTYPE
479                    if self.type_info.max_length == Some(0xFFFF) {
480                        // MAX type - use PLP format
481                        // For simplicity, send as single chunk
482                        let total_len = value.len() as u64;
483                        buf.put_u64_le(total_len);
484                        buf.put_u32_le(value.len() as u32);
485                        buf.put_slice(value);
486                        buf.put_u32_le(0); // Terminator
487                    } else {
488                        buf.put_u16_le(value.len() as u16);
489                        buf.put_slice(value);
490                    }
491                }
492                0x24 => {
493                    // GUIDTYPE
494                    buf.put_u8(value.len() as u8);
495                    buf.put_slice(value);
496                }
497                0x28 => {
498                    // DATETYPE (fixed 3 bytes)
499                    buf.put_slice(value);
500                }
501                0x2A => {
502                    // DATETIME2TYPE
503                    buf.put_u8(value.len() as u8);
504                    buf.put_slice(value);
505                }
506                0x6C => {
507                    // DECIMALNTYPE
508                    buf.put_u8(value.len() as u8);
509                    buf.put_slice(value);
510                }
511                0xF3 => {
512                    // TVP (Table-Valued Parameter)
513                    // TVP values are self-delimiting: they contain complete metadata,
514                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
515                    buf.put_slice(value);
516                }
517                _ => {
518                    // Generic: assume length-prefixed
519                    buf.put_u8(value.len() as u8);
520                    buf.put_slice(value);
521                }
522            }
523        } else {
524            // NULL value
525            match self.type_info.type_id {
526                0xE7 | 0xA5 => {
527                    // Variable-length types use 0xFFFF for NULL
528                    if self.type_info.max_length == Some(0xFFFF) {
529                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
530                    } else {
531                        buf.put_u16_le(0xFFFF);
532                    }
533                }
534                _ => {
535                    buf.put_u8(0); // Zero-length for NULL
536                }
537            }
538        }
539    }
540}
541
542/// RPC request builder.
543#[derive(Debug, Clone)]
544pub struct RpcRequest {
545    /// Procedure name (if using named procedure).
546    proc_name: Option<String>,
547    /// Procedure ID (if using well-known procedure).
548    proc_id: Option<ProcId>,
549    /// Option flags.
550    options: RpcOptionFlags,
551    /// Parameters.
552    params: Vec<RpcParam>,
553}
554
555impl RpcRequest {
556    /// Create a new RPC request for a named procedure.
557    pub fn named(proc_name: impl Into<String>) -> Self {
558        Self {
559            proc_name: Some(proc_name.into()),
560            proc_id: None,
561            options: RpcOptionFlags::default(),
562            params: Vec::new(),
563        }
564    }
565
566    /// Create a new RPC request for a well-known procedure.
567    pub fn by_id(proc_id: ProcId) -> Self {
568        Self {
569            proc_name: None,
570            proc_id: Some(proc_id),
571            options: RpcOptionFlags::default(),
572            params: Vec::new(),
573        }
574    }
575
576    /// Create an sp_executesql request.
577    ///
578    /// This is the primary method for parameterized queries.
579    ///
580    /// # Example
581    ///
582    /// ```
583    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
584    ///
585    /// let rpc = RpcRequest::execute_sql(
586    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
587    ///     vec![
588    ///         RpcParam::int("@p1", 42),
589    ///         RpcParam::nvarchar("@p2", "Alice"),
590    ///     ],
591    /// );
592    /// ```
593    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
594        let mut request = Self::by_id(ProcId::ExecuteSql);
595
596        // First parameter: the SQL statement (NVARCHAR(MAX))
597        request.params.push(RpcParam::nvarchar("", sql));
598
599        // Second parameter: parameter declarations
600        if !params.is_empty() {
601            let declarations = Self::build_param_declarations(&params);
602            request.params.push(RpcParam::nvarchar("", &declarations));
603        }
604
605        // Add the actual parameters
606        request.params.extend(params);
607
608        request
609    }
610
611    /// Build parameter declaration string for sp_executesql.
612    fn build_param_declarations(params: &[RpcParam]) -> String {
613        params
614            .iter()
615            .map(|p| {
616                let name = if p.name.starts_with('@') {
617                    p.name.clone()
618                } else if p.name.is_empty() {
619                    // Generate positional name
620                    format!(
621                        "@p{}",
622                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
623                    )
624                } else {
625                    format!("@{}", p.name)
626                };
627
628                let type_name: String = match p.type_info.type_id {
629                    0x26 => match p.type_info.max_length {
630                        Some(1) => "tinyint".to_string(),
631                        Some(2) => "smallint".to_string(),
632                        Some(4) => "int".to_string(),
633                        Some(8) => "bigint".to_string(),
634                        _ => "int".to_string(),
635                    },
636                    0x68 => "bit".to_string(),
637                    0x6D => match p.type_info.max_length {
638                        Some(4) => "real".to_string(),
639                        _ => "float".to_string(),
640                    },
641                    0xE7 => {
642                        if p.type_info.max_length == Some(0xFFFF) {
643                            "nvarchar(max)".to_string()
644                        } else {
645                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
646                            format!("nvarchar({})", len)
647                        }
648                    }
649                    0xA5 => {
650                        if p.type_info.max_length == Some(0xFFFF) {
651                            "varbinary(max)".to_string()
652                        } else {
653                            let len = p.type_info.max_length.unwrap_or(8000);
654                            format!("varbinary({})", len)
655                        }
656                    }
657                    0x24 => "uniqueidentifier".to_string(),
658                    0x28 => "date".to_string(),
659                    0x2A => {
660                        let scale = p.type_info.scale.unwrap_or(7);
661                        format!("datetime2({})", scale)
662                    }
663                    0x6C => {
664                        let precision = p.type_info.precision.unwrap_or(18);
665                        let scale = p.type_info.scale.unwrap_or(0);
666                        format!("decimal({}, {})", precision, scale)
667                    }
668                    _ => "sql_variant".to_string(),
669                };
670
671                format!("{} {}", name, type_name)
672            })
673            .collect::<Vec<_>>()
674            .join(", ")
675    }
676
677    /// Create an sp_prepare request.
678    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
679        let mut request = Self::by_id(ProcId::Prepare);
680
681        // OUT: handle (INT)
682        request
683            .params
684            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
685
686        // Param declarations
687        let declarations = Self::build_param_declarations(params);
688        request
689            .params
690            .push(RpcParam::nvarchar("@params", &declarations));
691
692        // SQL statement
693        request.params.push(RpcParam::nvarchar("@stmt", sql));
694
695        // Options (1 = WITH RECOMPILE)
696        request.params.push(RpcParam::int("@options", 1));
697
698        request
699    }
700
701    /// Create an sp_execute request.
702    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
703        let mut request = Self::by_id(ProcId::Execute);
704
705        // Handle from sp_prepare
706        request.params.push(RpcParam::int("@handle", handle));
707
708        // Add parameters
709        request.params.extend(params);
710
711        request
712    }
713
714    /// Create an sp_unprepare request.
715    pub fn unprepare(handle: i32) -> Self {
716        let mut request = Self::by_id(ProcId::Unprepare);
717        request.params.push(RpcParam::int("@handle", handle));
718        request
719    }
720
721    /// Set option flags.
722    #[must_use]
723    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
724        self.options = options;
725        self
726    }
727
728    /// Add a parameter.
729    #[must_use]
730    pub fn param(mut self, param: RpcParam) -> Self {
731        self.params.push(param);
732        self
733    }
734
735    /// Encode the RPC request to bytes (auto-commit mode).
736    ///
737    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
738    #[must_use]
739    pub fn encode(&self) -> Bytes {
740        self.encode_with_transaction(0)
741    }
742
743    /// Encode the RPC request with a transaction descriptor.
744    ///
745    /// Per MS-TDS spec, when executing within an explicit transaction:
746    /// - The `transaction_descriptor` MUST be the value returned by the server
747    ///   in the BeginTransaction EnvChange token.
748    /// - For auto-commit mode (no explicit transaction), use 0.
749    ///
750    /// # Arguments
751    ///
752    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
753    ///   or 0 for auto-commit mode.
754    #[must_use]
755    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
756        let mut buf = BytesMut::with_capacity(256);
757
758        // ALL_HEADERS - TDS 7.2+ requires this section
759        // Total length placeholder (will be filled in)
760        let all_headers_start = buf.len();
761        buf.put_u32_le(0); // Total length placeholder
762
763        // Transaction descriptor header (required for RPC)
764        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
765        buf.put_u32_le(18); // Header length
766        buf.put_u16_le(0x0002); // Header type: transaction descriptor
767        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
768        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
769
770        // Fill in ALL_HEADERS total length
771        let all_headers_len = buf.len() - all_headers_start;
772        let len_bytes = (all_headers_len as u32).to_le_bytes();
773        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
774
775        // Procedure name or ID
776        if let Some(proc_id) = self.proc_id {
777            // Use PROCID format
778            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
779            buf.put_u16_le(proc_id as u16);
780        } else if let Some(ref proc_name) = self.proc_name {
781            // Use procedure name
782            let name_len = proc_name.encode_utf16().count() as u16;
783            buf.put_u16_le(name_len);
784            write_utf16_string(&mut buf, proc_name);
785        }
786
787        // Option flags
788        buf.put_u16_le(self.options.encode());
789
790        // Parameters
791        for param in &self.params {
792            param.encode(&mut buf);
793        }
794
795        buf.freeze()
796    }
797}
798
799#[cfg(test)]
800#[allow(clippy::unwrap_used)]
801mod tests {
802    use super::*;
803
804    #[test]
805    fn test_proc_id_values() {
806        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
807        assert_eq!(ProcId::Prepare as u16, 0x000B);
808        assert_eq!(ProcId::Execute as u16, 0x000C);
809        assert_eq!(ProcId::Unprepare as u16, 0x000F);
810    }
811
812    #[test]
813    fn test_option_flags_encode() {
814        let flags = RpcOptionFlags::new().with_recompile(true);
815        assert_eq!(flags.encode(), 0x0001);
816    }
817
818    #[test]
819    fn test_param_flags_encode() {
820        let flags = ParamFlags::new().output();
821        assert_eq!(flags.encode(), 0x01);
822    }
823
824    #[test]
825    fn test_int_param() {
826        let param = RpcParam::int("@p1", 42);
827        assert_eq!(param.name, "@p1");
828        assert_eq!(param.type_info.type_id, 0x26);
829        assert!(param.value.is_some());
830    }
831
832    #[test]
833    fn test_nvarchar_param() {
834        let param = RpcParam::nvarchar("@name", "Alice");
835        assert_eq!(param.name, "@name");
836        assert_eq!(param.type_info.type_id, 0xE7);
837        // UTF-16 encoded "Alice" = 10 bytes
838        assert_eq!(param.value.as_ref().unwrap().len(), 10);
839    }
840
841    #[test]
842    fn test_execute_sql_request() {
843        let rpc = RpcRequest::execute_sql(
844            "SELECT * FROM users WHERE id = @p1",
845            vec![RpcParam::int("@p1", 42)],
846        );
847
848        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
849        // SQL statement + param declarations + actual params
850        assert_eq!(rpc.params.len(), 3);
851    }
852
853    #[test]
854    fn test_param_declarations() {
855        let params = vec![
856            RpcParam::int("@p1", 42),
857            RpcParam::nvarchar("@name", "Alice"),
858        ];
859
860        let decls = RpcRequest::build_param_declarations(&params);
861        assert!(decls.contains("@p1 int"));
862        assert!(decls.contains("@name nvarchar"));
863    }
864
865    #[test]
866    fn test_rpc_encode_not_empty() {
867        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
868        let encoded = rpc.encode();
869        assert!(!encoded.is_empty());
870    }
871
872    #[test]
873    fn test_prepare_request() {
874        let rpc = RpcRequest::prepare(
875            "SELECT * FROM users WHERE id = @p1",
876            &[RpcParam::int("@p1", 0)],
877        );
878
879        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
880        // handle (output), params, stmt, options
881        assert_eq!(rpc.params.len(), 4);
882        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
883    }
884
885    #[test]
886    fn test_execute_request() {
887        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
888
889        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
890        assert_eq!(rpc.params.len(), 2); // handle + param
891    }
892
893    #[test]
894    fn test_unprepare_request() {
895        let rpc = RpcRequest::unprepare(123);
896
897        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
898        assert_eq!(rpc.params.len(), 1); // just the handle
899    }
900}