use bytes::{BufMut, Bytes, BytesMut};
use crate::codec::write_utf16_string;
use crate::prelude::*;
use crate::version::TdsVersion;
pub const LOGIN7_HEADER_SIZE: usize = 94;
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionFlags1 {
pub byte_order_be: bool,
pub char_ebcdic: bool,
pub float_ieee: bool,
pub dump_load_off: bool,
pub use_db_notify: bool,
pub database_fatal: bool,
pub set_lang_warn: bool,
}
impl OptionFlags1 {
#[must_use]
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
if self.byte_order_be {
flags |= 0x01; }
if self.char_ebcdic {
flags |= 0x02; }
if self.dump_load_off {
flags |= 0x10; }
if self.use_db_notify {
flags |= 0x20; }
if self.database_fatal {
flags |= 0x40; }
if self.set_lang_warn {
flags |= 0x80; }
flags
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionFlags2 {
pub language_fatal: bool,
pub odbc: bool,
pub tran_boundary: bool,
pub cache_connect: bool,
pub user_type: u8,
pub integrated_security: bool,
}
impl OptionFlags2 {
#[must_use]
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
if self.language_fatal {
flags |= 0x01;
}
if self.odbc {
flags |= 0x02;
}
if self.tran_boundary {
flags |= 0x04;
}
if self.cache_connect {
flags |= 0x08;
}
flags |= (self.user_type & 0x07) << 4;
if self.integrated_security {
flags |= 0x80;
}
flags
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TypeFlags {
pub sql_type: u8,
pub oledb: bool,
pub read_only_intent: bool,
}
impl TypeFlags {
#[must_use]
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
flags |= self.sql_type & 0x0F;
if self.oledb {
flags |= 0x10;
}
if self.read_only_intent {
flags |= 0x20;
}
flags
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionFlags3 {
pub change_password: bool,
pub user_instance: bool,
pub send_yukon_binary_xml: bool,
pub unknown_collation_handling: bool,
pub extension: bool,
}
impl OptionFlags3 {
#[must_use]
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
if self.change_password {
flags |= 0x01;
}
if self.user_instance {
flags |= 0x02;
}
if self.send_yukon_binary_xml {
flags |= 0x04;
}
if self.unknown_collation_handling {
flags |= 0x08;
}
if self.extension {
flags |= 0x10;
}
flags
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum FeatureId {
SessionRecovery = 0x01,
FedAuth = 0x02,
ColumnEncryption = 0x04,
GlobalTransactions = 0x05,
AzureSqlSupport = 0x08,
DataClassification = 0x09,
Utf8Support = 0x0A,
AzureSqlDnsCaching = 0x0B,
Terminator = 0xFF,
}
#[derive(Debug, Clone)]
pub struct Login7 {
pub tds_version: TdsVersion,
pub packet_size: u32,
pub client_prog_version: u32,
pub client_pid: u32,
pub connection_id: u32,
pub option_flags1: OptionFlags1,
pub option_flags2: OptionFlags2,
pub type_flags: TypeFlags,
pub option_flags3: OptionFlags3,
pub client_timezone: i32,
pub client_lcid: u32,
pub hostname: String,
pub username: String,
pub password: String,
pub app_name: String,
pub server_name: String,
pub unused: String,
pub library_name: String,
pub language: String,
pub database: String,
pub client_id: [u8; 6],
pub sspi_data: Vec<u8>,
pub attach_db_file: String,
pub new_password: String,
pub features: Vec<FeatureExtension>,
}
#[derive(Debug, Clone)]
pub struct FeatureExtension {
pub feature_id: FeatureId,
pub data: Bytes,
}
impl Default for Login7 {
fn default() -> Self {
#[cfg(feature = "std")]
let client_pid = std::process::id();
#[cfg(not(feature = "std"))]
let client_pid = 0;
Self {
tds_version: TdsVersion::V7_4,
packet_size: 4096,
client_prog_version: 0,
client_pid,
connection_id: 0,
option_flags1: OptionFlags1 {
use_db_notify: true,
database_fatal: true,
..Default::default()
},
option_flags2: OptionFlags2 {
language_fatal: true,
odbc: true,
..Default::default()
},
type_flags: TypeFlags::default(), option_flags3: OptionFlags3 {
unknown_collation_handling: true,
..Default::default()
},
client_timezone: 0,
client_lcid: 0x0409, hostname: String::new(),
username: String::new(),
password: String::new(),
app_name: String::from("rust-mssql-driver"),
server_name: String::new(),
unused: String::new(),
library_name: String::from("rust-mssql-driver"),
language: String::new(),
database: String::new(),
client_id: [0u8; 6],
sspi_data: Vec::new(),
attach_db_file: String::new(),
new_password: String::new(),
features: Vec::new(),
}
}
}
impl Login7 {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_tds_version(mut self, version: TdsVersion) -> Self {
self.tds_version = version;
self
}
#[must_use]
pub fn with_sql_auth(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = username.into();
self.password = password.into();
self.option_flags2.integrated_security = false;
self
}
#[must_use]
pub fn with_integrated_auth(mut self, sspi_data: Vec<u8>) -> Self {
self.sspi_data = sspi_data;
self.option_flags2.integrated_security = true;
self
}
#[must_use]
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = database.into();
self
}
#[must_use]
pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
self.hostname = hostname.into();
self
}
#[must_use]
pub fn with_app_name(mut self, app_name: impl Into<String>) -> Self {
self.app_name = app_name.into();
self
}
#[must_use]
pub fn with_server_name(mut self, server_name: impl Into<String>) -> Self {
self.server_name = server_name.into();
self
}
#[must_use]
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = language.into();
self
}
#[must_use]
pub fn with_packet_size(mut self, packet_size: u32) -> Self {
self.packet_size = packet_size;
self
}
#[must_use]
pub fn with_read_only_intent(mut self, read_only: bool) -> Self {
self.type_flags.read_only_intent = read_only;
self
}
#[must_use]
pub fn with_feature(mut self, feature: FeatureExtension) -> Self {
self.option_flags3.extension = true;
self.features.push(feature);
self
}
#[must_use]
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(512);
let mut offset = LOGIN7_HEADER_SIZE as u16;
let hostname_len = self.hostname.encode_utf16().count() as u16;
let username_len = self.username.encode_utf16().count() as u16;
let password_len = self.password.encode_utf16().count() as u16;
let app_name_len = self.app_name.encode_utf16().count() as u16;
let server_name_len = self.server_name.encode_utf16().count() as u16;
let unused_len = self.unused.encode_utf16().count() as u16;
let library_name_len = self.library_name.encode_utf16().count() as u16;
let language_len = self.language.encode_utf16().count() as u16;
let database_len = self.database.encode_utf16().count() as u16;
let sspi_len = self.sspi_data.len() as u16;
let attach_db_len = self.attach_db_file.encode_utf16().count() as u16;
let new_password_len = self.new_password.encode_utf16().count() as u16;
let mut var_data = BytesMut::new();
let hostname_offset = offset;
write_utf16_string(&mut var_data, &self.hostname);
offset += hostname_len * 2;
let username_offset = offset;
write_utf16_string(&mut var_data, &self.username);
offset += username_len * 2;
let password_offset = offset;
Self::write_obfuscated_password(&mut var_data, &self.password);
offset += password_len * 2;
let app_name_offset = offset;
write_utf16_string(&mut var_data, &self.app_name);
offset += app_name_len * 2;
let server_name_offset = offset;
write_utf16_string(&mut var_data, &self.server_name);
offset += server_name_len * 2;
let extension_offset = if self.option_flags3.extension {
let base = offset
+ unused_len * 2
+ library_name_len * 2
+ language_len * 2
+ database_len * 2
+ sspi_len
+ attach_db_len * 2
+ new_password_len * 2;
var_data.put_u32_le(base as u32);
offset += 4;
base
} else {
let unused_offset = offset;
write_utf16_string(&mut var_data, &self.unused);
offset += unused_len * 2;
unused_offset
};
let library_name_offset = offset;
write_utf16_string(&mut var_data, &self.library_name);
offset += library_name_len * 2;
let language_offset = offset;
write_utf16_string(&mut var_data, &self.language);
offset += language_len * 2;
let database_offset = offset;
write_utf16_string(&mut var_data, &self.database);
offset += database_len * 2;
let sspi_offset = offset;
var_data.put_slice(&self.sspi_data);
offset += sspi_len;
let attach_db_offset = offset;
write_utf16_string(&mut var_data, &self.attach_db_file);
offset += attach_db_len * 2;
let new_password_offset = offset;
if !self.new_password.is_empty() {
Self::write_obfuscated_password(&mut var_data, &self.new_password);
}
#[allow(unused_assignments)]
{
offset += new_password_len * 2;
}
if self.option_flags3.extension {
for feature in &self.features {
var_data.put_u8(feature.feature_id as u8);
var_data.put_u32_le(feature.data.len() as u32);
var_data.put_slice(&feature.data);
}
var_data.put_u8(FeatureId::Terminator as u8);
}
let total_length = LOGIN7_HEADER_SIZE + var_data.len();
buf.put_u32_le(total_length as u32); buf.put_u32_le(self.tds_version.raw()); buf.put_u32_le(self.packet_size); buf.put_u32_le(self.client_prog_version); buf.put_u32_le(self.client_pid); buf.put_u32_le(self.connection_id);
buf.put_u8(self.option_flags1.to_byte());
buf.put_u8(self.option_flags2.to_byte());
buf.put_u8(self.type_flags.to_byte());
buf.put_u8(self.option_flags3.to_byte());
buf.put_i32_le(self.client_timezone); buf.put_u32_le(self.client_lcid);
buf.put_u16_le(hostname_offset);
buf.put_u16_le(hostname_len);
buf.put_u16_le(username_offset);
buf.put_u16_le(username_len);
buf.put_u16_le(password_offset);
buf.put_u16_le(password_len);
buf.put_u16_le(app_name_offset);
buf.put_u16_le(app_name_len);
buf.put_u16_le(server_name_offset);
buf.put_u16_le(server_name_len);
if self.option_flags3.extension {
buf.put_u16_le(extension_offset as u16);
buf.put_u16_le(4); } else {
buf.put_u16_le(extension_offset as u16);
buf.put_u16_le(unused_len);
}
buf.put_u16_le(library_name_offset);
buf.put_u16_le(library_name_len);
buf.put_u16_le(language_offset);
buf.put_u16_le(language_len);
buf.put_u16_le(database_offset);
buf.put_u16_le(database_len);
buf.put_slice(&self.client_id);
buf.put_u16_le(sspi_offset);
buf.put_u16_le(sspi_len);
buf.put_u16_le(attach_db_offset);
buf.put_u16_le(attach_db_len);
buf.put_u16_le(new_password_offset);
buf.put_u16_le(new_password_len);
buf.put_u32_le(0);
buf.put_slice(&var_data);
buf.freeze()
}
fn write_obfuscated_password(dst: &mut impl BufMut, password: &str) {
for c in password.encode_utf16() {
let low = (c & 0xFF) as u8;
let high = ((c >> 8) & 0xFF) as u8;
let low_enc = low.rotate_right(4) ^ 0xA5;
let high_enc = high.rotate_right(4) ^ 0xA5;
dst.put_u8(low_enc);
dst.put_u8(high_enc);
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_login7_default() {
let login = Login7::new();
assert_eq!(login.tds_version, TdsVersion::V7_4);
assert_eq!(login.packet_size, 4096);
assert!(login.option_flags2.odbc);
}
#[test]
fn test_login7_encode() {
let login = Login7::new()
.with_hostname("TESTHOST")
.with_sql_auth("testuser", "testpass")
.with_database("testdb")
.with_app_name("TestApp");
let encoded = login.encode();
assert!(encoded.len() >= LOGIN7_HEADER_SIZE);
let tds_version = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
assert_eq!(tds_version, TdsVersion::V7_4.raw());
}
#[test]
fn test_password_obfuscation() {
let mut buf = BytesMut::new();
Login7::write_obfuscated_password(&mut buf, "a");
assert_eq!(buf.len(), 2);
assert_eq!(buf[0], 0xB3);
assert_eq!(buf[1], 0xA5);
}
#[test]
fn test_option_flags() {
let flags1 = OptionFlags1::default();
assert_eq!(flags1.to_byte(), 0x00);
let flags2 = OptionFlags2 {
odbc: true,
integrated_security: true,
..Default::default()
};
assert_eq!(flags2.to_byte(), 0x82);
let flags3 = OptionFlags3 {
extension: true,
..Default::default()
};
assert_eq!(flags3.to_byte(), 0x10);
}
}