use super::guid::reorder_bytes;
use super::{Decode, Encode};
use crate::{tds, Error, Result};
use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt};
use bytes::{BufMut, BytesMut};
use std::convert::TryFrom;
use std::io::{Cursor, Read};
use tds::EncryptionLevel;
use uuid::Uuid;
#[allow(unused)]
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub struct ActivityId {
id: Uuid,
sequence: u32,
}
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub struct PreloginMessage {
pub version: u32,
pub sub_build: u16,
pub encryption: EncryptionLevel,
pub instance_name: Option<String>,
pub thread_id: u32,
pub mars: bool,
pub activity_id: Option<ActivityId>,
pub fed_auth_required: bool,
pub nonce: Option<[u8; 32]>,
}
impl PreloginMessage {
pub fn new() -> PreloginMessage {
let driver_version = crate::get_driver_version();
PreloginMessage {
version: driver_version as u32,
sub_build: (driver_version >> 32) as u16,
encryption: EncryptionLevel::NotSupported,
instance_name: None,
thread_id: 0,
mars: false,
activity_id: None,
fed_auth_required: false,
nonce: None,
}
}
#[cfg(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
))]
pub fn negotiated_encryption(&self, expected: EncryptionLevel) -> EncryptionLevel {
match (expected, self.encryption) {
(EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
EncryptionLevel::NotSupported
}
(EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
(EncryptionLevel::On, EncryptionLevel::Off)
| (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
panic!("Server does not allow the requested encryption level.")
}
(_, _) => EncryptionLevel::On,
}
}
#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
pub fn negotiated_encryption(&self, _: EncryptionLevel) -> EncryptionLevel {
EncryptionLevel::NotSupported
}
}
const PRELOGIN_VERSION: u8 = 0;
const PRELOGIN_ENCRYPTION: u8 = 1;
const PRELOGIN_INSTOPT: u8 = 2;
const PRELOGIN_THREADID: u8 = 3;
const PRELOGIN_MARS: u8 = 4;
const PRELOGIN_TRACEID: u8 = 5;
const PRELOGIN_FEDAUTHREQUIRED: u8 = 6;
const PRELOGIN_NONCEOPT: u8 = 7;
const PRELOGIN_TERMINATOR: u8 = 0xff;
impl Encode<BytesMut> for PreloginMessage {
fn encode(self, dst: &mut BytesMut) -> Result<()> {
let mut fields = Vec::new();
let mut data_cursor = Cursor::new(Vec::with_capacity(512));
fields.push((PRELOGIN_VERSION, 0x04 + 0x02)); data_cursor.write_u32::<BigEndian>(self.version)?;
data_cursor.write_u16::<BigEndian>(self.sub_build)?;
fields.push((PRELOGIN_ENCRYPTION, 0x01)); data_cursor.write_u8(self.encryption as u8)?;
fields.push((PRELOGIN_THREADID, 0x04)); data_cursor.write_u32::<BigEndian>(self.thread_id)?;
fields.push((PRELOGIN_MARS, 0x01)); data_cursor.write_u8(self.mars as u8)?;
if self.fed_auth_required {
fields.push((PRELOGIN_FEDAUTHREQUIRED, 0x01));
data_cursor.write_u8(0x01)?;
}
let mut data_offset = (fields.len() * 5 + 1) as u16;
for (token, length) in fields {
dst.put_u8(token);
dst.put_u16(data_offset);
dst.put_u16(length);
data_offset += length;
}
dst.put_u8(PRELOGIN_TERMINATOR);
dst.extend(data_cursor.into_inner());
Ok(())
}
}
impl Decode<BytesMut> for PreloginMessage {
fn decode(src: &mut BytesMut) -> Result<Self>
where
Self: Sized,
{
let mut cursor = Cursor::new(src);
let mut ret = PreloginMessage::new();
loop {
let token = cursor.read_u8()?;
if token == 0xff {
break;
}
let offset = cursor.read_u16::<BigEndian>()?;
let length = cursor.read_u16::<BigEndian>()?;
let old_pos = cursor.position();
cursor.set_position(offset as u64);
match token {
PRELOGIN_VERSION => {
ret.version = cursor.read_u32::<BigEndian>()?;
ret.sub_build = cursor.read_u16::<BigEndian>()?;
}
PRELOGIN_ENCRYPTION => {
let encrypt = cursor.read_u8()?;
ret.encryption = tds::EncryptionLevel::try_from(encrypt).map_err(|_| {
Error::Protocol(format!("invalid encryption value: {}", encrypt).into())
})?;
}
PRELOGIN_INSTOPT => {
let mut bytes = Vec::new();
let mut next_byte = cursor.read_u8()?;
while next_byte != 0x00 {
bytes.push(next_byte);
next_byte = cursor.read_u8()?;
}
if !bytes.is_empty() {
ret.instance_name = Some(String::from_utf8_lossy(&bytes).into_owned());
}
}
PRELOGIN_THREADID => {
ret.thread_id = if length == 0 {
0
} else if length == 4 {
cursor.read_u32::<BigEndian>()?
} else {
panic!("should never happen")
}
}
PRELOGIN_MARS => {
ret.mars = cursor.read_u8()? != 0;
}
PRELOGIN_TRACEID => {
let mut data = [0u8; 16];
cursor.read_exact(&mut data)?;
reorder_bytes(&mut data);
ret.activity_id = Some(ActivityId {
id: Uuid::from_bytes(data),
sequence: cursor.read_u32::<LittleEndian>()?,
});
}
PRELOGIN_FEDAUTHREQUIRED => {
ret.fed_auth_required = cursor.read_u8()? != 0;
}
PRELOGIN_NONCEOPT => {
let mut data = [0u8; 32];
for item in data.iter_mut() {
*item = cursor.read_u8()?;
}
ret.nonce = Some(data);
}
_ => panic!("unsupported prelogin token: {}", token),
}
cursor.set_position(old_pos);
}
Ok(ret)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prelogin_roundtrip() {
let mut payload = BytesMut::new();
let prelogin = PreloginMessage::new();
prelogin
.clone()
.encode(&mut payload)
.expect("encode should succeed");
let decoded = PreloginMessage::decode(&mut payload).expect("decode should succeed");
assert_eq!(prelogin, decoded);
}
#[test]
fn prelogin_with_fedauth_roundtrip() {
let mut payload = BytesMut::new();
let mut prelogin = PreloginMessage::new();
prelogin.fed_auth_required = true;
prelogin
.clone()
.encode(&mut payload)
.expect("encode should succeed");
let decoded = PreloginMessage::decode(&mut payload).expect("decode should succeed");
assert_eq!(prelogin, decoded);
}
}