mssql_client/client/mod.rs
1//! SQL Server client implementation.
2
3// Allow unwrap/expect for chrono date construction with known-valid constant dates
4// and for regex patterns that are compile-time constants
5#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7mod connect;
8mod params;
9mod response;
10
11use std::marker::PhantomData;
12
13use mssql_codec::connection::Connection;
14#[cfg(feature = "tls")]
15use mssql_tls::TlsStream;
16use tds_protocol::packet::PacketType;
17use tds_protocol::rpc::RpcRequest;
18use tds_protocol::token::{EnvChange, EnvChangeType};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21
22use crate::config::Config;
23use crate::error::{Error, Result};
24#[cfg(feature = "otel")]
25use crate::instrumentation::InstrumentationContext;
26use crate::state::{ConnectionState, InTransaction, Ready};
27use crate::statement_cache::StatementCache;
28use crate::stream::{MultiResultStream, QueryStream};
29use crate::transaction::SavePoint;
30
31/// SQL Server client with type-state connection management.
32///
33/// The generic parameter `S` represents the current connection state,
34/// ensuring at compile time that certain operations are only available
35/// in appropriate states.
36pub struct Client<S: ConnectionState> {
37 config: Config,
38 _state: PhantomData<S>,
39 /// The underlying connection (present only when connected)
40 connection: Option<ConnectionHandle>,
41 /// Server version from LoginAck (raw u32 TDS version)
42 server_version: Option<u32>,
43 /// Current database from EnvChange
44 current_database: Option<String>,
45 /// Server's default collation from SqlCollation EnvChange during login.
46 /// Used when `SendStringParametersAsUnicode=false` to encode VARCHAR
47 /// parameters with the correct character encoding and collation bytes.
48 server_collation: Option<tds_protocol::token::Collation>,
49 /// Prepared statement cache for query optimization
50 statement_cache: StatementCache,
51 /// Transaction descriptor from BeginTransaction EnvChange.
52 /// Per MS-TDS spec, this value must be included in ALL_HEADERS for subsequent
53 /// requests within an explicit transaction. 0 indicates auto-commit mode.
54 transaction_descriptor: u64,
55 /// Whether a request has been sent and the response has not yet been fully read.
56 /// Used by the connection pool to detect dirty connections after cancel/timeout.
57 in_flight: bool,
58 /// Whether this connection needs a reset on next use.
59 /// Set by connection pool on checkin, cleared after first query/execute.
60 /// When true, the RESETCONNECTION flag is set on the first TDS packet.
61 needs_reset: bool,
62 /// OpenTelemetry instrumentation context (when otel feature is enabled)
63 #[cfg(feature = "otel")]
64 instrumentation: InstrumentationContext,
65 /// Always Encrypted context for column decryption (when always-encrypted feature is enabled)
66 #[cfg(feature = "always-encrypted")]
67 pub(crate) encryption_context: Option<std::sync::Arc<crate::encryption::EncryptionContext>>,
68}
69
70/// Internal connection handle wrapping the actual connection.
71///
72/// This is an enum to support different connection types:
73/// - TLS (TDS 8.0 strict mode) - requires `tls` feature
74/// - TLS with PreLogin wrapping (TDS 7.x style) - requires `tls` feature
75/// - Plain TCP (for internal networks or when `tls` feature is disabled)
76#[allow(dead_code)] // Connection will be used once query execution is implemented
77enum ConnectionHandle {
78 /// TLS connection (TDS 8.0 strict mode - TLS before any TDS traffic)
79 #[cfg(feature = "tls")]
80 Tls(Connection<TlsStream<TcpStream>>),
81 /// TLS connection with PreLogin wrapping (TDS 7.x style)
82 #[cfg(feature = "tls")]
83 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
84 /// Plain TCP connection (for internal networks or when `tls` feature is disabled)
85 Plain(Connection<TcpStream>),
86}
87
88// Private helper methods available to all connection states
89impl<S: ConnectionState> Client<S> {
90 /// Process transaction-related EnvChange tokens.
91 ///
92 /// This handles BeginTransaction, CommitTransaction, and RollbackTransaction
93 /// EnvChange tokens, updating the transaction descriptor accordingly.
94 ///
95 /// This enables executing BEGIN TRANSACTION, COMMIT, and ROLLBACK via raw SQL
96 /// while still having the transaction descriptor tracked correctly.
97 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
98 use tds_protocol::token::EnvChangeValue;
99
100 match env.env_type {
101 EnvChangeType::BeginTransaction => {
102 if let EnvChangeValue::Binary(ref data) = env.new_value {
103 if data.len() >= 8 {
104 let descriptor = u64::from_le_bytes([
105 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
106 ]);
107 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
108 *transaction_descriptor = descriptor;
109 }
110 }
111 }
112 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
113 tracing::debug!(
114 env_type = ?env.env_type,
115 "transaction ended via raw SQL"
116 );
117 *transaction_descriptor = 0;
118 }
119 _ => {}
120 }
121 }
122
123 /// Send a SQL batch to the server.
124 ///
125 /// Uses the client's current transaction descriptor in ALL_HEADERS.
126 /// Per MS-TDS spec, when in an explicit transaction, the descriptor
127 /// returned by BeginTransaction must be included.
128 ///
129 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
130 /// is included in the first packet to reset connection state.
131 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
132 let payload =
133 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
134 let max_packet = self.config.packet_size as usize;
135
136 // Check if we need to reset the connection on this request
137 let reset = self.needs_reset;
138 if reset {
139 self.needs_reset = false; // Clear flag before sending
140 tracing::debug!("sending SQL batch with RESETCONNECTION flag");
141 }
142
143 self.in_flight = true;
144 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
145
146 match connection {
147 #[cfg(feature = "tls")]
148 ConnectionHandle::Tls(conn) => {
149 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
150 .await?;
151 }
152 #[cfg(feature = "tls")]
153 ConnectionHandle::TlsPrelogin(conn) => {
154 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
155 .await?;
156 }
157 ConnectionHandle::Plain(conn) => {
158 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
159 .await?;
160 }
161 }
162
163 Ok(())
164 }
165
166 /// Send an RPC request to the server.
167 ///
168 /// Uses the client's current transaction descriptor in ALL_HEADERS.
169 ///
170 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
171 /// is included in the first packet to reset connection state.
172 pub(crate) async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
173 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
174 let max_packet = self.config.packet_size as usize;
175
176 // Check if we need to reset the connection on this request
177 let reset = self.needs_reset;
178 if reset {
179 self.needs_reset = false; // Clear flag before sending
180 tracing::debug!("sending RPC with RESETCONNECTION flag");
181 }
182
183 self.in_flight = true;
184 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
185
186 match connection {
187 #[cfg(feature = "tls")]
188 ConnectionHandle::Tls(conn) => {
189 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
190 .await?;
191 }
192 #[cfg(feature = "tls")]
193 ConnectionHandle::TlsPrelogin(conn) => {
194 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
195 .await?;
196 }
197 ConnectionHandle::Plain(conn) => {
198 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
199 .await?;
200 }
201 }
202
203 Ok(())
204 }
205
206 /// Start building a stored procedure call with full control over parameters.
207 ///
208 /// Returns a [`crate::procedure::ProcedureBuilder`] that allows adding named input and output
209 /// parameters before executing the call.
210 ///
211 /// The procedure name is validated to prevent SQL injection. It may be
212 /// schema-qualified (e.g., `"dbo.MyProc"`).
213 ///
214 /// # Example
215 ///
216 /// ```rust,ignore
217 /// let result = client.procedure("dbo.CalculateSum")?
218 /// .input("@a", &10i32)
219 /// .input("@b", &20i32)
220 /// .output_int("@result")
221 /// .execute().await?;
222 ///
223 /// let sum = result.get_output("@result").unwrap();
224 /// ```
225 pub fn procedure(
226 &mut self,
227 proc_name: &str,
228 ) -> Result<crate::procedure::ProcedureBuilder<'_, S>> {
229 crate::validation::validate_qualified_identifier(proc_name)?;
230 Ok(crate::procedure::ProcedureBuilder::new(self, proc_name))
231 }
232
233 /// Execute a stored procedure with positional input parameters.
234 ///
235 /// This is a convenience method for the common case of calling a procedure
236 /// with input-only parameters. For output parameters or named parameters,
237 /// use [`procedure()`](Client::procedure) instead.
238 ///
239 /// # Example
240 ///
241 /// ```rust,ignore
242 /// let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
243 /// assert_eq!(result.return_value, 0);
244 ///
245 /// if let Some(rs) = result.first_result_set() {
246 /// println!("columns: {:?}", rs.columns());
247 /// }
248 /// ```
249 pub async fn call_procedure(
250 &mut self,
251 proc_name: &str,
252 params: &[&(dyn crate::ToSql + Sync)],
253 ) -> Result<crate::stream::ProcedureResult> {
254 crate::validation::validate_qualified_identifier(proc_name)?;
255
256 tracing::debug!(
257 proc_name = proc_name,
258 params_count = params.len(),
259 "executing stored procedure"
260 );
261
262 let rpc_params =
263 Self::convert_params_positional(params, self.send_unicode(), self.server_collation())?;
264 let mut rpc = RpcRequest::named(proc_name);
265 for param in rpc_params {
266 rpc = rpc.param(param);
267 }
268
269 self.send_rpc(&rpc).await?;
270 self.read_procedure_result().await
271 }
272
273 /// Start a bulk insert operation for the specified table.
274 ///
275 /// Sends the `INSERT BULK` statement to the server and returns a
276 /// [`crate::bulk::BulkWriter`] for streaming rows. The writer holds
277 /// a mutable borrow on the client, preventing other operations while
278 /// the bulk insert is in progress.
279 ///
280 /// # Example
281 ///
282 /// ```rust,ignore
283 /// use mssql_client::{BulkInsertBuilder, BulkColumn};
284 ///
285 /// let builder = BulkInsertBuilder::new("dbo.Users")
286 /// .with_typed_columns(vec![
287 /// BulkColumn::new("id", "INT", 0)?,
288 /// BulkColumn::new("name", "NVARCHAR(100)", 1)?,
289 /// ]);
290 ///
291 /// let mut writer = client.bulk_insert(&builder).await?;
292 /// writer.send_row(&[&1i32, &"Alice"])?;
293 /// writer.send_row(&[&2i32, &"Bob"])?;
294 /// let result = writer.finish().await?;
295 /// println!("Inserted {} rows", result.rows_affected);
296 /// ```
297 pub async fn bulk_insert(
298 &mut self,
299 builder: &crate::bulk::BulkInsertBuilder,
300 ) -> Result<crate::bulk::BulkWriter<'_, S>> {
301 use tds_protocol::token::{ColMetaData, Token};
302
303 tracing::debug!(
304 table = builder.table_name(),
305 columns = builder.columns().len(),
306 "starting bulk insert"
307 );
308
309 // Step 1: Query the server for column metadata.
310 // This gives us the exact type encoding the server expects for BulkLoad,
311 // following the pattern established by Tiberius.
312 let meta_query = format!("SELECT TOP 0 * FROM {}", builder.table_name());
313 self.send_sql_batch(&meta_query).await?;
314
315 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
316 let message = match connection {
317 #[cfg(feature = "tls")]
318 ConnectionHandle::Tls(conn) => conn.read_message().await?,
319 #[cfg(feature = "tls")]
320 ConnectionHandle::TlsPrelogin(conn) => conn.read_message().await?,
321 ConnectionHandle::Plain(conn) => conn.read_message().await?,
322 }
323 .ok_or(Error::ConnectionClosed)?;
324 self.in_flight = false;
325
326 // Capture both the raw COLMETADATA bytes and parsed column info
327 let raw_payload = message.payload.clone();
328 let mut parser = self.create_parser(message.payload);
329 let mut server_metadata: Option<ColMetaData> = None;
330 let mut meta_start: usize = 0;
331 let mut meta_end: usize = 0;
332
333 loop {
334 let pos_before = raw_payload.len() - parser.remaining();
335 let token = parser.next_token_with_metadata(server_metadata.as_ref())?;
336 let pos_after = raw_payload.len() - parser.remaining();
337 let Some(token) = token else { break };
338
339 match token {
340 Token::ColMetaData(meta) => {
341 meta_start = pos_before;
342 meta_end = pos_after;
343 server_metadata = Some(meta);
344 }
345 Token::Done(_) => break,
346 _ => {}
347 }
348 }
349
350 // Reject deprecated TEXT/NTEXT/IMAGE columns reported by the server.
351 // These types require a legacy TEXTPTR wire format that this driver
352 // does not support — users should migrate the column to VARCHAR(MAX) /
353 // NVARCHAR(MAX) / VARBINARY(MAX).
354 if let Some(ref meta) = server_metadata {
355 use tds_protocol::types::TypeId;
356 for col in meta.columns.iter() {
357 let (rejected, replacement) = match col.type_id {
358 TypeId::Text => (Some("TEXT"), "VARCHAR(MAX)"),
359 TypeId::NText => (Some("NTEXT"), "NVARCHAR(MAX)"),
360 TypeId::Image => (Some("IMAGE"), "VARBINARY(MAX)"),
361 _ => (None, ""),
362 };
363 if let Some(sql_type) = rejected {
364 return Err(Error::from(mssql_types::TypeError::UnsupportedType {
365 sql_type: sql_type.to_string(),
366 reason: format!(
367 "column `{}` in table `{}` is {} — TEXT/NTEXT/IMAGE \
368 are not supported. Alter the column to {} instead \
369 (Microsoft deprecated TEXT/NTEXT/IMAGE in SQL \
370 Server 2005).",
371 col.name,
372 builder.table_name(),
373 sql_type,
374 replacement,
375 ),
376 }));
377 }
378 }
379 }
380
381 // Step 2: Send INSERT BULK statement to put server in bulk load mode
382 let stmt = builder.build_insert_bulk_statement()?;
383 self.send_sql_batch(&stmt).await?;
384 self.read_execute_result().await?;
385
386 // Step 3: Create bulk writer with server's metadata
387 let raw_meta = if meta_end > meta_start {
388 Some(raw_payload.slice(meta_start..meta_end))
389 } else {
390 None
391 };
392
393 let server_cols = server_metadata.as_ref().map(|m| m.columns.as_slice());
394 let bulk = crate::bulk::BulkInsert::new_with_server_metadata(
395 builder.columns().to_vec(),
396 builder.options().batch_size,
397 raw_meta,
398 server_cols,
399 );
400
401 Ok(crate::bulk::BulkWriter::new(self, bulk))
402 }
403
404 /// Start a bulk insert without querying the server for column metadata.
405 ///
406 /// Unlike [`bulk_insert()`](Self::bulk_insert), this method does not send
407 /// `SELECT TOP 0 * FROM table` to discover column types. Instead, the
408 /// column metadata is constructed from the `BulkColumn` types provided
409 /// on the builder. This saves a round-trip when the schema is known.
410 ///
411 /// # Caveats
412 ///
413 /// The caller must ensure `BulkColumn` entries match the target table's
414 /// column definitions exactly. Mismatched types, lengths, precision/scale,
415 /// or column ordering will cause the server to reject the BulkLoad packet.
416 ///
417 /// For most use cases, prefer [`bulk_insert()`](Self::bulk_insert) — the
418 /// extra round-trip is usually negligible and the server-supplied metadata
419 /// is guaranteed correct.
420 pub async fn bulk_insert_without_schema_discovery(
421 &mut self,
422 builder: &crate::bulk::BulkInsertBuilder,
423 ) -> Result<crate::bulk::BulkWriter<'_, S>> {
424 tracing::debug!(
425 table = builder.table_name(),
426 columns = builder.columns().len(),
427 "starting bulk insert (no schema discovery)"
428 );
429
430 // Send INSERT BULK statement to put server in bulk load mode
431 let stmt = builder.build_insert_bulk_statement()?;
432 self.send_sql_batch(&stmt).await?;
433 self.read_execute_result().await?;
434
435 // Create bulk writer with hand-crafted metadata
436 let bulk =
437 crate::bulk::BulkInsert::new(builder.columns().to_vec(), builder.options().batch_size);
438
439 Ok(crate::bulk::BulkWriter::new(self, bulk))
440 }
441
442 /// Send bulk load data as a BulkLoad (0x07) message and read the server response.
443 ///
444 /// Used internally by [`crate::bulk::BulkWriter::finish()`] to transmit accumulated
445 /// row data after the `INSERT BULK` statement has been acknowledged.
446 pub(crate) async fn send_and_read_bulk_load(&mut self, payload: bytes::Bytes) -> Result<u64> {
447 let max_packet = self.config.packet_size as usize;
448
449 self.in_flight = true;
450 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
451
452 match connection {
453 #[cfg(feature = "tls")]
454 ConnectionHandle::Tls(conn) => {
455 conn.send_message(PacketType::BulkLoad, payload, max_packet)
456 .await?;
457 }
458 #[cfg(feature = "tls")]
459 ConnectionHandle::TlsPrelogin(conn) => {
460 conn.send_message(PacketType::BulkLoad, payload, max_packet)
461 .await?;
462 }
463 ConnectionHandle::Plain(conn) => {
464 conn.send_message(PacketType::BulkLoad, payload, max_packet)
465 .await?;
466 }
467 }
468
469 // Read the server's Done response with row count
470 self.read_execute_result().await
471 }
472
473 /// Execute a query with named parameters and return a streaming result set.
474 ///
475 /// This method accepts [`NamedParam`](crate::to_params::NamedParam) values,
476 /// making it compatible with the [`ToParams`](crate::to_params::ToParams) trait
477 /// and the `#[derive(ToParams)]` macro.
478 ///
479 /// # Example
480 ///
481 /// ```rust,ignore
482 /// use mssql_client::{NamedParam, ToParams};
483 ///
484 /// // With derive macro:
485 /// #[derive(ToParams)]
486 /// struct UserQuery { name: String }
487 ///
488 /// let q = UserQuery { name: "Alice".into() };
489 /// let rows = client.query_named(
490 /// "SELECT * FROM users WHERE name = @name",
491 /// &q.to_params()?,
492 /// ).await?;
493 ///
494 /// // Or manually:
495 /// let params = vec![NamedParam::from_value("name", &"Alice")?];
496 /// let rows = client.query_named(
497 /// "SELECT * FROM users WHERE name = @name",
498 /// ¶ms,
499 /// ).await?;
500 /// ```
501 pub async fn query_named<'a>(
502 &'a mut self,
503 sql: &str,
504 params: &[crate::to_params::NamedParam],
505 ) -> Result<QueryStream<'a>> {
506 tracing::debug!(
507 sql = sql,
508 params_count = params.len(),
509 "executing query with named parameters"
510 );
511
512 if params.is_empty() {
513 self.send_sql_batch(sql).await?;
514 } else {
515 let rpc_params =
516 Self::convert_named_params(params, self.send_unicode(), self.server_collation())?;
517 let rpc = RpcRequest::execute_sql(sql, rpc_params);
518 self.send_rpc(&rpc).await?;
519 }
520
521 let resp = self.read_query_response().await?;
522 #[cfg(feature = "always-encrypted")]
523 {
524 Ok(QueryStream::from_raw(
525 resp.columns,
526 resp.pending_rows,
527 resp.meta,
528 resp.decryptor,
529 ))
530 }
531 #[cfg(not(feature = "always-encrypted"))]
532 {
533 Ok(QueryStream::from_raw(
534 resp.columns,
535 resp.pending_rows,
536 resp.meta,
537 ))
538 }
539 }
540
541 /// Execute a statement with named parameters.
542 ///
543 /// Returns the number of affected rows. This is the named-parameter
544 /// counterpart of [`execute()`](Client::execute), compatible with the
545 /// [`ToParams`](crate::to_params::ToParams) trait.
546 ///
547 /// # Example
548 ///
549 /// ```rust,ignore
550 /// use mssql_client::NamedParam;
551 ///
552 /// let params = vec![
553 /// NamedParam::from_value("name", &"Alice")?,
554 /// NamedParam::from_value("email", &"alice@example.com")?,
555 /// ];
556 /// let rows_affected = client.execute_named(
557 /// "INSERT INTO users (name, email) VALUES (@name, @email)",
558 /// ¶ms,
559 /// ).await?;
560 /// ```
561 pub async fn execute_named(
562 &mut self,
563 sql: &str,
564 params: &[crate::to_params::NamedParam],
565 ) -> Result<u64> {
566 tracing::debug!(
567 sql = sql,
568 params_count = params.len(),
569 "executing statement with named parameters"
570 );
571
572 if params.is_empty() {
573 self.send_sql_batch(sql).await?;
574 } else {
575 let rpc_params =
576 Self::convert_named_params(params, self.send_unicode(), self.server_collation())?;
577 let rpc = RpcRequest::execute_sql(sql, rpc_params);
578 self.send_rpc(&rpc).await?;
579 }
580
581 self.read_execute_result().await
582 }
583
584 /// Whether string parameters are sent as NVARCHAR (Unicode).
585 pub(crate) fn send_unicode(&self) -> bool {
586 self.config.send_string_parameters_as_unicode
587 }
588
589 /// Server's default collation, captured from ENVCHANGE during login.
590 pub(crate) fn server_collation(&self) -> Option<&tds_protocol::token::Collation> {
591 self.server_collation.as_ref()
592 }
593}
594
595impl Client<Ready> {
596 /// Mark this connection as needing a reset on next use.
597 ///
598 /// Called by the connection pool when a connection is returned.
599 /// The next SQL batch or RPC will include the RESETCONNECTION flag
600 /// in the TDS packet header, causing SQL Server to reset connection
601 /// state (temp tables, SET options, transaction isolation level, etc.)
602 /// before executing the command.
603 ///
604 /// This is more efficient than calling `sp_reset_connection` as a
605 /// separate command because it's handled at the TDS protocol level.
606 pub fn mark_needs_reset(&mut self) {
607 self.needs_reset = true;
608 }
609
610 /// Check if this connection needs a reset.
611 ///
612 /// Returns true if `mark_needs_reset()` was called and the reset
613 /// hasn't been performed yet.
614 #[must_use]
615 pub fn needs_reset(&self) -> bool {
616 self.needs_reset
617 }
618
619 /// Execute a query and return a streaming result set.
620 ///
621 /// Per ADR-007, results are streamed by default for memory efficiency.
622 /// Use `.collect_all()` on the stream if you need all rows in memory.
623 ///
624 /// # Example
625 ///
626 /// ```rust,ignore
627 /// use futures::StreamExt;
628 ///
629 /// // Streaming (memory-efficient)
630 /// let mut stream = client.query("SELECT * FROM users WHERE id = @p1", &[&1]).await?;
631 /// while let Some(row) = stream.next().await {
632 /// let row = row?;
633 /// process(&row);
634 /// }
635 ///
636 /// // Buffered (loads all into memory)
637 /// let rows: Vec<Row> = client
638 /// .query("SELECT * FROM small_table", &[])
639 /// .await?
640 /// .collect_all()
641 /// .await?;
642 /// ```
643 pub async fn query<'a>(
644 &'a mut self,
645 sql: &str,
646 params: &[&(dyn crate::ToSql + Sync)],
647 ) -> Result<QueryStream<'a>> {
648 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
649
650 #[cfg(feature = "otel")]
651 let instrumentation = self.instrumentation.clone();
652 #[cfg(feature = "otel")]
653 let mut span = instrumentation.query_span(sql);
654
655 let result = async {
656 if params.is_empty() {
657 // Simple query without parameters - use SQL batch
658 self.send_sql_batch(sql).await?;
659 } else {
660 // Parameterized query - use sp_executesql via RPC
661 let rpc_params =
662 Self::convert_params(params, self.send_unicode(), self.server_collation())?;
663 let rpc = RpcRequest::execute_sql(sql, rpc_params);
664 self.send_rpc(&rpc).await?;
665 }
666
667 // Read complete response including columns and rows
668 self.read_query_response().await
669 }
670 .await;
671
672 #[cfg(feature = "otel")]
673 match &result {
674 Ok(_) => InstrumentationContext::record_success(&mut span, None),
675 Err(e) => InstrumentationContext::record_error(&mut span, e),
676 }
677
678 // Drop the span before returning
679 #[cfg(feature = "otel")]
680 drop(span);
681
682 let resp = result?;
683 #[cfg(feature = "always-encrypted")]
684 {
685 Ok(QueryStream::from_raw(
686 resp.columns,
687 resp.pending_rows,
688 resp.meta,
689 resp.decryptor,
690 ))
691 }
692 #[cfg(not(feature = "always-encrypted"))]
693 {
694 Ok(QueryStream::from_raw(
695 resp.columns,
696 resp.pending_rows,
697 resp.meta,
698 ))
699 }
700 }
701
702 /// Execute a query with a specific timeout.
703 ///
704 /// This overrides the default `command_timeout` from the connection configuration
705 /// for this specific query. If the query does not complete within the specified
706 /// duration, an error is returned.
707 ///
708 /// # Arguments
709 ///
710 /// * `sql` - The SQL query to execute
711 /// * `params` - Query parameters
712 /// * `timeout_duration` - Maximum time to wait for the query to complete
713 ///
714 /// # Example
715 ///
716 /// ```rust,ignore
717 /// use std::time::Duration;
718 ///
719 /// // Execute with a 5-second timeout
720 /// let rows = client
721 /// .query_with_timeout(
722 /// "SELECT * FROM large_table",
723 /// &[],
724 /// Duration::from_secs(5),
725 /// )
726 /// .await?;
727 /// ```
728 pub async fn query_with_timeout<'a>(
729 &'a mut self,
730 sql: &str,
731 params: &[&(dyn crate::ToSql + Sync)],
732 timeout_duration: std::time::Duration,
733 ) -> Result<QueryStream<'a>> {
734 timeout(timeout_duration, self.query(sql, params))
735 .await
736 .map_err(|_| Error::CommandTimeout)?
737 }
738
739 /// Execute a batch that may return multiple result sets.
740 ///
741 /// This is useful for stored procedures or SQL batches that contain
742 /// multiple SELECT statements.
743 ///
744 /// # Example
745 ///
746 /// ```rust,ignore
747 /// // Execute a batch with multiple SELECTs
748 /// let mut results = client.query_multiple(
749 /// "SELECT 1 AS a; SELECT 2 AS b, 3 AS c;",
750 /// &[]
751 /// ).await?;
752 ///
753 /// // Process first result set
754 /// while let Some(row) = results.next_row().await? {
755 /// println!("Result 1: {:?}", row);
756 /// }
757 ///
758 /// // Move to second result set
759 /// if results.next_result().await? {
760 /// while let Some(row) = results.next_row().await? {
761 /// println!("Result 2: {:?}", row);
762 /// }
763 /// }
764 /// ```
765 pub async fn query_multiple<'a>(
766 &'a mut self,
767 sql: &str,
768 params: &[&(dyn crate::ToSql + Sync)],
769 ) -> Result<MultiResultStream<'a>> {
770 tracing::debug!(
771 sql = sql,
772 params_count = params.len(),
773 "executing multi-result query"
774 );
775
776 if params.is_empty() {
777 // Simple batch without parameters - use SQL batch
778 self.send_sql_batch(sql).await?;
779 } else {
780 // Parameterized query - use sp_executesql via RPC
781 let rpc_params =
782 Self::convert_params(params, self.send_unicode(), self.server_collation())?;
783 let rpc = RpcRequest::execute_sql(sql, rpc_params);
784 self.send_rpc(&rpc).await?;
785 }
786
787 // Read all result sets
788 let result_sets = self.read_multi_result_response().await?;
789 Ok(MultiResultStream::new(result_sets))
790 }
791
792 /// Execute a query that doesn't return rows.
793 ///
794 /// Returns the number of affected rows.
795 pub async fn execute(
796 &mut self,
797 sql: &str,
798 params: &[&(dyn crate::ToSql + Sync)],
799 ) -> Result<u64> {
800 tracing::debug!(
801 sql = sql,
802 params_count = params.len(),
803 "executing statement"
804 );
805
806 #[cfg(feature = "otel")]
807 let instrumentation = self.instrumentation.clone();
808 #[cfg(feature = "otel")]
809 let mut span = instrumentation.query_span(sql);
810
811 let result = async {
812 if params.is_empty() {
813 // Simple statement without parameters - use SQL batch
814 self.send_sql_batch(sql).await?;
815 } else {
816 // Parameterized statement - use sp_executesql via RPC
817 let rpc_params =
818 Self::convert_params(params, self.send_unicode(), self.server_collation())?;
819 let rpc = RpcRequest::execute_sql(sql, rpc_params);
820 self.send_rpc(&rpc).await?;
821 }
822
823 // Read response and get row count
824 self.read_execute_result().await
825 }
826 .await;
827
828 #[cfg(feature = "otel")]
829 match &result {
830 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
831 Err(e) => InstrumentationContext::record_error(&mut span, e),
832 }
833
834 // Drop the span before returning
835 #[cfg(feature = "otel")]
836 drop(span);
837
838 result
839 }
840
841 /// Execute a statement with a specific timeout.
842 ///
843 /// This overrides the default `command_timeout` from the connection configuration
844 /// for this specific statement. If the statement does not complete within the
845 /// specified duration, an error is returned.
846 ///
847 /// # Arguments
848 ///
849 /// * `sql` - The SQL statement to execute
850 /// * `params` - Statement parameters
851 /// * `timeout_duration` - Maximum time to wait for the statement to complete
852 ///
853 /// # Example
854 ///
855 /// ```rust,ignore
856 /// use std::time::Duration;
857 ///
858 /// // Execute with a 10-second timeout
859 /// let rows_affected = client
860 /// .execute_with_timeout(
861 /// "UPDATE large_table SET status = @p1",
862 /// &[&"processed"],
863 /// Duration::from_secs(10),
864 /// )
865 /// .await?;
866 /// ```
867 pub async fn execute_with_timeout(
868 &mut self,
869 sql: &str,
870 params: &[&(dyn crate::ToSql + Sync)],
871 timeout_duration: std::time::Duration,
872 ) -> Result<u64> {
873 timeout(timeout_duration, self.execute(sql, params))
874 .await
875 .map_err(|_| Error::CommandTimeout)?
876 }
877
878 /// Begin a transaction.
879 ///
880 /// This transitions the client from `Ready` to `InTransaction` state.
881 /// Per MS-TDS spec, the server returns a transaction descriptor in the
882 /// BeginTransaction EnvChange token that must be included in subsequent
883 /// ALL_HEADERS sections.
884 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
885 tracing::debug!("beginning transaction");
886
887 #[cfg(feature = "otel")]
888 let instrumentation = self.instrumentation.clone();
889 #[cfg(feature = "otel")]
890 let mut span = instrumentation.transaction_span("BEGIN");
891
892 // Execute BEGIN TRANSACTION and extract the transaction descriptor
893 let result = async {
894 self.send_sql_batch("BEGIN TRANSACTION").await?;
895 self.read_transaction_begin_result().await
896 }
897 .await;
898
899 #[cfg(feature = "otel")]
900 match &result {
901 Ok(_) => InstrumentationContext::record_success(&mut span, None),
902 Err(e) => InstrumentationContext::record_error(&mut span, e),
903 }
904
905 // Drop the span before moving instrumentation
906 #[cfg(feature = "otel")]
907 drop(span);
908
909 let transaction_descriptor = result?;
910
911 Ok(Client {
912 config: self.config,
913 _state: PhantomData,
914 connection: self.connection,
915 server_version: self.server_version,
916 current_database: self.current_database,
917 server_collation: self.server_collation,
918 statement_cache: self.statement_cache,
919 transaction_descriptor, // Store the descriptor from server
920 needs_reset: self.needs_reset,
921 in_flight: self.in_flight,
922 #[cfg(feature = "otel")]
923 instrumentation: self.instrumentation,
924 #[cfg(feature = "always-encrypted")]
925 encryption_context: self.encryption_context,
926 })
927 }
928
929 /// Begin a transaction with a specific isolation level.
930 ///
931 /// This transitions the client from `Ready` to `InTransaction` state
932 /// with the specified isolation level.
933 ///
934 /// # Example
935 ///
936 /// ```rust,ignore
937 /// use mssql_client::IsolationLevel;
938 ///
939 /// let tx = client.begin_transaction_with_isolation(IsolationLevel::Serializable).await?;
940 /// // All operations in this transaction use SERIALIZABLE isolation
941 /// tx.commit().await?;
942 /// ```
943 pub async fn begin_transaction_with_isolation(
944 mut self,
945 isolation_level: crate::transaction::IsolationLevel,
946 ) -> Result<Client<InTransaction>> {
947 tracing::debug!(
948 isolation_level = %isolation_level.name(),
949 "beginning transaction with isolation level"
950 );
951
952 #[cfg(feature = "otel")]
953 let instrumentation = self.instrumentation.clone();
954 #[cfg(feature = "otel")]
955 let mut span = instrumentation.transaction_span("BEGIN");
956
957 // First set the isolation level
958 let result = async {
959 self.send_sql_batch(isolation_level.as_sql()).await?;
960 self.read_execute_result().await?;
961
962 // Then begin the transaction
963 self.send_sql_batch("BEGIN TRANSACTION").await?;
964 self.read_transaction_begin_result().await
965 }
966 .await;
967
968 #[cfg(feature = "otel")]
969 match &result {
970 Ok(_) => InstrumentationContext::record_success(&mut span, None),
971 Err(e) => InstrumentationContext::record_error(&mut span, e),
972 }
973
974 #[cfg(feature = "otel")]
975 drop(span);
976
977 let transaction_descriptor = result?;
978
979 Ok(Client {
980 config: self.config,
981 _state: PhantomData,
982 connection: self.connection,
983 server_version: self.server_version,
984 current_database: self.current_database,
985 server_collation: self.server_collation,
986 statement_cache: self.statement_cache,
987 transaction_descriptor,
988 needs_reset: self.needs_reset,
989 in_flight: self.in_flight,
990 #[cfg(feature = "otel")]
991 instrumentation: self.instrumentation,
992 #[cfg(feature = "always-encrypted")]
993 encryption_context: self.encryption_context,
994 })
995 }
996
997 /// Execute a simple query without parameters.
998 ///
999 /// This is useful for DDL statements and simple queries where you
1000 /// don't need to retrieve the affected row count.
1001 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
1002 tracing::debug!(sql = sql, "executing simple query");
1003
1004 // Send SQL batch
1005 self.send_sql_batch(sql).await?;
1006
1007 // Read and discard response
1008 let _ = self.read_execute_result().await?;
1009
1010 Ok(())
1011 }
1012
1013 /// Close the connection gracefully.
1014 pub async fn close(self) -> Result<()> {
1015 tracing::debug!("closing connection");
1016 Ok(())
1017 }
1018
1019 /// Get the current database name.
1020 #[must_use]
1021 pub fn database(&self) -> Option<&str> {
1022 self.config.database.as_deref()
1023 }
1024
1025 /// Get the server host.
1026 #[must_use]
1027 pub fn host(&self) -> &str {
1028 &self.config.host
1029 }
1030
1031 /// Get the server port.
1032 #[must_use]
1033 pub fn port(&self) -> u16 {
1034 self.config.port
1035 }
1036
1037 /// Check if the connection is currently in a transaction.
1038 ///
1039 /// This returns `true` if a transaction was started via raw SQL
1040 /// (`BEGIN TRANSACTION`) and has not yet been committed or rolled back.
1041 ///
1042 /// Note: This only tracks transactions started via raw SQL. Transactions
1043 /// started via the type-state API (`begin_transaction()`) result in a
1044 /// `Client<InTransaction>` which is a different type.
1045 ///
1046 /// # Example
1047 ///
1048 /// ```rust,ignore
1049 /// client.execute("BEGIN TRANSACTION", &[]).await?;
1050 /// assert!(client.is_in_transaction());
1051 ///
1052 /// client.execute("COMMIT", &[]).await?;
1053 /// assert!(!client.is_in_transaction());
1054 /// ```
1055 #[must_use]
1056 pub fn is_in_transaction(&self) -> bool {
1057 self.transaction_descriptor != 0
1058 }
1059
1060 /// Check if a request is in-flight (sent but response not fully read).
1061 ///
1062 /// Used by the connection pool to detect dirty connections that were
1063 /// interrupted mid-query (e.g., by `tokio::select!` or a timeout).
1064 /// A connection with an in-flight request has unread data in the TCP
1065 /// buffer and must be discarded rather than returned to the pool.
1066 #[must_use]
1067 pub fn is_in_flight(&self) -> bool {
1068 self.in_flight
1069 }
1070
1071 /// Report whether an Always Encrypted key-store provider with the given
1072 /// name is currently reachable through this client's encryption context.
1073 ///
1074 /// Returns `false` when the `always-encrypted` feature isn't enabled, when
1075 /// the connection was opened without `column_encryption` configured, or
1076 /// when no matching provider was registered.
1077 #[cfg(feature = "always-encrypted")]
1078 #[must_use]
1079 pub fn has_encryption_provider(&self, name: &str) -> bool {
1080 self.encryption_context
1081 .as_ref()
1082 .is_some_and(|ctx| ctx.has_provider(name))
1083 }
1084
1085 /// Get a handle for cancelling the current query.
1086 ///
1087 /// The cancel handle can be cloned and sent to other tasks, enabling
1088 /// cancellation of long-running queries from a separate async context.
1089 ///
1090 /// # Example
1091 ///
1092 /// ```rust,ignore
1093 /// use std::time::Duration;
1094 ///
1095 /// let cancel_handle = client.cancel_handle();
1096 ///
1097 /// // Spawn a task to cancel after 10 seconds
1098 /// let handle = tokio::spawn(async move {
1099 /// tokio::time::sleep(Duration::from_secs(10)).await;
1100 /// let _ = cancel_handle.cancel().await;
1101 /// });
1102 ///
1103 /// // This query will be cancelled if it runs longer than 10 seconds
1104 /// let result = client.query("SELECT * FROM very_large_table", &[]).await;
1105 /// ```
1106 #[must_use]
1107 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
1108 let connection = self
1109 .connection
1110 .as_ref()
1111 .expect("connection should be present");
1112 match connection {
1113 #[cfg(feature = "tls")]
1114 ConnectionHandle::Tls(conn) => {
1115 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1116 }
1117 #[cfg(feature = "tls")]
1118 ConnectionHandle::TlsPrelogin(conn) => {
1119 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1120 }
1121 ConnectionHandle::Plain(conn) => {
1122 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1123 }
1124 }
1125 }
1126}
1127
1128/// # Drop Behavior
1129///
1130/// **`Client<InTransaction>` has no automatic rollback on drop.** If the client is
1131/// dropped without calling [`commit()`](Client::commit) or [`rollback()`](Client::rollback),
1132/// the transaction remains open on the server until the TCP connection closes
1133/// (at which point SQL Server automatically rolls back).
1134///
1135/// This is because `Drop` is synchronous and cannot perform the async I/O needed
1136/// to send a `ROLLBACK TRANSACTION` command.
1137///
1138/// ## Consequences of dropping without commit/rollback
1139///
1140/// - **Direct connections:** The transaction leaks until the OS TCP timeout
1141/// (potentially 30+ minutes), holding locks on any modified rows.
1142/// - **Pooled connections:** The pool detects the active transaction descriptor
1143/// and discards the connection rather than returning it to the idle pool
1144/// (see `PooledConnection::drop` in `mssql-driver-pool`).
1145///
1146/// ## Best practice
1147///
1148/// Always ensure `commit()` or `rollback()` is called. Use helper patterns
1149/// for error paths:
1150///
1151/// ```rust,ignore
1152/// let tx = client.begin_transaction().await?;
1153/// match do_work(&tx).await {
1154/// Ok(_) => { tx.commit().await?; }
1155/// Err(e) => { tx.rollback().await?; return Err(e); }
1156/// }
1157/// ```
1158impl Client<InTransaction> {
1159 /// Execute a query within the transaction and return a streaming result set.
1160 ///
1161 /// See [`Client<Ready>::query`] for usage examples.
1162 pub async fn query<'a>(
1163 &'a mut self,
1164 sql: &str,
1165 params: &[&(dyn crate::ToSql + Sync)],
1166 ) -> Result<QueryStream<'a>> {
1167 tracing::debug!(
1168 sql = sql,
1169 params_count = params.len(),
1170 "executing query in transaction"
1171 );
1172
1173 #[cfg(feature = "otel")]
1174 let instrumentation = self.instrumentation.clone();
1175 #[cfg(feature = "otel")]
1176 let mut span = instrumentation.query_span(sql);
1177
1178 let result = async {
1179 if params.is_empty() {
1180 // Simple query without parameters - use SQL batch
1181 self.send_sql_batch(sql).await?;
1182 } else {
1183 // Parameterized query - use sp_executesql via RPC
1184 let rpc_params =
1185 Self::convert_params(params, self.send_unicode(), self.server_collation())?;
1186 let rpc = RpcRequest::execute_sql(sql, rpc_params);
1187 self.send_rpc(&rpc).await?;
1188 }
1189
1190 // Read complete response including columns and rows
1191 self.read_query_response().await
1192 }
1193 .await;
1194
1195 #[cfg(feature = "otel")]
1196 match &result {
1197 Ok(_) => InstrumentationContext::record_success(&mut span, None),
1198 Err(e) => InstrumentationContext::record_error(&mut span, e),
1199 }
1200
1201 // Drop the span before returning
1202 #[cfg(feature = "otel")]
1203 drop(span);
1204
1205 let resp = result?;
1206 #[cfg(feature = "always-encrypted")]
1207 {
1208 Ok(QueryStream::from_raw(
1209 resp.columns,
1210 resp.pending_rows,
1211 resp.meta,
1212 resp.decryptor,
1213 ))
1214 }
1215 #[cfg(not(feature = "always-encrypted"))]
1216 {
1217 Ok(QueryStream::from_raw(
1218 resp.columns,
1219 resp.pending_rows,
1220 resp.meta,
1221 ))
1222 }
1223 }
1224
1225 /// Execute a statement within the transaction.
1226 ///
1227 /// Returns the number of affected rows.
1228 pub async fn execute(
1229 &mut self,
1230 sql: &str,
1231 params: &[&(dyn crate::ToSql + Sync)],
1232 ) -> Result<u64> {
1233 tracing::debug!(
1234 sql = sql,
1235 params_count = params.len(),
1236 "executing statement in transaction"
1237 );
1238
1239 #[cfg(feature = "otel")]
1240 let instrumentation = self.instrumentation.clone();
1241 #[cfg(feature = "otel")]
1242 let mut span = instrumentation.query_span(sql);
1243
1244 let result = async {
1245 if params.is_empty() {
1246 // Simple statement without parameters - use SQL batch
1247 self.send_sql_batch(sql).await?;
1248 } else {
1249 // Parameterized statement - use sp_executesql via RPC
1250 let rpc_params =
1251 Self::convert_params(params, self.send_unicode(), self.server_collation())?;
1252 let rpc = RpcRequest::execute_sql(sql, rpc_params);
1253 self.send_rpc(&rpc).await?;
1254 }
1255
1256 // Read response and get row count
1257 self.read_execute_result().await
1258 }
1259 .await;
1260
1261 #[cfg(feature = "otel")]
1262 match &result {
1263 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
1264 Err(e) => InstrumentationContext::record_error(&mut span, e),
1265 }
1266
1267 // Drop the span before returning
1268 #[cfg(feature = "otel")]
1269 drop(span);
1270
1271 result
1272 }
1273
1274 /// Execute a query within the transaction with a specific timeout.
1275 ///
1276 /// See [`Client<Ready>::query_with_timeout`] for details.
1277 pub async fn query_with_timeout<'a>(
1278 &'a mut self,
1279 sql: &str,
1280 params: &[&(dyn crate::ToSql + Sync)],
1281 timeout_duration: std::time::Duration,
1282 ) -> Result<QueryStream<'a>> {
1283 timeout(timeout_duration, self.query(sql, params))
1284 .await
1285 .map_err(|_| Error::CommandTimeout)?
1286 }
1287
1288 /// Execute a statement within the transaction with a specific timeout.
1289 ///
1290 /// See [`Client<Ready>::execute_with_timeout`] for details.
1291 pub async fn execute_with_timeout(
1292 &mut self,
1293 sql: &str,
1294 params: &[&(dyn crate::ToSql + Sync)],
1295 timeout_duration: std::time::Duration,
1296 ) -> Result<u64> {
1297 timeout(timeout_duration, self.execute(sql, params))
1298 .await
1299 .map_err(|_| Error::CommandTimeout)?
1300 }
1301
1302 /// Open a FILESTREAM BLOB for async reading and/or writing.
1303 ///
1304 /// This method queries the server for the transaction context, then opens
1305 /// the FILESTREAM handle using the native Win32 `OpenSqlFilestream` API.
1306 ///
1307 /// # Arguments
1308 ///
1309 /// * `path` — The UNC path obtained from the T-SQL `column.PathName()` function.
1310 /// Query this yourself before calling `open_filestream`:
1311 /// ```sql
1312 /// SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1
1313 /// ```
1314 /// * `access` — Read, write, or read/write access mode.
1315 ///
1316 /// # Requirements
1317 ///
1318 /// - SQL Server must have FILESTREAM enabled (`sp_configure 'filestream access level', 2`)
1319 /// - The Microsoft OLE DB Driver for SQL Server must be installed on the client
1320 /// - The `FileStream` must be dropped before calling [`commit`] or [`rollback`]
1321 ///
1322 /// # Example
1323 ///
1324 /// ```rust,ignore
1325 /// use mssql_client::FileStreamAccess;
1326 /// use tokio::io::AsyncReadExt;
1327 ///
1328 /// let mut tx = client.begin_transaction().await?;
1329 ///
1330 /// // Get the FILESTREAM path
1331 /// let rows = tx.query(
1332 /// "SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1",
1333 /// &[&doc_id],
1334 /// ).await?;
1335 /// let path: String = rows.into_iter().next().unwrap()?.get(0)?;
1336 ///
1337 /// // Open and read the BLOB
1338 /// let mut stream = tx.open_filestream(&path, FileStreamAccess::Read).await?;
1339 /// let mut data = Vec::new();
1340 /// stream.read_to_end(&mut data).await?;
1341 /// drop(stream);
1342 ///
1343 /// tx.commit().await?;
1344 /// ```
1345 #[cfg(all(windows, feature = "filestream"))]
1346 pub async fn open_filestream(
1347 &mut self,
1348 path: &str,
1349 access: crate::filestream::FileStreamAccess,
1350 ) -> Result<crate::filestream::FileStream> {
1351 tracing::debug!(path = path, ?access, "opening FILESTREAM BLOB");
1352
1353 // Get the transaction context from SQL Server.
1354 // This binds the file access to the current SQL transaction.
1355 let txn_context: Vec<u8> = {
1356 let rows = self
1357 .query("SELECT GET_FILESTREAM_TRANSACTION_CONTEXT()", &[])
1358 .await?;
1359 let mut ctx = None;
1360 for result in rows {
1361 let row = result?;
1362 ctx = Some(row.get::<Vec<u8>>(0)?);
1363 }
1364 ctx.ok_or_else(|| {
1365 Error::FileStream("GET_FILESTREAM_TRANSACTION_CONTEXT() returned no rows".into())
1366 })?
1367 };
1368
1369 crate::filestream::FileStream::open(path, access, &txn_context)
1370 }
1371
1372 /// Commit the transaction.
1373 ///
1374 /// This transitions the client back to `Ready` state.
1375 pub async fn commit(mut self) -> Result<Client<Ready>> {
1376 tracing::debug!("committing transaction");
1377
1378 #[cfg(feature = "otel")]
1379 let instrumentation = self.instrumentation.clone();
1380 #[cfg(feature = "otel")]
1381 let mut span = instrumentation.transaction_span("COMMIT");
1382
1383 // Execute COMMIT TRANSACTION
1384 let result = async {
1385 self.send_sql_batch("COMMIT TRANSACTION").await?;
1386 self.read_execute_result().await
1387 }
1388 .await;
1389
1390 #[cfg(feature = "otel")]
1391 match &result {
1392 Ok(_) => InstrumentationContext::record_success(&mut span, None),
1393 Err(e) => InstrumentationContext::record_error(&mut span, e),
1394 }
1395
1396 // Drop the span before moving instrumentation
1397 #[cfg(feature = "otel")]
1398 drop(span);
1399
1400 result?;
1401
1402 Ok(Client {
1403 config: self.config,
1404 _state: PhantomData,
1405 connection: self.connection,
1406 server_version: self.server_version,
1407 current_database: self.current_database,
1408 server_collation: self.server_collation,
1409 statement_cache: self.statement_cache,
1410 transaction_descriptor: 0, // Reset to auto-commit mode
1411 needs_reset: self.needs_reset,
1412 in_flight: self.in_flight,
1413 #[cfg(feature = "otel")]
1414 instrumentation: self.instrumentation,
1415 #[cfg(feature = "always-encrypted")]
1416 encryption_context: self.encryption_context,
1417 })
1418 }
1419
1420 /// Rollback the transaction.
1421 ///
1422 /// This transitions the client back to `Ready` state.
1423 pub async fn rollback(mut self) -> Result<Client<Ready>> {
1424 tracing::debug!("rolling back transaction");
1425
1426 #[cfg(feature = "otel")]
1427 let instrumentation = self.instrumentation.clone();
1428 #[cfg(feature = "otel")]
1429 let mut span = instrumentation.transaction_span("ROLLBACK");
1430
1431 // Execute ROLLBACK TRANSACTION
1432 let result = async {
1433 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
1434 self.read_execute_result().await
1435 }
1436 .await;
1437
1438 #[cfg(feature = "otel")]
1439 match &result {
1440 Ok(_) => InstrumentationContext::record_success(&mut span, None),
1441 Err(e) => InstrumentationContext::record_error(&mut span, e),
1442 }
1443
1444 // Drop the span before moving instrumentation
1445 #[cfg(feature = "otel")]
1446 drop(span);
1447
1448 result?;
1449
1450 Ok(Client {
1451 config: self.config,
1452 _state: PhantomData,
1453 connection: self.connection,
1454 server_version: self.server_version,
1455 current_database: self.current_database,
1456 server_collation: self.server_collation,
1457 statement_cache: self.statement_cache,
1458 transaction_descriptor: 0, // Reset to auto-commit mode
1459 needs_reset: self.needs_reset,
1460 in_flight: self.in_flight,
1461 #[cfg(feature = "otel")]
1462 instrumentation: self.instrumentation,
1463 #[cfg(feature = "always-encrypted")]
1464 encryption_context: self.encryption_context,
1465 })
1466 }
1467
1468 /// Create a savepoint and return a handle for later rollback.
1469 ///
1470 /// The returned `SavePoint` handle contains the validated savepoint name.
1471 /// Use it with `rollback_to()` to partially undo transaction work.
1472 ///
1473 /// # Example
1474 ///
1475 /// ```rust,ignore
1476 /// let tx = client.begin_transaction().await?;
1477 /// tx.execute("INSERT INTO orders ...").await?;
1478 /// let sp = tx.save_point("before_items").await?;
1479 /// tx.execute("INSERT INTO items ...").await?;
1480 /// // Oops, rollback just the items
1481 /// tx.rollback_to(&sp).await?;
1482 /// tx.commit().await?;
1483 /// ```
1484 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
1485 crate::validation::validate_identifier(name)?;
1486 tracing::debug!(name = name, "creating savepoint");
1487
1488 // Execute SAVE TRANSACTION <name>
1489 // Note: name is validated by validate_identifier() to prevent SQL injection
1490 let sql = format!("SAVE TRANSACTION {name}");
1491 self.send_sql_batch(&sql).await?;
1492 self.read_execute_result().await?;
1493
1494 Ok(SavePoint::new(name.to_string()))
1495 }
1496
1497 /// Rollback to a savepoint.
1498 ///
1499 /// This rolls back all changes made after the savepoint was created,
1500 /// but keeps the transaction active. The savepoint remains valid and
1501 /// can be rolled back to again.
1502 ///
1503 /// # Example
1504 ///
1505 /// ```rust,ignore
1506 /// let sp = tx.save_point("checkpoint").await?;
1507 /// // ... do some work ...
1508 /// tx.rollback_to(&sp).await?; // Undo changes since checkpoint
1509 /// // Transaction is still active, savepoint is still valid
1510 /// ```
1511 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
1512 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
1513
1514 // Execute ROLLBACK TRANSACTION <name>
1515 // Note: savepoint name was validated during creation
1516 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
1517 self.send_sql_batch(&sql).await?;
1518 self.read_execute_result().await?;
1519
1520 Ok(())
1521 }
1522
1523 /// Release a savepoint (optional cleanup).
1524 ///
1525 /// Note: SQL Server doesn't have explicit savepoint release, but this
1526 /// method is provided for API completeness. The savepoint is automatically
1527 /// released when the transaction commits or rolls back.
1528 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
1529 tracing::debug!(name = savepoint.name(), "releasing savepoint");
1530
1531 // SQL Server doesn't require explicit savepoint release
1532 // The savepoint is implicitly released on commit/rollback
1533 // This method exists for API completeness
1534 drop(savepoint);
1535 Ok(())
1536 }
1537
1538 /// Get a handle for cancelling the current query within the transaction.
1539 ///
1540 /// See [`Client<Ready>::cancel_handle`] for usage examples.
1541 #[must_use]
1542 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
1543 let connection = self
1544 .connection
1545 .as_ref()
1546 .expect("connection should be present");
1547 match connection {
1548 #[cfg(feature = "tls")]
1549 ConnectionHandle::Tls(conn) => {
1550 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1551 }
1552 #[cfg(feature = "tls")]
1553 ConnectionHandle::TlsPrelogin(conn) => {
1554 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1555 }
1556 ConnectionHandle::Plain(conn) => {
1557 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1558 }
1559 }
1560 }
1561}
1562
1563impl<S: ConnectionState> std::fmt::Debug for Client<S> {
1564 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1565 f.debug_struct("Client")
1566 .field("host", &self.config.host)
1567 .field("port", &self.config.port)
1568 .field("database", &self.config.database)
1569 .finish()
1570 }
1571}