1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::TdsVersion;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum PreLoginOption {
20 Version = 0x00,
22 Encryption = 0x01,
24 Instance = 0x02,
26 ThreadId = 0x03,
28 Mars = 0x04,
30 TraceId = 0x05,
32 FedAuthRequired = 0x06,
34 Nonce = 0x07,
36 Terminator = 0xFF,
38}
39
40impl PreLoginOption {
41 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
43 match value {
44 0x00 => Ok(Self::Version),
45 0x01 => Ok(Self::Encryption),
46 0x02 => Ok(Self::Instance),
47 0x03 => Ok(Self::ThreadId),
48 0x04 => Ok(Self::Mars),
49 0x05 => Ok(Self::TraceId),
50 0x06 => Ok(Self::FedAuthRequired),
51 0x07 => Ok(Self::Nonce),
52 0xFF => Ok(Self::Terminator),
53 _ => Err(ProtocolError::InvalidPreloginOption(value)),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60#[repr(u8)]
61pub enum EncryptionLevel {
62 Off = 0x00,
64 On = 0x01,
66 NotSupported = 0x02,
68 #[default]
70 Required = 0x03,
71 ClientCertAuth = 0x80,
73}
74
75impl EncryptionLevel {
76 pub fn from_u8(value: u8) -> Self {
78 match value {
79 0x00 => Self::Off,
80 0x01 => Self::On,
81 0x02 => Self::NotSupported,
82 0x03 => Self::Required,
83 0x80 => Self::ClientCertAuth,
84 _ => Self::Off,
85 }
86 }
87
88 #[must_use]
90 pub const fn is_required(&self) -> bool {
91 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
92 }
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct PreLogin {
98 pub version: TdsVersion,
100 pub sub_build: u16,
102 pub encryption: EncryptionLevel,
104 pub instance: Option<String>,
106 pub thread_id: Option<u32>,
108 pub mars: bool,
110 pub trace_id: Option<TraceId>,
112 pub fed_auth_required: bool,
114 pub nonce: Option<[u8; 32]>,
116}
117
118#[derive(Debug, Clone, Copy)]
120pub struct TraceId {
121 pub activity_id: [u8; 16],
123 pub activity_sequence: u32,
125}
126
127impl PreLogin {
128 #[must_use]
130 pub fn new() -> Self {
131 Self {
132 version: TdsVersion::V7_4,
133 sub_build: 0,
134 encryption: EncryptionLevel::Required,
135 instance: None,
136 thread_id: None,
137 mars: false,
138 trace_id: None,
139 fed_auth_required: false,
140 nonce: None,
141 }
142 }
143
144 #[must_use]
146 pub fn with_version(mut self, version: TdsVersion) -> Self {
147 self.version = version;
148 self
149 }
150
151 #[must_use]
153 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
154 self.encryption = level;
155 self
156 }
157
158 #[must_use]
160 pub fn with_mars(mut self, enabled: bool) -> Self {
161 self.mars = enabled;
162 self
163 }
164
165 #[must_use]
167 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
168 self.instance = Some(instance.into());
169 self
170 }
171
172 #[must_use]
174 pub fn encode(&self) -> Bytes {
175 let mut buf = BytesMut::with_capacity(256);
176
177 let mut option_count = 3; if self.instance.is_some() {
182 option_count += 1;
183 }
184 if self.thread_id.is_some() {
185 option_count += 1;
186 }
187 if self.trace_id.is_some() {
188 option_count += 1;
189 }
190 if self.fed_auth_required {
191 option_count += 1;
192 }
193 if self.nonce.is_some() {
194 option_count += 1;
195 }
196
197 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
199 let mut data_buf = BytesMut::new();
200
201 buf.put_u8(PreLoginOption::Version as u8);
203 buf.put_u16(data_offset);
204 buf.put_u16(6);
205 let version_raw = self.version.raw();
206 data_buf.put_u8((version_raw >> 24) as u8);
207 data_buf.put_u8((version_raw >> 16) as u8);
208 data_buf.put_u8((version_raw >> 8) as u8);
209 data_buf.put_u8(version_raw as u8);
210 data_buf.put_u16_le(self.sub_build);
211 data_offset += 6;
212
213 buf.put_u8(PreLoginOption::Encryption as u8);
215 buf.put_u16(data_offset);
216 buf.put_u16(1);
217 data_buf.put_u8(self.encryption as u8);
218 data_offset += 1;
219
220 if let Some(ref instance) = self.instance {
222 let instance_bytes = instance.as_bytes();
223 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
225 buf.put_u16(data_offset);
226 buf.put_u16(len);
227 data_buf.put_slice(instance_bytes);
228 data_buf.put_u8(0); data_offset += len;
230 }
231
232 if let Some(thread_id) = self.thread_id {
234 buf.put_u8(PreLoginOption::ThreadId as u8);
235 buf.put_u16(data_offset);
236 buf.put_u16(4);
237 data_buf.put_u32(thread_id);
238 data_offset += 4;
239 }
240
241 buf.put_u8(PreLoginOption::Mars as u8);
243 buf.put_u16(data_offset);
244 buf.put_u16(1);
245 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
246 data_offset += 1;
247
248 if let Some(ref trace_id) = self.trace_id {
250 buf.put_u8(PreLoginOption::TraceId as u8);
251 buf.put_u16(data_offset);
252 buf.put_u16(36);
253 data_buf.put_slice(&trace_id.activity_id);
254 data_buf.put_u32_le(trace_id.activity_sequence);
255 data_buf.put_slice(&[0u8; 16]);
257 data_offset += 36;
258 }
259
260 if self.fed_auth_required {
262 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
263 buf.put_u16(data_offset);
264 buf.put_u16(1);
265 data_buf.put_u8(0x01);
266 data_offset += 1;
267 }
268
269 if let Some(ref nonce) = self.nonce {
271 buf.put_u8(PreLoginOption::Nonce as u8);
272 buf.put_u16(data_offset);
273 buf.put_u16(32);
274 data_buf.put_slice(nonce);
275 let _ = data_offset; }
277
278 buf.put_u8(PreLoginOption::Terminator as u8);
280
281 buf.put_slice(&data_buf);
283
284 buf.freeze()
285 }
286
287 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
296 let mut prelogin = Self::default();
297
298 let mut options = Vec::new();
300 loop {
301 if src.remaining() < 1 {
302 return Err(ProtocolError::UnexpectedEof);
303 }
304
305 let option_type = src.get_u8();
306 if option_type == PreLoginOption::Terminator as u8 {
307 break;
308 }
309
310 if src.remaining() < 4 {
311 return Err(ProtocolError::UnexpectedEof);
312 }
313
314 let offset = src.get_u16();
315 let length = src.get_u16();
316 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
317 }
318
319 let data = src.copy_to_bytes(src.remaining());
321
322 let header_size = options.len() * 5 + 1;
324
325 for (option, packet_offset, length) in options {
326 let packet_offset = packet_offset as usize;
327 let length = length as usize;
328
329 if packet_offset < header_size {
332 continue;
334 }
335 let data_offset = packet_offset - header_size;
336
337 if data_offset + length > data.len() {
339 continue;
340 }
341
342 match option {
343 PreLoginOption::Version if length >= 6 => {
344 let version_bytes = &data[data_offset..data_offset + 4];
346 let version_raw = u32::from_be_bytes([
347 version_bytes[0],
348 version_bytes[1],
349 version_bytes[2],
350 version_bytes[3],
351 ]);
352 prelogin.version = TdsVersion::new(version_raw);
353
354 if length >= 6 {
355 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
356 prelogin.sub_build =
357 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]]);
358 }
359 }
360 PreLoginOption::Encryption if length >= 1 => {
361 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
362 }
363 PreLoginOption::Mars if length >= 1 => {
364 prelogin.mars = data[data_offset] != 0;
365 }
366 PreLoginOption::Instance if length > 0 => {
367 let instance_data = &data[data_offset..data_offset + length];
369 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
370 if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
371 if !s.is_empty() {
372 prelogin.instance = Some(s.to_string());
373 }
374 }
375 }
376 }
377 PreLoginOption::ThreadId if length >= 4 => {
378 let bytes = &data[data_offset..data_offset + 4];
379 prelogin.thread_id =
380 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
381 }
382 PreLoginOption::FedAuthRequired if length >= 1 => {
383 prelogin.fed_auth_required = data[data_offset] != 0;
384 }
385 PreLoginOption::Nonce if length >= 32 => {
386 let mut nonce = [0u8; 32];
387 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
388 prelogin.nonce = Some(nonce);
389 }
390 _ => {}
391 }
392 }
393
394 Ok(prelogin)
395 }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_prelogin_encode() {
405 let prelogin = PreLogin::new()
406 .with_version(TdsVersion::V7_4)
407 .with_encryption(EncryptionLevel::Required);
408
409 let encoded = prelogin.encode();
410 assert!(!encoded.is_empty());
411 assert_eq!(encoded[0], PreLoginOption::Version as u8);
413 }
414
415 #[test]
416 fn test_encryption_level() {
417 assert!(EncryptionLevel::Required.is_required());
418 assert!(EncryptionLevel::On.is_required());
419 assert!(!EncryptionLevel::Off.is_required());
420 assert!(!EncryptionLevel::NotSupported.is_required());
421 }
422
423 #[test]
424 fn test_prelogin_decode_roundtrip() {
425 let original = PreLogin::new()
427 .with_version(TdsVersion::V7_4)
428 .with_encryption(EncryptionLevel::On)
429 .with_mars(true);
430
431 let encoded = original.encode();
433
434 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
436
437 assert_eq!(decoded.version, original.version);
439 assert_eq!(decoded.encryption, original.encryption);
440 assert_eq!(decoded.mars, original.mars);
441 }
442
443 #[test]
444 fn test_prelogin_decode_encryption_offset() {
445 use bytes::BufMut;
455
456 let mut buf = bytes::BytesMut::new();
457
458 let header_size: u16 = 11;
461
462 buf.put_u8(PreLoginOption::Encryption as u8);
464 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
469 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
474
475 buf.put_u8(0x01);
478
479 buf.put_u8(0x74);
481 buf.put_u8(0x00);
482 buf.put_u8(0x00);
483 buf.put_u8(0x04);
484 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
488
489 assert_eq!(decoded.encryption, EncryptionLevel::On);
491 assert_eq!(decoded.version, TdsVersion::V7_4);
492 }
493}