tds_protocol/
prelogin.rs

1//! TDS pre-login packet handling.
2//!
3//! The pre-login packet is the first message exchanged between client and server
4//! in TDS 7.x connections. It negotiates protocol version, encryption, and other
5//! connection parameters.
6//!
7//! Note: TDS 8.0 (strict mode) does not use pre-login negotiation; TLS is
8//! established before any TDS traffic.
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::version::TdsVersion;
14
15/// Pre-login option types.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum PreLoginOption {
19    /// Version information.
20    Version = 0x00,
21    /// Encryption negotiation.
22    Encryption = 0x01,
23    /// Instance name (for named instances).
24    Instance = 0x02,
25    /// Thread ID.
26    ThreadId = 0x03,
27    /// MARS (Multiple Active Result Sets) support.
28    Mars = 0x04,
29    /// Trace ID for distributed tracing.
30    TraceId = 0x05,
31    /// Federated authentication required.
32    FedAuthRequired = 0x06,
33    /// Nonce for encryption.
34    Nonce = 0x07,
35    /// Terminator (end of options).
36    Terminator = 0xFF,
37}
38
39impl PreLoginOption {
40    /// Create from raw byte value.
41    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
42        match value {
43            0x00 => Ok(Self::Version),
44            0x01 => Ok(Self::Encryption),
45            0x02 => Ok(Self::Instance),
46            0x03 => Ok(Self::ThreadId),
47            0x04 => Ok(Self::Mars),
48            0x05 => Ok(Self::TraceId),
49            0x06 => Ok(Self::FedAuthRequired),
50            0x07 => Ok(Self::Nonce),
51            0xFF => Ok(Self::Terminator),
52            _ => Err(ProtocolError::InvalidPreloginOption(value)),
53        }
54    }
55}
56
57/// Encryption level for connection.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59#[repr(u8)]
60pub enum EncryptionLevel {
61    /// Encryption is off.
62    Off = 0x00,
63    /// Encryption is on.
64    On = 0x01,
65    /// Encryption is not supported.
66    NotSupported = 0x02,
67    /// Encryption is required.
68    #[default]
69    Required = 0x03,
70    /// Client certificate authentication (TDS 8.0+).
71    ClientCertAuth = 0x80,
72}
73
74impl EncryptionLevel {
75    /// Create from raw byte value.
76    pub fn from_u8(value: u8) -> Self {
77        match value {
78            0x00 => Self::Off,
79            0x01 => Self::On,
80            0x02 => Self::NotSupported,
81            0x03 => Self::Required,
82            0x80 => Self::ClientCertAuth,
83            _ => Self::Off,
84        }
85    }
86
87    /// Check if encryption is required.
88    #[must_use]
89    pub const fn is_required(&self) -> bool {
90        matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
91    }
92}
93
94/// Pre-login message builder and parser.
95#[derive(Debug, Clone, Default)]
96pub struct PreLogin {
97    /// TDS version.
98    pub version: TdsVersion,
99    /// Sub-build version.
100    pub sub_build: u16,
101    /// Encryption level.
102    pub encryption: EncryptionLevel,
103    /// Instance name (for named instances).
104    pub instance: Option<String>,
105    /// Thread ID.
106    pub thread_id: Option<u32>,
107    /// MARS enabled.
108    pub mars: bool,
109    /// Trace ID (Activity ID and Sequence).
110    pub trace_id: Option<TraceId>,
111    /// Federated authentication required.
112    pub fed_auth_required: bool,
113    /// Nonce for encryption.
114    pub nonce: Option<[u8; 32]>,
115}
116
117/// Distributed tracing ID.
118#[derive(Debug, Clone, Copy)]
119pub struct TraceId {
120    /// Activity ID (GUID).
121    pub activity_id: [u8; 16],
122    /// Activity sequence.
123    pub activity_sequence: u32,
124}
125
126impl PreLogin {
127    /// Create a new pre-login message with default values.
128    #[must_use]
129    pub fn new() -> Self {
130        Self {
131            version: TdsVersion::V7_4,
132            sub_build: 0,
133            encryption: EncryptionLevel::Required,
134            instance: None,
135            thread_id: None,
136            mars: false,
137            trace_id: None,
138            fed_auth_required: false,
139            nonce: None,
140        }
141    }
142
143    /// Set the TDS version.
144    #[must_use]
145    pub fn with_version(mut self, version: TdsVersion) -> Self {
146        self.version = version;
147        self
148    }
149
150    /// Set the encryption level.
151    #[must_use]
152    pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
153        self.encryption = level;
154        self
155    }
156
157    /// Enable MARS.
158    #[must_use]
159    pub fn with_mars(mut self, enabled: bool) -> Self {
160        self.mars = enabled;
161        self
162    }
163
164    /// Set the instance name.
165    #[must_use]
166    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
167        self.instance = Some(instance.into());
168        self
169    }
170
171    /// Encode the pre-login message to bytes.
172    #[must_use]
173    pub fn encode(&self) -> Bytes {
174        let mut buf = BytesMut::with_capacity(256);
175
176        // Calculate option data offsets
177        // Each option entry is 5 bytes: type (1) + offset (2) + length (2)
178        // Plus 1 byte for terminator
179        let mut option_count = 3; // Version, Encryption, MARS are always present
180        if self.instance.is_some() {
181            option_count += 1;
182        }
183        if self.thread_id.is_some() {
184            option_count += 1;
185        }
186        if self.trace_id.is_some() {
187            option_count += 1;
188        }
189        if self.fed_auth_required {
190            option_count += 1;
191        }
192        if self.nonce.is_some() {
193            option_count += 1;
194        }
195
196        let header_size = option_count * 5 + 1; // +1 for terminator
197        let mut data_offset = header_size as u16;
198        let mut data_buf = BytesMut::new();
199
200        // VERSION option (6 bytes: 4 bytes version + 2 bytes sub-build)
201        buf.put_u8(PreLoginOption::Version as u8);
202        buf.put_u16(data_offset);
203        buf.put_u16(6);
204        let version_raw = self.version.raw();
205        data_buf.put_u8((version_raw >> 24) as u8);
206        data_buf.put_u8((version_raw >> 16) as u8);
207        data_buf.put_u8((version_raw >> 8) as u8);
208        data_buf.put_u8(version_raw as u8);
209        data_buf.put_u16_le(self.sub_build);
210        data_offset += 6;
211
212        // ENCRYPTION option (1 byte)
213        buf.put_u8(PreLoginOption::Encryption as u8);
214        buf.put_u16(data_offset);
215        buf.put_u16(1);
216        data_buf.put_u8(self.encryption as u8);
217        data_offset += 1;
218
219        // INSTANCE option (if set)
220        if let Some(ref instance) = self.instance {
221            let instance_bytes = instance.as_bytes();
222            let len = instance_bytes.len() as u16 + 1; // +1 for null terminator
223            buf.put_u8(PreLoginOption::Instance as u8);
224            buf.put_u16(data_offset);
225            buf.put_u16(len);
226            data_buf.put_slice(instance_bytes);
227            data_buf.put_u8(0); // null terminator
228            data_offset += len;
229        }
230
231        // THREADID option (if set)
232        if let Some(thread_id) = self.thread_id {
233            buf.put_u8(PreLoginOption::ThreadId as u8);
234            buf.put_u16(data_offset);
235            buf.put_u16(4);
236            data_buf.put_u32(thread_id);
237            data_offset += 4;
238        }
239
240        // MARS option (1 byte)
241        buf.put_u8(PreLoginOption::Mars as u8);
242        buf.put_u16(data_offset);
243        buf.put_u16(1);
244        data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
245        data_offset += 1;
246
247        // TRACEID option (if set)
248        if let Some(ref trace_id) = self.trace_id {
249            buf.put_u8(PreLoginOption::TraceId as u8);
250            buf.put_u16(data_offset);
251            buf.put_u16(36);
252            data_buf.put_slice(&trace_id.activity_id);
253            data_buf.put_u32_le(trace_id.activity_sequence);
254            // Connection ID (16 bytes, typically zeros for client)
255            data_buf.put_slice(&[0u8; 16]);
256            data_offset += 36;
257        }
258
259        // FEDAUTHREQUIRED option (if set)
260        if self.fed_auth_required {
261            buf.put_u8(PreLoginOption::FedAuthRequired as u8);
262            buf.put_u16(data_offset);
263            buf.put_u16(1);
264            data_buf.put_u8(0x01);
265            data_offset += 1;
266        }
267
268        // NONCE option (if set)
269        if let Some(ref nonce) = self.nonce {
270            buf.put_u8(PreLoginOption::Nonce as u8);
271            buf.put_u16(data_offset);
272            buf.put_u16(32);
273            data_buf.put_slice(nonce);
274            let _ = data_offset; // Suppress unused warning
275        }
276
277        // Terminator
278        buf.put_u8(PreLoginOption::Terminator as u8);
279
280        // Append data section
281        buf.put_slice(&data_buf);
282
283        buf.freeze()
284    }
285
286    /// Decode a pre-login response from the server.
287    ///
288    /// Per MS-TDS spec 2.2.6.4, PreLogin message structure:
289    /// - Option headers: each 5 bytes (type:1 + offset:2 + length:2)
290    /// - Terminator: 1 byte (0xFF)
291    /// - Option data: variable length, positioned at offsets specified in headers
292    ///
293    /// Offsets in headers are absolute from the start of the PreLogin packet payload.
294    pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
295        let mut prelogin = Self::default();
296
297        // Parse option headers first, collecting (option_type, offset, length)
298        let mut options = Vec::new();
299        loop {
300            if src.remaining() < 1 {
301                return Err(ProtocolError::UnexpectedEof);
302            }
303
304            let option_type = src.get_u8();
305            if option_type == PreLoginOption::Terminator as u8 {
306                break;
307            }
308
309            if src.remaining() < 4 {
310                return Err(ProtocolError::UnexpectedEof);
311            }
312
313            let offset = src.get_u16();
314            let length = src.get_u16();
315            options.push((PreLoginOption::from_u8(option_type)?, offset, length));
316        }
317
318        // Get remaining data as bytes for random access
319        let data = src.copy_to_bytes(src.remaining());
320
321        // Calculate header size: each option is 5 bytes + 1 byte terminator
322        let header_size = options.len() * 5 + 1;
323
324        for (option, packet_offset, length) in options {
325            let packet_offset = packet_offset as usize;
326            let length = length as usize;
327
328            // Convert absolute packet offset to offset within data buffer
329            // The data buffer starts after the headers, so we subtract header_size
330            if packet_offset < header_size {
331                // Invalid: offset points inside the headers
332                continue;
333            }
334            let data_offset = packet_offset - header_size;
335
336            // Bounds check
337            if data_offset + length > data.len() {
338                continue;
339            }
340
341            match option {
342                PreLoginOption::Version if length >= 6 => {
343                    // Per MS-TDS: UL_VERSION is 4 bytes big-endian, US_SUBBUILD is 2 bytes little-endian
344                    let version_bytes = &data[data_offset..data_offset + 4];
345                    let version_raw = u32::from_be_bytes([
346                        version_bytes[0],
347                        version_bytes[1],
348                        version_bytes[2],
349                        version_bytes[3],
350                    ]);
351                    prelogin.version = TdsVersion::new(version_raw);
352
353                    if length >= 6 {
354                        let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
355                        prelogin.sub_build =
356                            u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]]);
357                    }
358                }
359                PreLoginOption::Encryption if length >= 1 => {
360                    prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
361                }
362                PreLoginOption::Mars if length >= 1 => {
363                    prelogin.mars = data[data_offset] != 0;
364                }
365                PreLoginOption::Instance if length > 0 => {
366                    // Instance name is null-terminated string
367                    let instance_data = &data[data_offset..data_offset + length];
368                    if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
369                        if let Ok(s) = std::str::from_utf8(&instance_data[..null_pos]) {
370                            if !s.is_empty() {
371                                prelogin.instance = Some(s.to_string());
372                            }
373                        }
374                    }
375                }
376                PreLoginOption::ThreadId if length >= 4 => {
377                    let bytes = &data[data_offset..data_offset + 4];
378                    prelogin.thread_id =
379                        Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
380                }
381                PreLoginOption::FedAuthRequired if length >= 1 => {
382                    prelogin.fed_auth_required = data[data_offset] != 0;
383                }
384                PreLoginOption::Nonce if length >= 32 => {
385                    let mut nonce = [0u8; 32];
386                    nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
387                    prelogin.nonce = Some(nonce);
388                }
389                _ => {}
390            }
391        }
392
393        Ok(prelogin)
394    }
395}
396
397#[cfg(not(feature = "std"))]
398use alloc::string::String;
399#[cfg(not(feature = "std"))]
400use alloc::vec::Vec;
401
402#[cfg(test)]
403#[allow(clippy::unwrap_used)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_prelogin_encode() {
409        let prelogin = PreLogin::new()
410            .with_version(TdsVersion::V7_4)
411            .with_encryption(EncryptionLevel::Required);
412
413        let encoded = prelogin.encode();
414        assert!(!encoded.is_empty());
415        // First byte should be VERSION option type
416        assert_eq!(encoded[0], PreLoginOption::Version as u8);
417    }
418
419    #[test]
420    fn test_encryption_level() {
421        assert!(EncryptionLevel::Required.is_required());
422        assert!(EncryptionLevel::On.is_required());
423        assert!(!EncryptionLevel::Off.is_required());
424        assert!(!EncryptionLevel::NotSupported.is_required());
425    }
426
427    #[test]
428    fn test_prelogin_decode_roundtrip() {
429        // Create a PreLogin with various options
430        let original = PreLogin::new()
431            .with_version(TdsVersion::V7_4)
432            .with_encryption(EncryptionLevel::On)
433            .with_mars(true);
434
435        // Encode it
436        let encoded = original.encode();
437
438        // Decode it back
439        let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
440
441        // Verify the critical fields match
442        assert_eq!(decoded.version, original.version);
443        assert_eq!(decoded.encryption, original.encryption);
444        assert_eq!(decoded.mars, original.mars);
445    }
446
447    #[test]
448    fn test_prelogin_decode_encryption_offset() {
449        // Manually construct a PreLogin packet with options in non-standard order
450        // to verify offset handling works correctly
451        //
452        // Structure:
453        // - ENCRYPTION header at offset pointing to encryption data
454        // - VERSION header at offset pointing to version data
455        // - Terminator
456        // - Data section
457
458        use bytes::BufMut;
459
460        let mut buf = bytes::BytesMut::new();
461
462        // Header section: each option is 5 bytes (type:1 + offset:2 + length:2)
463        // We'll have 2 options + terminator = 11 bytes header
464        let header_size: u16 = 11;
465
466        // ENCRYPTION option header (put this first to test that we read from correct offset)
467        buf.put_u8(PreLoginOption::Encryption as u8);
468        buf.put_u16(header_size); // offset to encryption data
469        buf.put_u16(1); // length
470
471        // VERSION option header
472        buf.put_u8(PreLoginOption::Version as u8);
473        buf.put_u16(header_size + 1); // offset to version data (after encryption)
474        buf.put_u16(6); // length
475
476        // Terminator
477        buf.put_u8(PreLoginOption::Terminator as u8);
478
479        // Data section
480        // Encryption data (1 byte): ENCRYPT_ON = 0x01
481        buf.put_u8(0x01);
482
483        // Version data (6 bytes): TDS 7.4 = 0x74000004 big-endian + sub-build 0x0000 little-endian
484        buf.put_u8(0x74);
485        buf.put_u8(0x00);
486        buf.put_u8(0x00);
487        buf.put_u8(0x04);
488        buf.put_u16_le(0x0000); // sub-build
489
490        // Decode
491        let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
492
493        // Verify encryption was read from correct offset (not from index 0)
494        assert_eq!(decoded.encryption, EncryptionLevel::On);
495        assert_eq!(decoded.version, TdsVersion::V7_4);
496    }
497}