1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19#[non_exhaustive]
20pub enum PreLoginOption {
21 Version = 0x00,
23 Encryption = 0x01,
25 Instance = 0x02,
27 ThreadId = 0x03,
29 Mars = 0x04,
31 TraceId = 0x05,
33 FedAuthRequired = 0x06,
35 Nonce = 0x07,
37 Terminator = 0xFF,
39}
40
41impl PreLoginOption {
42 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
44 match value {
45 0x00 => Ok(Self::Version),
46 0x01 => Ok(Self::Encryption),
47 0x02 => Ok(Self::Instance),
48 0x03 => Ok(Self::ThreadId),
49 0x04 => Ok(Self::Mars),
50 0x05 => Ok(Self::TraceId),
51 0x06 => Ok(Self::FedAuthRequired),
52 0x07 => Ok(Self::Nonce),
53 0xFF => Ok(Self::Terminator),
54 _ => Err(ProtocolError::InvalidPreloginOption(value)),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61#[repr(u8)]
62#[non_exhaustive]
63pub enum EncryptionLevel {
64 Off = 0x00,
66 On = 0x01,
68 NotSupported = 0x02,
70 #[default]
72 Required = 0x03,
73 ClientCertAuth = 0x80,
75}
76
77impl EncryptionLevel {
78 pub fn from_u8(value: u8) -> Self {
80 match value {
81 0x00 => Self::Off,
82 0x01 => Self::On,
83 0x02 => Self::NotSupported,
84 0x03 => Self::Required,
85 0x80 => Self::ClientCertAuth,
86 _ => Self::Off,
87 }
88 }
89
90 #[must_use]
92 pub const fn is_required(&self) -> bool {
93 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
94 }
95}
96
97#[derive(Debug, Clone, Default)]
107pub struct PreLogin {
108 pub version: TdsVersion,
113
114 pub server_version: Option<SqlServerVersion>,
121
122 pub encryption: EncryptionLevel,
124 pub instance: Option<String>,
126 pub thread_id: Option<u32>,
128 pub mars: bool,
130 pub trace_id: Option<TraceId>,
132 pub fed_auth_required: bool,
134 pub nonce: Option<[u8; 32]>,
136}
137
138#[derive(Debug, Clone, Copy)]
140pub struct TraceId {
141 pub activity_id: [u8; 16],
143 pub activity_sequence: u32,
145}
146
147impl PreLogin {
148 #[must_use]
150 pub fn new() -> Self {
151 Self {
152 version: TdsVersion::V7_4,
153 server_version: None,
154 encryption: EncryptionLevel::Required,
155 instance: None,
156 thread_id: None,
157 mars: false,
158 trace_id: None,
159 fed_auth_required: false,
160 nonce: None,
161 }
162 }
163
164 #[must_use]
166 pub fn with_version(mut self, version: TdsVersion) -> Self {
167 self.version = version;
168 self
169 }
170
171 #[must_use]
173 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
174 self.encryption = level;
175 self
176 }
177
178 #[must_use]
180 pub fn with_mars(mut self, enabled: bool) -> Self {
181 self.mars = enabled;
182 self
183 }
184
185 #[must_use]
187 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
188 self.instance = Some(instance.into());
189 self
190 }
191
192 #[must_use]
194 pub fn encode(&self) -> Bytes {
195 let mut buf = BytesMut::with_capacity(256);
196
197 let mut option_count = 3; if self.instance.is_some() {
202 option_count += 1;
203 }
204 if self.thread_id.is_some() {
205 option_count += 1;
206 }
207 if self.trace_id.is_some() {
208 option_count += 1;
209 }
210 if self.fed_auth_required {
211 option_count += 1;
212 }
213 if self.nonce.is_some() {
214 option_count += 1;
215 }
216
217 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
219 let mut data_buf = BytesMut::new();
220
221 buf.put_u8(PreLoginOption::Version as u8);
223 buf.put_u16(data_offset);
224 buf.put_u16(6);
225 let version_raw = self.version.raw();
226 data_buf.put_u8((version_raw >> 24) as u8);
227 data_buf.put_u8((version_raw >> 16) as u8);
228 data_buf.put_u8((version_raw >> 8) as u8);
229 data_buf.put_u8(version_raw as u8);
230 data_buf.put_u16_le(0);
233 data_offset += 6;
234
235 buf.put_u8(PreLoginOption::Encryption as u8);
237 buf.put_u16(data_offset);
238 buf.put_u16(1);
239 data_buf.put_u8(self.encryption as u8);
240 data_offset += 1;
241
242 if let Some(ref instance) = self.instance {
244 let instance_bytes = instance.as_bytes();
245 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
247 buf.put_u16(data_offset);
248 buf.put_u16(len);
249 data_buf.put_slice(instance_bytes);
250 data_buf.put_u8(0); data_offset += len;
252 }
253
254 if let Some(thread_id) = self.thread_id {
256 buf.put_u8(PreLoginOption::ThreadId as u8);
257 buf.put_u16(data_offset);
258 buf.put_u16(4);
259 data_buf.put_u32(thread_id);
260 data_offset += 4;
261 }
262
263 buf.put_u8(PreLoginOption::Mars as u8);
265 buf.put_u16(data_offset);
266 buf.put_u16(1);
267 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
268 data_offset += 1;
269
270 if let Some(ref trace_id) = self.trace_id {
272 buf.put_u8(PreLoginOption::TraceId as u8);
273 buf.put_u16(data_offset);
274 buf.put_u16(36);
275 data_buf.put_slice(&trace_id.activity_id);
276 data_buf.put_u32_le(trace_id.activity_sequence);
277 data_buf.put_slice(&[0u8; 16]);
279 data_offset += 36;
280 }
281
282 if self.fed_auth_required {
284 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
285 buf.put_u16(data_offset);
286 buf.put_u16(1);
287 data_buf.put_u8(0x01);
288 data_offset += 1;
289 }
290
291 if let Some(ref nonce) = self.nonce {
293 buf.put_u8(PreLoginOption::Nonce as u8);
294 buf.put_u16(data_offset);
295 buf.put_u16(32);
296 data_buf.put_slice(nonce);
297 let _ = data_offset; }
299
300 buf.put_u8(PreLoginOption::Terminator as u8);
302
303 buf.put_slice(&data_buf);
305
306 buf.freeze()
307 }
308
309 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
318 let mut prelogin = Self::default();
319
320 let mut options = Vec::new();
322 loop {
323 if src.remaining() < 1 {
324 return Err(ProtocolError::UnexpectedEof);
325 }
326
327 let option_type = src.get_u8();
328 if option_type == PreLoginOption::Terminator as u8 {
329 break;
330 }
331
332 if src.remaining() < 4 {
333 return Err(ProtocolError::UnexpectedEof);
334 }
335
336 let offset = src.get_u16();
337 let length = src.get_u16();
338 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
339 }
340
341 let data = src.copy_to_bytes(src.remaining());
343
344 let header_size = options.len() * 5 + 1;
346
347 for (option, packet_offset, length) in options {
348 let packet_offset = packet_offset as usize;
349 let length = length as usize;
350
351 if packet_offset < header_size {
354 continue;
356 }
357 let data_offset = packet_offset - header_size;
358
359 if data_offset + length > data.len() {
361 continue;
362 }
363
364 match option {
365 PreLoginOption::Version if length >= 4 => {
366 let version_bytes = &data[data_offset..data_offset + 4];
374 let version_raw = u32::from_be_bytes([
375 version_bytes[0],
376 version_bytes[1],
377 version_bytes[2],
378 version_bytes[3],
379 ]);
380
381 let sub_build = if length >= 6 {
383 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
384 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
385 } else {
386 0
387 };
388
389 prelogin.server_version =
391 Some(SqlServerVersion::from_raw(version_raw, sub_build));
392
393 prelogin.version = TdsVersion::new(version_raw);
395 }
396 PreLoginOption::Encryption if length >= 1 => {
397 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
398 }
399 PreLoginOption::Mars if length >= 1 => {
400 prelogin.mars = data[data_offset] != 0;
401 }
402 PreLoginOption::Instance if length > 0 => {
403 let instance_data = &data[data_offset..data_offset + length];
405 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
406 if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
407 if !s.is_empty() {
408 prelogin.instance = Some(s.to_string());
409 }
410 }
411 }
412 }
413 PreLoginOption::ThreadId if length >= 4 => {
414 let bytes = &data[data_offset..data_offset + 4];
415 prelogin.thread_id =
416 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
417 }
418 PreLoginOption::FedAuthRequired if length >= 1 => {
419 prelogin.fed_auth_required = data[data_offset] != 0;
420 }
421 PreLoginOption::Nonce if length >= 32 => {
422 let mut nonce = [0u8; 32];
423 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
424 prelogin.nonce = Some(nonce);
425 }
426 _ => {}
427 }
428 }
429
430 Ok(prelogin)
431 }
432}
433
434#[cfg(test)]
435#[allow(clippy::unwrap_used)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_prelogin_encode() {
441 let prelogin = PreLogin::new()
442 .with_version(TdsVersion::V7_4)
443 .with_encryption(EncryptionLevel::Required);
444
445 let encoded = prelogin.encode();
446 assert!(!encoded.is_empty());
447 assert_eq!(encoded[0], PreLoginOption::Version as u8);
449 }
450
451 #[test]
452 fn test_encryption_level() {
453 assert!(EncryptionLevel::Required.is_required());
454 assert!(EncryptionLevel::On.is_required());
455 assert!(!EncryptionLevel::Off.is_required());
456 assert!(!EncryptionLevel::NotSupported.is_required());
457 }
458
459 #[test]
460 fn test_prelogin_decode_roundtrip() {
461 let original = PreLogin::new()
463 .with_version(TdsVersion::V7_4)
464 .with_encryption(EncryptionLevel::On)
465 .with_mars(true);
466
467 let encoded = original.encode();
469
470 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
472
473 assert_eq!(decoded.version, original.version);
475 assert_eq!(decoded.encryption, original.encryption);
476 assert_eq!(decoded.mars, original.mars);
477 }
478
479 #[test]
480 fn test_prelogin_decode_encryption_offset() {
481 use bytes::BufMut;
491
492 let mut buf = bytes::BytesMut::new();
493
494 let header_size: u16 = 11;
497
498 buf.put_u8(PreLoginOption::Encryption as u8);
500 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
505 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
510
511 buf.put_u8(0x01);
514
515 buf.put_u8(0x74);
517 buf.put_u8(0x00);
518 buf.put_u8(0x00);
519 buf.put_u8(0x04);
520 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
524
525 assert_eq!(decoded.encryption, EncryptionLevel::On);
527 assert_eq!(decoded.version, TdsVersion::V7_4);
528 }
529}