1use bytes::{BufMut, Bytes, BytesMut};
19
20use crate::codec::write_utf16_string;
21use crate::prelude::*;
22use crate::version::TdsVersion;
23
24pub const LOGIN7_HEADER_SIZE: usize = 94;
26
27#[derive(Debug, Clone, Copy, Default)]
29pub struct OptionFlags1 {
30 pub byte_order_be: bool,
32 pub char_ebcdic: bool,
34 pub float_ieee: bool,
36 pub dump_load_off: bool,
38 pub use_db_notify: bool,
40 pub database_fatal: bool,
42 pub set_lang_warn: bool,
44}
45
46impl OptionFlags1 {
47 #[must_use]
58 pub fn to_byte(&self) -> u8 {
59 let mut flags = 0u8;
60 if self.byte_order_be {
61 flags |= 0x01; }
63 if self.char_ebcdic {
64 flags |= 0x02; }
66 if self.dump_load_off {
69 flags |= 0x10; }
71 if self.use_db_notify {
72 flags |= 0x20; }
74 if self.database_fatal {
75 flags |= 0x40; }
77 if self.set_lang_warn {
78 flags |= 0x80; }
80 flags
81 }
82}
83
84#[derive(Debug, Clone, Copy, Default)]
86pub struct OptionFlags2 {
87 pub language_fatal: bool,
89 pub odbc: bool,
91 pub tran_boundary: bool,
93 pub cache_connect: bool,
95 pub user_type: u8,
97 pub integrated_security: bool,
99}
100
101impl OptionFlags2 {
102 #[must_use]
104 pub fn to_byte(&self) -> u8 {
105 let mut flags = 0u8;
106 if self.language_fatal {
107 flags |= 0x01;
108 }
109 if self.odbc {
110 flags |= 0x02;
111 }
112 if self.tran_boundary {
113 flags |= 0x04;
114 }
115 if self.cache_connect {
116 flags |= 0x08;
117 }
118 flags |= (self.user_type & 0x07) << 4;
119 if self.integrated_security {
120 flags |= 0x80;
121 }
122 flags
123 }
124}
125
126#[derive(Debug, Clone, Copy, Default)]
128pub struct TypeFlags {
129 pub sql_type: u8,
131 pub oledb: bool,
133 pub read_only_intent: bool,
135}
136
137impl TypeFlags {
138 #[must_use]
140 pub fn to_byte(&self) -> u8 {
141 let mut flags = 0u8;
142 flags |= self.sql_type & 0x0F;
143 if self.oledb {
144 flags |= 0x10;
145 }
146 if self.read_only_intent {
147 flags |= 0x20;
148 }
149 flags
150 }
151}
152
153#[derive(Debug, Clone, Copy, Default)]
155pub struct OptionFlags3 {
156 pub change_password: bool,
158 pub user_instance: bool,
160 pub send_yukon_binary_xml: bool,
162 pub unknown_collation_handling: bool,
164 pub extension: bool,
166}
167
168impl OptionFlags3 {
169 #[must_use]
171 pub fn to_byte(&self) -> u8 {
172 let mut flags = 0u8;
173 if self.change_password {
174 flags |= 0x01;
175 }
176 if self.user_instance {
177 flags |= 0x02;
178 }
179 if self.send_yukon_binary_xml {
180 flags |= 0x04;
181 }
182 if self.unknown_collation_handling {
183 flags |= 0x08;
184 }
185 if self.extension {
186 flags |= 0x10;
187 }
188 flags
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194#[repr(u8)]
195#[non_exhaustive]
196pub enum FeatureId {
197 SessionRecovery = 0x01,
199 FedAuth = 0x02,
201 ColumnEncryption = 0x04,
203 GlobalTransactions = 0x05,
205 AzureSqlSupport = 0x08,
207 DataClassification = 0x09,
209 Utf8Support = 0x0A,
211 AzureSqlDnsCaching = 0x0B,
213 Terminator = 0xFF,
215}
216
217#[derive(Debug, Clone)]
219pub struct Login7 {
220 pub tds_version: TdsVersion,
222 pub packet_size: u32,
224 pub client_prog_version: u32,
226 pub client_pid: u32,
228 pub connection_id: u32,
230 pub option_flags1: OptionFlags1,
232 pub option_flags2: OptionFlags2,
234 pub type_flags: TypeFlags,
236 pub option_flags3: OptionFlags3,
238 pub client_timezone: i32,
240 pub client_lcid: u32,
242 pub hostname: String,
244 pub username: String,
246 pub password: String,
248 pub app_name: String,
250 pub server_name: String,
252 pub unused: String,
254 pub library_name: String,
256 pub language: String,
258 pub database: String,
260 pub client_id: [u8; 6],
262 pub sspi_data: Vec<u8>,
264 pub attach_db_file: String,
266 pub new_password: String,
268 pub features: Vec<FeatureExtension>,
270}
271
272#[derive(Debug, Clone)]
274pub struct FeatureExtension {
275 pub feature_id: FeatureId,
277 pub data: Bytes,
279}
280
281impl Default for Login7 {
282 fn default() -> Self {
283 #[cfg(feature = "std")]
284 let client_pid = std::process::id();
285 #[cfg(not(feature = "std"))]
286 let client_pid = 0;
287
288 Self {
289 tds_version: TdsVersion::V7_4,
290 packet_size: 4096,
291 client_prog_version: 0,
292 client_pid,
293 connection_id: 0,
294 option_flags1: OptionFlags1 {
296 use_db_notify: true,
297 database_fatal: true,
298 ..Default::default()
299 },
300 option_flags2: OptionFlags2 {
301 language_fatal: true,
302 odbc: true,
303 ..Default::default()
304 },
305 type_flags: TypeFlags::default(), option_flags3: OptionFlags3 {
307 unknown_collation_handling: true,
308 ..Default::default()
309 },
310 client_timezone: 0,
311 client_lcid: 0x0409, hostname: String::new(),
313 username: String::new(),
314 password: String::new(),
315 app_name: String::from("rust-mssql-driver"),
316 server_name: String::new(),
317 unused: String::new(),
318 library_name: String::from("rust-mssql-driver"),
319 language: String::new(),
320 database: String::new(),
321 client_id: [0u8; 6],
322 sspi_data: Vec::new(),
323 attach_db_file: String::new(),
324 new_password: String::new(),
325 features: Vec::new(),
326 }
327 }
328}
329
330impl Login7 {
331 #[must_use]
333 pub fn new() -> Self {
334 Self::default()
335 }
336
337 #[must_use]
339 pub fn with_tds_version(mut self, version: TdsVersion) -> Self {
340 self.tds_version = version;
341 self
342 }
343
344 #[must_use]
346 pub fn with_sql_auth(
347 mut self,
348 username: impl Into<String>,
349 password: impl Into<String>,
350 ) -> Self {
351 self.username = username.into();
352 self.password = password.into();
353 self.option_flags2.integrated_security = false;
354 self
355 }
356
357 #[must_use]
359 pub fn with_integrated_auth(mut self, sspi_data: Vec<u8>) -> Self {
360 self.sspi_data = sspi_data;
361 self.option_flags2.integrated_security = true;
362 self
363 }
364
365 #[must_use]
367 pub fn with_database(mut self, database: impl Into<String>) -> Self {
368 self.database = database.into();
369 self
370 }
371
372 #[must_use]
374 pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
375 self.hostname = hostname.into();
376 self
377 }
378
379 #[must_use]
381 pub fn with_app_name(mut self, app_name: impl Into<String>) -> Self {
382 self.app_name = app_name.into();
383 self
384 }
385
386 #[must_use]
388 pub fn with_server_name(mut self, server_name: impl Into<String>) -> Self {
389 self.server_name = server_name.into();
390 self
391 }
392
393 #[must_use]
395 pub fn with_packet_size(mut self, packet_size: u32) -> Self {
396 self.packet_size = packet_size;
397 self
398 }
399
400 #[must_use]
402 pub fn with_read_only_intent(mut self, read_only: bool) -> Self {
403 self.type_flags.read_only_intent = read_only;
404 self
405 }
406
407 #[must_use]
409 pub fn with_feature(mut self, feature: FeatureExtension) -> Self {
410 self.option_flags3.extension = true;
411 self.features.push(feature);
412 self
413 }
414
415 #[must_use]
417 pub fn encode(&self) -> Bytes {
418 let mut buf = BytesMut::with_capacity(512);
419
420 let mut offset = LOGIN7_HEADER_SIZE as u16;
423
424 let hostname_len = self.hostname.encode_utf16().count() as u16;
426 let username_len = self.username.encode_utf16().count() as u16;
427 let password_len = self.password.encode_utf16().count() as u16;
428 let app_name_len = self.app_name.encode_utf16().count() as u16;
429 let server_name_len = self.server_name.encode_utf16().count() as u16;
430 let unused_len = self.unused.encode_utf16().count() as u16;
431 let library_name_len = self.library_name.encode_utf16().count() as u16;
432 let language_len = self.language.encode_utf16().count() as u16;
433 let database_len = self.database.encode_utf16().count() as u16;
434 let sspi_len = self.sspi_data.len() as u16;
435 let attach_db_len = self.attach_db_file.encode_utf16().count() as u16;
436 let new_password_len = self.new_password.encode_utf16().count() as u16;
437
438 let mut var_data = BytesMut::new();
440
441 let hostname_offset = offset;
443 write_utf16_string(&mut var_data, &self.hostname);
444 offset += hostname_len * 2;
445
446 let username_offset = offset;
448 write_utf16_string(&mut var_data, &self.username);
449 offset += username_len * 2;
450
451 let password_offset = offset;
453 Self::write_obfuscated_password(&mut var_data, &self.password);
454 offset += password_len * 2;
455
456 let app_name_offset = offset;
458 write_utf16_string(&mut var_data, &self.app_name);
459 offset += app_name_len * 2;
460
461 let server_name_offset = offset;
463 write_utf16_string(&mut var_data, &self.server_name);
464 offset += server_name_len * 2;
465
466 let extension_offset = if self.option_flags3.extension {
468 let base = offset
470 + unused_len * 2
471 + library_name_len * 2
472 + language_len * 2
473 + database_len * 2
474 + sspi_len
475 + attach_db_len * 2
476 + new_password_len * 2;
477 var_data.put_u32_le(base as u32);
479 offset += 4;
480 base
481 } else {
482 let unused_offset = offset;
483 write_utf16_string(&mut var_data, &self.unused);
484 offset += unused_len * 2;
485 unused_offset
486 };
487
488 let library_name_offset = offset;
490 write_utf16_string(&mut var_data, &self.library_name);
491 offset += library_name_len * 2;
492
493 let language_offset = offset;
495 write_utf16_string(&mut var_data, &self.language);
496 offset += language_len * 2;
497
498 let database_offset = offset;
500 write_utf16_string(&mut var_data, &self.database);
501 offset += database_len * 2;
502
503 let sspi_offset = offset;
508 var_data.put_slice(&self.sspi_data);
509 offset += sspi_len;
510
511 let attach_db_offset = offset;
513 write_utf16_string(&mut var_data, &self.attach_db_file);
514 offset += attach_db_len * 2;
515
516 let new_password_offset = offset;
518 if !self.new_password.is_empty() {
519 Self::write_obfuscated_password(&mut var_data, &self.new_password);
520 }
521 #[allow(unused_assignments)]
522 {
523 offset += new_password_len * 2;
524 }
525
526 if self.option_flags3.extension {
528 for feature in &self.features {
529 var_data.put_u8(feature.feature_id as u8);
530 var_data.put_u32_le(feature.data.len() as u32);
531 var_data.put_slice(&feature.data);
532 }
533 var_data.put_u8(FeatureId::Terminator as u8);
534 }
535
536 let total_length = LOGIN7_HEADER_SIZE + var_data.len();
538
539 buf.put_u32_le(total_length as u32); buf.put_u32_le(self.tds_version.raw()); buf.put_u32_le(self.packet_size); buf.put_u32_le(self.client_prog_version); buf.put_u32_le(self.client_pid); buf.put_u32_le(self.connection_id); buf.put_u8(self.option_flags1.to_byte());
549 buf.put_u8(self.option_flags2.to_byte());
550 buf.put_u8(self.type_flags.to_byte());
551 buf.put_u8(self.option_flags3.to_byte());
552
553 buf.put_i32_le(self.client_timezone); buf.put_u32_le(self.client_lcid); buf.put_u16_le(hostname_offset);
558 buf.put_u16_le(hostname_len);
559 buf.put_u16_le(username_offset);
560 buf.put_u16_le(username_len);
561 buf.put_u16_le(password_offset);
562 buf.put_u16_le(password_len);
563 buf.put_u16_le(app_name_offset);
564 buf.put_u16_le(app_name_len);
565 buf.put_u16_le(server_name_offset);
566 buf.put_u16_le(server_name_len);
567
568 if self.option_flags3.extension {
570 buf.put_u16_le(extension_offset as u16);
571 buf.put_u16_le(4); } else {
573 buf.put_u16_le(extension_offset as u16);
574 buf.put_u16_le(unused_len);
575 }
576
577 buf.put_u16_le(library_name_offset);
578 buf.put_u16_le(library_name_len);
579 buf.put_u16_le(language_offset);
580 buf.put_u16_le(language_len);
581 buf.put_u16_le(database_offset);
582 buf.put_u16_le(database_len);
583
584 buf.put_slice(&self.client_id);
586
587 buf.put_u16_le(sspi_offset);
588 buf.put_u16_le(sspi_len);
589 buf.put_u16_le(attach_db_offset);
590 buf.put_u16_le(attach_db_len);
591 buf.put_u16_le(new_password_offset);
592 buf.put_u16_le(new_password_len);
593
594 buf.put_u32_le(0);
596
597 buf.put_slice(&var_data);
599
600 buf.freeze()
601 }
602
603 fn write_obfuscated_password(dst: &mut impl BufMut, password: &str) {
608 for c in password.encode_utf16() {
609 let low = (c & 0xFF) as u8;
610 let high = ((c >> 8) & 0xFF) as u8;
611
612 let low_enc = low.rotate_right(4) ^ 0xA5;
615 let high_enc = high.rotate_right(4) ^ 0xA5;
616
617 dst.put_u8(low_enc);
618 dst.put_u8(high_enc);
619 }
620 }
621}
622
623#[cfg(test)]
624#[allow(clippy::unwrap_used)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_login7_default() {
630 let login = Login7::new();
631 assert_eq!(login.tds_version, TdsVersion::V7_4);
632 assert_eq!(login.packet_size, 4096);
633 assert!(login.option_flags2.odbc);
634 }
635
636 #[test]
637 fn test_login7_encode() {
638 let login = Login7::new()
639 .with_hostname("TESTHOST")
640 .with_sql_auth("testuser", "testpass")
641 .with_database("testdb")
642 .with_app_name("TestApp");
643
644 let encoded = login.encode();
645
646 assert!(encoded.len() >= LOGIN7_HEADER_SIZE);
648
649 let tds_version = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
651 assert_eq!(tds_version, TdsVersion::V7_4.raw());
652 }
653
654 #[test]
655 fn test_password_obfuscation() {
656 let mut buf = BytesMut::new();
658 Login7::write_obfuscated_password(&mut buf, "a");
659
660 assert_eq!(buf.len(), 2);
665 assert_eq!(buf[0], 0xB3);
666 assert_eq!(buf[1], 0xA5);
667 }
668
669 #[test]
670 fn test_option_flags() {
671 let flags1 = OptionFlags1::default();
672 assert_eq!(flags1.to_byte(), 0x00);
673
674 let flags2 = OptionFlags2 {
675 odbc: true,
676 integrated_security: true,
677 ..Default::default()
678 };
679 assert_eq!(flags2.to_byte(), 0x82);
680
681 let flags3 = OptionFlags3 {
682 extension: true,
683 ..Default::default()
684 };
685 assert_eq!(flags3.to_byte(), 0x10);
686 }
687}