#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
mod connect;
mod params;
mod response;
use std::marker::PhantomData;
use mssql_codec::connection::Connection;
#[cfg(feature = "tls")]
use mssql_tls::TlsStream;
use tds_protocol::packet::PacketType;
use tds_protocol::rpc::RpcRequest;
use tds_protocol::token::{EnvChange, EnvChangeType};
use tokio::net::TcpStream;
use tokio::time::timeout;
use crate::config::Config;
use crate::error::{Error, Result};
#[cfg(feature = "otel")]
use crate::instrumentation::InstrumentationContext;
use crate::state::{ConnectionState, InTransaction, Ready};
use crate::statement_cache::StatementCache;
use crate::stream::{MultiResultStream, QueryStream};
use crate::transaction::SavePoint;
pub struct Client<S: ConnectionState> {
config: Config,
_state: PhantomData<S>,
connection: Option<ConnectionHandle>,
server_version: Option<u32>,
current_database: Option<String>,
statement_cache: StatementCache,
transaction_descriptor: u64,
needs_reset: bool,
#[cfg(feature = "otel")]
instrumentation: InstrumentationContext,
}
#[allow(dead_code)] enum ConnectionHandle {
#[cfg(feature = "tls")]
Tls(Connection<TlsStream<TcpStream>>),
#[cfg(feature = "tls")]
TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
Plain(Connection<TcpStream>),
}
impl<S: ConnectionState> Client<S> {
fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
use tds_protocol::token::EnvChangeValue;
match env.env_type {
EnvChangeType::BeginTransaction => {
if let EnvChangeValue::Binary(ref data) = env.new_value {
if data.len() >= 8 {
let descriptor = u64::from_le_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]);
tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
*transaction_descriptor = descriptor;
}
}
}
EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
tracing::debug!(
env_type = ?env.env_type,
"transaction ended via raw SQL"
);
*transaction_descriptor = 0;
}
_ => {}
}
}
async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
let payload =
tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
let max_packet = self.config.packet_size as usize;
let reset = self.needs_reset;
if reset {
self.needs_reset = false; tracing::debug!("sending SQL batch with RESETCONNECTION flag");
}
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
ConnectionHandle::Plain(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
}
Ok(())
}
async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
let payload = rpc.encode_with_transaction(self.transaction_descriptor);
let max_packet = self.config.packet_size as usize;
let reset = self.needs_reset;
if reset {
self.needs_reset = false; tracing::debug!("sending RPC with RESETCONNECTION flag");
}
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
ConnectionHandle::Plain(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
}
Ok(())
}
}
impl Client<Ready> {
pub fn mark_needs_reset(&mut self) {
self.needs_reset = true;
}
#[must_use]
pub fn needs_reset(&self) -> bool {
self.needs_reset
}
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<QueryStream<'a>> {
tracing::debug!(sql = sql, params_count = params.len(), "executing query");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params = Self::convert_params(params)?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_query_response().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let (columns, rows) = result?;
Ok(QueryStream::new(columns, rows))
}
pub async fn query_with_timeout<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<QueryStream<'a>> {
timeout(timeout_duration, self.query(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn query_multiple<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<MultiResultStream<'a>> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing multi-result query"
);
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params = Self::convert_params(params)?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
let result_sets = self.read_multi_result_response().await?;
Ok(MultiResultStream::new(result_sets))
}
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<u64> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing statement"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params = Self::convert_params(params)?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result
}
pub async fn execute_with_timeout(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<u64> {
timeout(timeout_duration, self.execute(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
tracing::debug!("beginning transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("BEGIN");
let result = async {
self.send_sql_batch("BEGIN TRANSACTION").await?;
self.read_transaction_begin_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let transaction_descriptor = result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
statement_cache: self.statement_cache,
transaction_descriptor, needs_reset: self.needs_reset,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
})
}
pub async fn begin_transaction_with_isolation(
mut self,
isolation_level: crate::transaction::IsolationLevel,
) -> Result<Client<InTransaction>> {
tracing::debug!(
isolation_level = %isolation_level.name(),
"beginning transaction with isolation level"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("BEGIN");
let result = async {
self.send_sql_batch(isolation_level.as_sql()).await?;
self.read_execute_result().await?;
self.send_sql_batch("BEGIN TRANSACTION").await?;
self.read_transaction_begin_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let transaction_descriptor = result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
statement_cache: self.statement_cache,
transaction_descriptor,
needs_reset: self.needs_reset,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
})
}
pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
tracing::debug!(sql = sql, "executing simple query");
self.send_sql_batch(sql).await?;
let _ = self.read_execute_result().await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
tracing::debug!("closing connection");
Ok(())
}
#[must_use]
pub fn database(&self) -> Option<&str> {
self.config.database.as_deref()
}
#[must_use]
pub fn host(&self) -> &str {
&self.config.host
}
#[must_use]
pub fn port(&self) -> u16 {
self.config.port
}
#[must_use]
pub fn is_in_transaction(&self) -> bool {
self.transaction_descriptor != 0
}
#[must_use]
pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
let connection = self
.connection
.as_ref()
.expect("connection should be present");
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
}
ConnectionHandle::Plain(conn) => {
crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
}
}
}
}
impl Client<InTransaction> {
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<QueryStream<'a>> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing query in transaction"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params = Self::convert_params(params)?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_query_response().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let (columns, rows) = result?;
Ok(QueryStream::new(columns, rows))
}
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<u64> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing statement in transaction"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params = Self::convert_params(params)?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result
}
pub async fn query_with_timeout<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<QueryStream<'a>> {
timeout(timeout_duration, self.query(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn execute_with_timeout(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<u64> {
timeout(timeout_duration, self.execute(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn commit(mut self) -> Result<Client<Ready>> {
tracing::debug!("committing transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("COMMIT");
let result = async {
self.send_sql_batch("COMMIT TRANSACTION").await?;
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
statement_cache: self.statement_cache,
transaction_descriptor: 0, needs_reset: self.needs_reset,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
})
}
pub async fn rollback(mut self) -> Result<Client<Ready>> {
tracing::debug!("rolling back transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("ROLLBACK");
let result = async {
self.send_sql_batch("ROLLBACK TRANSACTION").await?;
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
statement_cache: self.statement_cache,
transaction_descriptor: 0, needs_reset: self.needs_reset,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
})
}
pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
validate_identifier(name)?;
tracing::debug!(name = name, "creating savepoint");
let sql = format!("SAVE TRANSACTION {name}");
self.send_sql_batch(&sql).await?;
self.read_execute_result().await?;
Ok(SavePoint::new(name.to_string()))
}
pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
self.send_sql_batch(&sql).await?;
self.read_execute_result().await?;
Ok(())
}
pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
tracing::debug!(name = savepoint.name(), "releasing savepoint");
drop(savepoint);
Ok(())
}
#[must_use]
pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
let connection = self
.connection
.as_ref()
.expect("connection should be present");
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
}
ConnectionHandle::Plain(conn) => {
crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
}
}
}
}
fn validate_identifier(name: &str) -> Result<()> {
use once_cell::sync::Lazy;
use regex::Regex;
static IDENTIFIER_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
if name.is_empty() {
return Err(Error::InvalidIdentifier(
"identifier cannot be empty".into(),
));
}
if !IDENTIFIER_RE.is_match(name) {
return Err(Error::InvalidIdentifier(format!(
"invalid identifier '{name}': must start with letter/underscore, \
contain only alphanumerics/_/@/#/$, and be 1-128 characters"
)));
}
Ok(())
}
impl<S: ConnectionState> std::fmt::Debug for Client<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("database", &self.config.database)
.finish()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_validate_identifier_valid() {
assert!(validate_identifier("my_table").is_ok());
assert!(validate_identifier("Table123").is_ok());
assert!(validate_identifier("_private").is_ok());
assert!(validate_identifier("sp_test").is_ok());
}
#[test]
fn test_validate_identifier_invalid() {
assert!(validate_identifier("").is_err());
assert!(validate_identifier("123abc").is_err());
assert!(validate_identifier("table-name").is_err());
assert!(validate_identifier("table name").is_err());
assert!(validate_identifier("table;DROP TABLE users").is_err());
}
}