Skip to main content

tds_protocol/
token.rs

1//! TDS token stream definitions.
2//!
3//! Tokens are the fundamental units of TDS response data. The server sends
4//! a stream of tokens that describe metadata, rows, errors, and other information.
5//!
6//! ## Token Structure
7//!
8//! Each token begins with a 1-byte token type identifier, followed by
9//! token-specific data. Some tokens have fixed lengths, while others
10//! have length prefixes.
11//!
12//! ## Usage
13//!
14//! ```rust,no_run
15//! use tds_protocol::token::{Token, TokenParser};
16//! use bytes::Bytes;
17//!
18//! fn parse(data: Bytes) -> Result<(), tds_protocol::ProtocolError> {
19//!     let mut parser = TokenParser::new(data);
20//!
21//!     while let Some(token) = parser.next_token()? {
22//!         match token {
23//!             Token::Done(done) => println!("Rows affected: {}", done.row_count),
24//!             Token::Error(err) => eprintln!("Error {}: {}", err.number, err.message),
25//!             _ => {}
26//!         }
27//!     }
28//!     Ok(())
29//! }
30//! ```
31
32use bytes::{Buf, BufMut, Bytes};
33
34use crate::codec::{read_b_varchar, read_us_varchar};
35use crate::error::ProtocolError;
36use crate::prelude::*;
37use crate::types::TypeId;
38
39/// Token type identifier.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41#[repr(u8)]
42#[non_exhaustive]
43pub enum TokenType {
44    /// Column metadata (COLMETADATA).
45    ColMetaData = 0x81,
46    /// Error message (ERROR).
47    Error = 0xAA,
48    /// Informational message (INFO).
49    Info = 0xAB,
50    /// Login acknowledgment (LOGINACK).
51    LoginAck = 0xAD,
52    /// Row data (ROW).
53    Row = 0xD1,
54    /// Null bitmap compressed row (NBCROW).
55    NbcRow = 0xD2,
56    /// Environment change (ENVCHANGE).
57    EnvChange = 0xE3,
58    /// SSPI authentication (SSPI).
59    Sspi = 0xED,
60    /// Done (DONE).
61    Done = 0xFD,
62    /// Done in procedure (DONEINPROC).
63    DoneInProc = 0xFF,
64    /// Done procedure (DONEPROC).
65    DoneProc = 0xFE,
66    /// Return status (RETURNSTATUS).
67    ReturnStatus = 0x79,
68    /// Return value (RETURNVALUE).
69    ReturnValue = 0xAC,
70    /// Order (ORDER).
71    Order = 0xA9,
72    /// Feature extension acknowledgment (FEATUREEXTACK).
73    FeatureExtAck = 0xAE,
74    /// Session state (SESSIONSTATE).
75    SessionState = 0xE4,
76    /// Federated authentication info (FEDAUTHINFO).
77    FedAuthInfo = 0xEE,
78    /// Column info (COLINFO).
79    ColInfo = 0xA5,
80    /// Table name (TABNAME).
81    TabName = 0xA4,
82    /// Offset (OFFSET).
83    Offset = 0x78,
84}
85
86impl TokenType {
87    /// Create a token type from a raw byte.
88    pub fn from_u8(value: u8) -> Option<Self> {
89        match value {
90            0x81 => Some(Self::ColMetaData),
91            0xAA => Some(Self::Error),
92            0xAB => Some(Self::Info),
93            0xAD => Some(Self::LoginAck),
94            0xD1 => Some(Self::Row),
95            0xD2 => Some(Self::NbcRow),
96            0xE3 => Some(Self::EnvChange),
97            0xED => Some(Self::Sspi),
98            0xFD => Some(Self::Done),
99            0xFF => Some(Self::DoneInProc),
100            0xFE => Some(Self::DoneProc),
101            0x79 => Some(Self::ReturnStatus),
102            0xAC => Some(Self::ReturnValue),
103            0xA9 => Some(Self::Order),
104            0xAE => Some(Self::FeatureExtAck),
105            0xE4 => Some(Self::SessionState),
106            0xEE => Some(Self::FedAuthInfo),
107            0xA5 => Some(Self::ColInfo),
108            0xA4 => Some(Self::TabName),
109            0x78 => Some(Self::Offset),
110            _ => None,
111        }
112    }
113}
114
115/// Parsed TDS token.
116///
117/// This enum represents all possible tokens that can be received from SQL Server.
118/// Each variant contains the parsed token data.
119#[derive(Debug, Clone)]
120#[non_exhaustive]
121pub enum Token {
122    /// Column metadata describing result set structure.
123    ColMetaData(ColMetaData),
124    /// Row data.
125    Row(RawRow),
126    /// Null bitmap compressed row.
127    NbcRow(NbcRow),
128    /// Completion of a SQL statement.
129    Done(Done),
130    /// Completion of a stored procedure.
131    DoneProc(DoneProc),
132    /// Completion within a stored procedure.
133    DoneInProc(DoneInProc),
134    /// Return status from stored procedure.
135    ReturnStatus(i32),
136    /// Return value from stored procedure.
137    ReturnValue(ReturnValue),
138    /// Error message from server.
139    Error(ServerError),
140    /// Informational message from server.
141    Info(ServerInfo),
142    /// Login acknowledgment.
143    LoginAck(LoginAck),
144    /// Environment change notification.
145    EnvChange(EnvChange),
146    /// Column ordering information.
147    Order(Order),
148    /// Feature extension acknowledgment.
149    FeatureExtAck(FeatureExtAck),
150    /// SSPI authentication data.
151    Sspi(SspiToken),
152    /// Session state information.
153    SessionState(SessionState),
154    /// Federated authentication info.
155    FedAuthInfo(FedAuthInfo),
156}
157
158/// Column metadata token.
159#[derive(Debug, Clone, Default)]
160pub struct ColMetaData {
161    /// Column definitions.
162    pub columns: Vec<ColumnData>,
163    /// CEK table for Always Encrypted result sets.
164    /// Present only when the server sends encrypted column metadata.
165    pub cek_table: Option<crate::crypto::CekTable>,
166}
167
168/// Column definition within metadata.
169#[derive(Debug, Clone)]
170pub struct ColumnData {
171    /// Column name.
172    pub name: String,
173    /// Column data type ID.
174    pub type_id: TypeId,
175    /// Column data type raw byte (for unknown types).
176    pub col_type: u8,
177    /// Column flags.
178    pub flags: u16,
179    /// User type ID.
180    pub user_type: u32,
181    /// Type-specific metadata.
182    pub type_info: TypeInfo,
183    /// Per-column encryption metadata (Always Encrypted).
184    /// Present only for columns with the encrypted flag (0x0800) set.
185    pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
186}
187
188/// Type-specific metadata.
189#[derive(Debug, Clone, Default)]
190pub struct TypeInfo {
191    /// Maximum length for variable-length types.
192    pub max_length: Option<u32>,
193    /// Precision for numeric types.
194    pub precision: Option<u8>,
195    /// Scale for numeric types.
196    pub scale: Option<u8>,
197    /// Collation for string types.
198    pub collation: Option<Collation>,
199}
200
201/// SQL Server collation.
202///
203/// Collations in SQL Server define the character encoding and sorting rules
204/// for string data. For `VARCHAR` columns, the collation determines which
205/// code page (character encoding) is used to store the data.
206///
207/// # Encoding Support
208///
209/// When the `encoding` feature is enabled, the [`Collation::encoding()`] method
210/// returns the appropriate [`encoding_rs::Encoding`] for decoding `VARCHAR` data.
211///
212/// # Example
213///
214/// ```rust,ignore
215/// use tds_protocol::token::Collation;
216///
217/// let collation = Collation { lcid: 0x0804, sort_id: 0 }; // Chinese (PRC)
218/// if let Some(encoding) = collation.encoding() {
219///     let (decoded, _, _) = encoding.decode(raw_bytes);
220///     // decoded is now proper Chinese text
221/// }
222/// ```
223#[derive(Debug, Clone, Copy, Default)]
224pub struct Collation {
225    /// Locale ID (LCID).
226    ///
227    /// The LCID encodes both the language and region. The lower 16 bits
228    /// contain the primary language ID, and bits 16-19 contain the sort ID
229    /// for some collations.
230    ///
231    /// For UTF-8 collations (SQL Server 2019+), fUTF8 (bit 26, 0x0400_0000) is set.
232    pub lcid: u32,
233    /// Sort ID.
234    ///
235    /// Used with certain collations to specify sorting behavior.
236    pub sort_id: u8,
237}
238
239impl Collation {
240    /// Create a `Collation` from the 5-byte TDS wire format.
241    ///
242    /// Format: 4 bytes LCID (little-endian u32) + 1 byte sort ID.
243    pub fn from_bytes(bytes: &[u8; 5]) -> Self {
244        Self {
245            lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
246            sort_id: bytes[4],
247        }
248    }
249
250    /// Serialize to the 5-byte TDS wire format.
251    ///
252    /// Format: 4 bytes LCID (little-endian u32) + 1 byte sort ID.
253    pub fn to_bytes(&self) -> [u8; 5] {
254        let b = self.lcid.to_le_bytes();
255        [b[0], b[1], b[2], b[3], self.sort_id]
256    }
257
258    /// Returns the character encoding for this collation.
259    ///
260    /// This method maps the collation's LCID to the appropriate character
261    /// encoding from the `encoding_rs` crate.
262    ///
263    /// # Returns
264    ///
265    /// - `Some(&Encoding)` - The encoding to use for decoding `VARCHAR` data
266    /// - `None` - If the collation uses UTF-8 (no transcoding needed) or
267    ///   the LCID is not recognized (caller should use Windows-1252 fallback)
268    ///
269    /// # UTF-8 Collations
270    ///
271    /// SQL Server 2019+ supports UTF-8 collations (identified by the `_UTF8`
272    /// suffix). These return `None` because no transcoding is needed.
273    ///
274    /// # Example
275    ///
276    /// ```rust,ignore
277    /// let collation = Collation { lcid: 0x0419, sort_id: 0 }; // Russian
278    /// if let Some(encoding) = collation.encoding() {
279    ///     // encoding is Windows-1251 for Cyrillic
280    ///     let (text, _, had_errors) = encoding.decode(&raw_bytes);
281    /// }
282    /// ```
283    #[cfg(feature = "encoding")]
284    pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
285        // A non-zero SortId means a SQL collation, whose code page derives
286        // from the SortId, not the LCID (MS-TDS). Consulting only the LCID
287        // silently decoded these as windows-1252 (issue #158).
288        if self.sort_id != 0 {
289            return crate::collation::encoding_for_sort_id(self.sort_id);
290        }
291        crate::collation::encoding_for_lcid(self.lcid)
292    }
293
294    /// Returns whether this collation uses UTF-8 encoding.
295    ///
296    /// UTF-8 collations were introduced in SQL Server 2019 and are
297    /// identified by the `_UTF8` suffix in the collation name.
298    #[cfg(feature = "encoding")]
299    pub fn is_utf8(&self) -> bool {
300        crate::collation::is_utf8_collation(self.lcid)
301    }
302
303    /// Returns the Windows code page number for this collation.
304    ///
305    /// Useful for error messages and debugging.
306    ///
307    /// # Returns
308    ///
309    /// The code page number (e.g., 1252 for Western European, 932 for Japanese).
310    #[cfg(feature = "encoding")]
311    pub fn code_page(&self) -> Option<u16> {
312        // SQL collations (non-zero SortId) derive their code page from the
313        // SortId, not the LCID (issue #158).
314        if self.sort_id != 0 {
315            return crate::collation::code_page_for_sort_id(self.sort_id);
316        }
317        crate::collation::code_page_for_lcid(self.lcid)
318    }
319
320    /// Returns the encoding name for this collation.
321    ///
322    /// Useful for error messages and debugging.
323    #[cfg(feature = "encoding")]
324    pub fn encoding_name(&self) -> &'static str {
325        if self.sort_id != 0 {
326            return match crate::collation::encoding_for_sort_id(self.sort_id) {
327                Some(enc) => enc.name(),
328                None => "unsupported",
329            };
330        }
331        crate::collation::encoding_name_for_lcid(self.lcid)
332    }
333}
334
335/// Raw row data (not yet decoded).
336#[derive(Debug, Clone)]
337pub struct RawRow {
338    /// Raw column values.
339    pub data: bytes::Bytes,
340}
341
342/// Null bitmap compressed row.
343#[derive(Debug, Clone)]
344pub struct NbcRow {
345    /// Null bitmap.
346    pub null_bitmap: Vec<u8>,
347    /// Raw non-null column values.
348    pub data: bytes::Bytes,
349}
350
351/// Done token indicating statement completion.
352#[derive(Debug, Clone, Copy)]
353pub struct Done {
354    /// Status flags.
355    pub status: DoneStatus,
356    /// Current command.
357    pub cur_cmd: u16,
358    /// Row count (if applicable).
359    pub row_count: u64,
360}
361
362/// Done status flags.
363#[derive(Debug, Clone, Copy, Default)]
364#[non_exhaustive]
365pub struct DoneStatus {
366    /// More results follow.
367    pub more: bool,
368    /// Error occurred.
369    pub error: bool,
370    /// Transaction in progress.
371    pub in_xact: bool,
372    /// Row count is valid.
373    pub count: bool,
374    /// Attention acknowledgment.
375    pub attn: bool,
376    /// Server error caused statement termination.
377    pub srverror: bool,
378}
379
380/// Done in procedure token.
381#[derive(Debug, Clone, Copy)]
382pub struct DoneInProc {
383    /// Status flags.
384    pub status: DoneStatus,
385    /// Current command.
386    pub cur_cmd: u16,
387    /// Row count.
388    pub row_count: u64,
389}
390
391/// Done procedure token.
392#[derive(Debug, Clone, Copy)]
393pub struct DoneProc {
394    /// Status flags.
395    pub status: DoneStatus,
396    /// Current command.
397    pub cur_cmd: u16,
398    /// Row count.
399    pub row_count: u64,
400}
401
402/// Return value from stored procedure.
403#[derive(Debug, Clone)]
404#[non_exhaustive]
405pub struct ReturnValue {
406    /// Parameter ordinal.
407    pub param_ordinal: u16,
408    /// Parameter name.
409    pub param_name: String,
410    /// Status flags.
411    pub status: u8,
412    /// User type.
413    pub user_type: u32,
414    /// Type flags.
415    pub flags: u16,
416    /// Raw column type byte from the wire.
417    pub col_type: u8,
418    /// Type info.
419    pub type_info: TypeInfo,
420    /// Value data.
421    pub value: bytes::Bytes,
422}
423
424/// Server error message.
425#[derive(Debug, Clone)]
426pub struct ServerError {
427    /// Error number.
428    pub number: i32,
429    /// Error state.
430    pub state: u8,
431    /// Error severity class.
432    pub class: u8,
433    /// Error message text.
434    pub message: String,
435    /// Server name.
436    pub server: String,
437    /// Procedure name.
438    pub procedure: String,
439    /// Line number.
440    pub line: i32,
441}
442
443/// Server informational message.
444#[derive(Debug, Clone)]
445pub struct ServerInfo {
446    /// Info number.
447    pub number: i32,
448    /// Info state.
449    pub state: u8,
450    /// Info class (severity).
451    pub class: u8,
452    /// Info message text.
453    pub message: String,
454    /// Server name.
455    pub server: String,
456    /// Procedure name.
457    pub procedure: String,
458    /// Line number.
459    pub line: i32,
460}
461
462/// Login acknowledgment token.
463#[derive(Debug, Clone)]
464pub struct LoginAck {
465    /// Interface type.
466    pub interface: u8,
467    /// TDS version.
468    pub tds_version: u32,
469    /// Program name.
470    pub prog_name: String,
471    /// Program version.
472    pub prog_version: u32,
473}
474
475/// Environment change token.
476#[derive(Debug, Clone)]
477pub struct EnvChange {
478    /// Type of environment change.
479    pub env_type: EnvChangeType,
480    /// New value.
481    pub new_value: EnvChangeValue,
482    /// Old value.
483    pub old_value: EnvChangeValue,
484}
485
486/// Environment change type.
487#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488#[repr(u8)]
489#[non_exhaustive]
490pub enum EnvChangeType {
491    /// Database changed.
492    Database = 1,
493    /// Language changed.
494    Language = 2,
495    /// Character set changed.
496    CharacterSet = 3,
497    /// Packet size changed.
498    PacketSize = 4,
499    /// Unicode data sorting locale ID.
500    UnicodeSortingLocalId = 5,
501    /// Unicode comparison flags.
502    UnicodeComparisonFlags = 6,
503    /// SQL collation.
504    SqlCollation = 7,
505    /// Begin transaction.
506    BeginTransaction = 8,
507    /// Commit transaction.
508    CommitTransaction = 9,
509    /// Rollback transaction.
510    RollbackTransaction = 10,
511    /// Enlist DTC transaction.
512    EnlistDtcTransaction = 11,
513    /// Defect DTC transaction.
514    DefectTransaction = 12,
515    /// Real-time log shipping.
516    RealTimeLogShipping = 13,
517    /// Promote transaction.
518    PromoteTransaction = 15,
519    /// Transaction manager address.
520    TransactionManagerAddress = 16,
521    /// Transaction ended.
522    TransactionEnded = 17,
523    /// Reset connection completion acknowledgment.
524    ResetConnectionCompletionAck = 18,
525    /// User instance started.
526    UserInstanceStarted = 19,
527    /// Routing information.
528    Routing = 20,
529}
530
531/// Environment change value.
532#[derive(Debug, Clone)]
533#[non_exhaustive]
534pub enum EnvChangeValue {
535    /// String value.
536    String(String),
537    /// Binary value.
538    Binary(bytes::Bytes),
539    /// Routing information.
540    Routing {
541        /// Host name.
542        host: String,
543        /// Port number.
544        port: u16,
545    },
546}
547
548/// Column ordering information.
549#[derive(Debug, Clone)]
550pub struct Order {
551    /// Ordered column indices.
552    pub columns: Vec<u16>,
553}
554
555/// Feature extension acknowledgment.
556#[derive(Debug, Clone)]
557pub struct FeatureExtAck {
558    /// Acknowledged features.
559    pub features: Vec<FeatureAck>,
560}
561
562/// Individual feature acknowledgment.
563#[derive(Debug, Clone)]
564pub struct FeatureAck {
565    /// Feature ID.
566    pub feature_id: u8,
567    /// Feature data.
568    pub data: bytes::Bytes,
569}
570
571/// SSPI authentication token.
572#[derive(Debug, Clone)]
573pub struct SspiToken {
574    /// SSPI data.
575    pub data: bytes::Bytes,
576}
577
578/// Session state token.
579#[derive(Debug, Clone)]
580pub struct SessionState {
581    /// Session state data.
582    pub data: bytes::Bytes,
583}
584
585/// Federated authentication info.
586#[derive(Debug, Clone)]
587pub struct FedAuthInfo {
588    /// STS URL.
589    pub sts_url: String,
590    /// Service principal name.
591    pub spn: String,
592}
593
594// =============================================================================
595// ColMetaData and Row Parsing Implementation
596// =============================================================================
597
598/// Decode collation information (5 bytes).
599///
600/// Shared by ColMetaData column parsing and CryptoMetadata base type parsing.
601pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
602    if src.remaining() < 5 {
603        return Err(ProtocolError::UnexpectedEof);
604    }
605    // Collation: LCID (4 bytes) + Sort ID (1 byte)
606    let lcid = src.get_u32_le();
607    let sort_id = src.get_u8();
608    Ok(Collation { lcid, sort_id })
609}
610
611/// Decode type-specific metadata for a column based on its TypeId.
612///
613/// Shared by ColMetaData column parsing and CryptoMetadata base type parsing.
614pub(crate) fn decode_type_info(
615    src: &mut impl Buf,
616    type_id: TypeId,
617    col_type: u8,
618) -> Result<TypeInfo, ProtocolError> {
619    match type_id {
620        // Fixed-length types have no additional metadata
621        TypeId::Null => Ok(TypeInfo::default()),
622        TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
623        TypeId::Int2 => Ok(TypeInfo::default()),
624        TypeId::Int4 => Ok(TypeInfo::default()),
625        TypeId::Int8 => Ok(TypeInfo::default()),
626        TypeId::Float4 => Ok(TypeInfo::default()),
627        TypeId::Float8 => Ok(TypeInfo::default()),
628        TypeId::Money => Ok(TypeInfo::default()),
629        TypeId::Money4 => Ok(TypeInfo::default()),
630        TypeId::DateTime => Ok(TypeInfo::default()),
631        TypeId::DateTime4 => Ok(TypeInfo::default()),
632
633        // Variable length integer/float/money (1-byte max length)
634        TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
635            if src.remaining() < 1 {
636                return Err(ProtocolError::UnexpectedEof);
637            }
638            let max_length = src.get_u8() as u32;
639            Ok(TypeInfo {
640                max_length: Some(max_length),
641                ..Default::default()
642            })
643        }
644
645        // GUID has 1-byte length
646        TypeId::Guid => {
647            if src.remaining() < 1 {
648                return Err(ProtocolError::UnexpectedEof);
649            }
650            let max_length = src.get_u8() as u32;
651            Ok(TypeInfo {
652                max_length: Some(max_length),
653                ..Default::default()
654            })
655        }
656
657        // Decimal/Numeric types (1-byte length + precision + scale)
658        TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
659            if src.remaining() < 3 {
660                return Err(ProtocolError::UnexpectedEof);
661            }
662            let max_length = src.get_u8() as u32;
663            let precision = src.get_u8();
664            let scale = src.get_u8();
665            Ok(TypeInfo {
666                max_length: Some(max_length),
667                precision: Some(precision),
668                scale: Some(scale),
669                ..Default::default()
670            })
671        }
672
673        // Old-style byte-length strings (Char, VarChar, Binary, VarBinary)
674        TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
675            if src.remaining() < 1 {
676                return Err(ProtocolError::UnexpectedEof);
677            }
678            let max_length = src.get_u8() as u32;
679            Ok(TypeInfo {
680                max_length: Some(max_length),
681                ..Default::default()
682            })
683        }
684
685        // Big varchar/binary with 2-byte length + collation for strings
686        TypeId::BigVarChar | TypeId::BigChar => {
687            if src.remaining() < 7 {
688                // 2 (length) + 5 (collation)
689                return Err(ProtocolError::UnexpectedEof);
690            }
691            let max_length = src.get_u16_le() as u32;
692            let collation = decode_collation(src)?;
693            Ok(TypeInfo {
694                max_length: Some(max_length),
695                collation: Some(collation),
696                ..Default::default()
697            })
698        }
699
700        // Big binary (2-byte length, no collation)
701        TypeId::BigVarBinary | TypeId::BigBinary => {
702            if src.remaining() < 2 {
703                return Err(ProtocolError::UnexpectedEof);
704            }
705            let max_length = src.get_u16_le() as u32;
706            Ok(TypeInfo {
707                max_length: Some(max_length),
708                ..Default::default()
709            })
710        }
711
712        // Unicode strings (NChar, NVarChar) - 2-byte length + collation
713        TypeId::NChar | TypeId::NVarChar => {
714            if src.remaining() < 7 {
715                // 2 (length) + 5 (collation)
716                return Err(ProtocolError::UnexpectedEof);
717            }
718            let max_length = src.get_u16_le() as u32;
719            let collation = decode_collation(src)?;
720            Ok(TypeInfo {
721                max_length: Some(max_length),
722                collation: Some(collation),
723                ..Default::default()
724            })
725        }
726
727        // Date type (no additional metadata)
728        TypeId::Date => Ok(TypeInfo::default()),
729
730        // Time, DateTime2, DateTimeOffset have scale
731        TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
732            if src.remaining() < 1 {
733                return Err(ProtocolError::UnexpectedEof);
734            }
735            let scale = src.get_u8();
736            Ok(TypeInfo {
737                scale: Some(scale),
738                ..Default::default()
739            })
740        }
741
742        // Text/NText/Image (deprecated LOB types)
743        TypeId::Text | TypeId::NText | TypeId::Image => {
744            // These have complex metadata: length (4) + collation (5) + table name parts
745            if src.remaining() < 4 {
746                return Err(ProtocolError::UnexpectedEof);
747            }
748            let max_length = src.get_u32_le();
749
750            // For Text/NText, read collation
751            let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
752                if src.remaining() < 5 {
753                    return Err(ProtocolError::UnexpectedEof);
754                }
755                Some(decode_collation(src)?)
756            } else {
757                None
758            };
759
760            // Skip table name parts (variable length)
761            // Format: numParts (1 byte) followed by us_varchar for each part
762            if src.remaining() < 1 {
763                return Err(ProtocolError::UnexpectedEof);
764            }
765            let num_parts = src.get_u8();
766            for _ in 0..num_parts {
767                // Read and discard table name part
768                let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
769            }
770
771            Ok(TypeInfo {
772                max_length: Some(max_length),
773                collation,
774                ..Default::default()
775            })
776        }
777
778        // XML type
779        TypeId::Xml => {
780            if src.remaining() < 1 {
781                return Err(ProtocolError::UnexpectedEof);
782            }
783            let schema_present = src.get_u8();
784
785            if schema_present != 0 {
786                // XML_INFO per MS-TDS §2.2.5.5.3: DBNAME and OWNING_SCHEMA are
787                // B_VARCHAR (1-byte length); only XML_SCHEMA_COLLECTION is
788                // US_VARCHAR (2-byte length).
789                let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // db name
790                let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // owning schema
791                let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // xml schema collection
792            }
793
794            Ok(TypeInfo::default())
795        }
796
797        // UDT (User-defined type) - complex metadata
798        TypeId::Udt => {
799            // Max length (2 bytes)
800            if src.remaining() < 2 {
801                return Err(ProtocolError::UnexpectedEof);
802            }
803            let max_length = src.get_u16_le() as u32;
804
805            // UDT_INFO per MS-TDS: DB_NAME, SCHEMA_NAME, and TYPE_NAME are
806            // B_VARCHAR (1-byte length); only ASSEMBLY_QUALIFIED_NAME is
807            // US_VARCHAR (2-byte length).
808            let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // db name
809            let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // schema name
810            let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // type name
811            let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; // assembly qualified name
812
813            Ok(TypeInfo {
814                max_length: Some(max_length),
815                ..Default::default()
816            })
817        }
818
819        // Table-valued parameter - complex metadata (skip for now)
820        TypeId::Tvp => {
821            // TVP has very complex metadata, not commonly used
822            // For now, we can't properly parse this
823            Err(ProtocolError::InvalidTokenType(col_type))
824        }
825
826        // SQL Variant - 4-byte length
827        TypeId::Variant => {
828            if src.remaining() < 4 {
829                return Err(ProtocolError::UnexpectedEof);
830            }
831            let max_length = src.get_u32_le();
832            Ok(TypeInfo {
833                max_length: Some(max_length),
834                ..Default::default()
835            })
836        }
837    }
838}
839
840impl ColMetaData {
841    /// Special value indicating no metadata.
842    pub const NO_METADATA: u16 = 0xFFFF;
843
844    /// Decode a COLMETADATA token from bytes.
845    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
846        if src.remaining() < 2 {
847            return Err(ProtocolError::UnexpectedEof);
848        }
849
850        let column_count = src.get_u16_le();
851
852        // 0xFFFF means no metadata present
853        if column_count == Self::NO_METADATA {
854            return Ok(Self {
855                columns: Vec::new(),
856                cek_table: None,
857            });
858        }
859
860        let mut columns = Vec::with_capacity(column_count as usize);
861
862        for _ in 0..column_count {
863            let column = Self::decode_column(src)?;
864            columns.push(column);
865        }
866
867        Ok(Self {
868            columns,
869            cek_table: None,
870        })
871    }
872
873    /// Decode a single column from the metadata.
874    fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
875        // UserType (4 bytes) + Flags (2 bytes) + TypeId (1 byte)
876        if src.remaining() < 7 {
877            return Err(ProtocolError::UnexpectedEof);
878        }
879
880        let user_type = src.get_u32_le();
881        let flags = src.get_u16_le();
882        let col_type = src.get_u8();
883
884        // An unknown type byte must be a hard error: treating it as Null (a
885        // zero-length column) misaligns every subsequent column and row,
886        // producing plausible garbage values (issue #157).
887        let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
888
889        // Parse type-specific metadata
890        let type_info = decode_type_info(src, type_id, col_type)?;
891
892        // Read column name (B_VARCHAR format - 1 byte length in characters)
893        let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
894
895        Ok(ColumnData {
896            name,
897            type_id,
898            col_type,
899            flags,
900            user_type,
901            type_info,
902            crypto_metadata: None,
903        })
904    }
905
906    /// Decode a COLMETADATA token with Always Encrypted support.
907    ///
908    /// When column encryption was negotiated in Login7, the server sends a CekTable
909    /// before column definitions and per-column CryptoMetadata for encrypted columns.
910    ///
911    /// # Wire Format (with encryption)
912    ///
913    /// ```text
914    /// column_count: USHORT
915    /// cek_table: CekTable (always present when encryption negotiated)
916    /// columns: ColumnData[column_count] (with CryptoMetadata for encrypted columns)
917    /// ```
918    pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
919        if src.remaining() < 2 {
920            return Err(ProtocolError::UnexpectedEof);
921        }
922
923        let column_count = src.get_u16_le();
924
925        if column_count == Self::NO_METADATA {
926            return Ok(Self {
927                columns: Vec::new(),
928                cek_table: None,
929            });
930        }
931
932        // Parse CEK table (always present when encryption was negotiated)
933        let cek_table = crate::crypto::CekTable::decode(src)?;
934
935        let mut columns = Vec::with_capacity(column_count as usize);
936
937        for _ in 0..column_count {
938            let column = Self::decode_column_encrypted(src)?;
939            columns.push(column);
940        }
941
942        Ok(Self {
943            columns,
944            cek_table: Some(cek_table),
945        })
946    }
947
948    /// Decode a single column definition with Always Encrypted support.
949    ///
950    /// For encrypted columns (flags & 0x0800), parses CryptoMetadata after the type info.
951    fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
952        if src.remaining() < 7 {
953            return Err(ProtocolError::UnexpectedEof);
954        }
955
956        let user_type = src.get_u32_le();
957        let flags = src.get_u16_le();
958        let col_type = src.get_u8();
959
960        let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
961
962        // Parse type-specific metadata (for encrypted columns, this is the transport type)
963        let type_info = decode_type_info(src, type_id, col_type)?;
964
965        // Parse CryptoMetadata if the column is encrypted
966        let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
967            Some(crate::crypto::CryptoMetadata::decode(src)?)
968        } else {
969            None
970        };
971
972        // Read column name
973        let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
974
975        Ok(ColumnData {
976            name,
977            type_id,
978            col_type,
979            flags,
980            user_type,
981            type_info,
982            crypto_metadata,
983        })
984    }
985
986    /// Get the number of columns.
987    #[must_use]
988    pub fn column_count(&self) -> usize {
989        self.columns.len()
990    }
991
992    /// Check if this represents no metadata.
993    #[must_use]
994    pub fn is_empty(&self) -> bool {
995        self.columns.is_empty()
996    }
997}
998
999impl ColumnData {
1000    /// Check if this column is nullable.
1001    #[must_use]
1002    pub fn is_nullable(&self) -> bool {
1003        (self.flags & 0x0001) != 0
1004    }
1005
1006    /// Get the fixed size in bytes for this column, if applicable.
1007    ///
1008    /// Returns `None` for variable-length types.
1009    #[must_use]
1010    pub fn fixed_size(&self) -> Option<usize> {
1011        match self.type_id {
1012            TypeId::Null => Some(0),
1013            TypeId::Int1 | TypeId::Bit => Some(1),
1014            TypeId::Int2 => Some(2),
1015            TypeId::Int4 => Some(4),
1016            TypeId::Int8 => Some(8),
1017            TypeId::Float4 => Some(4),
1018            TypeId::Float8 => Some(8),
1019            TypeId::Money => Some(8),
1020            TypeId::Money4 => Some(4),
1021            TypeId::DateTime => Some(8),
1022            TypeId::DateTime4 => Some(4),
1023            TypeId::Date => Some(3),
1024            _ => None,
1025        }
1026    }
1027}
1028
1029// =============================================================================
1030// Row Parsing Implementation
1031// =============================================================================
1032
1033impl RawRow {
1034    /// Decode a ROW token from bytes.
1035    ///
1036    /// This function requires the column metadata to know how to parse the row.
1037    /// The row data is stored as raw bytes for later parsing.
1038    pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1039        let mut data = bytes::BytesMut::new();
1040
1041        for col in &metadata.columns {
1042            Self::decode_column_value(src, col, &mut data)?;
1043        }
1044
1045        Ok(Self {
1046            data: data.freeze(),
1047        })
1048    }
1049
1050    /// Decode only the first `prefix_len` columns of a ROW token, leaving `src`
1051    /// positioned at the start of column `prefix_len`.
1052    ///
1053    /// Used by the BLOB streaming path to decode the leading scalar columns of a
1054    /// row and stop at a trailing MAX column, whose PLP value is then streamed
1055    /// directly from the socket rather than buffered.
1056    pub fn decode_prefix(
1057        src: &mut impl Buf,
1058        metadata: &ColMetaData,
1059        prefix_len: usize,
1060    ) -> Result<Self, ProtocolError> {
1061        let mut data = bytes::BytesMut::new();
1062        for col in metadata.columns.iter().take(prefix_len) {
1063            Self::decode_column_value(src, col, &mut data)?;
1064        }
1065        Ok(Self {
1066            data: data.freeze(),
1067        })
1068    }
1069
1070    /// Decode a single column value and append to the output buffer.
1071    fn decode_column_value(
1072        src: &mut impl Buf,
1073        col: &ColumnData,
1074        dst: &mut bytes::BytesMut,
1075    ) -> Result<(), ProtocolError> {
1076        match col.type_id {
1077            // Fixed-length types
1078            TypeId::Null => {
1079                // No data
1080            }
1081            TypeId::Int1 | TypeId::Bit => {
1082                if src.remaining() < 1 {
1083                    return Err(ProtocolError::UnexpectedEof);
1084                }
1085                dst.extend_from_slice(&[src.get_u8()]);
1086            }
1087            TypeId::Int2 => {
1088                if src.remaining() < 2 {
1089                    return Err(ProtocolError::UnexpectedEof);
1090                }
1091                dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
1092            }
1093            TypeId::Int4 => {
1094                if src.remaining() < 4 {
1095                    return Err(ProtocolError::UnexpectedEof);
1096                }
1097                dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1098            }
1099            TypeId::Int8 => {
1100                if src.remaining() < 8 {
1101                    return Err(ProtocolError::UnexpectedEof);
1102                }
1103                dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1104            }
1105            TypeId::Float4 => {
1106                if src.remaining() < 4 {
1107                    return Err(ProtocolError::UnexpectedEof);
1108                }
1109                dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1110            }
1111            TypeId::Float8 => {
1112                if src.remaining() < 8 {
1113                    return Err(ProtocolError::UnexpectedEof);
1114                }
1115                dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1116            }
1117            TypeId::Money => {
1118                if src.remaining() < 8 {
1119                    return Err(ProtocolError::UnexpectedEof);
1120                }
1121                let hi = src.get_u32_le();
1122                let lo = src.get_u32_le();
1123                dst.extend_from_slice(&hi.to_le_bytes());
1124                dst.extend_from_slice(&lo.to_le_bytes());
1125            }
1126            TypeId::Money4 => {
1127                if src.remaining() < 4 {
1128                    return Err(ProtocolError::UnexpectedEof);
1129                }
1130                dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1131            }
1132            TypeId::DateTime => {
1133                if src.remaining() < 8 {
1134                    return Err(ProtocolError::UnexpectedEof);
1135                }
1136                let days = src.get_u32_le();
1137                let time = src.get_u32_le();
1138                dst.extend_from_slice(&days.to_le_bytes());
1139                dst.extend_from_slice(&time.to_le_bytes());
1140            }
1141            TypeId::DateTime4 => {
1142                if src.remaining() < 4 {
1143                    return Err(ProtocolError::UnexpectedEof);
1144                }
1145                dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1146            }
1147            // DATE type uses 1-byte length prefix (can be NULL)
1148            TypeId::Date => {
1149                Self::decode_bytelen_type(src, dst)?;
1150            }
1151
1152            // Variable-length nullable types (length-prefixed)
1153            TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
1154                Self::decode_bytelen_type(src, dst)?;
1155            }
1156
1157            TypeId::Guid => {
1158                Self::decode_bytelen_type(src, dst)?;
1159            }
1160
1161            TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1162                Self::decode_bytelen_type(src, dst)?;
1163            }
1164
1165            // Old-style byte-length strings
1166            TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
1167                Self::decode_bytelen_type(src, dst)?;
1168            }
1169
1170            // 2-byte length strings (or PLP for MAX types)
1171            TypeId::BigVarChar | TypeId::BigVarBinary => {
1172                // max_length == 0xFFFF indicates VARCHAR(MAX) or VARBINARY(MAX), which uses PLP
1173                if col.type_info.max_length == Some(0xFFFF) {
1174                    Self::decode_plp_type(src, dst)?;
1175                } else {
1176                    Self::decode_ushortlen_type(src, dst)?;
1177                }
1178            }
1179
1180            // Fixed-length types that don't have MAX variants
1181            TypeId::BigChar | TypeId::BigBinary => {
1182                Self::decode_ushortlen_type(src, dst)?;
1183            }
1184
1185            // Unicode strings (2-byte length in bytes, or PLP for NVARCHAR(MAX))
1186            TypeId::NVarChar => {
1187                // max_length == 0xFFFF indicates NVARCHAR(MAX), which uses PLP
1188                if col.type_info.max_length == Some(0xFFFF) {
1189                    Self::decode_plp_type(src, dst)?;
1190                } else {
1191                    Self::decode_ushortlen_type(src, dst)?;
1192                }
1193            }
1194
1195            // Fixed-length NCHAR doesn't have MAX variant
1196            TypeId::NChar => {
1197                Self::decode_ushortlen_type(src, dst)?;
1198            }
1199
1200            // Time types with scale
1201            TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
1202                Self::decode_bytelen_type(src, dst)?;
1203            }
1204
1205            // TEXT/NTEXT/IMAGE - deprecated LOB types using textptr format
1206            TypeId::Text | TypeId::NText | TypeId::Image => {
1207                Self::decode_textptr_type(src, dst)?;
1208            }
1209
1210            // XML - uses actual PLP format
1211            TypeId::Xml => {
1212                Self::decode_plp_type(src, dst)?;
1213            }
1214
1215            // Complex types
1216            TypeId::Variant => {
1217                Self::decode_intlen_type(src, dst)?;
1218            }
1219
1220            TypeId::Udt => {
1221                // UDT uses PLP encoding
1222                Self::decode_plp_type(src, dst)?;
1223            }
1224
1225            TypeId::Tvp => {
1226                // TVP not supported in row data
1227                return Err(ProtocolError::InvalidTokenType(col.col_type));
1228            }
1229        }
1230
1231        Ok(())
1232    }
1233
1234    /// Decode a 1-byte length-prefixed value.
1235    fn decode_bytelen_type(
1236        src: &mut impl Buf,
1237        dst: &mut bytes::BytesMut,
1238    ) -> Result<(), ProtocolError> {
1239        if src.remaining() < 1 {
1240            return Err(ProtocolError::UnexpectedEof);
1241        }
1242        let len = src.get_u8() as usize;
1243        if len == 0xFF {
1244            // NULL value - store as zero-length with NULL marker
1245            dst.extend_from_slice(&[0xFF]);
1246        } else if len == 0 {
1247            // Empty value
1248            dst.extend_from_slice(&[0x00]);
1249        } else {
1250            if src.remaining() < len {
1251                return Err(ProtocolError::UnexpectedEof);
1252            }
1253            dst.extend_from_slice(&[len as u8]);
1254            for _ in 0..len {
1255                dst.extend_from_slice(&[src.get_u8()]);
1256            }
1257        }
1258        Ok(())
1259    }
1260
1261    /// Decode a 2-byte length-prefixed value.
1262    fn decode_ushortlen_type(
1263        src: &mut impl Buf,
1264        dst: &mut bytes::BytesMut,
1265    ) -> Result<(), ProtocolError> {
1266        if src.remaining() < 2 {
1267            return Err(ProtocolError::UnexpectedEof);
1268        }
1269        let len = src.get_u16_le() as usize;
1270        if len == 0xFFFF {
1271            // NULL value
1272            dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
1273        } else if len == 0 {
1274            // Empty value
1275            dst.extend_from_slice(&0u16.to_le_bytes());
1276        } else {
1277            if src.remaining() < len {
1278                return Err(ProtocolError::UnexpectedEof);
1279            }
1280            dst.extend_from_slice(&(len as u16).to_le_bytes());
1281            for _ in 0..len {
1282                dst.extend_from_slice(&[src.get_u8()]);
1283            }
1284        }
1285        Ok(())
1286    }
1287
1288    /// Decode a 4-byte length-prefixed value.
1289    fn decode_intlen_type(
1290        src: &mut impl Buf,
1291        dst: &mut bytes::BytesMut,
1292    ) -> Result<(), ProtocolError> {
1293        if src.remaining() < 4 {
1294            return Err(ProtocolError::UnexpectedEof);
1295        }
1296        let len = src.get_u32_le() as usize;
1297        if len == 0xFFFFFFFF {
1298            // NULL value
1299            dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
1300        } else if len == 0 {
1301            // Empty value
1302            dst.extend_from_slice(&0u32.to_le_bytes());
1303        } else {
1304            if src.remaining() < len {
1305                return Err(ProtocolError::UnexpectedEof);
1306            }
1307            dst.extend_from_slice(&(len as u32).to_le_bytes());
1308            for _ in 0..len {
1309                dst.extend_from_slice(&[src.get_u8()]);
1310            }
1311        }
1312        Ok(())
1313    }
1314
1315    /// Decode a TEXT/NTEXT/IMAGE type (textptr format).
1316    ///
1317    /// These deprecated LOB types use a special format:
1318    /// - 1 byte: textptr_len (0 = NULL)
1319    /// - textptr_len bytes: textptr (if not NULL)
1320    /// - 8 bytes: timestamp (if not NULL)
1321    /// - 4 bytes: data length (if not NULL)
1322    /// - data_len bytes: the actual data (if not NULL)
1323    ///
1324    /// We convert this to PLP format for the client to parse:
1325    /// - 8 bytes: total length (0xFFFFFFFFFFFFFFFF = NULL)
1326    /// - 4 bytes: chunk length (= data length)
1327    /// - chunk data
1328    /// - 4 bytes: 0 (terminator)
1329    fn decode_textptr_type(
1330        src: &mut impl Buf,
1331        dst: &mut bytes::BytesMut,
1332    ) -> Result<(), ProtocolError> {
1333        if src.remaining() < 1 {
1334            return Err(ProtocolError::UnexpectedEof);
1335        }
1336
1337        let textptr_len = src.get_u8() as usize;
1338
1339        if textptr_len == 0 {
1340            // NULL value - write PLP NULL marker
1341            dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
1342            return Ok(());
1343        }
1344
1345        // Skip textptr bytes
1346        if src.remaining() < textptr_len {
1347            return Err(ProtocolError::UnexpectedEof);
1348        }
1349        src.advance(textptr_len);
1350
1351        // Skip 8-byte timestamp
1352        if src.remaining() < 8 {
1353            return Err(ProtocolError::UnexpectedEof);
1354        }
1355        src.advance(8);
1356
1357        // Read data length
1358        if src.remaining() < 4 {
1359            return Err(ProtocolError::UnexpectedEof);
1360        }
1361        let data_len = src.get_u32_le() as usize;
1362
1363        if src.remaining() < data_len {
1364            return Err(ProtocolError::UnexpectedEof);
1365        }
1366
1367        // Write in PLP format for client parsing:
1368        // - 8 bytes: total length
1369        // - 4 bytes: chunk length
1370        // - chunk data
1371        // - 4 bytes: 0 (terminator)
1372        dst.extend_from_slice(&(data_len as u64).to_le_bytes());
1373        dst.extend_from_slice(&(data_len as u32).to_le_bytes());
1374        for _ in 0..data_len {
1375            dst.extend_from_slice(&[src.get_u8()]);
1376        }
1377        dst.extend_from_slice(&0u32.to_le_bytes()); // PLP terminator
1378
1379        Ok(())
1380    }
1381
1382    /// Decode a PLP (Partially Length-Prefixed) value.
1383    ///
1384    /// PLP format:
1385    /// - 8 bytes: total length (0xFFFFFFFFFFFFFFFE = unknown, 0xFFFFFFFFFFFFFFFF = NULL)
1386    /// - If not NULL: chunks of (4 byte chunk length + data) until chunk length = 0
1387    fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
1388        if src.remaining() < 8 {
1389            return Err(ProtocolError::UnexpectedEof);
1390        }
1391
1392        let total_len = src.get_u64_le();
1393
1394        // Store the total length marker
1395        dst.extend_from_slice(&total_len.to_le_bytes());
1396
1397        if total_len == 0xFFFFFFFFFFFFFFFF {
1398            // NULL value - no more data
1399            return Ok(());
1400        }
1401
1402        // Read chunks until terminator
1403        loop {
1404            if src.remaining() < 4 {
1405                return Err(ProtocolError::UnexpectedEof);
1406            }
1407            let chunk_len = src.get_u32_le() as usize;
1408            dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
1409
1410            if chunk_len == 0 {
1411                // End of PLP data
1412                break;
1413            }
1414
1415            if src.remaining() < chunk_len {
1416                return Err(ProtocolError::UnexpectedEof);
1417            }
1418
1419            for _ in 0..chunk_len {
1420                dst.extend_from_slice(&[src.get_u8()]);
1421            }
1422        }
1423
1424        Ok(())
1425    }
1426}
1427
1428// =============================================================================
1429// NbcRow Parsing Implementation
1430// =============================================================================
1431
1432impl NbcRow {
1433    /// Decode an NBCROW token from bytes.
1434    ///
1435    /// NBCROW (Null Bitmap Compressed Row) stores a bitmap indicating which
1436    /// columns are NULL, followed by only the non-NULL values.
1437    pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1438        let col_count = metadata.columns.len();
1439        let bitmap_len = col_count.div_ceil(8);
1440
1441        if src.remaining() < bitmap_len {
1442            return Err(ProtocolError::UnexpectedEof);
1443        }
1444
1445        // Read null bitmap
1446        let mut null_bitmap = vec![0u8; bitmap_len];
1447        for byte in &mut null_bitmap {
1448            *byte = src.get_u8();
1449        }
1450
1451        // Read non-null values
1452        let mut data = bytes::BytesMut::new();
1453
1454        for (i, col) in metadata.columns.iter().enumerate() {
1455            let byte_idx = i / 8;
1456            let bit_idx = i % 8;
1457            let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
1458
1459            if !is_null {
1460                // Read the value - for NBCROW, we read without the length prefix
1461                // for fixed-length types, and with length prefix for variable types
1462                RawRow::decode_column_value(src, col, &mut data)?;
1463            }
1464        }
1465
1466        Ok(Self {
1467            null_bitmap,
1468            data: data.freeze(),
1469        })
1470    }
1471
1472    /// Decode the null bitmap and the first `prefix_len` columns of an NBCROW,
1473    /// leaving `src` positioned at column `prefix_len`'s value (when that column
1474    /// is non-NULL per the bitmap). The returned row carries the full bitmap and
1475    /// the leading non-NULL values; query the trailing column's nullness with
1476    /// [`is_null`](Self::is_null).
1477    pub fn decode_prefix(
1478        src: &mut impl Buf,
1479        metadata: &ColMetaData,
1480        prefix_len: usize,
1481    ) -> Result<Self, ProtocolError> {
1482        let col_count = metadata.columns.len();
1483        let bitmap_len = col_count.div_ceil(8);
1484
1485        if src.remaining() < bitmap_len {
1486            return Err(ProtocolError::UnexpectedEof);
1487        }
1488
1489        let mut null_bitmap = vec![0u8; bitmap_len];
1490        for byte in &mut null_bitmap {
1491            *byte = src.get_u8();
1492        }
1493
1494        let mut data = bytes::BytesMut::new();
1495        for (i, col) in metadata.columns.iter().enumerate().take(prefix_len) {
1496            let is_null = (null_bitmap[i / 8] & (1 << (i % 8))) != 0;
1497            if !is_null {
1498                RawRow::decode_column_value(src, col, &mut data)?;
1499            }
1500        }
1501
1502        Ok(Self {
1503            null_bitmap,
1504            data: data.freeze(),
1505        })
1506    }
1507
1508    /// Check if a column at the given index is NULL.
1509    #[must_use]
1510    pub fn is_null(&self, column_index: usize) -> bool {
1511        let byte_idx = column_index / 8;
1512        let bit_idx = column_index % 8;
1513        if byte_idx < self.null_bitmap.len() {
1514            (self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
1515        } else {
1516            true // Out of bounds = NULL
1517        }
1518    }
1519}
1520
1521// =============================================================================
1522// ReturnValue Parsing Implementation
1523// =============================================================================
1524
1525impl ReturnValue {
1526    /// Decode a RETURNVALUE token from bytes.
1527    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1528        // MS-TDS §2.2.7.18: the RETURNVALUE token has no length prefix —
1529        // it begins directly with the 2-byte ParamOrdinal. The previous
1530        // spurious 2-byte read consumed the ordinal and shifted every
1531        // subsequent field, leaving the stream parser two bytes ahead and
1532        // reading value bytes as the next token type (e.g. `0x74` from a
1533        // Unicode name fragment was misread as an unknown token).
1534        if src.remaining() < 2 {
1535            return Err(ProtocolError::UnexpectedEof);
1536        }
1537        let param_ordinal = src.get_u16_le();
1538
1539        // Parameter name (B_VARCHAR)
1540        let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1541
1542        // Status (1 byte)
1543        if src.remaining() < 1 {
1544            return Err(ProtocolError::UnexpectedEof);
1545        }
1546        let status = src.get_u8();
1547
1548        // User type (4 bytes) + flags (2 bytes) + type id (1 byte)
1549        if src.remaining() < 7 {
1550            return Err(ProtocolError::UnexpectedEof);
1551        }
1552        let user_type = src.get_u32_le();
1553        let flags = src.get_u16_le();
1554        let col_type = src.get_u8();
1555
1556        let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
1557
1558        // Parse type info
1559        let type_info = decode_type_info(src, type_id, col_type)?;
1560
1561        // Read the value data
1562        let mut value_buf = bytes::BytesMut::new();
1563
1564        // Create a temporary column for value parsing
1565        let temp_col = ColumnData {
1566            name: String::new(),
1567            type_id,
1568            col_type,
1569            flags,
1570            user_type,
1571            type_info: type_info.clone(),
1572            crypto_metadata: None,
1573        };
1574
1575        RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
1576
1577        Ok(Self {
1578            param_ordinal,
1579            param_name,
1580            status,
1581            user_type,
1582            flags,
1583            col_type,
1584            type_info,
1585            value: value_buf.freeze(),
1586        })
1587    }
1588}
1589
1590// =============================================================================
1591// SessionState Parsing Implementation
1592// =============================================================================
1593
1594impl SessionState {
1595    /// Decode a SESSIONSTATE token from bytes.
1596    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1597        if src.remaining() < 4 {
1598            return Err(ProtocolError::UnexpectedEof);
1599        }
1600
1601        let length = src.get_u32_le() as usize;
1602
1603        if src.remaining() < length {
1604            return Err(ProtocolError::IncompletePacket {
1605                expected: length,
1606                actual: src.remaining(),
1607            });
1608        }
1609
1610        let data = src.copy_to_bytes(length);
1611
1612        Ok(Self { data })
1613    }
1614}
1615
1616// =============================================================================
1617// Token Parsing Implementation
1618// =============================================================================
1619
1620/// Done token status flags bit positions.
1621mod done_status_bits {
1622    pub const DONE_MORE: u16 = 0x0001;
1623    pub const DONE_ERROR: u16 = 0x0002;
1624    pub const DONE_INXACT: u16 = 0x0004;
1625    pub const DONE_COUNT: u16 = 0x0010;
1626    pub const DONE_ATTN: u16 = 0x0020;
1627    pub const DONE_SRVERROR: u16 = 0x0100;
1628}
1629
1630impl DoneStatus {
1631    /// Parse done status from raw bits.
1632    #[must_use]
1633    pub fn from_bits(bits: u16) -> Self {
1634        use done_status_bits::*;
1635        Self {
1636            more: (bits & DONE_MORE) != 0,
1637            error: (bits & DONE_ERROR) != 0,
1638            in_xact: (bits & DONE_INXACT) != 0,
1639            count: (bits & DONE_COUNT) != 0,
1640            attn: (bits & DONE_ATTN) != 0,
1641            srverror: (bits & DONE_SRVERROR) != 0,
1642        }
1643    }
1644
1645    /// Convert to raw bits.
1646    #[must_use]
1647    pub fn to_bits(&self) -> u16 {
1648        use done_status_bits::*;
1649        let mut bits = 0u16;
1650        if self.more {
1651            bits |= DONE_MORE;
1652        }
1653        if self.error {
1654            bits |= DONE_ERROR;
1655        }
1656        if self.in_xact {
1657            bits |= DONE_INXACT;
1658        }
1659        if self.count {
1660            bits |= DONE_COUNT;
1661        }
1662        if self.attn {
1663            bits |= DONE_ATTN;
1664        }
1665        if self.srverror {
1666            bits |= DONE_SRVERROR;
1667        }
1668        bits
1669    }
1670}
1671
1672impl Done {
1673    /// Size of the DONE token in bytes (excluding token type byte).
1674    pub const SIZE: usize = 12; // 2 (status) + 2 (curcmd) + 8 (rowcount)
1675
1676    /// Decode a DONE token from bytes.
1677    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1678        if src.remaining() < Self::SIZE {
1679            return Err(ProtocolError::IncompletePacket {
1680                expected: Self::SIZE,
1681                actual: src.remaining(),
1682            });
1683        }
1684
1685        let status = DoneStatus::from_bits(src.get_u16_le());
1686        let cur_cmd = src.get_u16_le();
1687        let row_count = src.get_u64_le();
1688
1689        Ok(Self {
1690            status,
1691            cur_cmd,
1692            row_count,
1693        })
1694    }
1695
1696    /// Encode the DONE token to bytes.
1697    pub fn encode(&self, dst: &mut impl BufMut) {
1698        dst.put_u8(TokenType::Done as u8);
1699        dst.put_u16_le(self.status.to_bits());
1700        dst.put_u16_le(self.cur_cmd);
1701        dst.put_u64_le(self.row_count);
1702    }
1703
1704    /// Check if more results follow this DONE token.
1705    #[must_use]
1706    pub const fn has_more(&self) -> bool {
1707        self.status.more
1708    }
1709
1710    /// Check if an error occurred.
1711    #[must_use]
1712    pub const fn has_error(&self) -> bool {
1713        self.status.error
1714    }
1715
1716    /// Check if the row count is valid.
1717    #[must_use]
1718    pub const fn has_count(&self) -> bool {
1719        self.status.count
1720    }
1721}
1722
1723impl DoneProc {
1724    /// Size of the DONEPROC token in bytes (excluding token type byte).
1725    pub const SIZE: usize = 12;
1726
1727    /// Decode a DONEPROC token from bytes.
1728    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1729        if src.remaining() < Self::SIZE {
1730            return Err(ProtocolError::IncompletePacket {
1731                expected: Self::SIZE,
1732                actual: src.remaining(),
1733            });
1734        }
1735
1736        let status = DoneStatus::from_bits(src.get_u16_le());
1737        let cur_cmd = src.get_u16_le();
1738        let row_count = src.get_u64_le();
1739
1740        Ok(Self {
1741            status,
1742            cur_cmd,
1743            row_count,
1744        })
1745    }
1746
1747    /// Encode the DONEPROC token to bytes.
1748    pub fn encode(&self, dst: &mut impl BufMut) {
1749        dst.put_u8(TokenType::DoneProc as u8);
1750        dst.put_u16_le(self.status.to_bits());
1751        dst.put_u16_le(self.cur_cmd);
1752        dst.put_u64_le(self.row_count);
1753    }
1754}
1755
1756impl DoneInProc {
1757    /// Size of the DONEINPROC token in bytes (excluding token type byte).
1758    pub const SIZE: usize = 12;
1759
1760    /// Decode a DONEINPROC token from bytes.
1761    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1762        if src.remaining() < Self::SIZE {
1763            return Err(ProtocolError::IncompletePacket {
1764                expected: Self::SIZE,
1765                actual: src.remaining(),
1766            });
1767        }
1768
1769        let status = DoneStatus::from_bits(src.get_u16_le());
1770        let cur_cmd = src.get_u16_le();
1771        let row_count = src.get_u64_le();
1772
1773        Ok(Self {
1774            status,
1775            cur_cmd,
1776            row_count,
1777        })
1778    }
1779
1780    /// Encode the DONEINPROC token to bytes.
1781    pub fn encode(&self, dst: &mut impl BufMut) {
1782        dst.put_u8(TokenType::DoneInProc as u8);
1783        dst.put_u16_le(self.status.to_bits());
1784        dst.put_u16_le(self.cur_cmd);
1785        dst.put_u64_le(self.row_count);
1786    }
1787}
1788
1789impl ServerError {
1790    /// Decode an ERROR token from bytes.
1791    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1792        // ERROR token: length (2) + number (4) + state (1) + class (1) +
1793        //              message (us_varchar) + server (b_varchar) + procedure (b_varchar) + line (4)
1794        if src.remaining() < 2 {
1795            return Err(ProtocolError::UnexpectedEof);
1796        }
1797
1798        let _length = src.get_u16_le();
1799
1800        if src.remaining() < 6 {
1801            return Err(ProtocolError::UnexpectedEof);
1802        }
1803
1804        let number = src.get_i32_le();
1805        let state = src.get_u8();
1806        let class = src.get_u8();
1807
1808        let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1809        let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1810        let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1811
1812        if src.remaining() < 4 {
1813            return Err(ProtocolError::UnexpectedEof);
1814        }
1815        let line = src.get_i32_le();
1816
1817        Ok(Self {
1818            number,
1819            state,
1820            class,
1821            message,
1822            server,
1823            procedure,
1824            line,
1825        })
1826    }
1827
1828    /// Check if this is a fatal error (severity >= 20).
1829    #[must_use]
1830    pub const fn is_fatal(&self) -> bool {
1831        self.class >= 20
1832    }
1833
1834    /// Check if this error indicates the batch was aborted (severity >= 16).
1835    #[must_use]
1836    pub const fn is_batch_abort(&self) -> bool {
1837        self.class >= 16
1838    }
1839}
1840
1841impl ServerInfo {
1842    /// Decode an INFO token from bytes.
1843    ///
1844    /// INFO tokens have the same structure as ERROR tokens but with lower severity.
1845    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1846        if src.remaining() < 2 {
1847            return Err(ProtocolError::UnexpectedEof);
1848        }
1849
1850        let _length = src.get_u16_le();
1851
1852        if src.remaining() < 6 {
1853            return Err(ProtocolError::UnexpectedEof);
1854        }
1855
1856        let number = src.get_i32_le();
1857        let state = src.get_u8();
1858        let class = src.get_u8();
1859
1860        let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1861        let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1862        let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1863
1864        if src.remaining() < 4 {
1865            return Err(ProtocolError::UnexpectedEof);
1866        }
1867        let line = src.get_i32_le();
1868
1869        Ok(Self {
1870            number,
1871            state,
1872            class,
1873            message,
1874            server,
1875            procedure,
1876            line,
1877        })
1878    }
1879}
1880
1881impl LoginAck {
1882    /// Decode a LOGINACK token from bytes.
1883    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1884        // LOGINACK: length (2) + interface (1) + tds_version (4) + prog_name (b_varchar) + prog_version (4)
1885        if src.remaining() < 2 {
1886            return Err(ProtocolError::UnexpectedEof);
1887        }
1888
1889        let _length = src.get_u16_le();
1890
1891        if src.remaining() < 5 {
1892            return Err(ProtocolError::UnexpectedEof);
1893        }
1894
1895        let interface = src.get_u8();
1896        let tds_version = src.get_u32_le();
1897        let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1898
1899        if src.remaining() < 4 {
1900            return Err(ProtocolError::UnexpectedEof);
1901        }
1902        let prog_version = src.get_u32_le();
1903
1904        Ok(Self {
1905            interface,
1906            tds_version,
1907            prog_name,
1908            prog_version,
1909        })
1910    }
1911
1912    /// Get the TDS version as a `TdsVersion`.
1913    #[must_use]
1914    pub fn tds_version(&self) -> crate::version::TdsVersion {
1915        crate::version::TdsVersion::new(self.tds_version)
1916    }
1917}
1918
1919impl EnvChangeType {
1920    /// Create from raw byte value.
1921    pub fn from_u8(value: u8) -> Option<Self> {
1922        match value {
1923            1 => Some(Self::Database),
1924            2 => Some(Self::Language),
1925            3 => Some(Self::CharacterSet),
1926            4 => Some(Self::PacketSize),
1927            5 => Some(Self::UnicodeSortingLocalId),
1928            6 => Some(Self::UnicodeComparisonFlags),
1929            7 => Some(Self::SqlCollation),
1930            8 => Some(Self::BeginTransaction),
1931            9 => Some(Self::CommitTransaction),
1932            10 => Some(Self::RollbackTransaction),
1933            11 => Some(Self::EnlistDtcTransaction),
1934            12 => Some(Self::DefectTransaction),
1935            13 => Some(Self::RealTimeLogShipping),
1936            15 => Some(Self::PromoteTransaction),
1937            16 => Some(Self::TransactionManagerAddress),
1938            17 => Some(Self::TransactionEnded),
1939            18 => Some(Self::ResetConnectionCompletionAck),
1940            19 => Some(Self::UserInstanceStarted),
1941            20 => Some(Self::Routing),
1942            _ => None,
1943        }
1944    }
1945}
1946
1947impl EnvChange {
1948    /// Decode an ENVCHANGE token from bytes.
1949    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1950        if src.remaining() < 3 {
1951            return Err(ProtocolError::UnexpectedEof);
1952        }
1953
1954        let length = src.get_u16_le() as usize;
1955        if length == 0 {
1956            // The frame must at least contain the type byte; reading it from
1957            // outside a zero-length frame would consume the next token.
1958            return Err(ProtocolError::UnexpectedEof);
1959        }
1960        if src.remaining() < length {
1961            return Err(ProtocolError::IncompletePacket {
1962                expected: length,
1963                actual: src.remaining(),
1964            });
1965        }
1966
1967        // Frame-strict decoding (issue #145): the value decoders below only
1968        // bounds-check against the *buffer*, so on an under-declared frame
1969        // they could read past the declared length into the next token's
1970        // bytes. Slice exactly the declared frame and decode from that:
1971        // over-read attempts now hit frame end and take the lenient
1972        // empty-value fallbacks (preserving the #140 hostile-input
1973        // behavior), and the outer buffer always advances by exactly
1974        // `length`.
1975        let mut frame = src.copy_to_bytes(length);
1976        let src = &mut frame;
1977
1978        let env_type_byte = src.get_u8();
1979        let env_type = EnvChangeType::from_u8(env_type_byte)
1980            .ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
1981
1982        let (new_value, old_value) = match env_type {
1983            EnvChangeType::Routing => {
1984                // Routing has special format
1985                let new_value = Self::decode_routing_value(src)?;
1986                let old_value = EnvChangeValue::Binary(Bytes::new());
1987                (new_value, old_value)
1988            }
1989            EnvChangeType::BeginTransaction
1990            | EnvChangeType::CommitTransaction
1991            | EnvChangeType::RollbackTransaction
1992            | EnvChangeType::EnlistDtcTransaction
1993            | EnvChangeType::SqlCollation => {
1994                // These use binary format per MS-TDS spec:
1995                // - Transaction tokens: transaction descriptor (8 bytes)
1996                // - SqlCollation: collation info (5 bytes: LCID + sort flags)
1997                // The declared ENVCHANGE `length` can be shorter than this
1998                // branch needs (e.g. covers only the type byte), so the
1999                // length-prefix reads must be bounds-checked individually:
2000                // `get_u8` on an empty buffer panics. Match the branch's
2001                // existing graceful style — a missing prefix means empty.
2002                let new_len = if src.has_remaining() {
2003                    src.get_u8() as usize
2004                } else {
2005                    0
2006                };
2007                let new_value = if new_len > 0 && src.remaining() >= new_len {
2008                    EnvChangeValue::Binary(src.copy_to_bytes(new_len))
2009                } else {
2010                    EnvChangeValue::Binary(Bytes::new())
2011                };
2012
2013                let old_len = if src.has_remaining() {
2014                    src.get_u8() as usize
2015                } else {
2016                    0
2017                };
2018                let old_value = if old_len > 0 && src.remaining() >= old_len {
2019                    EnvChangeValue::Binary(src.copy_to_bytes(old_len))
2020                } else {
2021                    EnvChangeValue::Binary(Bytes::new())
2022                };
2023
2024                (new_value, old_value)
2025            }
2026            _ => {
2027                // String format for most env changes
2028                let new_value = read_b_varchar(src)
2029                    .map(EnvChangeValue::String)
2030                    .unwrap_or(EnvChangeValue::String(String::new()));
2031
2032                let old_value = read_b_varchar(src)
2033                    .map(EnvChangeValue::String)
2034                    .unwrap_or(EnvChangeValue::String(String::new()));
2035
2036                (new_value, old_value)
2037            }
2038        };
2039
2040        // No frame-boundary fixup needed: the whole declared frame was
2041        // consumed from the outer buffer up front, so decoders that
2042        // under-consume (e.g. Routing's implicit zero-length OldValue) just
2043        // leave bytes behind in the dropped sub-frame.
2044
2045        Ok(Self {
2046            env_type,
2047            new_value,
2048            old_value,
2049        })
2050    }
2051
2052    fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
2053        // Routing format: length (2) + protocol (1) + port (2) + server_len (2) + server (utf16)
2054        if src.remaining() < 2 {
2055            return Err(ProtocolError::UnexpectedEof);
2056        }
2057
2058        let _routing_len = src.get_u16_le();
2059
2060        if src.remaining() < 5 {
2061            return Err(ProtocolError::UnexpectedEof);
2062        }
2063
2064        let _protocol = src.get_u8();
2065        let port = src.get_u16_le();
2066        let server_len = src.get_u16_le() as usize;
2067
2068        // Read UTF-16LE server name
2069        if src.remaining() < server_len * 2 {
2070            return Err(ProtocolError::UnexpectedEof);
2071        }
2072
2073        let mut chars = Vec::with_capacity(server_len);
2074        for _ in 0..server_len {
2075            chars.push(src.get_u16_le());
2076        }
2077
2078        let host = String::from_utf16(&chars).map_err(|_| {
2079            ProtocolError::StringEncoding(
2080                #[cfg(feature = "std")]
2081                "invalid UTF-16 in routing hostname".to_string(),
2082                #[cfg(not(feature = "std"))]
2083                "invalid UTF-16 in routing hostname",
2084            )
2085        })?;
2086
2087        Ok(EnvChangeValue::Routing { host, port })
2088    }
2089
2090    /// Check if this is a routing redirect.
2091    #[must_use]
2092    pub fn is_routing(&self) -> bool {
2093        self.env_type == EnvChangeType::Routing
2094    }
2095
2096    /// Get routing information if this is a routing change.
2097    #[must_use]
2098    pub fn routing_info(&self) -> Option<(&str, u16)> {
2099        if let EnvChangeValue::Routing { host, port } = &self.new_value {
2100            Some((host, *port))
2101        } else {
2102            None
2103        }
2104    }
2105
2106    /// Get the new database name if this is a database change.
2107    #[must_use]
2108    pub fn new_database(&self) -> Option<&str> {
2109        if self.env_type == EnvChangeType::Database {
2110            if let EnvChangeValue::String(s) = &self.new_value {
2111                return Some(s);
2112            }
2113        }
2114        None
2115    }
2116}
2117
2118impl Order {
2119    /// Decode an ORDER token from bytes.
2120    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2121        if src.remaining() < 2 {
2122            return Err(ProtocolError::UnexpectedEof);
2123        }
2124
2125        let length = src.get_u16_le() as usize;
2126        let column_count = length / 2;
2127
2128        if src.remaining() < length {
2129            return Err(ProtocolError::IncompletePacket {
2130                expected: length,
2131                actual: src.remaining(),
2132            });
2133        }
2134
2135        let mut columns = Vec::with_capacity(column_count);
2136        for _ in 0..column_count {
2137            columns.push(src.get_u16_le());
2138        }
2139
2140        Ok(Self { columns })
2141    }
2142}
2143
2144impl FeatureExtAck {
2145    /// Feature terminator byte.
2146    pub const TERMINATOR: u8 = 0xFF;
2147
2148    /// Decode a FEATUREEXTACK token from bytes.
2149    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2150        let mut features = Vec::new();
2151
2152        loop {
2153            if !src.has_remaining() {
2154                return Err(ProtocolError::UnexpectedEof);
2155            }
2156
2157            let feature_id = src.get_u8();
2158            if feature_id == Self::TERMINATOR {
2159                break;
2160            }
2161
2162            if src.remaining() < 4 {
2163                return Err(ProtocolError::UnexpectedEof);
2164            }
2165
2166            let data_len = src.get_u32_le() as usize;
2167
2168            if src.remaining() < data_len {
2169                return Err(ProtocolError::IncompletePacket {
2170                    expected: data_len,
2171                    actual: src.remaining(),
2172                });
2173            }
2174
2175            let data = src.copy_to_bytes(data_len);
2176            features.push(FeatureAck { feature_id, data });
2177        }
2178
2179        Ok(Self { features })
2180    }
2181}
2182
2183impl SspiToken {
2184    /// Decode an SSPI token from bytes.
2185    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2186        if src.remaining() < 2 {
2187            return Err(ProtocolError::UnexpectedEof);
2188        }
2189
2190        let length = src.get_u16_le() as usize;
2191
2192        if src.remaining() < length {
2193            return Err(ProtocolError::IncompletePacket {
2194                expected: length,
2195                actual: src.remaining(),
2196            });
2197        }
2198
2199        let data = src.copy_to_bytes(length);
2200        Ok(Self { data })
2201    }
2202}
2203
2204impl FedAuthInfo {
2205    /// `FedAuthInfoID` for the STS URL (MS-TDS §2.2.7.12: %0x01 = STSURL).
2206    const ID_STSURL: u8 = 0x01;
2207    /// `FedAuthInfoID` for the service principal name (MS-TDS §2.2.7.12:
2208    /// %0x02 = SPN).
2209    const ID_SPN: u8 = 0x02;
2210    /// Size of one `FedAuthInfoOpt` header: ID (1) + DataLen (4) + DataOffset (4).
2211    const OPT_HEADER_LEN: usize = 9;
2212
2213    /// Decode a FEDAUTHINFO token from bytes.
2214    ///
2215    /// Wire layout per MS-TDS §2.2.7.12 (after the 0xEE token byte):
2216    /// `TokenLength` (DWORD) covering everything that follows, then
2217    /// `CountOfInfoIDs` (DWORD), then `CountOfInfoIDs` option headers of
2218    /// ID (BYTE) + `FedAuthInfoDataLen` (DWORD) + `FedAuthInfoDataOffset`
2219    /// (DWORD), then the data block. Offsets are relative to the start of
2220    /// the `CountOfInfoIDs` field, and the option data is UTF-16LE.
2221    ///
2222    /// Exactly `TokenLength` bytes are consumed, so tokens that follow
2223    /// FEDAUTHINFO in the login stream (LOGINACK, DONE) are preserved.
2224    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2225        if src.remaining() < 4 {
2226            return Err(ProtocolError::UnexpectedEof);
2227        }
2228        let token_len = src.get_u32_le() as usize;
2229        if src.remaining() < token_len {
2230            return Err(ProtocolError::UnexpectedEof);
2231        }
2232
2233        // Offsets in the option headers are relative to the start of this
2234        // region (the CountOfInfoIDs field), so address into it directly.
2235        let region = src.copy_to_bytes(token_len);
2236        if region.len() < 4 {
2237            return Err(ProtocolError::UnexpectedEof);
2238        }
2239        let count = u32::from_le_bytes([region[0], region[1], region[2], region[3]]) as usize;
2240
2241        // All headers must fit between the count field and the end of the
2242        // token. The checked math also rejects hostile counts that would
2243        // overflow the offset arithmetic.
2244        let headers_end = count
2245            .checked_mul(Self::OPT_HEADER_LEN)
2246            .and_then(|n| n.checked_add(4))
2247            .ok_or(ProtocolError::UnexpectedEof)?;
2248        if headers_end > region.len() {
2249            return Err(ProtocolError::UnexpectedEof);
2250        }
2251
2252        let mut sts_url = String::new();
2253        let mut spn = String::new();
2254
2255        for i in 0..count {
2256            let h = 4 + i * Self::OPT_HEADER_LEN;
2257            let info_id = region[h];
2258            let data_len =
2259                u32::from_le_bytes([region[h + 1], region[h + 2], region[h + 3], region[h + 4]])
2260                    as usize;
2261            let data_off =
2262                u32::from_le_bytes([region[h + 5], region[h + 6], region[h + 7], region[h + 8]])
2263                    as usize;
2264
2265            // Unknown IDs are skipped without validating their data, per the
2266            // spec's instruction to ignore unrecognized options.
2267            if info_id != Self::ID_SPN && info_id != Self::ID_STSURL {
2268                continue;
2269            }
2270
2271            let data_end = data_off
2272                .checked_add(data_len)
2273                .ok_or(ProtocolError::UnexpectedEof)?;
2274            if data_end > region.len() {
2275                return Err(ProtocolError::UnexpectedEof);
2276            }
2277            if data_len % 2 != 0 {
2278                return Err(ProtocolError::StringEncoding(
2279                    #[cfg(feature = "std")]
2280                    "FEDAUTHINFO option data has odd length, not UTF-16".to_string(),
2281                    #[cfg(not(feature = "std"))]
2282                    "FEDAUTHINFO option data has odd length, not UTF-16",
2283                ));
2284            }
2285
2286            let chars: Vec<u16> = region[data_off..data_end]
2287                .chunks_exact(2)
2288                .map(|b| u16::from_le_bytes([b[0], b[1]]))
2289                .collect();
2290            let value = String::from_utf16(&chars).map_err(|_| {
2291                ProtocolError::StringEncoding(
2292                    #[cfg(feature = "std")]
2293                    "invalid UTF-16 in FEDAUTHINFO option".to_string(),
2294                    #[cfg(not(feature = "std"))]
2295                    "invalid UTF-16 in FEDAUTHINFO option",
2296                )
2297            })?;
2298
2299            if info_id == Self::ID_SPN {
2300                spn = value;
2301            } else {
2302                sts_url = value;
2303            }
2304        }
2305
2306        Ok(Self { sts_url, spn })
2307    }
2308}
2309
2310// =============================================================================
2311// Token Parser
2312// =============================================================================
2313
2314/// Token stream parser.
2315///
2316/// Parses a stream of TDS tokens from a byte buffer.
2317///
2318/// # Basic vs Context-Aware Parsing
2319///
2320/// Some tokens (like `Done`, `Error`, `LoginAck`) can be parsed without context.
2321/// Use [`next_token()`](TokenParser::next_token) for these.
2322///
2323/// Other tokens (like `ColMetaData`, `Row`, `NbcRow`) require column metadata
2324/// to parse correctly. Use [`next_token_with_metadata()`](TokenParser::next_token_with_metadata)
2325/// for these.
2326///
2327/// # Example
2328///
2329/// ```rust,ignore
2330/// let mut parser = TokenParser::new(data);
2331/// let mut metadata = None;
2332///
2333/// while let Some(token) = parser.next_token_with_metadata(metadata.as_ref())? {
2334///     match token {
2335///         Token::ColMetaData(meta) => {
2336///             metadata = Some(meta);
2337///         }
2338///         Token::Row(row) => {
2339///             // Process row using metadata
2340///         }
2341///         Token::Done(done) => {
2342///             if !done.has_more() {
2343///                 break;
2344///             }
2345///         }
2346///         _ => {}
2347///     }
2348/// }
2349/// ```
2350pub struct TokenParser {
2351    data: Bytes,
2352    position: usize,
2353    /// Whether Always Encrypted was negotiated for this connection.
2354    /// When true, ColMetaData tokens are parsed with CekTable and per-column CryptoMetadata.
2355    encryption_enabled: bool,
2356}
2357
2358impl TokenParser {
2359    /// Create a new token parser from bytes.
2360    #[must_use]
2361    pub fn new(data: Bytes) -> Self {
2362        Self {
2363            data,
2364            position: 0,
2365            encryption_enabled: false,
2366        }
2367    }
2368
2369    /// Enable Always Encrypted metadata parsing.
2370    ///
2371    /// When enabled, ColMetaData tokens are parsed using the encrypted format
2372    /// which includes a CekTable and per-column CryptoMetadata.
2373    #[must_use]
2374    pub fn with_encryption(mut self, enabled: bool) -> Self {
2375        self.encryption_enabled = enabled;
2376        self
2377    }
2378
2379    /// Get remaining bytes in the buffer.
2380    #[must_use]
2381    pub fn remaining(&self) -> usize {
2382        self.data.len().saturating_sub(self.position)
2383    }
2384
2385    /// Check if there are more bytes to parse.
2386    #[must_use]
2387    pub fn has_remaining(&self) -> bool {
2388        self.position < self.data.len()
2389    }
2390
2391    /// Peek at the next token type without consuming it.
2392    #[must_use]
2393    pub fn peek_token_type(&self) -> Option<TokenType> {
2394        if self.position < self.data.len() {
2395            TokenType::from_u8(self.data[self.position])
2396        } else {
2397            None
2398        }
2399    }
2400
2401    /// Parse the next token from the stream.
2402    ///
2403    /// This method can only parse context-independent tokens. For tokens that
2404    /// require column metadata (ColMetaData, Row, NbcRow), use
2405    /// [`next_token_with_metadata()`](TokenParser::next_token_with_metadata).
2406    ///
2407    /// Returns `None` if no more tokens are available.
2408    pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
2409        self.next_token_with_metadata(None)
2410    }
2411
2412    /// Parse the next token with optional column metadata context.
2413    ///
2414    /// When `metadata` is provided, this method can parse Row and NbcRow tokens.
2415    /// Without metadata, those tokens will return an error.
2416    ///
2417    /// Returns `None` if no more tokens are available.
2418    pub fn next_token_with_metadata(
2419        &mut self,
2420        metadata: Option<&ColMetaData>,
2421    ) -> Result<Option<Token>, ProtocolError> {
2422        loop {
2423            if !self.has_remaining() {
2424                return Ok(None);
2425            }
2426
2427            let mut buf = &self.data[self.position..];
2428            let start_pos = self.position;
2429
2430            let token_type_byte = buf.get_u8();
2431            let token_type = TokenType::from_u8(token_type_byte);
2432
2433            let token = match token_type {
2434                Some(TokenType::Done) => {
2435                    let done = Done::decode(&mut buf)?;
2436                    Token::Done(done)
2437                }
2438                Some(TokenType::DoneProc) => {
2439                    let done = DoneProc::decode(&mut buf)?;
2440                    Token::DoneProc(done)
2441                }
2442                Some(TokenType::DoneInProc) => {
2443                    let done = DoneInProc::decode(&mut buf)?;
2444                    Token::DoneInProc(done)
2445                }
2446                Some(TokenType::Error) => {
2447                    let error = ServerError::decode(&mut buf)?;
2448                    Token::Error(error)
2449                }
2450                Some(TokenType::Info) => {
2451                    let info = ServerInfo::decode(&mut buf)?;
2452                    Token::Info(info)
2453                }
2454                Some(TokenType::LoginAck) => {
2455                    let login_ack = LoginAck::decode(&mut buf)?;
2456                    Token::LoginAck(login_ack)
2457                }
2458                Some(TokenType::EnvChange) => {
2459                    let env_change = EnvChange::decode(&mut buf)?;
2460                    Token::EnvChange(env_change)
2461                }
2462                Some(TokenType::Order) => {
2463                    let order = Order::decode(&mut buf)?;
2464                    Token::Order(order)
2465                }
2466                Some(TokenType::FeatureExtAck) => {
2467                    let ack = FeatureExtAck::decode(&mut buf)?;
2468                    Token::FeatureExtAck(ack)
2469                }
2470                Some(TokenType::Sspi) => {
2471                    let sspi = SspiToken::decode(&mut buf)?;
2472                    Token::Sspi(sspi)
2473                }
2474                Some(TokenType::FedAuthInfo) => {
2475                    let info = FedAuthInfo::decode(&mut buf)?;
2476                    Token::FedAuthInfo(info)
2477                }
2478                Some(TokenType::ReturnStatus) => {
2479                    if buf.remaining() < 4 {
2480                        return Err(ProtocolError::UnexpectedEof);
2481                    }
2482                    let status = buf.get_i32_le();
2483                    Token::ReturnStatus(status)
2484                }
2485                Some(TokenType::ColMetaData) => {
2486                    let col_meta = if self.encryption_enabled {
2487                        ColMetaData::decode_encrypted(&mut buf)?
2488                    } else {
2489                        ColMetaData::decode(&mut buf)?
2490                    };
2491                    Token::ColMetaData(col_meta)
2492                }
2493                Some(TokenType::Row) => {
2494                    let meta = metadata.ok_or_else(|| {
2495                        ProtocolError::StringEncoding(
2496                            #[cfg(feature = "std")]
2497                            "Row token requires column metadata".to_string(),
2498                            #[cfg(not(feature = "std"))]
2499                            "Row token requires column metadata",
2500                        )
2501                    })?;
2502                    let row = RawRow::decode(&mut buf, meta)?;
2503                    Token::Row(row)
2504                }
2505                Some(TokenType::NbcRow) => {
2506                    let meta = metadata.ok_or_else(|| {
2507                        ProtocolError::StringEncoding(
2508                            #[cfg(feature = "std")]
2509                            "NbcRow token requires column metadata".to_string(),
2510                            #[cfg(not(feature = "std"))]
2511                            "NbcRow token requires column metadata",
2512                        )
2513                    })?;
2514                    let row = NbcRow::decode(&mut buf, meta)?;
2515                    Token::NbcRow(row)
2516                }
2517                Some(TokenType::ReturnValue) => {
2518                    let ret_val = ReturnValue::decode(&mut buf)?;
2519                    Token::ReturnValue(ret_val)
2520                }
2521                Some(TokenType::SessionState) => {
2522                    let session = SessionState::decode(&mut buf)?;
2523                    Token::SessionState(session)
2524                }
2525                Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
2526                    // These tokens are rarely used and have complex formats.
2527                    // Skip them by reading the length and advancing.
2528                    if buf.remaining() < 2 {
2529                        return Err(ProtocolError::UnexpectedEof);
2530                    }
2531                    let length = buf.get_u16_le() as usize;
2532                    if buf.remaining() < length {
2533                        return Err(ProtocolError::IncompletePacket {
2534                            expected: length,
2535                            actual: buf.remaining(),
2536                        });
2537                    }
2538                    // Skip the data
2539                    buf.advance(length);
2540                    // #273: advance past the skipped token and iterate. The skip
2541                    // path must NOT recurse — a server-controlled flat run of these
2542                    // tokens would otherwise add one stack frame per token and
2543                    // overflow the stack (remote DoS).
2544                    self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
2545                    continue;
2546                }
2547                None => {
2548                    return Err(ProtocolError::InvalidTokenType(token_type_byte));
2549                }
2550            };
2551
2552            // Update position based on how much was consumed
2553            let consumed = self.data.len() - start_pos - buf.remaining();
2554            self.position = start_pos + consumed;
2555
2556            return Ok(Some(token));
2557        }
2558    }
2559
2560    /// Skip the current token without fully parsing it.
2561    ///
2562    /// This is useful for skipping unknown or uninteresting tokens.
2563    pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
2564        if !self.has_remaining() {
2565            return Ok(());
2566        }
2567
2568        let token_type_byte = self.data[self.position];
2569        let token_type = TokenType::from_u8(token_type_byte);
2570
2571        // Calculate how many bytes to skip based on token type
2572        let skip_amount = match token_type {
2573            // Fixed-size tokens
2574            Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
2575                1 + Done::SIZE // token type + 12 bytes
2576            }
2577            Some(TokenType::ReturnStatus) => {
2578                1 + 4 // token type + 4 bytes
2579            }
2580            // Variable-length tokens with 2-byte length prefix
2581            Some(TokenType::Error)
2582            | Some(TokenType::Info)
2583            | Some(TokenType::LoginAck)
2584            | Some(TokenType::EnvChange)
2585            | Some(TokenType::Order)
2586            | Some(TokenType::Sspi)
2587            | Some(TokenType::ColInfo)
2588            | Some(TokenType::TabName)
2589            | Some(TokenType::Offset)
2590            | Some(TokenType::ReturnValue) => {
2591                if self.remaining() < 3 {
2592                    return Err(ProtocolError::UnexpectedEof);
2593                }
2594                let length = u16::from_le_bytes([
2595                    self.data[self.position + 1],
2596                    self.data[self.position + 2],
2597                ]) as usize;
2598                1 + 2 + length // token type + length prefix + data
2599            }
2600            // Tokens with 4-byte length prefix
2601            Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
2602                if self.remaining() < 5 {
2603                    return Err(ProtocolError::UnexpectedEof);
2604                }
2605                let length = u32::from_le_bytes([
2606                    self.data[self.position + 1],
2607                    self.data[self.position + 2],
2608                    self.data[self.position + 3],
2609                    self.data[self.position + 4],
2610                ]) as usize;
2611                1 + 4 + length
2612            }
2613            // FeatureExtAck has no length prefix - must parse
2614            Some(TokenType::FeatureExtAck) => {
2615                // Parse to find end
2616                let mut buf = &self.data[self.position + 1..];
2617                let _ = FeatureExtAck::decode(&mut buf)?;
2618                self.data.len() - self.position - buf.remaining()
2619            }
2620            // ColMetaData, Row, NbcRow require context and can't be easily skipped
2621            Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
2622                return Err(ProtocolError::InvalidTokenType(token_type_byte));
2623            }
2624            None => {
2625                return Err(ProtocolError::InvalidTokenType(token_type_byte));
2626            }
2627        };
2628
2629        if self.remaining() < skip_amount {
2630            return Err(ProtocolError::UnexpectedEof);
2631        }
2632
2633        self.position += skip_amount;
2634        Ok(())
2635    }
2636
2637    /// Get the current position in the buffer.
2638    #[must_use]
2639    pub fn position(&self) -> usize {
2640        self.position
2641    }
2642
2643    /// Reset the parser to the beginning.
2644    pub fn reset(&mut self) {
2645        self.position = 0;
2646    }
2647}
2648
2649// =============================================================================
2650// Tests
2651// =============================================================================
2652
2653#[cfg(test)]
2654#[allow(clippy::unwrap_used, clippy::panic)]
2655mod tests {
2656    use super::*;
2657    use bytes::BytesMut;
2658
2659    #[test]
2660    fn test_done_roundtrip() {
2661        let done = Done {
2662            status: DoneStatus {
2663                more: false,
2664                error: false,
2665                in_xact: false,
2666                count: true,
2667                attn: false,
2668                srverror: false,
2669            },
2670            cur_cmd: 193, // SELECT
2671            row_count: 42,
2672        };
2673
2674        let mut buf = BytesMut::new();
2675        done.encode(&mut buf);
2676
2677        // Skip the token type byte
2678        let mut cursor = &buf[1..];
2679        let decoded = Done::decode(&mut cursor).unwrap();
2680
2681        assert_eq!(decoded.status.count, done.status.count);
2682        assert_eq!(decoded.cur_cmd, done.cur_cmd);
2683        assert_eq!(decoded.row_count, done.row_count);
2684    }
2685
2686    #[test]
2687    fn test_done_status_bits() {
2688        let status = DoneStatus {
2689            more: true,
2690            error: true,
2691            in_xact: true,
2692            count: true,
2693            attn: false,
2694            srverror: false,
2695        };
2696
2697        let bits = status.to_bits();
2698        let restored = DoneStatus::from_bits(bits);
2699
2700        assert_eq!(status.more, restored.more);
2701        assert_eq!(status.error, restored.error);
2702        assert_eq!(status.in_xact, restored.in_xact);
2703        assert_eq!(status.count, restored.count);
2704    }
2705
2706    #[test]
2707    fn test_token_parser_done() {
2708        // DONE token: type (1) + status (2) + curcmd (2) + rowcount (8)
2709        let data = Bytes::from_static(&[
2710            0xFD, // DONE token type
2711            0x10, 0x00, // status: DONE_COUNT
2712            0xC1, 0x00, // cur_cmd: 193 (SELECT)
2713            0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // row_count: 5
2714        ]);
2715
2716        let mut parser = TokenParser::new(data);
2717        let token = parser.next_token().unwrap().unwrap();
2718
2719        match token {
2720            Token::Done(done) => {
2721                assert!(done.status.count);
2722                assert!(!done.status.more);
2723                assert_eq!(done.cur_cmd, 193);
2724                assert_eq!(done.row_count, 5);
2725            }
2726            _ => panic!("Expected Done token"),
2727        }
2728
2729        // No more tokens
2730        assert!(parser.next_token().unwrap().is_none());
2731    }
2732
2733    #[test]
2734    fn test_env_change_type_from_u8() {
2735        assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
2736        assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
2737        assert_eq!(EnvChangeType::from_u8(100), None);
2738    }
2739
2740    /// A spec-faithful Routing ENVCHANGE (MS-TDS 2.2.7.9) carries a
2741    /// zero-length OldValue (two bytes) after the routing data. The decoder
2742    /// reads only the NewValue, so it must skip to the declared frame
2743    /// boundary — otherwise the leftover `00 00` is misparsed as the next
2744    /// token type and the rest of the login response is garbage. Azure SQL
2745    /// Gateway redirects send exactly this shape.
2746    #[test]
2747    fn test_env_change_routing_consumes_declared_length() {
2748        let host = "redirect.example";
2749        let host_utf16: Vec<u16> = host.encode_utf16().collect();
2750
2751        let mut data = BytesMut::new();
2752        // RoutingDataValue: length + protocol + port + server_len + server
2753        let routing_len = 1 + 2 + 2 + host_utf16.len() * 2;
2754        // ENVCHANGE length: type byte + routing length prefix + routing data
2755        // + zero-length OldValue
2756        let env_len = 1 + 2 + routing_len + 2;
2757        data.put_u16_le(env_len as u16);
2758        data.put_u8(20); // Routing
2759        data.put_u16_le(routing_len as u16);
2760        data.put_u8(0); // protocol: TCP
2761        data.put_u16_le(11000); // port
2762        data.put_u16_le(host_utf16.len() as u16);
2763        for c in &host_utf16 {
2764            data.put_u16_le(*c);
2765        }
2766        data.put_u16_le(0); // OldValue: zero-length US_VARBYTE
2767        // A trailing DONE token type byte that must remain for the next read.
2768        data.put_u8(0xFD);
2769
2770        let mut buf: &[u8] = &data;
2771        let env = EnvChange::decode(&mut buf).unwrap();
2772        assert_eq!(env.routing_info(), Some((host, 11000)));
2773        assert_eq!(
2774            buf,
2775            &[0xFD],
2776            "decode must consume exactly the declared ENVCHANGE frame"
2777        );
2778    }
2779
2780    fn put_b_varchar(buf: &mut BytesMut, s: &str) {
2781        let utf16: Vec<u16> = s.encode_utf16().collect();
2782        buf.put_u8(utf16.len() as u8);
2783        for c in utf16 {
2784            buf.put_u16_le(c);
2785        }
2786    }
2787
2788    fn put_us_varchar(buf: &mut BytesMut, s: &str) {
2789        let utf16: Vec<u16> = s.encode_utf16().collect();
2790        buf.put_u16_le(utf16.len() as u16);
2791        for c in utf16 {
2792            buf.put_u16_le(c);
2793        }
2794    }
2795
2796    /// UDT_INFO regression (issue #154): per MS-TDS, DB_NAME, SCHEMA_NAME,
2797    /// and TYPE_NAME are B_VARCHAR (1-byte length); only
2798    /// ASSEMBLY_QUALIFIED_NAME is US_VARCHAR. Reading all four as US_VARCHAR
2799    /// misaligned the stream, so every query selecting a UDT column
2800    /// (geography, geometry, hierarchyid, CLR UDTs) failed with
2801    /// UnexpectedEof.
2802    #[test]
2803    fn test_udt_info_metadata_uses_b_varchar_names() {
2804        let mut data = BytesMut::new();
2805        data.put_u16_le(0xFFFF); // MAX_BYTE_SIZE
2806        put_b_varchar(&mut data, "master");
2807        put_b_varchar(&mut data, "dbo");
2808        put_b_varchar(&mut data, "hierarchyid");
2809        put_us_varchar(
2810            &mut data,
2811            "Microsoft.SqlServer.Types.SqlHierarchyId, Microsoft.SqlServer.Types",
2812        );
2813        // A trailing token type byte that must remain for the next read.
2814        data.put_u8(0xFD);
2815
2816        let mut buf: &[u8] = &data;
2817        let info = decode_type_info(&mut buf, TypeId::Udt, TypeId::Udt as u8).unwrap();
2818        assert_eq!(info.max_length, Some(0xFFFF));
2819        assert_eq!(
2820            buf,
2821            &[0xFD],
2822            "decode must consume exactly the UDT_INFO frame"
2823        );
2824    }
2825
2826    /// XML_INFO regression (issue #154): per MS-TDS §2.2.5.5.3, DBNAME and
2827    /// OWNING_SCHEMA are B_VARCHAR; only XML_SCHEMA_COLLECTION is US_VARCHAR.
2828    /// Schema-bound xml columns (SCHEMA_PRESENT=1) previously misparsed.
2829    #[test]
2830    fn test_xml_info_schema_bound_uses_b_varchar_names() {
2831        let mut data = BytesMut::new();
2832        data.put_u8(1); // SCHEMA_PRESENT
2833        put_b_varchar(&mut data, "master");
2834        put_b_varchar(&mut data, "dbo");
2835        put_us_varchar(&mut data, "MyXmlSchemaCollection");
2836        data.put_u8(0xFD);
2837
2838        let mut buf: &[u8] = &data;
2839        decode_type_info(&mut buf, TypeId::Xml, TypeId::Xml as u8).unwrap();
2840        assert_eq!(
2841            buf,
2842            &[0xFD],
2843            "decode must consume exactly the XML_INFO frame"
2844        );
2845    }
2846
2847    #[test]
2848    fn hostile_env_change_binary_truncated_is_not_panic() {
2849        // length=1 covers only the type byte (0x08 = BeginTransaction, a
2850        // binary-format type); the new_len/old_len prefix reads then hit an
2851        // empty buffer. Must decode gracefully, never panic (found by the
2852        // parse_env_change and parse_token fuzz targets).
2853        let data = [0x01, 0x00, 0x08];
2854        let mut buf: &[u8] = &data;
2855        let env = EnvChange::decode(&mut buf).unwrap();
2856        assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2857    }
2858
2859    /// Issue #145: an under-declared frame must not let the value decoders
2860    /// read past the declared length into the next token's bytes.
2861    #[test]
2862    fn hostile_env_change_under_declared_cannot_steal_following_bytes() {
2863        // length=1 covers only the type byte; the bytes after the frame are
2864        // shaped exactly like the transaction-descriptor payload the old
2865        // buffer-bounded decoder would have consumed (new_len=8, descriptor,
2866        // old_len=0).
2867        let mut data = BytesMut::new();
2868        data.put_u16_le(1); // declared frame: type byte only
2869        data.put_u8(0x08); // BeginTransaction
2870        let following: &[u8] = &[0x08, 1, 2, 3, 4, 5, 6, 7, 8, 0x00];
2871        data.extend_from_slice(following);
2872
2873        let mut buf: &[u8] = &data;
2874        let env = EnvChange::decode(&mut buf).unwrap();
2875        assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2876        match &env.new_value {
2877            EnvChangeValue::Binary(b) => {
2878                assert!(
2879                    b.is_empty(),
2880                    "under-declared frame yields the lenient empty value"
2881                );
2882            }
2883            other => panic!("expected empty Binary value, got {other:?}"),
2884        }
2885        assert_eq!(
2886            buf, following,
2887            "bytes beyond the declared frame belong to the next token"
2888        );
2889    }
2890
2891    /// Issue #145: a zero-length frame cannot supply a type byte; reading
2892    /// one from beyond the frame would consume the next token.
2893    #[test]
2894    fn hostile_env_change_zero_length_frame_errors() {
2895        let data = [0x00, 0x00, 0xFD];
2896        let mut buf: &[u8] = &data;
2897        assert!(EnvChange::decode(&mut buf).is_err());
2898    }
2899
2900    #[test]
2901    fn test_colmetadata_no_columns() {
2902        // No metadata marker (0xFFFF)
2903        let data = Bytes::from_static(&[0xFF, 0xFF]);
2904        let mut cursor: &[u8] = &data;
2905        let meta = ColMetaData::decode(&mut cursor).unwrap();
2906        assert!(meta.is_empty());
2907        assert_eq!(meta.column_count(), 0);
2908    }
2909
2910    #[test]
2911    fn test_colmetadata_single_int_column() {
2912        // COLMETADATA with 1 INT column
2913        // Format: column_count (2) + [user_type (4) + flags (2) + type_id (1) + name (b_varchar)]
2914        let mut data = BytesMut::new();
2915        data.extend_from_slice(&[0x01, 0x00]); // 1 column
2916        data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // user_type = 0
2917        data.extend_from_slice(&[0x01, 0x00]); // flags (nullable)
2918        data.extend_from_slice(&[0x38]); // TypeId::Int4
2919        // Column name "id" as B_VARCHAR (1 byte length + UTF-16LE)
2920        data.extend_from_slice(&[0x02]); // 2 characters
2921        data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); // "id" in UTF-16LE
2922
2923        let mut cursor: &[u8] = &data;
2924        let meta = ColMetaData::decode(&mut cursor).unwrap();
2925
2926        assert_eq!(meta.column_count(), 1);
2927        assert_eq!(meta.columns[0].name, "id");
2928        assert_eq!(meta.columns[0].type_id, TypeId::Int4);
2929        assert!(meta.columns[0].is_nullable());
2930    }
2931
2932    #[test]
2933    fn test_colmetadata_nvarchar_column() {
2934        // COLMETADATA with 1 NVARCHAR(50) column
2935        let mut data = BytesMut::new();
2936        data.extend_from_slice(&[0x01, 0x00]); // 1 column
2937        data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // user_type = 0
2938        data.extend_from_slice(&[0x01, 0x00]); // flags (nullable)
2939        data.extend_from_slice(&[0xE7]); // TypeId::NVarChar
2940        // Type info: max_length (2 bytes) + collation (5 bytes)
2941        data.extend_from_slice(&[0x64, 0x00]); // max_length = 100 (50 chars * 2)
2942        data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); // collation
2943        // Column name "name"
2944        data.extend_from_slice(&[0x04]); // 4 characters
2945        data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
2946
2947        let mut cursor: &[u8] = &data;
2948        let meta = ColMetaData::decode(&mut cursor).unwrap();
2949
2950        assert_eq!(meta.column_count(), 1);
2951        assert_eq!(meta.columns[0].name, "name");
2952        assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
2953        assert_eq!(meta.columns[0].type_info.max_length, Some(100));
2954        assert!(meta.columns[0].type_info.collation.is_some());
2955    }
2956
2957    #[test]
2958    fn test_raw_row_decode_int() {
2959        // Create metadata for a single INT column
2960        let metadata = ColMetaData {
2961            cek_table: None,
2962            columns: vec![ColumnData {
2963                name: "id".to_string(),
2964                type_id: TypeId::Int4,
2965                col_type: 0x38,
2966                flags: 0,
2967                user_type: 0,
2968                type_info: TypeInfo::default(),
2969                crypto_metadata: None,
2970            }],
2971        };
2972
2973        // Row data: just 4 bytes for the int value 42
2974        let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); // 42 in little-endian
2975        let mut cursor: &[u8] = &data;
2976        let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2977
2978        // The raw data should contain the 4 bytes
2979        assert_eq!(row.data.len(), 4);
2980        assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
2981    }
2982
2983    #[test]
2984    fn test_raw_row_decode_nullable_int() {
2985        // Create metadata for a nullable INT column (IntN)
2986        let metadata = ColMetaData {
2987            cek_table: None,
2988            columns: vec![ColumnData {
2989                name: "id".to_string(),
2990                type_id: TypeId::IntN,
2991                col_type: 0x26,
2992                flags: 0x01, // nullable
2993                user_type: 0,
2994                type_info: TypeInfo {
2995                    max_length: Some(4),
2996                    ..Default::default()
2997                },
2998                crypto_metadata: None,
2999            }],
3000        };
3001
3002        // Row data with value: 1 byte length + 4 bytes value
3003        let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); // length=4, value=42
3004        let mut cursor: &[u8] = &data;
3005        let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3006
3007        assert_eq!(row.data.len(), 5);
3008        assert_eq!(row.data[0], 4); // length
3009        assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
3010    }
3011
3012    #[test]
3013    fn test_raw_row_decode_null_value() {
3014        // Create metadata for a nullable INT column (IntN)
3015        let metadata = ColMetaData {
3016            cek_table: None,
3017            columns: vec![ColumnData {
3018                name: "id".to_string(),
3019                type_id: TypeId::IntN,
3020                col_type: 0x26,
3021                flags: 0x01, // nullable
3022                user_type: 0,
3023                type_info: TypeInfo {
3024                    max_length: Some(4),
3025                    ..Default::default()
3026                },
3027                crypto_metadata: None,
3028            }],
3029        };
3030
3031        // NULL value: length = 0xFF (for bytelen types)
3032        let data = Bytes::from_static(&[0xFF]);
3033        let mut cursor: &[u8] = &data;
3034        let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3035
3036        assert_eq!(row.data.len(), 1);
3037        assert_eq!(row.data[0], 0xFF); // NULL marker
3038    }
3039
3040    #[test]
3041    fn test_nbcrow_null_bitmap() {
3042        let row = NbcRow {
3043            null_bitmap: vec![0b00000101], // columns 0 and 2 are NULL
3044            data: Bytes::new(),
3045        };
3046
3047        assert!(row.is_null(0));
3048        assert!(!row.is_null(1));
3049        assert!(row.is_null(2));
3050        assert!(!row.is_null(3));
3051    }
3052
3053    #[test]
3054    fn test_token_parser_colmetadata() {
3055        // Build a COLMETADATA token with 1 INT column
3056        let mut data = BytesMut::new();
3057        data.extend_from_slice(&[0x81]); // COLMETADATA token type
3058        data.extend_from_slice(&[0x01, 0x00]); // 1 column
3059        data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); // user_type = 0
3060        data.extend_from_slice(&[0x01, 0x00]); // flags (nullable)
3061        data.extend_from_slice(&[0x38]); // TypeId::Int4
3062        data.extend_from_slice(&[0x02]); // column name length
3063        data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); // "id"
3064
3065        let mut parser = TokenParser::new(data.freeze());
3066        let token = parser.next_token().unwrap().unwrap();
3067
3068        match token {
3069            Token::ColMetaData(meta) => {
3070                assert_eq!(meta.column_count(), 1);
3071                assert_eq!(meta.columns[0].name, "id");
3072                assert_eq!(meta.columns[0].type_id, TypeId::Int4);
3073            }
3074            _ => panic!("Expected ColMetaData token"),
3075        }
3076    }
3077
3078    #[test]
3079    fn test_token_parser_row_with_metadata() {
3080        // Build metadata
3081        let metadata = ColMetaData {
3082            cek_table: None,
3083            columns: vec![ColumnData {
3084                name: "id".to_string(),
3085                type_id: TypeId::Int4,
3086                col_type: 0x38,
3087                flags: 0,
3088                user_type: 0,
3089                type_info: TypeInfo::default(),
3090                crypto_metadata: None,
3091            }],
3092        };
3093
3094        // Build ROW token
3095        let mut data = BytesMut::new();
3096        data.extend_from_slice(&[0xD1]); // ROW token type
3097        data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); // value = 42
3098
3099        let mut parser = TokenParser::new(data.freeze());
3100        let token = parser
3101            .next_token_with_metadata(Some(&metadata))
3102            .unwrap()
3103            .unwrap();
3104
3105        match token {
3106            Token::Row(row) => {
3107                assert_eq!(row.data.len(), 4);
3108            }
3109            _ => panic!("Expected Row token"),
3110        }
3111    }
3112
3113    #[test]
3114    fn test_token_parser_row_without_metadata_fails() {
3115        // Build ROW token
3116        let mut data = BytesMut::new();
3117        data.extend_from_slice(&[0xD1]); // ROW token type
3118        data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); // value = 42
3119
3120        let mut parser = TokenParser::new(data.freeze());
3121        let result = parser.next_token(); // No metadata provided
3122
3123        assert!(result.is_err());
3124    }
3125
3126    #[test]
3127    fn test_token_parser_peek() {
3128        let data = Bytes::from_static(&[
3129            0xFD, // DONE token type
3130            0x10, 0x00, // status
3131            0xC1, 0x00, // cur_cmd
3132            0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // row_count
3133        ]);
3134
3135        let parser = TokenParser::new(data);
3136        assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
3137    }
3138
3139    #[test]
3140    fn test_column_data_fixed_size() {
3141        let col = ColumnData {
3142            name: String::new(),
3143            type_id: TypeId::Int4,
3144            col_type: 0x38,
3145            flags: 0,
3146            user_type: 0,
3147            type_info: TypeInfo::default(),
3148            crypto_metadata: None,
3149        };
3150        assert_eq!(col.fixed_size(), Some(4));
3151
3152        let col2 = ColumnData {
3153            name: String::new(),
3154            type_id: TypeId::NVarChar,
3155            col_type: 0xE7,
3156            flags: 0,
3157            user_type: 0,
3158            type_info: TypeInfo::default(),
3159            crypto_metadata: None,
3160        };
3161        assert_eq!(col2.fixed_size(), None);
3162    }
3163
3164    // ========================================================================
3165    // End-to-End Decode Tests (Wire → Stored → Verification)
3166    // ========================================================================
3167    //
3168    // These tests verify that RawRow::decode_column_value correctly stores
3169    // column values in a format that can be parsed back.
3170
3171    #[test]
3172    fn test_decode_nvarchar_then_intn_roundtrip() {
3173        // Simulate wire data for: "World" (NVarChar), 42 (IntN)
3174        // This tests the scenario from the MCP parameterized query
3175
3176        // Build wire data (what the server sends)
3177        let mut wire_data = BytesMut::new();
3178
3179        // Column 0: NVarChar "World" - 2-byte length prefix in bytes
3180        // "World" in UTF-16LE: W=0x0057, o=0x006F, r=0x0072, l=0x006C, d=0x0064
3181        let word = "World";
3182        let utf16: Vec<u16> = word.encode_utf16().collect();
3183        wire_data.put_u16_le((utf16.len() * 2) as u16); // byte length = 10
3184        for code_unit in &utf16 {
3185            wire_data.put_u16_le(*code_unit);
3186        }
3187
3188        // Column 1: IntN 42 - 1-byte length prefix
3189        wire_data.put_u8(4); // 4 bytes for INT
3190        wire_data.put_i32_le(42);
3191
3192        // Build column metadata
3193        let metadata = ColMetaData {
3194            cek_table: None,
3195            columns: vec![
3196                ColumnData {
3197                    name: "greeting".to_string(),
3198                    type_id: TypeId::NVarChar,
3199                    col_type: 0xE7,
3200                    flags: 0x01,
3201                    user_type: 0,
3202                    type_info: TypeInfo {
3203                        max_length: Some(10), // non-MAX
3204                        precision: None,
3205                        scale: None,
3206                        collation: None,
3207                    },
3208                    crypto_metadata: None,
3209                },
3210                ColumnData {
3211                    name: "number".to_string(),
3212                    type_id: TypeId::IntN,
3213                    col_type: 0x26,
3214                    flags: 0x01,
3215                    user_type: 0,
3216                    type_info: TypeInfo {
3217                        max_length: Some(4),
3218                        precision: None,
3219                        scale: None,
3220                        collation: None,
3221                    },
3222                    crypto_metadata: None,
3223                },
3224            ],
3225        };
3226
3227        // Decode the wire data into stored format
3228        let mut wire_cursor = wire_data.freeze();
3229        let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3230
3231        // Verify wire data was fully consumed
3232        assert_eq!(
3233            wire_cursor.remaining(),
3234            0,
3235            "wire data should be fully consumed"
3236        );
3237
3238        // Now parse the stored data
3239        let mut stored_cursor: &[u8] = &raw_row.data;
3240
3241        // Parse column 0 (NVarChar)
3242        // Stored format for non-MAX NVarChar: [2-byte len][data]
3243        assert!(
3244            stored_cursor.remaining() >= 2,
3245            "need at least 2 bytes for length"
3246        );
3247        let len0 = stored_cursor.get_u16_le() as usize;
3248        assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
3249        assert!(
3250            stored_cursor.remaining() >= len0,
3251            "need {len0} bytes for data"
3252        );
3253
3254        // Read UTF-16LE and convert to string
3255        let mut utf16_read = Vec::new();
3256        for _ in 0..(len0 / 2) {
3257            utf16_read.push(stored_cursor.get_u16_le());
3258        }
3259        let string0 = String::from_utf16(&utf16_read).unwrap();
3260        assert_eq!(string0, "World", "column 0 should be 'World'");
3261
3262        // Parse column 1 (IntN)
3263        // Stored format for IntN: [1-byte len][data]
3264        assert!(
3265            stored_cursor.remaining() >= 1,
3266            "need at least 1 byte for length"
3267        );
3268        let len1 = stored_cursor.get_u8();
3269        assert_eq!(len1, 4, "IntN length should be 4");
3270        assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
3271        let int1 = stored_cursor.get_i32_le();
3272        assert_eq!(int1, 42, "column 1 should be 42");
3273
3274        // Verify stored data was fully consumed
3275        assert_eq!(
3276            stored_cursor.remaining(),
3277            0,
3278            "stored data should be fully consumed"
3279        );
3280    }
3281
3282    #[test]
3283    fn test_decode_nvarchar_max_then_intn_roundtrip() {
3284        // Test NVARCHAR(MAX) followed by IntN - uses PLP encoding
3285
3286        // Build wire data for PLP NVARCHAR(MAX) + IntN
3287        let mut wire_data = BytesMut::new();
3288
3289        // Column 0: NVARCHAR(MAX) "Hello" - PLP format
3290        // PLP: 8-byte total length, then chunks
3291        let word = "Hello";
3292        let utf16: Vec<u16> = word.encode_utf16().collect();
3293        let byte_len = (utf16.len() * 2) as u64;
3294
3295        wire_data.put_u64_le(byte_len); // total length = 10
3296        wire_data.put_u32_le(byte_len as u32); // chunk length = 10
3297        for code_unit in &utf16 {
3298            wire_data.put_u16_le(*code_unit);
3299        }
3300        wire_data.put_u32_le(0); // terminating zero-length chunk
3301
3302        // Column 1: IntN 99
3303        wire_data.put_u8(4);
3304        wire_data.put_i32_le(99);
3305
3306        // Build metadata with MAX type
3307        let metadata = ColMetaData {
3308            cek_table: None,
3309            columns: vec![
3310                ColumnData {
3311                    name: "text".to_string(),
3312                    type_id: TypeId::NVarChar,
3313                    col_type: 0xE7,
3314                    flags: 0x01,
3315                    user_type: 0,
3316                    type_info: TypeInfo {
3317                        max_length: Some(0xFFFF), // MAX indicator
3318                        precision: None,
3319                        scale: None,
3320                        collation: None,
3321                    },
3322                    crypto_metadata: None,
3323                },
3324                ColumnData {
3325                    name: "num".to_string(),
3326                    type_id: TypeId::IntN,
3327                    col_type: 0x26,
3328                    flags: 0x01,
3329                    user_type: 0,
3330                    type_info: TypeInfo {
3331                        max_length: Some(4),
3332                        precision: None,
3333                        scale: None,
3334                        collation: None,
3335                    },
3336                    crypto_metadata: None,
3337                },
3338            ],
3339        };
3340
3341        // Decode wire data
3342        let mut wire_cursor = wire_data.freeze();
3343        let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3344
3345        // Verify wire data was fully consumed
3346        assert_eq!(
3347            wire_cursor.remaining(),
3348            0,
3349            "wire data should be fully consumed"
3350        );
3351
3352        // Parse stored PLP data for column 0
3353        let mut stored_cursor: &[u8] = &raw_row.data;
3354
3355        // PLP stored format: [8-byte total][chunks...][4-byte 0]
3356        let total_len = stored_cursor.get_u64_le();
3357        assert_eq!(total_len, 10, "PLP total length should be 10");
3358
3359        let chunk_len = stored_cursor.get_u32_le();
3360        assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
3361
3362        let mut utf16_read = Vec::new();
3363        for _ in 0..(chunk_len / 2) {
3364            utf16_read.push(stored_cursor.get_u16_le());
3365        }
3366        let string0 = String::from_utf16(&utf16_read).unwrap();
3367        assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
3368
3369        let terminator = stored_cursor.get_u32_le();
3370        assert_eq!(terminator, 0, "PLP should end with 0");
3371
3372        // Parse IntN
3373        let len1 = stored_cursor.get_u8();
3374        assert_eq!(len1, 4);
3375        let int1 = stored_cursor.get_i32_le();
3376        assert_eq!(int1, 99, "column 1 should be 99");
3377
3378        // Verify fully consumed
3379        assert_eq!(
3380            stored_cursor.remaining(),
3381            0,
3382            "stored data should be fully consumed"
3383        );
3384    }
3385
3386    // ========================================================================
3387    // ReturnStatus Token Tests
3388    // ========================================================================
3389
3390    #[test]
3391    fn test_return_status_via_parser() {
3392        // RETURNSTATUS token: type (0x79) + value (i32 LE)
3393        let data = Bytes::from_static(&[
3394            0x79, // RETURNSTATUS token type
3395            0x00, 0x00, 0x00, 0x00, // return value = 0 (success)
3396        ]);
3397
3398        let mut parser = TokenParser::new(data);
3399        let token = parser.next_token().unwrap().unwrap();
3400
3401        match token {
3402            Token::ReturnStatus(status) => {
3403                assert_eq!(status, 0);
3404            }
3405            _ => panic!("Expected ReturnStatus token, got {token:?}"),
3406        }
3407
3408        assert!(parser.next_token().unwrap().is_none());
3409    }
3410
3411    #[test]
3412    fn test_return_status_nonzero() {
3413        // Return value = -6 (common for error returns)
3414        let mut buf = BytesMut::new();
3415        buf.put_u8(0x79); // RETURNSTATUS
3416        buf.put_i32_le(-6);
3417
3418        let mut parser = TokenParser::new(buf.freeze());
3419        let token = parser.next_token().unwrap().unwrap();
3420
3421        match token {
3422            Token::ReturnStatus(status) => {
3423                assert_eq!(status, -6);
3424            }
3425            _ => panic!("Expected ReturnStatus token"),
3426        }
3427    }
3428
3429    // ========================================================================
3430    // DoneProc Token Tests
3431    // ========================================================================
3432
3433    #[test]
3434    fn test_done_proc_roundtrip() {
3435        let done = DoneProc {
3436            status: DoneStatus {
3437                more: false,
3438                error: false,
3439                in_xact: false,
3440                count: true,
3441                attn: false,
3442                srverror: false,
3443            },
3444            cur_cmd: 0x00C6, // EXECUTE (198)
3445            row_count: 100,
3446        };
3447
3448        let mut buf = BytesMut::new();
3449        done.encode(&mut buf);
3450
3451        // Verify token type byte
3452        assert_eq!(buf[0], 0xFE);
3453
3454        // Skip token type byte and decode
3455        let mut cursor = &buf[1..];
3456        let decoded = DoneProc::decode(&mut cursor).unwrap();
3457
3458        assert!(decoded.status.count);
3459        assert!(!decoded.status.more);
3460        assert!(!decoded.status.error);
3461        assert_eq!(decoded.cur_cmd, 0x00C6);
3462        assert_eq!(decoded.row_count, 100);
3463    }
3464
3465    #[test]
3466    fn test_done_proc_via_parser() {
3467        let data = Bytes::from_static(&[
3468            0xFE, // DONEPROC token type
3469            0x00, 0x00, // status: no flags
3470            0xC6, 0x00, // cur_cmd: EXECUTE (198)
3471            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // row_count: 0
3472        ]);
3473
3474        let mut parser = TokenParser::new(data);
3475        let token = parser.next_token().unwrap().unwrap();
3476
3477        match token {
3478            Token::DoneProc(done) => {
3479                assert!(!done.status.count);
3480                assert!(!done.status.more);
3481                assert_eq!(done.cur_cmd, 198);
3482                assert_eq!(done.row_count, 0);
3483            }
3484            _ => panic!("Expected DoneProc token"),
3485        }
3486    }
3487
3488    #[test]
3489    fn test_done_proc_with_error_flag() {
3490        let mut buf = BytesMut::new();
3491        buf.put_u8(0xFE); // DONEPROC
3492        buf.put_u16_le(0x0002); // status: DONE_ERROR
3493        buf.put_u16_le(0x00C6); // cur_cmd: EXECUTE
3494        buf.put_u64_le(0); // row_count
3495
3496        let mut parser = TokenParser::new(buf.freeze());
3497        let token = parser.next_token().unwrap().unwrap();
3498
3499        match token {
3500            Token::DoneProc(done) => {
3501                assert!(done.status.error);
3502                assert!(!done.status.count);
3503                assert!(!done.status.more);
3504            }
3505            _ => panic!("Expected DoneProc token"),
3506        }
3507    }
3508
3509    // ========================================================================
3510    // DoneInProc Token Tests
3511    // ========================================================================
3512
3513    #[test]
3514    fn test_done_in_proc_roundtrip() {
3515        let done = DoneInProc {
3516            status: DoneStatus {
3517                more: true,
3518                error: false,
3519                in_xact: false,
3520                count: true,
3521                attn: false,
3522                srverror: false,
3523            },
3524            cur_cmd: 193, // SELECT
3525            row_count: 7,
3526        };
3527
3528        let mut buf = BytesMut::new();
3529        done.encode(&mut buf);
3530
3531        assert_eq!(buf[0], 0xFF);
3532
3533        let mut cursor = &buf[1..];
3534        let decoded = DoneInProc::decode(&mut cursor).unwrap();
3535
3536        assert!(decoded.status.more);
3537        assert!(decoded.status.count);
3538        assert!(!decoded.status.error);
3539        assert_eq!(decoded.cur_cmd, 193);
3540        assert_eq!(decoded.row_count, 7);
3541    }
3542
3543    #[test]
3544    fn test_done_in_proc_via_parser() {
3545        let data = Bytes::from_static(&[
3546            0xFF, // DONEINPROC token type
3547            0x11, 0x00, // status: MORE | COUNT
3548            0xC1, 0x00, // cur_cmd: SELECT (193)
3549            0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // row_count: 3
3550        ]);
3551
3552        let mut parser = TokenParser::new(data);
3553        let token = parser.next_token().unwrap().unwrap();
3554
3555        match token {
3556            Token::DoneInProc(done) => {
3557                assert!(done.status.more);
3558                assert!(done.status.count);
3559                assert_eq!(done.cur_cmd, 193);
3560                assert_eq!(done.row_count, 3);
3561            }
3562            _ => panic!("Expected DoneInProc token"),
3563        }
3564    }
3565
3566    // ========================================================================
3567    // ServerError Token Tests
3568    // ========================================================================
3569
3570    #[test]
3571    fn test_server_error_decode() {
3572        // Build a realistic ERROR token (without the 0xAA type byte,
3573        // since decode() is called after the parser strips it).
3574        let mut buf = BytesMut::new();
3575
3576        // Construct the message fields first to compute length
3577        let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
3578        let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
3579        let proc_utf16: Vec<u16> = "".encode_utf16().collect();
3580
3581        // Length = number(4) + state(1) + class(1)
3582        //        + us_varchar(message): 2 + msg_utf16.len()*2
3583        //        + b_varchar(server): 1 + srv_utf16.len()*2
3584        //        + b_varchar(procedure): 1 + proc_utf16.len()*2
3585        //        + line(4)
3586        let length: u16 = (4
3587            + 1
3588            + 1
3589            + 2
3590            + (msg_utf16.len() * 2)
3591            + 1
3592            + (srv_utf16.len() * 2)
3593            + 1
3594            + (proc_utf16.len() * 2)
3595            + 4) as u16;
3596
3597        buf.put_u16_le(length);
3598        buf.put_i32_le(207); // error number: Invalid column
3599        buf.put_u8(1); // state
3600        buf.put_u8(16); // class (severity 16)
3601
3602        // Message (US_VARCHAR: 2-byte char count + UTF-16LE)
3603        buf.put_u16_le(msg_utf16.len() as u16);
3604        for &c in &msg_utf16 {
3605            buf.put_u16_le(c);
3606        }
3607
3608        // Server (B_VARCHAR: 1-byte char count + UTF-16LE)
3609        buf.put_u8(srv_utf16.len() as u8);
3610        for &c in &srv_utf16 {
3611            buf.put_u16_le(c);
3612        }
3613
3614        // Procedure (B_VARCHAR: empty)
3615        buf.put_u8(proc_utf16.len() as u8);
3616
3617        // Line number
3618        buf.put_i32_le(42);
3619
3620        let mut cursor = buf.freeze();
3621        let error = ServerError::decode(&mut cursor).unwrap();
3622
3623        assert_eq!(error.number, 207);
3624        assert_eq!(error.state, 1);
3625        assert_eq!(error.class, 16);
3626        assert_eq!(error.message, "Invalid column name 'foo'.");
3627        assert_eq!(error.server, "SQLDB01");
3628        assert_eq!(error.procedure, "");
3629        assert_eq!(error.line, 42);
3630    }
3631
3632    #[test]
3633    fn test_server_error_severity_helpers() {
3634        let fatal = ServerError {
3635            number: 4014,
3636            state: 1,
3637            class: 20,
3638            message: "Fatal error".to_string(),
3639            server: String::new(),
3640            procedure: String::new(),
3641            line: 0,
3642        };
3643        assert!(fatal.is_fatal());
3644        assert!(fatal.is_batch_abort());
3645
3646        let batch_abort = ServerError {
3647            number: 547,
3648            state: 0,
3649            class: 16,
3650            message: "Constraint violation".to_string(),
3651            server: String::new(),
3652            procedure: String::new(),
3653            line: 1,
3654        };
3655        assert!(!batch_abort.is_fatal());
3656        assert!(batch_abort.is_batch_abort());
3657
3658        let informational = ServerError {
3659            number: 5701,
3660            state: 2,
3661            class: 10,
3662            message: "Changed db context".to_string(),
3663            server: String::new(),
3664            procedure: String::new(),
3665            line: 0,
3666        };
3667        assert!(!informational.is_fatal());
3668        assert!(!informational.is_batch_abort());
3669    }
3670
3671    #[test]
3672    fn test_server_error_via_parser() {
3673        // Build an ERROR token with the 0xAA type byte for the parser
3674        let mut buf = BytesMut::new();
3675        buf.put_u8(0xAA); // ERROR token type
3676
3677        let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
3678        let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
3679        let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
3680
3681        let length: u16 = (4
3682            + 1
3683            + 1
3684            + 2
3685            + (msg_utf16.len() * 2)
3686            + 1
3687            + (srv_utf16.len() * 2)
3688            + 1
3689            + (proc_utf16.len() * 2)
3690            + 4) as u16;
3691
3692        buf.put_u16_le(length);
3693        buf.put_i32_le(102); // Syntax error
3694        buf.put_u8(1);
3695        buf.put_u8(15);
3696
3697        buf.put_u16_le(msg_utf16.len() as u16);
3698        for &c in &msg_utf16 {
3699            buf.put_u16_le(c);
3700        }
3701        buf.put_u8(srv_utf16.len() as u8);
3702        for &c in &srv_utf16 {
3703            buf.put_u16_le(c);
3704        }
3705        buf.put_u8(proc_utf16.len() as u8);
3706        for &c in &proc_utf16 {
3707            buf.put_u16_le(c);
3708        }
3709        buf.put_i32_le(5);
3710
3711        let mut parser = TokenParser::new(buf.freeze());
3712        let token = parser.next_token().unwrap().unwrap();
3713
3714        match token {
3715            Token::Error(err) => {
3716                assert_eq!(err.number, 102);
3717                assert_eq!(err.class, 15);
3718                assert_eq!(err.message, "Syntax error");
3719                assert_eq!(err.server, "SRV");
3720                assert_eq!(err.procedure, "sp_test");
3721                assert_eq!(err.line, 5);
3722            }
3723            _ => panic!("Expected Error token"),
3724        }
3725    }
3726
3727    // ========================================================================
3728    // ReturnValue Token Tests
3729    // ========================================================================
3730
3731    /// Helper: build a ReturnValue token (without the 0xAC type byte)
3732    /// for an IntN output parameter.
3733    fn build_return_value_intn(
3734        ordinal: u16,
3735        name: &str,
3736        status: u8,
3737        value: Option<i32>,
3738    ) -> BytesMut {
3739        let mut inner = BytesMut::new();
3740
3741        // param_ordinal
3742        inner.put_u16_le(ordinal);
3743
3744        // param_name (B_VARCHAR)
3745        let name_utf16: Vec<u16> = name.encode_utf16().collect();
3746        inner.put_u8(name_utf16.len() as u8);
3747        for &c in &name_utf16 {
3748            inner.put_u16_le(c);
3749        }
3750
3751        // status
3752        inner.put_u8(status);
3753
3754        // user_type (4 bytes)
3755        inner.put_u32_le(0);
3756
3757        // flags (2 bytes)
3758        inner.put_u16_le(0x0001); // nullable
3759
3760        // type_id: IntN = 0x26
3761        inner.put_u8(0x26);
3762
3763        // type_info for IntN: 1-byte max_length
3764        inner.put_u8(4);
3765
3766        // value (TYPE_VARBYTE for IntN: 1-byte length + data)
3767        match value {
3768            Some(v) => {
3769                inner.put_u8(4); // length = 4
3770                inner.put_i32_le(v);
3771            }
3772            None => {
3773                inner.put_u8(0); // length = 0 means NULL
3774            }
3775        }
3776
3777        // RETURNVALUE has no outer length prefix (MS-TDS §2.2.7.18) — the
3778        // decoder walks the inner fields directly after the 0xAC token byte.
3779        inner
3780    }
3781
3782    #[test]
3783    fn test_return_value_int_output() {
3784        let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
3785        let mut cursor = buf.freeze();
3786        let rv = ReturnValue::decode(&mut cursor).unwrap();
3787
3788        assert_eq!(rv.param_ordinal, 1);
3789        assert_eq!(rv.param_name, "@result");
3790        assert_eq!(rv.status, 0x01); // OUTPUT
3791        assert_eq!(rv.col_type, 0x26); // IntN
3792        assert_eq!(rv.type_info.max_length, Some(4));
3793        // Value should contain: length byte (4) + i32 LE (42)
3794        assert_eq!(rv.value.len(), 5);
3795        assert_eq!(rv.value[0], 4);
3796        assert_eq!(
3797            i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3798            42
3799        );
3800    }
3801
3802    #[test]
3803    fn test_return_value_null_output() {
3804        let buf = build_return_value_intn(2, "@count", 0x01, None);
3805        let mut cursor = buf.freeze();
3806        let rv = ReturnValue::decode(&mut cursor).unwrap();
3807
3808        assert_eq!(rv.param_ordinal, 2);
3809        assert_eq!(rv.param_name, "@count");
3810        assert_eq!(rv.status, 0x01);
3811        assert_eq!(rv.col_type, 0x26);
3812        // NULL value: length byte = 0
3813        assert_eq!(rv.value.len(), 1);
3814        assert_eq!(rv.value[0], 0);
3815    }
3816
3817    #[test]
3818    fn test_return_value_udf_status() {
3819        // UDF return value has status = 0x02
3820        let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
3821        let mut cursor = buf.freeze();
3822        let rv = ReturnValue::decode(&mut cursor).unwrap();
3823
3824        assert_eq!(rv.param_ordinal, 0);
3825        assert_eq!(rv.param_name, "@RETURN_VALUE");
3826        assert_eq!(rv.status, 0x02); // UDF return value
3827        assert_eq!(rv.value[0], 4);
3828        assert_eq!(
3829            i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3830            -1
3831        );
3832    }
3833
3834    #[test]
3835    fn test_return_value_nvarchar_output() {
3836        // Build a ReturnValue for NVARCHAR(100) output parameter
3837        let mut inner = BytesMut::new();
3838
3839        // param_ordinal
3840        inner.put_u16_le(1);
3841
3842        // param_name "@name"
3843        let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
3844        inner.put_u8(name_utf16.len() as u8);
3845        for &c in &name_utf16 {
3846            inner.put_u16_le(c);
3847        }
3848
3849        // status = OUTPUT
3850        inner.put_u8(0x01);
3851        // user_type
3852        inner.put_u32_le(0);
3853        // flags (nullable)
3854        inner.put_u16_le(0x0001);
3855        // type_id: NVarChar = 0xE7
3856        inner.put_u8(0xE7);
3857        // type_info for NVarChar: 2-byte max_length + 5-byte collation
3858        inner.put_u16_le(200); // max 100 chars * 2 bytes
3859        inner.put_u32_le(0x0904D000); // collation LCID
3860        inner.put_u8(0x34); // collation sort_id
3861
3862        // value: "Hello" in UTF-16LE with 2-byte length prefix
3863        let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
3864        let byte_len = (val_utf16.len() * 2) as u16;
3865        inner.put_u16_le(byte_len);
3866        for &c in &val_utf16 {
3867            inner.put_u16_le(c);
3868        }
3869
3870        let mut cursor = inner.freeze();
3871        let rv = ReturnValue::decode(&mut cursor).unwrap();
3872
3873        assert_eq!(rv.param_ordinal, 1);
3874        assert_eq!(rv.param_name, "@name");
3875        assert_eq!(rv.status, 0x01);
3876        assert_eq!(rv.col_type, 0xE7); // NVarChar
3877        assert_eq!(rv.type_info.max_length, Some(200));
3878        assert!(rv.type_info.collation.is_some());
3879
3880        // Value: 2-byte length (10) + "Hello" in UTF-16LE
3881        assert_eq!(rv.value.len(), 12); // 2 + 10
3882        let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
3883        assert_eq!(val_len, 10);
3884    }
3885
3886    #[test]
3887    fn test_return_value_via_parser() {
3888        // Build a full ReturnValue token with the 0xAC type byte
3889        let mut data = BytesMut::new();
3890        data.put_u8(0xAC); // RETURNVALUE token type
3891        data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
3892
3893        let mut parser = TokenParser::new(data.freeze());
3894        let token = parser.next_token().unwrap().unwrap();
3895
3896        match token {
3897            Token::ReturnValue(rv) => {
3898                assert_eq!(rv.param_name, "@out");
3899                assert_eq!(rv.param_ordinal, 0);
3900                assert_eq!(rv.status, 0x01);
3901                assert_eq!(rv.col_type, 0x26);
3902            }
3903            _ => panic!("Expected ReturnValue token"),
3904        }
3905    }
3906
3907    // ========================================================================
3908    // Multi-Token Stream Tests
3909    // ========================================================================
3910
3911    #[test]
3912    fn test_multi_token_stored_proc_response() {
3913        // Simulate a stored procedure response:
3914        // DoneInProc (result set done) → ReturnStatus → DoneProc
3915        let mut data = BytesMut::new();
3916
3917        // Token 1: DONEINPROC — result set with 3 rows
3918        data.put_u8(0xFF); // DONEINPROC
3919        data.put_u16_le(0x0010); // status: COUNT
3920        data.put_u16_le(0x00C1); // cur_cmd: SELECT
3921        data.put_u64_le(3); // row_count
3922
3923        // Token 2: RETURNSTATUS — procedure returned 0
3924        data.put_u8(0x79); // RETURNSTATUS
3925        data.put_i32_le(0);
3926
3927        // Token 3: DONEPROC — final
3928        data.put_u8(0xFE); // DONEPROC
3929        data.put_u16_le(0x0000); // status: no flags
3930        data.put_u16_le(0x00C6); // cur_cmd: EXECUTE
3931        data.put_u64_le(0);
3932
3933        let mut parser = TokenParser::new(data.freeze());
3934
3935        // Token 1: DoneInProc
3936        let t1 = parser.next_token().unwrap().unwrap();
3937        match t1 {
3938            Token::DoneInProc(done) => {
3939                assert!(done.status.count);
3940                assert_eq!(done.row_count, 3);
3941                assert_eq!(done.cur_cmd, 193);
3942            }
3943            _ => panic!("Expected DoneInProc, got {t1:?}"),
3944        }
3945
3946        // Token 2: ReturnStatus
3947        let t2 = parser.next_token().unwrap().unwrap();
3948        match t2 {
3949            Token::ReturnStatus(status) => {
3950                assert_eq!(status, 0);
3951            }
3952            _ => panic!("Expected ReturnStatus, got {t2:?}"),
3953        }
3954
3955        // Token 3: DoneProc
3956        let t3 = parser.next_token().unwrap().unwrap();
3957        match t3 {
3958            Token::DoneProc(done) => {
3959                assert!(!done.status.count);
3960                assert!(!done.status.more);
3961                assert_eq!(done.cur_cmd, 198);
3962            }
3963            _ => panic!("Expected DoneProc, got {t3:?}"),
3964        }
3965
3966        // No more tokens
3967        assert!(parser.next_token().unwrap().is_none());
3968    }
3969
3970    #[test]
3971    fn test_multi_token_error_in_stream() {
3972        // Simulate: ERROR → DONE (error during query)
3973        let mut data = BytesMut::new();
3974
3975        // Token 1: ERROR
3976        data.put_u8(0xAA);
3977
3978        let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
3979        let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
3980
3981        let length: u16 = (4 + 1 + 1
3982            + 2 + (msg_utf16.len() * 2)
3983            + 1 + (srv_utf16.len() * 2)
3984            + 1  // empty procedure
3985            + 4) as u16;
3986
3987        data.put_u16_le(length);
3988        data.put_i32_le(1205); // deadlock
3989        data.put_u8(51); // state
3990        data.put_u8(13); // class
3991
3992        data.put_u16_le(msg_utf16.len() as u16);
3993        for &c in &msg_utf16 {
3994            data.put_u16_le(c);
3995        }
3996        data.put_u8(srv_utf16.len() as u8);
3997        for &c in &srv_utf16 {
3998            data.put_u16_le(c);
3999        }
4000        data.put_u8(0); // empty procedure
4001        data.put_i32_le(0);
4002
4003        // Token 2: DONE with error flag
4004        data.put_u8(0xFD);
4005        data.put_u16_le(0x0002); // DONE_ERROR
4006        data.put_u16_le(0x00C1); // SELECT
4007        data.put_u64_le(0);
4008
4009        let mut parser = TokenParser::new(data.freeze());
4010
4011        // Token 1: Error
4012        let t1 = parser.next_token().unwrap().unwrap();
4013        match t1 {
4014            Token::Error(err) => {
4015                assert_eq!(err.number, 1205);
4016                assert_eq!(err.class, 13);
4017                assert_eq!(err.message, "Deadlock");
4018                assert_eq!(err.server, "DB1");
4019            }
4020            _ => panic!("Expected Error token, got {t1:?}"),
4021        }
4022
4023        // Token 2: Done with error
4024        let t2 = parser.next_token().unwrap().unwrap();
4025        match t2 {
4026            Token::Done(done) => {
4027                assert!(done.status.error);
4028                assert!(!done.status.count);
4029            }
4030            _ => panic!("Expected Done token, got {t2:?}"),
4031        }
4032
4033        assert!(parser.next_token().unwrap().is_none());
4034    }
4035
4036    #[test]
4037    fn test_multi_token_proc_with_return_value() {
4038        // Simulate stored proc: ReturnValue → ReturnStatus → DoneProc
4039        let mut data = BytesMut::new();
4040
4041        // Token 1: ReturnValue (@result = 42)
4042        data.put_u8(0xAC);
4043        data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
4044
4045        // Token 2: ReturnStatus = 0
4046        data.put_u8(0x79);
4047        data.put_i32_le(0);
4048
4049        // Token 3: DoneProc
4050        data.put_u8(0xFE);
4051        data.put_u16_le(0x0000);
4052        data.put_u16_le(0x00C6);
4053        data.put_u64_le(0);
4054
4055        let mut parser = TokenParser::new(data.freeze());
4056
4057        let t1 = parser.next_token().unwrap().unwrap();
4058        match t1 {
4059            Token::ReturnValue(rv) => {
4060                assert_eq!(rv.param_name, "@result");
4061                assert_eq!(rv.param_ordinal, 1);
4062            }
4063            _ => panic!("Expected ReturnValue, got {t1:?}"),
4064        }
4065
4066        let t2 = parser.next_token().unwrap().unwrap();
4067        assert!(matches!(t2, Token::ReturnStatus(0)));
4068
4069        let t3 = parser.next_token().unwrap().unwrap();
4070        assert!(matches!(t3, Token::DoneProc(_)));
4071
4072        assert!(parser.next_token().unwrap().is_none());
4073    }
4074
4075    // ========================================================================
4076    // EOF / Truncation Edge Cases
4077    // ========================================================================
4078
4079    #[test]
4080    fn test_return_status_truncated() {
4081        // Only 3 bytes instead of 4 for i32
4082        let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
4083        let mut parser = TokenParser::new(data);
4084        assert!(parser.next_token().is_err());
4085    }
4086
4087    #[test]
4088    fn test_done_proc_truncated() {
4089        // Only 8 bytes instead of 12
4090        let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
4091        let mut parser = TokenParser::new(data);
4092        assert!(parser.next_token().is_err());
4093    }
4094
4095    #[test]
4096    fn test_server_error_truncated() {
4097        // ERROR token with only the length field (body truncated)
4098        let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
4099        let mut parser = TokenParser::new(data);
4100        assert!(parser.next_token().is_err());
4101    }
4102
4103    // ========================================================================
4104    // FEDAUTHINFO (issue #189: parser must follow MS-TDS §2.2.7.12)
4105    // ========================================================================
4106
4107    /// Build a spec-exact FEDAUTHINFO token (including the 0xEE type byte):
4108    /// DWORD TokenLength, DWORD CountOfInfoIDs, option headers of
4109    /// ID/DataLen/DataOffset, then UTF-16LE data addressed by the offsets
4110    /// (relative to the start of the count field).
4111    fn build_fed_auth_info_token(options: &[(u8, &str)]) -> Vec<u8> {
4112        let headers_end = 4 + options.len() * 9;
4113        let mut data_block = Vec::new();
4114        let mut headers = Vec::new();
4115        for (id, value) in options {
4116            let encoded: Vec<u8> = value.encode_utf16().flat_map(u16::to_le_bytes).collect();
4117            let offset = headers_end + data_block.len();
4118            headers.push(*id);
4119            headers.extend_from_slice(&u32::try_from(encoded.len()).unwrap().to_le_bytes());
4120            headers.extend_from_slice(&u32::try_from(offset).unwrap().to_le_bytes());
4121            data_block.extend_from_slice(&encoded);
4122        }
4123
4124        let token_len = 4 + headers.len() + data_block.len();
4125        let mut out = vec![0xEE];
4126        out.extend_from_slice(&u32::try_from(token_len).unwrap().to_le_bytes());
4127        out.extend_from_slice(&u32::try_from(options.len()).unwrap().to_le_bytes());
4128        out.extend_from_slice(&headers);
4129        out.extend_from_slice(&data_block);
4130        out
4131    }
4132
4133    #[test]
4134    fn test_fed_auth_info_decodes_spec_layout() {
4135        const STS: &str = "https://login.microsoftonline.com/common";
4136        const SPN: &str = "https://database.windows.net/";
4137        // STSURL (0x01) listed first, SPN (0x02) second. Real Azure servers
4138        // list SPN first (see the captured-token test below); decoding must
4139        // not depend on option order.
4140        let token = build_fed_auth_info_token(&[(0x01, STS), (0x02, SPN)]);
4141
4142        let mut parser = TokenParser::new(Bytes::from(token));
4143        let parsed = parser.next_token().unwrap().unwrap();
4144        let Token::FedAuthInfo(info) = parsed else {
4145            panic!("expected FedAuthInfo, got {parsed:?}");
4146        };
4147        assert_eq!(info.sts_url, STS);
4148        assert_eq!(info.spn, SPN);
4149        assert!(parser.next_token().unwrap().is_none(), "exact consumption");
4150    }
4151
4152    #[test]
4153    fn test_fed_auth_info_preserves_following_tokens() {
4154        // The old parser looped over the whole remaining stream, swallowing
4155        // the tokens that follow FEDAUTHINFO during login. A DONE token
4156        // appended after it must survive.
4157        let mut stream = build_fed_auth_info_token(&[
4158            (0x01, "https://sts.example/"),
4159            (0x02, "https://db.example/"),
4160        ]);
4161        stream.push(0xFD); // DONE
4162        stream.extend_from_slice(&0u16.to_le_bytes()); // status
4163        stream.extend_from_slice(&0u16.to_le_bytes()); // curcmd
4164        stream.extend_from_slice(&0u64.to_le_bytes()); // rowcount
4165
4166        let mut parser = TokenParser::new(Bytes::from(stream));
4167        assert!(matches!(
4168            parser.next_token().unwrap(),
4169            Some(Token::FedAuthInfo(_))
4170        ));
4171        assert!(
4172            matches!(parser.next_token().unwrap(), Some(Token::Done(_))),
4173            "DONE after FEDAUTHINFO must not be swallowed"
4174        );
4175        assert!(parser.next_token().unwrap().is_none());
4176    }
4177
4178    #[test]
4179    fn test_fed_auth_info_unknown_ids_ignored() {
4180        // Spec: unrecognized FedAuthInfoIDs must be ignored.
4181        let token =
4182            build_fed_auth_info_token(&[(0x7F, "ignore-me"), (0x01, "https://sts.example/")]);
4183        let mut parser = TokenParser::new(Bytes::from(token));
4184        let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4185            panic!("expected FedAuthInfo");
4186        };
4187        assert_eq!(info.sts_url, "https://sts.example/");
4188        assert_eq!(info.spn, "");
4189    }
4190
4191    #[test]
4192    fn test_fed_auth_info_hostile_inputs_error() {
4193        // TokenLength longer than the buffer.
4194        let mut truncated = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4195        truncated.truncate(truncated.len() - 4);
4196        assert!(
4197            TokenParser::new(Bytes::from(truncated))
4198                .next_token()
4199                .is_err()
4200        );
4201
4202        // CountOfInfoIDs claims more headers than the token holds
4203        // (also covers hostile counts whose header math would overflow).
4204        let mut bad_count = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4205        bad_count[5..9].copy_from_slice(&u32::MAX.to_le_bytes());
4206        assert!(
4207            TokenParser::new(Bytes::from(bad_count))
4208                .next_token()
4209                .is_err()
4210        );
4211
4212        // Data offset pointing past the end of the token.
4213        let mut bad_offset = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4214        bad_offset[14..18].copy_from_slice(&u32::MAX.to_le_bytes());
4215        assert!(
4216            TokenParser::new(Bytes::from(bad_offset))
4217                .next_token()
4218                .is_err()
4219        );
4220
4221        // Odd data length cannot be UTF-16.
4222        let mut odd_len = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4223        odd_len[10..14].copy_from_slice(&3u32.to_le_bytes());
4224        assert!(TokenParser::new(Bytes::from(odd_len)).next_token().is_err());
4225    }
4226
4227    #[test]
4228    fn test_fed_auth_info_parse_and_skip_agree() {
4229        // Issue #189: decode() and skip_token() must consume the same bytes
4230        // (the old decode ran past the token while skip honored the length).
4231        let token = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4232        let total = token.len();
4233
4234        let mut parser = TokenParser::new(Bytes::from(token.clone()));
4235        parser.next_token().unwrap();
4236        assert_eq!(parser.position(), total, "decode consumption");
4237
4238        let mut skipper = TokenParser::new(Bytes::from(token));
4239        skipper.skip_token().unwrap();
4240        assert_eq!(skipper.position(), total, "skip consumption");
4241    }
4242
4243    /// A FEDAUTHINFO token captured from a live Azure SQL Database login on
4244    /// 2026-06-12 (the client declared the ADAL library in LOGIN7; the server
4245    /// responded with this token). The tenant GUID inside the STS URL is
4246    /// replaced with an all-zero GUID of identical length, so every offset
4247    /// and length is byte-identical to the wire capture.
4248    ///
4249    /// This is the regression test deferred from PR #193, and it earns its
4250    /// keep: the real token proves FedAuthInfoID 0x01 = STSURL and
4251    /// 0x02 = SPN (Azure lists SPN first), which the synthetic tests
4252    /// originally had swapped.
4253    #[test]
4254    fn test_fed_auth_info_captured_from_azure() {
4255        const CAPTURED: &[u8] = &[
4256            0xEE, 0xCC, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x3A, 0x00, 0x00, 0x00,
4257            0x16, 0x00, 0x00, 0x00, 0x01, 0x7C, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68,
4258            0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F, 0x00, 0x2F,
4259            0x00, 0x64, 0x00, 0x61, 0x00, 0x74, 0x00, 0x61, 0x00, 0x62, 0x00, 0x61, 0x00, 0x73,
4260            0x00, 0x65, 0x00, 0x2E, 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F,
4261            0x00, 0x77, 0x00, 0x73, 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F,
4262            0x00, 0x68, 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F,
4263            0x00, 0x2F, 0x00, 0x6C, 0x00, 0x6F, 0x00, 0x67, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x2E,
4264            0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F, 0x00, 0x77, 0x00, 0x73,
4265            0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F, 0x00, 0x30, 0x00, 0x30,
4266            0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D,
4267            0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30,
4268            0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4269            0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4270            0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00,
4271        ];
4272
4273        let mut parser = TokenParser::new(Bytes::from_static(CAPTURED));
4274        let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4275            panic!("expected FedAuthInfo");
4276        };
4277        assert_eq!(
4278            info.sts_url,
4279            "https://login.windows.net/00000000-0000-0000-0000-000000000000"
4280        );
4281        assert_eq!(info.spn, "https://database.windows.net/");
4282        assert!(
4283            parser.next_token().unwrap().is_none(),
4284            "the captured token must be consumed exactly"
4285        );
4286    }
4287
4288    /// Regression test for #273 (remote DoS via unbounded token-skip recursion).
4289    ///
4290    /// COLINFO/TABNAME/OFFSET tokens are skipped, and the skip path used to
4291    /// self-recurse to fetch the next token — one real stack frame per skipped
4292    /// token (Rust guarantees no tail-call elimination). Because
4293    /// `Connection::read_message` reassembles every packet into one `Bytes`
4294    /// before tokenizing, a server-controlled multi-MB message of these 3-byte
4295    /// tokens drove ~10^5–10^6 frames and aborted the process.
4296    ///
4297    /// This feeds a flat run of 200_000 skip-tokens — far deeper than the
4298    /// (~2 MB) test-thread stack could survive by recursion — followed by a
4299    /// real DONE. Pre-fix, the recursive parser aborted (SIGABRT) on this
4300    /// input. Post-fix it must return that DONE with the whole buffer consumed,
4301    /// proving the skip path is bounded-stack and that the input was
4302    /// non-degenerate (all 200_000 tokens were actually traversed).
4303    #[test]
4304    fn skip_tokens_iterate_not_recurse_273() {
4305        const SKIP_COUNT: usize = 200_000;
4306        let mut buf = BytesMut::with_capacity(SKIP_COUNT * 3 + 13);
4307        for _ in 0..SKIP_COUNT {
4308            buf.put_u8(TokenType::ColInfo as u8);
4309            buf.put_u16_le(0); // zero-length body: 3 bytes total per skip-token
4310        }
4311        let done = Done {
4312            status: DoneStatus {
4313                more: false,
4314                error: false,
4315                in_xact: false,
4316                count: true,
4317                attn: false,
4318                srverror: false,
4319            },
4320            cur_cmd: 0xABCD,
4321            row_count: 99,
4322        };
4323        done.encode(&mut buf);
4324        let total_len = buf.len();
4325
4326        let mut parser = TokenParser::new(buf.freeze());
4327
4328        // A specific outcome — the DONE that follows the skip run — not a
4329        // generic is_err an unrelated parse failure could also satisfy.
4330        let Some(Token::Done(decoded)) = parser.next_token().unwrap() else {
4331            panic!("expected the DONE token after the skip run");
4332        };
4333        assert_eq!(decoded.cur_cmd, 0xABCD);
4334        assert_eq!(decoded.row_count, 99);
4335
4336        // Every skipped token was traversed: proof the input was non-trivial.
4337        assert_eq!(parser.position(), total_len);
4338        assert!(parser.next_token().unwrap().is_none());
4339    }
4340}