#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
mod connect;
mod params;
mod response;
use std::marker::PhantomData;
use mssql_codec::connection::Connection;
#[cfg(feature = "tls")]
use mssql_tls::TlsStream;
use tds_protocol::packet::PacketType;
use tds_protocol::rpc::RpcRequest;
use tds_protocol::token::{EnvChange, EnvChangeType};
use tokio::net::TcpStream;
use tokio::time::timeout;
use crate::config::Config;
use crate::error::{Error, Result};
#[cfg(feature = "otel")]
use crate::instrumentation::InstrumentationContext;
use crate::state::{ConnectionState, InTransaction, Ready};
use crate::statement_cache::StatementCache;
use crate::stream::{MultiResultStream, QueryStream};
use crate::transaction::SavePoint;
pub struct Client<S: ConnectionState> {
config: Config,
_state: PhantomData<S>,
connection: Option<ConnectionHandle>,
server_version: Option<u32>,
current_database: Option<String>,
server_collation: Option<tds_protocol::token::Collation>,
statement_cache: StatementCache,
transaction_descriptor: u64,
in_flight: bool,
needs_reset: bool,
#[cfg(feature = "otel")]
instrumentation: InstrumentationContext,
#[cfg(feature = "always-encrypted")]
pub(crate) encryption_context: Option<std::sync::Arc<crate::encryption::EncryptionContext>>,
}
#[allow(dead_code)] enum ConnectionHandle {
#[cfg(feature = "tls")]
Tls(Connection<TlsStream<TcpStream>>),
#[cfg(feature = "tls")]
TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
Plain(Connection<TcpStream>),
}
impl<S: ConnectionState> Client<S> {
fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
use tds_protocol::token::EnvChangeValue;
match env.env_type {
EnvChangeType::BeginTransaction => {
if let EnvChangeValue::Binary(ref data) = env.new_value {
if data.len() >= 8 {
let descriptor = u64::from_le_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]);
tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
*transaction_descriptor = descriptor;
}
}
}
EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
tracing::debug!(
env_type = ?env.env_type,
"transaction ended via raw SQL"
);
*transaction_descriptor = 0;
}
_ => {}
}
}
async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
let payload =
tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
let max_packet = self.config.packet_size as usize;
let reset = self.needs_reset;
if reset {
self.needs_reset = false; tracing::debug!("sending SQL batch with RESETCONNECTION flag");
}
self.in_flight = true;
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
ConnectionHandle::Plain(conn) => {
conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
.await?;
}
}
Ok(())
}
pub(crate) async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
let payload = rpc.encode_with_transaction(self.transaction_descriptor);
let max_packet = self.config.packet_size as usize;
let reset = self.needs_reset;
if reset {
self.needs_reset = false; tracing::debug!("sending RPC with RESETCONNECTION flag");
}
self.in_flight = true;
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
ConnectionHandle::Plain(conn) => {
conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
.await?;
}
}
Ok(())
}
pub fn procedure(
&mut self,
proc_name: &str,
) -> Result<crate::procedure::ProcedureBuilder<'_, S>> {
crate::validation::validate_qualified_identifier(proc_name)?;
Ok(crate::procedure::ProcedureBuilder::new(self, proc_name))
}
pub async fn call_procedure(
&mut self,
proc_name: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<crate::stream::ProcedureResult> {
crate::validation::validate_qualified_identifier(proc_name)?;
tracing::debug!(
proc_name = proc_name,
params_count = params.len(),
"executing stored procedure"
);
let rpc_params =
Self::convert_params_positional(params, self.send_unicode(), self.server_collation())?;
let mut rpc = RpcRequest::named(proc_name);
for param in rpc_params {
rpc = rpc.param(param);
}
self.send_rpc(&rpc).await?;
self.read_procedure_result().await
}
pub async fn bulk_insert(
&mut self,
builder: &crate::bulk::BulkInsertBuilder,
) -> Result<crate::bulk::BulkWriter<'_, S>> {
use tds_protocol::token::{ColMetaData, Token};
tracing::debug!(
table = builder.table_name(),
columns = builder.columns().len(),
"starting bulk insert"
);
let meta_query = format!("SELECT TOP 0 * FROM {}", builder.table_name());
self.send_sql_batch(&meta_query).await?;
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
let message = match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => conn.read_message().await?,
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => conn.read_message().await?,
ConnectionHandle::Plain(conn) => conn.read_message().await?,
}
.ok_or(Error::ConnectionClosed)?;
self.in_flight = false;
let raw_payload = message.payload.clone();
let mut parser = self.create_parser(message.payload);
let mut server_metadata: Option<ColMetaData> = None;
let mut meta_start: usize = 0;
let mut meta_end: usize = 0;
loop {
let pos_before = raw_payload.len() - parser.remaining();
let token = parser.next_token_with_metadata(server_metadata.as_ref())?;
let pos_after = raw_payload.len() - parser.remaining();
let Some(token) = token else { break };
match token {
Token::ColMetaData(meta) => {
meta_start = pos_before;
meta_end = pos_after;
server_metadata = Some(meta);
}
Token::Done(_) => break,
_ => {}
}
}
if let Some(ref meta) = server_metadata {
use tds_protocol::types::TypeId;
for col in meta.columns.iter() {
let (rejected, replacement) = match col.type_id {
TypeId::Text => (Some("TEXT"), "VARCHAR(MAX)"),
TypeId::NText => (Some("NTEXT"), "NVARCHAR(MAX)"),
TypeId::Image => (Some("IMAGE"), "VARBINARY(MAX)"),
_ => (None, ""),
};
if let Some(sql_type) = rejected {
return Err(Error::from(mssql_types::TypeError::UnsupportedType {
sql_type: sql_type.to_string(),
reason: format!(
"column `{}` in table `{}` is {} — TEXT/NTEXT/IMAGE \
are not supported. Alter the column to {} instead \
(Microsoft deprecated TEXT/NTEXT/IMAGE in SQL \
Server 2005).",
col.name,
builder.table_name(),
sql_type,
replacement,
),
}));
}
}
}
let stmt = builder.build_insert_bulk_statement()?;
self.send_sql_batch(&stmt).await?;
self.read_execute_result().await?;
let raw_meta = if meta_end > meta_start {
Some(raw_payload.slice(meta_start..meta_end))
} else {
None
};
let server_cols = server_metadata.as_ref().map(|m| m.columns.as_slice());
let bulk = crate::bulk::BulkInsert::new_with_server_metadata(
builder.columns().to_vec(),
builder.options().batch_size,
raw_meta,
server_cols,
);
Ok(crate::bulk::BulkWriter::new(self, bulk))
}
pub async fn bulk_insert_without_schema_discovery(
&mut self,
builder: &crate::bulk::BulkInsertBuilder,
) -> Result<crate::bulk::BulkWriter<'_, S>> {
tracing::debug!(
table = builder.table_name(),
columns = builder.columns().len(),
"starting bulk insert (no schema discovery)"
);
let stmt = builder.build_insert_bulk_statement()?;
self.send_sql_batch(&stmt).await?;
self.read_execute_result().await?;
let bulk =
crate::bulk::BulkInsert::new(builder.columns().to_vec(), builder.options().batch_size);
Ok(crate::bulk::BulkWriter::new(self, bulk))
}
pub(crate) async fn send_and_read_bulk_load(&mut self, payload: bytes::Bytes) -> Result<u64> {
let max_packet = self.config.packet_size as usize;
self.in_flight = true;
let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
conn.send_message(PacketType::BulkLoad, payload, max_packet)
.await?;
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
conn.send_message(PacketType::BulkLoad, payload, max_packet)
.await?;
}
ConnectionHandle::Plain(conn) => {
conn.send_message(PacketType::BulkLoad, payload, max_packet)
.await?;
}
}
self.read_execute_result().await
}
pub async fn query_named<'a>(
&'a mut self,
sql: &str,
params: &[crate::to_params::NamedParam],
) -> Result<QueryStream<'a>> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing query with named parameters"
);
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_named_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
let resp = self.read_query_response().await?;
#[cfg(feature = "always-encrypted")]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
resp.decryptor,
))
}
#[cfg(not(feature = "always-encrypted"))]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
))
}
}
pub async fn execute_named(
&mut self,
sql: &str,
params: &[crate::to_params::NamedParam],
) -> Result<u64> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing statement with named parameters"
);
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_named_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_execute_result().await
}
pub(crate) fn send_unicode(&self) -> bool {
self.config.send_string_parameters_as_unicode
}
pub(crate) fn server_collation(&self) -> Option<&tds_protocol::token::Collation> {
self.server_collation.as_ref()
}
}
impl Client<Ready> {
pub fn mark_needs_reset(&mut self) {
self.needs_reset = true;
}
#[must_use]
pub fn needs_reset(&self) -> bool {
self.needs_reset
}
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<QueryStream<'a>> {
tracing::debug!(sql = sql, params_count = params.len(), "executing query");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_query_response().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let resp = result?;
#[cfg(feature = "always-encrypted")]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
resp.decryptor,
))
}
#[cfg(not(feature = "always-encrypted"))]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
))
}
}
pub async fn query_with_timeout<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<QueryStream<'a>> {
timeout(timeout_duration, self.query(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn query_multiple<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<MultiResultStream<'a>> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing multi-result query"
);
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
let result_sets = self.read_multi_result_response().await?;
Ok(MultiResultStream::new(result_sets))
}
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<u64> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing statement"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result
}
pub async fn execute_with_timeout(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<u64> {
timeout(timeout_duration, self.execute(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
tracing::debug!("beginning transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("BEGIN");
let result = async {
self.send_sql_batch("BEGIN TRANSACTION").await?;
self.read_transaction_begin_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let transaction_descriptor = result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
server_collation: self.server_collation,
statement_cache: self.statement_cache,
transaction_descriptor, needs_reset: self.needs_reset,
in_flight: self.in_flight,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
#[cfg(feature = "always-encrypted")]
encryption_context: self.encryption_context,
})
}
pub async fn begin_transaction_with_isolation(
mut self,
isolation_level: crate::transaction::IsolationLevel,
) -> Result<Client<InTransaction>> {
tracing::debug!(
isolation_level = %isolation_level.name(),
"beginning transaction with isolation level"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("BEGIN");
let result = async {
self.send_sql_batch(isolation_level.as_sql()).await?;
self.read_execute_result().await?;
self.send_sql_batch("BEGIN TRANSACTION").await?;
self.read_transaction_begin_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let transaction_descriptor = result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
server_collation: self.server_collation,
statement_cache: self.statement_cache,
transaction_descriptor,
needs_reset: self.needs_reset,
in_flight: self.in_flight,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
#[cfg(feature = "always-encrypted")]
encryption_context: self.encryption_context,
})
}
pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
tracing::debug!(sql = sql, "executing simple query");
self.send_sql_batch(sql).await?;
let _ = self.read_execute_result().await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
tracing::debug!("closing connection");
Ok(())
}
#[must_use]
pub fn database(&self) -> Option<&str> {
self.config.database.as_deref()
}
#[must_use]
pub fn host(&self) -> &str {
&self.config.host
}
#[must_use]
pub fn port(&self) -> u16 {
self.config.port
}
#[must_use]
pub fn is_in_transaction(&self) -> bool {
self.transaction_descriptor != 0
}
#[must_use]
pub fn is_in_flight(&self) -> bool {
self.in_flight
}
#[cfg(feature = "always-encrypted")]
#[must_use]
pub fn has_encryption_provider(&self, name: &str) -> bool {
self.encryption_context
.as_ref()
.is_some_and(|ctx| ctx.has_provider(name))
}
#[must_use]
pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
let connection = self
.connection
.as_ref()
.expect("connection should be present");
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
}
ConnectionHandle::Plain(conn) => {
crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
}
}
}
}
impl Client<InTransaction> {
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<QueryStream<'a>> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing query in transaction"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_query_response().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
let resp = result?;
#[cfg(feature = "always-encrypted")]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
resp.decryptor,
))
}
#[cfg(not(feature = "always-encrypted"))]
{
Ok(QueryStream::from_raw(
resp.columns,
resp.pending_rows,
resp.meta,
))
}
}
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
) -> Result<u64> {
tracing::debug!(
sql = sql,
params_count = params.len(),
"executing statement in transaction"
);
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.query_span(sql);
let result = async {
if params.is_empty() {
self.send_sql_batch(sql).await?;
} else {
let rpc_params =
Self::convert_params(params, self.send_unicode(), self.server_collation())?;
let rpc = RpcRequest::execute_sql(sql, rpc_params);
self.send_rpc(&rpc).await?;
}
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result
}
pub async fn query_with_timeout<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<QueryStream<'a>> {
timeout(timeout_duration, self.query(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
pub async fn execute_with_timeout(
&mut self,
sql: &str,
params: &[&(dyn crate::ToSql + Sync)],
timeout_duration: std::time::Duration,
) -> Result<u64> {
timeout(timeout_duration, self.execute(sql, params))
.await
.map_err(|_| Error::CommandTimeout)?
}
#[cfg(all(windows, feature = "filestream"))]
pub async fn open_filestream(
&mut self,
path: &str,
access: crate::filestream::FileStreamAccess,
) -> Result<crate::filestream::FileStream> {
tracing::debug!(path = path, ?access, "opening FILESTREAM BLOB");
let txn_context: Vec<u8> = {
let rows = self
.query("SELECT GET_FILESTREAM_TRANSACTION_CONTEXT()", &[])
.await?;
let mut ctx = None;
for result in rows {
let row = result?;
ctx = Some(row.get::<Vec<u8>>(0)?);
}
ctx.ok_or_else(|| {
Error::FileStream("GET_FILESTREAM_TRANSACTION_CONTEXT() returned no rows".into())
})?
};
crate::filestream::FileStream::open(path, access, &txn_context)
}
pub async fn commit(mut self) -> Result<Client<Ready>> {
tracing::debug!("committing transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("COMMIT");
let result = async {
self.send_sql_batch("COMMIT TRANSACTION").await?;
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
server_collation: self.server_collation,
statement_cache: self.statement_cache,
transaction_descriptor: 0, needs_reset: self.needs_reset,
in_flight: self.in_flight,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
#[cfg(feature = "always-encrypted")]
encryption_context: self.encryption_context,
})
}
pub async fn rollback(mut self) -> Result<Client<Ready>> {
tracing::debug!("rolling back transaction");
#[cfg(feature = "otel")]
let instrumentation = self.instrumentation.clone();
#[cfg(feature = "otel")]
let mut span = instrumentation.transaction_span("ROLLBACK");
let result = async {
self.send_sql_batch("ROLLBACK TRANSACTION").await?;
self.read_execute_result().await
}
.await;
#[cfg(feature = "otel")]
match &result {
Ok(_) => InstrumentationContext::record_success(&mut span, None),
Err(e) => InstrumentationContext::record_error(&mut span, e),
}
#[cfg(feature = "otel")]
drop(span);
result?;
Ok(Client {
config: self.config,
_state: PhantomData,
connection: self.connection,
server_version: self.server_version,
current_database: self.current_database,
server_collation: self.server_collation,
statement_cache: self.statement_cache,
transaction_descriptor: 0, needs_reset: self.needs_reset,
in_flight: self.in_flight,
#[cfg(feature = "otel")]
instrumentation: self.instrumentation,
#[cfg(feature = "always-encrypted")]
encryption_context: self.encryption_context,
})
}
pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
crate::validation::validate_identifier(name)?;
tracing::debug!(name = name, "creating savepoint");
let sql = format!("SAVE TRANSACTION {name}");
self.send_sql_batch(&sql).await?;
self.read_execute_result().await?;
Ok(SavePoint::new(name.to_string()))
}
pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
self.send_sql_batch(&sql).await?;
self.read_execute_result().await?;
Ok(())
}
pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
tracing::debug!(name = savepoint.name(), "releasing savepoint");
drop(savepoint);
Ok(())
}
#[must_use]
pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
let connection = self
.connection
.as_ref()
.expect("connection should be present");
match connection {
#[cfg(feature = "tls")]
ConnectionHandle::Tls(conn) => {
crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
}
#[cfg(feature = "tls")]
ConnectionHandle::TlsPrelogin(conn) => {
crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
}
ConnectionHandle::Plain(conn) => {
crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
}
}
}
}
impl<S: ConnectionState> std::fmt::Debug for Client<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("database", &self.config.database)
.finish()
}
}