use std::ops::{Deref, DerefMut};
use bytes::{Bytes, BytesMut};
use sqlx_rt::{AsyncWriteExt, TcpStream};
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::{BufStream, Encode};
use crate::mssql::connection::tls_prelogin_stream_wrapper::TlsPreloginWrapper;
use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error as ProtocolError;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::order::Order;
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status, PACKET_HEADER_SIZE};
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError};
use crate::net::{MaybeTlsStream, TlsConfig};
use crate::HashMap;
use std::sync::Arc;
pub(crate) struct MssqlStream {
inner: BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>,
pub(crate) pending_done_count: usize,
pub(crate) transaction_descriptor: u64,
pub(crate) transaction_depth: usize,
response: Option<(PacketHeader, Bytes)>,
pub(crate) columns: Arc<Vec<MssqlColumn>>,
pub(crate) column_names: Arc<HashMap<UStr, usize>>,
pub(crate) max_packet_size: usize,
options: MssqlConnectOptions,
}
impl MssqlStream {
pub(super) async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
let port = match (options.port, &options.instance) {
(Some(port), _) => {
log::debug!(
"using explicitly specified port {} for host '{}'",
port,
options.host
);
port
}
(None, Some(instance)) => {
super::ssrp::resolve_instance_port(&options.host, instance).await?
}
(None, None) => {
const DEFAULT_PORT: u16 = 1433;
log::debug!(
"using default port {} for host '{}'",
DEFAULT_PORT,
options.host
);
DEFAULT_PORT
}
};
log::debug!("establishing TCP connection to {}:{}", options.host, port);
let tcp_stream = TcpStream::connect((&*options.host, port)).await?;
log::debug!("TCP connection established to {}:{}", options.host, port);
let wrapped_stream = TlsPreloginWrapper::new(tcp_stream);
let inner = BufStream::new(MaybeTlsStream::Raw(wrapped_stream));
Ok(Self {
inner,
columns: Default::default(),
column_names: Default::default(),
response: None,
pending_done_count: 0,
transaction_descriptor: 0,
transaction_depth: 0,
max_packet_size: options
.requested_packet_size
.try_into()
.unwrap_or(usize::MAX),
options: options.clone(),
})
}
pub(crate) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) {
write_packets(&mut self.inner.wbuf, self.max_packet_size, ty, payload)
}
pub(crate) async fn write_packet_and_flush<'en, T: Encode<'en>>(
&mut self,
ty: PacketType,
payload: T,
) -> Result<(), Error> {
if !self.inner.wbuf.is_empty() {
self.flush().await?;
}
self.write_packet(ty, payload);
self.flush().await?;
Ok(())
}
pub(crate) async fn flush(&mut self) -> Result<(), Error> {
if self.inner.wbuf.len() > self.max_packet_size {
for chunk in self.inner.wbuf.chunks(self.max_packet_size) {
self.inner.stream.write_all(chunk).await?;
self.inner.stream.flush().await?;
}
self.inner.wbuf.clear();
} else {
self.inner.flush().await?;
}
Ok(())
}
pub(super) async fn recv_packet(&mut self) -> Result<(PacketHeader, Bytes), Error> {
let mut header: PacketHeader = self.inner.read(8).await?;
if !matches!(header.r#type, PacketType::TabularResult) {
return Err(err_protocol!(
"received unexpected packet: {:?}",
header.r#type
));
}
let mut payload = BytesMut::new();
loop {
self.inner
.read_raw_into(&mut payload, (header.length - 8) as usize)
.await?;
if header.status.contains(Status::END_OF_MESSAGE) {
break;
}
header = self.inner.read(8).await?;
}
Ok((header, payload.freeze()))
}
pub(super) async fn recv_message(&mut self) -> Result<Message, Error> {
loop {
while self.response.as_ref().is_some_and(|r| !r.1.is_empty()) {
let buf = if let Some((_, buf)) = self.response.as_mut() {
buf
} else {
break;
};
let ty = MessageType::get(buf)?;
let message = match ty {
MessageType::EnvChange => {
match EnvChange::get(buf)? {
EnvChange::BeginTransaction(desc) => {
self.transaction_descriptor = desc;
}
EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
self.transaction_descriptor = 0;
}
EnvChange::PacketSize(size) => {
self.max_packet_size = size.clamp(512, 32767).try_into().unwrap();
}
_ => {}
}
continue;
}
MessageType::Info => {
let _ = Info::get(buf)?;
continue;
}
MessageType::Row => Message::Row(Row::get(buf, false, &self.columns)?),
MessageType::NbcRow => Message::Row(Row::get(buf, true, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
MessageType::ReturnValue => Message::ReturnValue(ReturnValue::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
MessageType::Order => Message::Order(Order::get(buf)?),
MessageType::Error => {
let error = ProtocolError::get(buf)?;
return self.handle_error(error);
}
MessageType::ColMetaData => {
ColMetaData::get(
buf,
Arc::make_mut(&mut self.columns),
Arc::make_mut(&mut self.column_names),
)?;
continue;
}
};
return Ok(message);
}
self.response = Some(self.recv_packet().await?);
}
}
pub(crate) fn handle_done(&mut self, _done: &Done) {
self.pending_done_count -= 1;
}
pub(crate) fn handle_error<T>(&mut self, error: ProtocolError) -> Result<T, Error> {
Err(MssqlDatabaseError(error).into())
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.wbuf.is_empty() {
self.flush().await?;
}
while self.pending_done_count > 0 {
let message = self.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
if !done.status.contains(DoneStatus::DONE_MORE) {
self.handle_done(&done);
}
}
}
Ok(())
}
pub(crate) async fn setup_encryption(&mut self) -> Result<(), Error> {
let tls_config = TlsConfig {
accept_invalid_certs: self.options.trust_server_certificate,
hostname: self
.options
.hostname_in_certificate
.as_deref()
.unwrap_or(&self.options.host),
accept_invalid_hostnames: self.options.hostname_in_certificate.is_none(),
root_cert_path: self.options.ssl_root_cert.as_ref(),
client_cert_path: None,
client_key_path: None,
};
self.inner.deref_mut().start_handshake();
self.inner.upgrade(tls_config).await?;
self.inner.deref_mut().handshake_complete();
Ok(())
}
pub(crate) async fn disable_encryption(&mut self) -> Result<(), Error> {
self.inner.downgrade()?;
Ok(())
}
}
pub(crate) fn write_packets<'en, T: Encode<'en>>(
buffer: &mut Vec<u8>,
max_packet_size: usize,
ty: PacketType,
payload: T,
) {
assert!(buffer.is_empty());
let mut packet_header = [0u8; PACKET_HEADER_SIZE].to_vec();
buffer.extend_from_slice(&packet_header);
payload.encode(buffer);
let len = buffer.len() - PACKET_HEADER_SIZE;
let max_packet_contents_size = max_packet_size - PACKET_HEADER_SIZE;
let mut packet_count = len / max_packet_contents_size;
let last_packet_contents_size = len % max_packet_contents_size;
if last_packet_contents_size > 0 {
packet_count += 1;
}
buffer.resize(len + PACKET_HEADER_SIZE * packet_count, 0);
for packet_index in (0..packet_count).rev() {
let header_start = packet_index * max_packet_size;
let target_contents_start = header_start + PACKET_HEADER_SIZE;
let is_last = packet_index + 1 == packet_count;
let packet_contents_size = if is_last && last_packet_contents_size > 0 {
last_packet_contents_size
} else {
max_packet_contents_size
};
let packet_size = packet_contents_size + PACKET_HEADER_SIZE;
let current_contents_start = PACKET_HEADER_SIZE + packet_index * max_packet_contents_size;
let current_contents_end = current_contents_start + packet_contents_size;
if current_contents_start != target_contents_start {
assert!(current_contents_start < target_contents_start);
buffer.copy_within(
current_contents_start..current_contents_end,
target_contents_start,
);
}
packet_header.truncate(0);
PacketHeader {
r#type: ty,
status: if is_last {
Status::END_OF_MESSAGE
} else {
Status::NORMAL
},
length: u16::try_from(packet_size).expect("packet size impossibly large"),
server_process_id: 0,
packet_id: 1,
}
.encode(&mut packet_header);
assert_eq!(packet_header.len(), PACKET_HEADER_SIZE);
buffer[header_start..target_contents_start].copy_from_slice(&packet_header);
}
}
#[test]
fn test_write_packets() {
let mut buffer = Vec::<u8>::new();
write_packets(
&mut buffer,
PACKET_HEADER_SIZE + 4,
PacketType::Rpc,
&b"123456789"[..],
);
let expected = b"\
\x03\x00\x00\x0C\x00\x00\x01\x00\
1234\
\x03\x00\x00\x0C\x00\x00\x01\x00\
5678\
\x03\x01\x00\x09\x00\x00\x01\x00\
9";
assert_eq!(buffer, expected);
buffer.truncate(0);
write_packets(
&mut buffer,
PACKET_HEADER_SIZE + 4,
PacketType::Rpc,
&b"12345678"[..],
);
let expected = b"\
\x03\x00\x00\x0C\x00\x00\x01\x00\
1234\
\x03\x01\x00\x0C\x00\x00\x01\x00\
5678";
assert_eq!(buffer, expected);
}
impl Deref for MssqlStream {
type Target = BufStream<MaybeTlsStream<TlsPreloginWrapper<TcpStream>>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for MssqlStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}