Skip to main content

tds_protocol/
prelogin.rs

1//! TDS pre-login packet handling.
2//!
3//! The pre-login packet is the first message exchanged between client and server
4//! in TDS 7.x connections. It negotiates protocol version, encryption, and other
5//! connection parameters.
6//!
7//! Note: TDS 8.0 (strict mode) does not use pre-login negotiation; TLS is
8//! established before any TDS traffic.
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16/// Pre-login option types.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19#[non_exhaustive]
20pub enum PreLoginOption {
21    /// Version information.
22    Version = 0x00,
23    /// Encryption negotiation.
24    Encryption = 0x01,
25    /// Instance name (for named instances).
26    Instance = 0x02,
27    /// Thread ID.
28    ThreadId = 0x03,
29    /// MARS (Multiple Active Result Sets) support.
30    Mars = 0x04,
31    /// Trace ID for distributed tracing.
32    TraceId = 0x05,
33    /// Federated authentication required.
34    FedAuthRequired = 0x06,
35    /// Nonce for encryption.
36    Nonce = 0x07,
37    /// Terminator (end of options).
38    Terminator = 0xFF,
39}
40
41impl PreLoginOption {
42    /// Create from raw byte value.
43    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
44        match value {
45            0x00 => Ok(Self::Version),
46            0x01 => Ok(Self::Encryption),
47            0x02 => Ok(Self::Instance),
48            0x03 => Ok(Self::ThreadId),
49            0x04 => Ok(Self::Mars),
50            0x05 => Ok(Self::TraceId),
51            0x06 => Ok(Self::FedAuthRequired),
52            0x07 => Ok(Self::Nonce),
53            0xFF => Ok(Self::Terminator),
54            _ => Err(ProtocolError::InvalidPreloginOption(value)),
55        }
56    }
57}
58
59/// Encryption level for connection.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61#[repr(u8)]
62#[non_exhaustive]
63pub enum EncryptionLevel {
64    /// Encryption is off.
65    Off = 0x00,
66    /// Encryption is on.
67    On = 0x01,
68    /// Encryption is not supported.
69    NotSupported = 0x02,
70    /// Encryption is required.
71    #[default]
72    Required = 0x03,
73    /// Client certificate authentication (TDS 8.0+).
74    ClientCertAuth = 0x80,
75}
76
77impl EncryptionLevel {
78    /// Create from raw byte value.
79    pub fn from_u8(value: u8) -> Self {
80        match value {
81            0x00 => Self::Off,
82            0x01 => Self::On,
83            0x02 => Self::NotSupported,
84            0x03 => Self::Required,
85            0x80 => Self::ClientCertAuth,
86            _ => Self::Off,
87        }
88    }
89
90    /// Check if encryption is required.
91    #[must_use]
92    pub const fn is_required(&self) -> bool {
93        matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
94    }
95}
96
97/// Pre-login message builder and parser.
98///
99/// This struct is used for both client requests and server responses:
100/// - **Client → Server**: Set `version` to the requested TDS version
101/// - **Server → Client**: `server_version` contains the SQL Server product version
102///
103/// Note: The VERSION field has different semantics in each direction:
104/// - Client sends: TDS protocol version (e.g., 7.4)
105/// - Server sends: SQL Server product version (e.g., 13.0.6300 for SQL Server 2016)
106#[derive(Debug, Clone, Default)]
107pub struct PreLogin {
108    /// TDS version (client request).
109    ///
110    /// This is the TDS protocol version the client requests. When sending a
111    /// PreLogin, set this to the desired TDS version.
112    pub version: TdsVersion,
113
114    /// SQL Server product version (server response).
115    ///
116    /// When decoding a PreLogin response from the server, this contains the
117    /// SQL Server product version (e.g., 13.0.6300 for SQL Server 2016).
118    /// This is NOT the TDS version - the actual TDS version is negotiated
119    /// in the LOGINACK token after login.
120    pub server_version: Option<SqlServerVersion>,
121
122    /// Encryption level.
123    pub encryption: EncryptionLevel,
124    /// Instance name (for named instances).
125    pub instance: Option<String>,
126    /// Thread ID.
127    pub thread_id: Option<u32>,
128    /// MARS enabled.
129    pub mars: bool,
130    /// Trace ID (Activity ID and Sequence).
131    pub trace_id: Option<TraceId>,
132    /// Federated authentication required.
133    pub fed_auth_required: bool,
134    /// Nonce for encryption.
135    pub nonce: Option<[u8; 32]>,
136}
137
138/// Distributed tracing ID.
139#[derive(Debug, Clone, Copy)]
140pub struct TraceId {
141    /// Activity ID (GUID).
142    pub activity_id: [u8; 16],
143    /// Activity sequence.
144    pub activity_sequence: u32,
145}
146
147impl PreLogin {
148    /// Create a new pre-login message with default values.
149    #[must_use]
150    pub fn new() -> Self {
151        Self {
152            version: TdsVersion::V7_4,
153            server_version: None,
154            encryption: EncryptionLevel::Required,
155            instance: None,
156            thread_id: None,
157            mars: false,
158            trace_id: None,
159            fed_auth_required: false,
160            nonce: None,
161        }
162    }
163
164    /// Set the TDS version.
165    #[must_use]
166    pub fn with_version(mut self, version: TdsVersion) -> Self {
167        self.version = version;
168        self
169    }
170
171    /// Set the encryption level.
172    #[must_use]
173    pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
174        self.encryption = level;
175        self
176    }
177
178    /// Enable MARS.
179    #[must_use]
180    pub fn with_mars(mut self, enabled: bool) -> Self {
181        self.mars = enabled;
182        self
183    }
184
185    /// Set the instance name.
186    #[must_use]
187    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
188        self.instance = Some(instance.into());
189        self
190    }
191
192    /// Encode the pre-login message to bytes.
193    #[must_use]
194    pub fn encode(&self) -> Bytes {
195        let mut buf = BytesMut::with_capacity(256);
196
197        // Calculate option data offsets
198        // Each option entry is 5 bytes: type (1) + offset (2) + length (2)
199        // Plus 1 byte for terminator
200        let mut option_count = 3; // Version, Encryption, MARS are always present
201        if self.instance.is_some() {
202            option_count += 1;
203        }
204        if self.thread_id.is_some() {
205            option_count += 1;
206        }
207        if self.trace_id.is_some() {
208            option_count += 1;
209        }
210        if self.fed_auth_required {
211            option_count += 1;
212        }
213        if self.nonce.is_some() {
214            option_count += 1;
215        }
216
217        let header_size = option_count * 5 + 1; // +1 for terminator
218        let mut data_offset = header_size as u16;
219        let mut data_buf = BytesMut::new();
220
221        // VERSION option (6 bytes: 4 bytes version + 2 bytes sub-build)
222        buf.put_u8(PreLoginOption::Version as u8);
223        buf.put_u16(data_offset);
224        buf.put_u16(6);
225        let version_raw = self.version.raw();
226        data_buf.put_u8((version_raw >> 24) as u8);
227        data_buf.put_u8((version_raw >> 16) as u8);
228        data_buf.put_u8((version_raw >> 8) as u8);
229        data_buf.put_u8(version_raw as u8);
230        // Sub-build is always 0 for client-sent PreLogin; server sub-build
231        // lives in server_version after decode.
232        data_buf.put_u16_le(0);
233        data_offset += 6;
234
235        // ENCRYPTION option (1 byte)
236        buf.put_u8(PreLoginOption::Encryption as u8);
237        buf.put_u16(data_offset);
238        buf.put_u16(1);
239        data_buf.put_u8(self.encryption as u8);
240        data_offset += 1;
241
242        // INSTANCE option (if set)
243        if let Some(ref instance) = self.instance {
244            let instance_bytes = instance.as_bytes();
245            let len = instance_bytes.len() as u16 + 1; // +1 for null terminator
246            buf.put_u8(PreLoginOption::Instance as u8);
247            buf.put_u16(data_offset);
248            buf.put_u16(len);
249            data_buf.put_slice(instance_bytes);
250            data_buf.put_u8(0); // null terminator
251            data_offset += len;
252        }
253
254        // THREADID option (if set)
255        if let Some(thread_id) = self.thread_id {
256            buf.put_u8(PreLoginOption::ThreadId as u8);
257            buf.put_u16(data_offset);
258            buf.put_u16(4);
259            data_buf.put_u32(thread_id);
260            data_offset += 4;
261        }
262
263        // MARS option (1 byte)
264        buf.put_u8(PreLoginOption::Mars as u8);
265        buf.put_u16(data_offset);
266        buf.put_u16(1);
267        data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
268        data_offset += 1;
269
270        // TRACEID option (if set)
271        if let Some(ref trace_id) = self.trace_id {
272            buf.put_u8(PreLoginOption::TraceId as u8);
273            buf.put_u16(data_offset);
274            buf.put_u16(36);
275            data_buf.put_slice(&trace_id.activity_id);
276            data_buf.put_u32_le(trace_id.activity_sequence);
277            // Connection ID (16 bytes, typically zeros for client)
278            data_buf.put_slice(&[0u8; 16]);
279            data_offset += 36;
280        }
281
282        // FEDAUTHREQUIRED option (if set)
283        if self.fed_auth_required {
284            buf.put_u8(PreLoginOption::FedAuthRequired as u8);
285            buf.put_u16(data_offset);
286            buf.put_u16(1);
287            data_buf.put_u8(0x01);
288            data_offset += 1;
289        }
290
291        // NONCE option (if set)
292        if let Some(ref nonce) = self.nonce {
293            buf.put_u8(PreLoginOption::Nonce as u8);
294            buf.put_u16(data_offset);
295            buf.put_u16(32);
296            data_buf.put_slice(nonce);
297            let _ = data_offset; // Suppress unused warning
298        }
299
300        // Terminator
301        buf.put_u8(PreLoginOption::Terminator as u8);
302
303        // Append data section
304        buf.put_slice(&data_buf);
305
306        buf.freeze()
307    }
308
309    /// Decode a pre-login response from the server.
310    ///
311    /// Per MS-TDS spec 2.2.6.4, PreLogin message structure:
312    /// - Option headers: each 5 bytes (type:1 + offset:2 + length:2)
313    /// - Terminator: 1 byte (0xFF)
314    /// - Option data: variable length, positioned at offsets specified in headers
315    ///
316    /// Offsets in headers are absolute from the start of the PreLogin packet payload.
317    pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
318        let mut prelogin = Self::default();
319
320        // Parse option headers first, collecting (option_type, offset, length)
321        let mut options = Vec::new();
322        loop {
323            if src.remaining() < 1 {
324                return Err(ProtocolError::UnexpectedEof);
325            }
326
327            let option_type = src.get_u8();
328            if option_type == PreLoginOption::Terminator as u8 {
329                break;
330            }
331
332            if src.remaining() < 4 {
333                return Err(ProtocolError::UnexpectedEof);
334            }
335
336            let offset = src.get_u16();
337            let length = src.get_u16();
338            options.push((PreLoginOption::from_u8(option_type)?, offset, length));
339        }
340
341        // Get remaining data as bytes for random access
342        let data = src.copy_to_bytes(src.remaining());
343
344        // Calculate header size: each option is 5 bytes + 1 byte terminator
345        let header_size = options.len() * 5 + 1;
346
347        for (option, packet_offset, length) in options {
348            let packet_offset = packet_offset as usize;
349            let length = length as usize;
350
351            // Convert absolute packet offset to offset within data buffer
352            // The data buffer starts after the headers, so we subtract header_size
353            if packet_offset < header_size {
354                // Invalid: offset points inside the headers
355                continue;
356            }
357            let data_offset = packet_offset - header_size;
358
359            // Bounds check
360            if data_offset + length > data.len() {
361                continue;
362            }
363
364            match option {
365                PreLoginOption::Version if length >= 4 => {
366                    // Per MS-TDS 2.2.6.4: The server sends its SQL Server product version
367                    // in the VERSION field, NOT the TDS protocol version.
368                    //
369                    // Format: UL_VERSION (4 bytes big-endian) + US_SUBBUILD (2 bytes little-endian)
370                    // UL_VERSION contains: [major][minor][build_hi][build_lo]
371                    //
372                    // For example, SQL Server 2016 sends 13.0.xxxx (major=13, minor=0)
373                    let version_bytes = &data[data_offset..data_offset + 4];
374                    let version_raw = u32::from_be_bytes([
375                        version_bytes[0],
376                        version_bytes[1],
377                        version_bytes[2],
378                        version_bytes[3],
379                    ]);
380
381                    // Extract sub_build if present
382                    let sub_build = if length >= 6 {
383                        let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
384                        u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
385                    } else {
386                        0
387                    };
388
389                    // Populate the new SqlServerVersion field (correct semantics)
390                    prelogin.server_version =
391                        Some(SqlServerVersion::from_raw(version_raw, sub_build));
392
393                    // Also set version for backward compatibility
394                    prelogin.version = TdsVersion::new(version_raw);
395                }
396                PreLoginOption::Encryption if length >= 1 => {
397                    prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
398                }
399                PreLoginOption::Mars if length >= 1 => {
400                    prelogin.mars = data[data_offset] != 0;
401                }
402                PreLoginOption::Instance if length > 0 => {
403                    // Instance name is null-terminated string
404                    let instance_data = &data[data_offset..data_offset + length];
405                    if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
406                        if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
407                            if !s.is_empty() {
408                                prelogin.instance = Some(s.to_string());
409                            }
410                        }
411                    }
412                }
413                PreLoginOption::ThreadId if length >= 4 => {
414                    let bytes = &data[data_offset..data_offset + 4];
415                    prelogin.thread_id =
416                        Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
417                }
418                PreLoginOption::FedAuthRequired if length >= 1 => {
419                    prelogin.fed_auth_required = data[data_offset] != 0;
420                }
421                PreLoginOption::Nonce if length >= 32 => {
422                    let mut nonce = [0u8; 32];
423                    nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
424                    prelogin.nonce = Some(nonce);
425                }
426                _ => {}
427            }
428        }
429
430        Ok(prelogin)
431    }
432}
433
434#[cfg(test)]
435#[allow(clippy::unwrap_used)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_prelogin_encode() {
441        let prelogin = PreLogin::new()
442            .with_version(TdsVersion::V7_4)
443            .with_encryption(EncryptionLevel::Required);
444
445        let encoded = prelogin.encode();
446        assert!(!encoded.is_empty());
447        // First byte should be VERSION option type
448        assert_eq!(encoded[0], PreLoginOption::Version as u8);
449    }
450
451    #[test]
452    fn test_encryption_level() {
453        assert!(EncryptionLevel::Required.is_required());
454        assert!(EncryptionLevel::On.is_required());
455        assert!(!EncryptionLevel::Off.is_required());
456        assert!(!EncryptionLevel::NotSupported.is_required());
457    }
458
459    #[test]
460    fn test_prelogin_decode_roundtrip() {
461        // Create a PreLogin with various options
462        let original = PreLogin::new()
463            .with_version(TdsVersion::V7_4)
464            .with_encryption(EncryptionLevel::On)
465            .with_mars(true);
466
467        // Encode it
468        let encoded = original.encode();
469
470        // Decode it back
471        let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
472
473        // Verify the critical fields match
474        assert_eq!(decoded.version, original.version);
475        assert_eq!(decoded.encryption, original.encryption);
476        assert_eq!(decoded.mars, original.mars);
477    }
478
479    #[test]
480    fn test_prelogin_decode_encryption_offset() {
481        // Manually construct a PreLogin packet with options in non-standard order
482        // to verify offset handling works correctly
483        //
484        // Structure:
485        // - ENCRYPTION header at offset pointing to encryption data
486        // - VERSION header at offset pointing to version data
487        // - Terminator
488        // - Data section
489
490        use bytes::BufMut;
491
492        let mut buf = bytes::BytesMut::new();
493
494        // Header section: each option is 5 bytes (type:1 + offset:2 + length:2)
495        // We'll have 2 options + terminator = 11 bytes header
496        let header_size: u16 = 11;
497
498        // ENCRYPTION option header (put this first to test that we read from correct offset)
499        buf.put_u8(PreLoginOption::Encryption as u8);
500        buf.put_u16(header_size); // offset to encryption data
501        buf.put_u16(1); // length
502
503        // VERSION option header
504        buf.put_u8(PreLoginOption::Version as u8);
505        buf.put_u16(header_size + 1); // offset to version data (after encryption)
506        buf.put_u16(6); // length
507
508        // Terminator
509        buf.put_u8(PreLoginOption::Terminator as u8);
510
511        // Data section
512        // Encryption data (1 byte): ENCRYPT_ON = 0x01
513        buf.put_u8(0x01);
514
515        // Version data (6 bytes): TDS 7.4 = 0x74000004 big-endian + sub-build 0x0000 little-endian
516        buf.put_u8(0x74);
517        buf.put_u8(0x00);
518        buf.put_u8(0x00);
519        buf.put_u8(0x04);
520        buf.put_u16_le(0x0000); // sub-build
521
522        // Decode
523        let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
524
525        // Verify encryption was read from correct offset (not from index 0)
526        assert_eq!(decoded.encryption, EncryptionLevel::On);
527        assert_eq!(decoded.version, TdsVersion::V7_4);
528    }
529}