mssql_client/client/mod.rs
1//! SQL Server client implementation.
2
3// Allow unwrap/expect for chrono date construction with known-valid constant dates
4// and for regex patterns that are compile-time constants
5#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7mod connect;
8mod params;
9mod response;
10
11use std::marker::PhantomData;
12
13use mssql_codec::connection::Connection;
14#[cfg(feature = "tls")]
15use mssql_tls::TlsStream;
16use tds_protocol::packet::PacketType;
17use tds_protocol::rpc::RpcRequest;
18use tds_protocol::token::{EnvChange, EnvChangeType};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21
22use crate::config::Config;
23use crate::error::{Error, Result};
24#[cfg(feature = "otel")]
25use crate::instrumentation::InstrumentationContext;
26use crate::state::{ConnectionState, InTransaction, Ready};
27use crate::statement_cache::StatementCache;
28use crate::stream::{MultiResultStream, QueryStream};
29use crate::transaction::SavePoint;
30
31/// SQL Server client with type-state connection management.
32///
33/// The generic parameter `S` represents the current connection state,
34/// ensuring at compile time that certain operations are only available
35/// in appropriate states.
36pub struct Client<S: ConnectionState> {
37 config: Config,
38 _state: PhantomData<S>,
39 /// The underlying connection (present only when connected)
40 connection: Option<ConnectionHandle>,
41 /// Server version from LoginAck (raw u32 TDS version)
42 server_version: Option<u32>,
43 /// Current database from EnvChange
44 current_database: Option<String>,
45 /// Prepared statement cache for query optimization
46 statement_cache: StatementCache,
47 /// Transaction descriptor from BeginTransaction EnvChange.
48 /// Per MS-TDS spec, this value must be included in ALL_HEADERS for subsequent
49 /// requests within an explicit transaction. 0 indicates auto-commit mode.
50 transaction_descriptor: u64,
51 /// Whether this connection needs a reset on next use.
52 /// Set by connection pool on checkin, cleared after first query/execute.
53 /// When true, the RESETCONNECTION flag is set on the first TDS packet.
54 needs_reset: bool,
55 /// OpenTelemetry instrumentation context (when otel feature is enabled)
56 #[cfg(feature = "otel")]
57 instrumentation: InstrumentationContext,
58 /// Always Encrypted context for column decryption (when always-encrypted feature is enabled)
59 #[cfg(feature = "always-encrypted")]
60 pub(crate) encryption_context: Option<std::sync::Arc<crate::encryption::EncryptionContext>>,
61}
62
63/// Internal connection handle wrapping the actual connection.
64///
65/// This is an enum to support different connection types:
66/// - TLS (TDS 8.0 strict mode) - requires `tls` feature
67/// - TLS with PreLogin wrapping (TDS 7.x style) - requires `tls` feature
68/// - Plain TCP (for internal networks or when `tls` feature is disabled)
69#[allow(dead_code)] // Connection will be used once query execution is implemented
70enum ConnectionHandle {
71 /// TLS connection (TDS 8.0 strict mode - TLS before any TDS traffic)
72 #[cfg(feature = "tls")]
73 Tls(Connection<TlsStream<TcpStream>>),
74 /// TLS connection with PreLogin wrapping (TDS 7.x style)
75 #[cfg(feature = "tls")]
76 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
77 /// Plain TCP connection (for internal networks or when `tls` feature is disabled)
78 Plain(Connection<TcpStream>),
79}
80
81// Private helper methods available to all connection states
82impl<S: ConnectionState> Client<S> {
83 /// Process transaction-related EnvChange tokens.
84 ///
85 /// This handles BeginTransaction, CommitTransaction, and RollbackTransaction
86 /// EnvChange tokens, updating the transaction descriptor accordingly.
87 ///
88 /// This enables executing BEGIN TRANSACTION, COMMIT, and ROLLBACK via raw SQL
89 /// while still having the transaction descriptor tracked correctly.
90 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
91 use tds_protocol::token::EnvChangeValue;
92
93 match env.env_type {
94 EnvChangeType::BeginTransaction => {
95 if let EnvChangeValue::Binary(ref data) = env.new_value {
96 if data.len() >= 8 {
97 let descriptor = u64::from_le_bytes([
98 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
99 ]);
100 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
101 *transaction_descriptor = descriptor;
102 }
103 }
104 }
105 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
106 tracing::debug!(
107 env_type = ?env.env_type,
108 "transaction ended via raw SQL"
109 );
110 *transaction_descriptor = 0;
111 }
112 _ => {}
113 }
114 }
115
116 /// Send a SQL batch to the server.
117 ///
118 /// Uses the client's current transaction descriptor in ALL_HEADERS.
119 /// Per MS-TDS spec, when in an explicit transaction, the descriptor
120 /// returned by BeginTransaction must be included.
121 ///
122 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
123 /// is included in the first packet to reset connection state.
124 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
125 let payload =
126 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
127 let max_packet = self.config.packet_size as usize;
128
129 // Check if we need to reset the connection on this request
130 let reset = self.needs_reset;
131 if reset {
132 self.needs_reset = false; // Clear flag before sending
133 tracing::debug!("sending SQL batch with RESETCONNECTION flag");
134 }
135
136 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
137
138 match connection {
139 #[cfg(feature = "tls")]
140 ConnectionHandle::Tls(conn) => {
141 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
142 .await?;
143 }
144 #[cfg(feature = "tls")]
145 ConnectionHandle::TlsPrelogin(conn) => {
146 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
147 .await?;
148 }
149 ConnectionHandle::Plain(conn) => {
150 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
151 .await?;
152 }
153 }
154
155 Ok(())
156 }
157
158 /// Send an RPC request to the server.
159 ///
160 /// Uses the client's current transaction descriptor in ALL_HEADERS.
161 ///
162 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
163 /// is included in the first packet to reset connection state.
164 pub(crate) async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
165 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
166 let max_packet = self.config.packet_size as usize;
167
168 // Check if we need to reset the connection on this request
169 let reset = self.needs_reset;
170 if reset {
171 self.needs_reset = false; // Clear flag before sending
172 tracing::debug!("sending RPC with RESETCONNECTION flag");
173 }
174
175 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
176
177 match connection {
178 #[cfg(feature = "tls")]
179 ConnectionHandle::Tls(conn) => {
180 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
181 .await?;
182 }
183 #[cfg(feature = "tls")]
184 ConnectionHandle::TlsPrelogin(conn) => {
185 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
186 .await?;
187 }
188 ConnectionHandle::Plain(conn) => {
189 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
190 .await?;
191 }
192 }
193
194 Ok(())
195 }
196
197 /// Start building a stored procedure call with full control over parameters.
198 ///
199 /// Returns a [`crate::procedure::ProcedureBuilder`] that allows adding named input and output
200 /// parameters before executing the call.
201 ///
202 /// The procedure name is validated to prevent SQL injection. It may be
203 /// schema-qualified (e.g., `"dbo.MyProc"`).
204 ///
205 /// # Example
206 ///
207 /// ```rust,ignore
208 /// let result = client.procedure("dbo.CalculateSum")?
209 /// .input("@a", &10i32)
210 /// .input("@b", &20i32)
211 /// .output_int("@result")
212 /// .execute().await?;
213 ///
214 /// let sum = result.get_output("@result").unwrap();
215 /// ```
216 pub fn procedure(
217 &mut self,
218 proc_name: &str,
219 ) -> Result<crate::procedure::ProcedureBuilder<'_, S>> {
220 crate::validation::validate_qualified_identifier(proc_name)?;
221 Ok(crate::procedure::ProcedureBuilder::new(self, proc_name))
222 }
223
224 /// Execute a stored procedure with positional input parameters.
225 ///
226 /// This is a convenience method for the common case of calling a procedure
227 /// with input-only parameters. For output parameters or named parameters,
228 /// use [`procedure()`](Client::procedure) instead.
229 ///
230 /// # Example
231 ///
232 /// ```rust,ignore
233 /// let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
234 /// assert_eq!(result.return_value, 0);
235 ///
236 /// if let Some(rs) = result.first_result_set() {
237 /// println!("columns: {:?}", rs.columns());
238 /// }
239 /// ```
240 pub async fn call_procedure(
241 &mut self,
242 proc_name: &str,
243 params: &[&(dyn crate::ToSql + Sync)],
244 ) -> Result<crate::stream::ProcedureResult> {
245 crate::validation::validate_qualified_identifier(proc_name)?;
246
247 tracing::debug!(
248 proc_name = proc_name,
249 params_count = params.len(),
250 "executing stored procedure"
251 );
252
253 let rpc_params = Self::convert_params(params)?;
254 let mut rpc = RpcRequest::named(proc_name);
255 for param in rpc_params {
256 rpc = rpc.param(param);
257 }
258
259 self.send_rpc(&rpc).await?;
260 self.read_procedure_result().await
261 }
262}
263
264impl Client<Ready> {
265 /// Mark this connection as needing a reset on next use.
266 ///
267 /// Called by the connection pool when a connection is returned.
268 /// The next SQL batch or RPC will include the RESETCONNECTION flag
269 /// in the TDS packet header, causing SQL Server to reset connection
270 /// state (temp tables, SET options, transaction isolation level, etc.)
271 /// before executing the command.
272 ///
273 /// This is more efficient than calling `sp_reset_connection` as a
274 /// separate command because it's handled at the TDS protocol level.
275 pub fn mark_needs_reset(&mut self) {
276 self.needs_reset = true;
277 }
278
279 /// Check if this connection needs a reset.
280 ///
281 /// Returns true if `mark_needs_reset()` was called and the reset
282 /// hasn't been performed yet.
283 #[must_use]
284 pub fn needs_reset(&self) -> bool {
285 self.needs_reset
286 }
287
288 /// Execute a query and return a streaming result set.
289 ///
290 /// Per ADR-007, results are streamed by default for memory efficiency.
291 /// Use `.collect_all()` on the stream if you need all rows in memory.
292 ///
293 /// # Example
294 ///
295 /// ```rust,ignore
296 /// use futures::StreamExt;
297 ///
298 /// // Streaming (memory-efficient)
299 /// let mut stream = client.query("SELECT * FROM users WHERE id = @p1", &[&1]).await?;
300 /// while let Some(row) = stream.next().await {
301 /// let row = row?;
302 /// process(&row);
303 /// }
304 ///
305 /// // Buffered (loads all into memory)
306 /// let rows: Vec<Row> = client
307 /// .query("SELECT * FROM small_table", &[])
308 /// .await?
309 /// .collect_all()
310 /// .await?;
311 /// ```
312 pub async fn query<'a>(
313 &'a mut self,
314 sql: &str,
315 params: &[&(dyn crate::ToSql + Sync)],
316 ) -> Result<QueryStream<'a>> {
317 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
318
319 #[cfg(feature = "otel")]
320 let instrumentation = self.instrumentation.clone();
321 #[cfg(feature = "otel")]
322 let mut span = instrumentation.query_span(sql);
323
324 let result = async {
325 if params.is_empty() {
326 // Simple query without parameters - use SQL batch
327 self.send_sql_batch(sql).await?;
328 } else {
329 // Parameterized query - use sp_executesql via RPC
330 let rpc_params = Self::convert_params(params)?;
331 let rpc = RpcRequest::execute_sql(sql, rpc_params);
332 self.send_rpc(&rpc).await?;
333 }
334
335 // Read complete response including columns and rows
336 self.read_query_response().await
337 }
338 .await;
339
340 #[cfg(feature = "otel")]
341 match &result {
342 Ok(_) => InstrumentationContext::record_success(&mut span, None),
343 Err(e) => InstrumentationContext::record_error(&mut span, e),
344 }
345
346 // Drop the span before returning
347 #[cfg(feature = "otel")]
348 drop(span);
349
350 let (columns, rows) = result?;
351 Ok(QueryStream::new(columns, rows))
352 }
353
354 /// Execute a query with a specific timeout.
355 ///
356 /// This overrides the default `command_timeout` from the connection configuration
357 /// for this specific query. If the query does not complete within the specified
358 /// duration, an error is returned.
359 ///
360 /// # Arguments
361 ///
362 /// * `sql` - The SQL query to execute
363 /// * `params` - Query parameters
364 /// * `timeout_duration` - Maximum time to wait for the query to complete
365 ///
366 /// # Example
367 ///
368 /// ```rust,ignore
369 /// use std::time::Duration;
370 ///
371 /// // Execute with a 5-second timeout
372 /// let rows = client
373 /// .query_with_timeout(
374 /// "SELECT * FROM large_table",
375 /// &[],
376 /// Duration::from_secs(5),
377 /// )
378 /// .await?;
379 /// ```
380 pub async fn query_with_timeout<'a>(
381 &'a mut self,
382 sql: &str,
383 params: &[&(dyn crate::ToSql + Sync)],
384 timeout_duration: std::time::Duration,
385 ) -> Result<QueryStream<'a>> {
386 timeout(timeout_duration, self.query(sql, params))
387 .await
388 .map_err(|_| Error::CommandTimeout)?
389 }
390
391 /// Execute a batch that may return multiple result sets.
392 ///
393 /// This is useful for stored procedures or SQL batches that contain
394 /// multiple SELECT statements.
395 ///
396 /// # Example
397 ///
398 /// ```rust,ignore
399 /// // Execute a batch with multiple SELECTs
400 /// let mut results = client.query_multiple(
401 /// "SELECT 1 AS a; SELECT 2 AS b, 3 AS c;",
402 /// &[]
403 /// ).await?;
404 ///
405 /// // Process first result set
406 /// while let Some(row) = results.next_row().await? {
407 /// println!("Result 1: {:?}", row);
408 /// }
409 ///
410 /// // Move to second result set
411 /// if results.next_result().await? {
412 /// while let Some(row) = results.next_row().await? {
413 /// println!("Result 2: {:?}", row);
414 /// }
415 /// }
416 /// ```
417 pub async fn query_multiple<'a>(
418 &'a mut self,
419 sql: &str,
420 params: &[&(dyn crate::ToSql + Sync)],
421 ) -> Result<MultiResultStream<'a>> {
422 tracing::debug!(
423 sql = sql,
424 params_count = params.len(),
425 "executing multi-result query"
426 );
427
428 if params.is_empty() {
429 // Simple batch without parameters - use SQL batch
430 self.send_sql_batch(sql).await?;
431 } else {
432 // Parameterized query - use sp_executesql via RPC
433 let rpc_params = Self::convert_params(params)?;
434 let rpc = RpcRequest::execute_sql(sql, rpc_params);
435 self.send_rpc(&rpc).await?;
436 }
437
438 // Read all result sets
439 let result_sets = self.read_multi_result_response().await?;
440 Ok(MultiResultStream::new(result_sets))
441 }
442
443 /// Execute a query that doesn't return rows.
444 ///
445 /// Returns the number of affected rows.
446 pub async fn execute(
447 &mut self,
448 sql: &str,
449 params: &[&(dyn crate::ToSql + Sync)],
450 ) -> Result<u64> {
451 tracing::debug!(
452 sql = sql,
453 params_count = params.len(),
454 "executing statement"
455 );
456
457 #[cfg(feature = "otel")]
458 let instrumentation = self.instrumentation.clone();
459 #[cfg(feature = "otel")]
460 let mut span = instrumentation.query_span(sql);
461
462 let result = async {
463 if params.is_empty() {
464 // Simple statement without parameters - use SQL batch
465 self.send_sql_batch(sql).await?;
466 } else {
467 // Parameterized statement - use sp_executesql via RPC
468 let rpc_params = Self::convert_params(params)?;
469 let rpc = RpcRequest::execute_sql(sql, rpc_params);
470 self.send_rpc(&rpc).await?;
471 }
472
473 // Read response and get row count
474 self.read_execute_result().await
475 }
476 .await;
477
478 #[cfg(feature = "otel")]
479 match &result {
480 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
481 Err(e) => InstrumentationContext::record_error(&mut span, e),
482 }
483
484 // Drop the span before returning
485 #[cfg(feature = "otel")]
486 drop(span);
487
488 result
489 }
490
491 /// Execute a statement with a specific timeout.
492 ///
493 /// This overrides the default `command_timeout` from the connection configuration
494 /// for this specific statement. If the statement does not complete within the
495 /// specified duration, an error is returned.
496 ///
497 /// # Arguments
498 ///
499 /// * `sql` - The SQL statement to execute
500 /// * `params` - Statement parameters
501 /// * `timeout_duration` - Maximum time to wait for the statement to complete
502 ///
503 /// # Example
504 ///
505 /// ```rust,ignore
506 /// use std::time::Duration;
507 ///
508 /// // Execute with a 10-second timeout
509 /// let rows_affected = client
510 /// .execute_with_timeout(
511 /// "UPDATE large_table SET status = @p1",
512 /// &[&"processed"],
513 /// Duration::from_secs(10),
514 /// )
515 /// .await?;
516 /// ```
517 pub async fn execute_with_timeout(
518 &mut self,
519 sql: &str,
520 params: &[&(dyn crate::ToSql + Sync)],
521 timeout_duration: std::time::Duration,
522 ) -> Result<u64> {
523 timeout(timeout_duration, self.execute(sql, params))
524 .await
525 .map_err(|_| Error::CommandTimeout)?
526 }
527
528 /// Begin a transaction.
529 ///
530 /// This transitions the client from `Ready` to `InTransaction` state.
531 /// Per MS-TDS spec, the server returns a transaction descriptor in the
532 /// BeginTransaction EnvChange token that must be included in subsequent
533 /// ALL_HEADERS sections.
534 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
535 tracing::debug!("beginning transaction");
536
537 #[cfg(feature = "otel")]
538 let instrumentation = self.instrumentation.clone();
539 #[cfg(feature = "otel")]
540 let mut span = instrumentation.transaction_span("BEGIN");
541
542 // Execute BEGIN TRANSACTION and extract the transaction descriptor
543 let result = async {
544 self.send_sql_batch("BEGIN TRANSACTION").await?;
545 self.read_transaction_begin_result().await
546 }
547 .await;
548
549 #[cfg(feature = "otel")]
550 match &result {
551 Ok(_) => InstrumentationContext::record_success(&mut span, None),
552 Err(e) => InstrumentationContext::record_error(&mut span, e),
553 }
554
555 // Drop the span before moving instrumentation
556 #[cfg(feature = "otel")]
557 drop(span);
558
559 let transaction_descriptor = result?;
560
561 Ok(Client {
562 config: self.config,
563 _state: PhantomData,
564 connection: self.connection,
565 server_version: self.server_version,
566 current_database: self.current_database,
567 statement_cache: self.statement_cache,
568 transaction_descriptor, // Store the descriptor from server
569 needs_reset: self.needs_reset,
570 #[cfg(feature = "otel")]
571 instrumentation: self.instrumentation,
572 #[cfg(feature = "always-encrypted")]
573 encryption_context: self.encryption_context,
574 })
575 }
576
577 /// Begin a transaction with a specific isolation level.
578 ///
579 /// This transitions the client from `Ready` to `InTransaction` state
580 /// with the specified isolation level.
581 ///
582 /// # Example
583 ///
584 /// ```rust,ignore
585 /// use mssql_client::IsolationLevel;
586 ///
587 /// let tx = client.begin_transaction_with_isolation(IsolationLevel::Serializable).await?;
588 /// // All operations in this transaction use SERIALIZABLE isolation
589 /// tx.commit().await?;
590 /// ```
591 pub async fn begin_transaction_with_isolation(
592 mut self,
593 isolation_level: crate::transaction::IsolationLevel,
594 ) -> Result<Client<InTransaction>> {
595 tracing::debug!(
596 isolation_level = %isolation_level.name(),
597 "beginning transaction with isolation level"
598 );
599
600 #[cfg(feature = "otel")]
601 let instrumentation = self.instrumentation.clone();
602 #[cfg(feature = "otel")]
603 let mut span = instrumentation.transaction_span("BEGIN");
604
605 // First set the isolation level
606 let result = async {
607 self.send_sql_batch(isolation_level.as_sql()).await?;
608 self.read_execute_result().await?;
609
610 // Then begin the transaction
611 self.send_sql_batch("BEGIN TRANSACTION").await?;
612 self.read_transaction_begin_result().await
613 }
614 .await;
615
616 #[cfg(feature = "otel")]
617 match &result {
618 Ok(_) => InstrumentationContext::record_success(&mut span, None),
619 Err(e) => InstrumentationContext::record_error(&mut span, e),
620 }
621
622 #[cfg(feature = "otel")]
623 drop(span);
624
625 let transaction_descriptor = result?;
626
627 Ok(Client {
628 config: self.config,
629 _state: PhantomData,
630 connection: self.connection,
631 server_version: self.server_version,
632 current_database: self.current_database,
633 statement_cache: self.statement_cache,
634 transaction_descriptor,
635 needs_reset: self.needs_reset,
636 #[cfg(feature = "otel")]
637 instrumentation: self.instrumentation,
638 #[cfg(feature = "always-encrypted")]
639 encryption_context: self.encryption_context,
640 })
641 }
642
643 /// Execute a simple query without parameters.
644 ///
645 /// This is useful for DDL statements and simple queries where you
646 /// don't need to retrieve the affected row count.
647 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
648 tracing::debug!(sql = sql, "executing simple query");
649
650 // Send SQL batch
651 self.send_sql_batch(sql).await?;
652
653 // Read and discard response
654 let _ = self.read_execute_result().await?;
655
656 Ok(())
657 }
658
659 /// Close the connection gracefully.
660 pub async fn close(self) -> Result<()> {
661 tracing::debug!("closing connection");
662 Ok(())
663 }
664
665 /// Get the current database name.
666 #[must_use]
667 pub fn database(&self) -> Option<&str> {
668 self.config.database.as_deref()
669 }
670
671 /// Get the server host.
672 #[must_use]
673 pub fn host(&self) -> &str {
674 &self.config.host
675 }
676
677 /// Get the server port.
678 #[must_use]
679 pub fn port(&self) -> u16 {
680 self.config.port
681 }
682
683 /// Check if the connection is currently in a transaction.
684 ///
685 /// This returns `true` if a transaction was started via raw SQL
686 /// (`BEGIN TRANSACTION`) and has not yet been committed or rolled back.
687 ///
688 /// Note: This only tracks transactions started via raw SQL. Transactions
689 /// started via the type-state API (`begin_transaction()`) result in a
690 /// `Client<InTransaction>` which is a different type.
691 ///
692 /// # Example
693 ///
694 /// ```rust,ignore
695 /// client.execute("BEGIN TRANSACTION", &[]).await?;
696 /// assert!(client.is_in_transaction());
697 ///
698 /// client.execute("COMMIT", &[]).await?;
699 /// assert!(!client.is_in_transaction());
700 /// ```
701 #[must_use]
702 pub fn is_in_transaction(&self) -> bool {
703 self.transaction_descriptor != 0
704 }
705
706 /// Get a handle for cancelling the current query.
707 ///
708 /// The cancel handle can be cloned and sent to other tasks, enabling
709 /// cancellation of long-running queries from a separate async context.
710 ///
711 /// # Example
712 ///
713 /// ```rust,ignore
714 /// use std::time::Duration;
715 ///
716 /// let cancel_handle = client.cancel_handle();
717 ///
718 /// // Spawn a task to cancel after 10 seconds
719 /// let handle = tokio::spawn(async move {
720 /// tokio::time::sleep(Duration::from_secs(10)).await;
721 /// let _ = cancel_handle.cancel().await;
722 /// });
723 ///
724 /// // This query will be cancelled if it runs longer than 10 seconds
725 /// let result = client.query("SELECT * FROM very_large_table", &[]).await;
726 /// ```
727 #[must_use]
728 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
729 let connection = self
730 .connection
731 .as_ref()
732 .expect("connection should be present");
733 match connection {
734 #[cfg(feature = "tls")]
735 ConnectionHandle::Tls(conn) => {
736 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
737 }
738 #[cfg(feature = "tls")]
739 ConnectionHandle::TlsPrelogin(conn) => {
740 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
741 }
742 ConnectionHandle::Plain(conn) => {
743 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
744 }
745 }
746 }
747}
748
749/// # Drop Behavior
750///
751/// **`Client<InTransaction>` has no automatic rollback on drop.** If the client is
752/// dropped without calling [`commit()`](Client::commit) or [`rollback()`](Client::rollback),
753/// the transaction remains open on the server until the TCP connection closes
754/// (at which point SQL Server automatically rolls back).
755///
756/// This is because `Drop` is synchronous and cannot perform the async I/O needed
757/// to send a `ROLLBACK TRANSACTION` command.
758///
759/// ## Consequences of dropping without commit/rollback
760///
761/// - **Direct connections:** The transaction leaks until the OS TCP timeout
762/// (potentially 30+ minutes), holding locks on any modified rows.
763/// - **Pooled connections:** The pool detects the active transaction descriptor
764/// and discards the connection rather than returning it to the idle pool
765/// (see `PooledConnection::drop` in `mssql-driver-pool`).
766///
767/// ## Best practice
768///
769/// Always ensure `commit()` or `rollback()` is called. Use helper patterns
770/// for error paths:
771///
772/// ```rust,ignore
773/// let tx = client.begin_transaction().await?;
774/// match do_work(&tx).await {
775/// Ok(_) => { tx.commit().await?; }
776/// Err(e) => { tx.rollback().await?; return Err(e); }
777/// }
778/// ```
779impl Client<InTransaction> {
780 /// Execute a query within the transaction and return a streaming result set.
781 ///
782 /// See [`Client<Ready>::query`] for usage examples.
783 pub async fn query<'a>(
784 &'a mut self,
785 sql: &str,
786 params: &[&(dyn crate::ToSql + Sync)],
787 ) -> Result<QueryStream<'a>> {
788 tracing::debug!(
789 sql = sql,
790 params_count = params.len(),
791 "executing query in transaction"
792 );
793
794 #[cfg(feature = "otel")]
795 let instrumentation = self.instrumentation.clone();
796 #[cfg(feature = "otel")]
797 let mut span = instrumentation.query_span(sql);
798
799 let result = async {
800 if params.is_empty() {
801 // Simple query without parameters - use SQL batch
802 self.send_sql_batch(sql).await?;
803 } else {
804 // Parameterized query - use sp_executesql via RPC
805 let rpc_params = Self::convert_params(params)?;
806 let rpc = RpcRequest::execute_sql(sql, rpc_params);
807 self.send_rpc(&rpc).await?;
808 }
809
810 // Read complete response including columns and rows
811 self.read_query_response().await
812 }
813 .await;
814
815 #[cfg(feature = "otel")]
816 match &result {
817 Ok(_) => InstrumentationContext::record_success(&mut span, None),
818 Err(e) => InstrumentationContext::record_error(&mut span, e),
819 }
820
821 // Drop the span before returning
822 #[cfg(feature = "otel")]
823 drop(span);
824
825 let (columns, rows) = result?;
826 Ok(QueryStream::new(columns, rows))
827 }
828
829 /// Execute a statement within the transaction.
830 ///
831 /// Returns the number of affected rows.
832 pub async fn execute(
833 &mut self,
834 sql: &str,
835 params: &[&(dyn crate::ToSql + Sync)],
836 ) -> Result<u64> {
837 tracing::debug!(
838 sql = sql,
839 params_count = params.len(),
840 "executing statement in transaction"
841 );
842
843 #[cfg(feature = "otel")]
844 let instrumentation = self.instrumentation.clone();
845 #[cfg(feature = "otel")]
846 let mut span = instrumentation.query_span(sql);
847
848 let result = async {
849 if params.is_empty() {
850 // Simple statement without parameters - use SQL batch
851 self.send_sql_batch(sql).await?;
852 } else {
853 // Parameterized statement - use sp_executesql via RPC
854 let rpc_params = Self::convert_params(params)?;
855 let rpc = RpcRequest::execute_sql(sql, rpc_params);
856 self.send_rpc(&rpc).await?;
857 }
858
859 // Read response and get row count
860 self.read_execute_result().await
861 }
862 .await;
863
864 #[cfg(feature = "otel")]
865 match &result {
866 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
867 Err(e) => InstrumentationContext::record_error(&mut span, e),
868 }
869
870 // Drop the span before returning
871 #[cfg(feature = "otel")]
872 drop(span);
873
874 result
875 }
876
877 /// Execute a query within the transaction with a specific timeout.
878 ///
879 /// See [`Client<Ready>::query_with_timeout`] for details.
880 pub async fn query_with_timeout<'a>(
881 &'a mut self,
882 sql: &str,
883 params: &[&(dyn crate::ToSql + Sync)],
884 timeout_duration: std::time::Duration,
885 ) -> Result<QueryStream<'a>> {
886 timeout(timeout_duration, self.query(sql, params))
887 .await
888 .map_err(|_| Error::CommandTimeout)?
889 }
890
891 /// Execute a statement within the transaction with a specific timeout.
892 ///
893 /// See [`Client<Ready>::execute_with_timeout`] for details.
894 pub async fn execute_with_timeout(
895 &mut self,
896 sql: &str,
897 params: &[&(dyn crate::ToSql + Sync)],
898 timeout_duration: std::time::Duration,
899 ) -> Result<u64> {
900 timeout(timeout_duration, self.execute(sql, params))
901 .await
902 .map_err(|_| Error::CommandTimeout)?
903 }
904
905 /// Open a FILESTREAM BLOB for async reading and/or writing.
906 ///
907 /// This method queries the server for the transaction context, then opens
908 /// the FILESTREAM handle using the native Win32 `OpenSqlFilestream` API.
909 ///
910 /// # Arguments
911 ///
912 /// * `path` — The UNC path obtained from the T-SQL `column.PathName()` function.
913 /// Query this yourself before calling `open_filestream`:
914 /// ```sql
915 /// SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1
916 /// ```
917 /// * `access` — Read, write, or read/write access mode.
918 ///
919 /// # Requirements
920 ///
921 /// - SQL Server must have FILESTREAM enabled (`sp_configure 'filestream access level', 2`)
922 /// - The Microsoft OLE DB Driver for SQL Server must be installed on the client
923 /// - The `FileStream` must be dropped before calling [`commit`] or [`rollback`]
924 ///
925 /// # Example
926 ///
927 /// ```rust,ignore
928 /// use mssql_client::FileStreamAccess;
929 /// use tokio::io::AsyncReadExt;
930 ///
931 /// let mut tx = client.begin_transaction().await?;
932 ///
933 /// // Get the FILESTREAM path
934 /// let rows = tx.query(
935 /// "SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1",
936 /// &[&doc_id],
937 /// ).await?;
938 /// let path: String = rows.into_iter().next().unwrap()?.get(0)?;
939 ///
940 /// // Open and read the BLOB
941 /// let mut stream = tx.open_filestream(&path, FileStreamAccess::Read).await?;
942 /// let mut data = Vec::new();
943 /// stream.read_to_end(&mut data).await?;
944 /// drop(stream);
945 ///
946 /// tx.commit().await?;
947 /// ```
948 #[cfg(all(windows, feature = "filestream"))]
949 pub async fn open_filestream(
950 &mut self,
951 path: &str,
952 access: crate::filestream::FileStreamAccess,
953 ) -> Result<crate::filestream::FileStream> {
954 tracing::debug!(path = path, ?access, "opening FILESTREAM BLOB");
955
956 // Get the transaction context from SQL Server.
957 // This binds the file access to the current SQL transaction.
958 let txn_context: Vec<u8> = {
959 let rows = self
960 .query("SELECT GET_FILESTREAM_TRANSACTION_CONTEXT()", &[])
961 .await?;
962 let mut ctx = None;
963 for result in rows {
964 let row = result?;
965 ctx = Some(row.get::<Vec<u8>>(0)?);
966 }
967 ctx.ok_or_else(|| {
968 Error::FileStream("GET_FILESTREAM_TRANSACTION_CONTEXT() returned no rows".into())
969 })?
970 };
971
972 crate::filestream::FileStream::open(path, access, &txn_context)
973 }
974
975 /// Commit the transaction.
976 ///
977 /// This transitions the client back to `Ready` state.
978 pub async fn commit(mut self) -> Result<Client<Ready>> {
979 tracing::debug!("committing transaction");
980
981 #[cfg(feature = "otel")]
982 let instrumentation = self.instrumentation.clone();
983 #[cfg(feature = "otel")]
984 let mut span = instrumentation.transaction_span("COMMIT");
985
986 // Execute COMMIT TRANSACTION
987 let result = async {
988 self.send_sql_batch("COMMIT TRANSACTION").await?;
989 self.read_execute_result().await
990 }
991 .await;
992
993 #[cfg(feature = "otel")]
994 match &result {
995 Ok(_) => InstrumentationContext::record_success(&mut span, None),
996 Err(e) => InstrumentationContext::record_error(&mut span, e),
997 }
998
999 // Drop the span before moving instrumentation
1000 #[cfg(feature = "otel")]
1001 drop(span);
1002
1003 result?;
1004
1005 Ok(Client {
1006 config: self.config,
1007 _state: PhantomData,
1008 connection: self.connection,
1009 server_version: self.server_version,
1010 current_database: self.current_database,
1011 statement_cache: self.statement_cache,
1012 transaction_descriptor: 0, // Reset to auto-commit mode
1013 needs_reset: self.needs_reset,
1014 #[cfg(feature = "otel")]
1015 instrumentation: self.instrumentation,
1016 #[cfg(feature = "always-encrypted")]
1017 encryption_context: self.encryption_context,
1018 })
1019 }
1020
1021 /// Rollback the transaction.
1022 ///
1023 /// This transitions the client back to `Ready` state.
1024 pub async fn rollback(mut self) -> Result<Client<Ready>> {
1025 tracing::debug!("rolling back transaction");
1026
1027 #[cfg(feature = "otel")]
1028 let instrumentation = self.instrumentation.clone();
1029 #[cfg(feature = "otel")]
1030 let mut span = instrumentation.transaction_span("ROLLBACK");
1031
1032 // Execute ROLLBACK TRANSACTION
1033 let result = async {
1034 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
1035 self.read_execute_result().await
1036 }
1037 .await;
1038
1039 #[cfg(feature = "otel")]
1040 match &result {
1041 Ok(_) => InstrumentationContext::record_success(&mut span, None),
1042 Err(e) => InstrumentationContext::record_error(&mut span, e),
1043 }
1044
1045 // Drop the span before moving instrumentation
1046 #[cfg(feature = "otel")]
1047 drop(span);
1048
1049 result?;
1050
1051 Ok(Client {
1052 config: self.config,
1053 _state: PhantomData,
1054 connection: self.connection,
1055 server_version: self.server_version,
1056 current_database: self.current_database,
1057 statement_cache: self.statement_cache,
1058 transaction_descriptor: 0, // Reset to auto-commit mode
1059 needs_reset: self.needs_reset,
1060 #[cfg(feature = "otel")]
1061 instrumentation: self.instrumentation,
1062 #[cfg(feature = "always-encrypted")]
1063 encryption_context: self.encryption_context,
1064 })
1065 }
1066
1067 /// Create a savepoint and return a handle for later rollback.
1068 ///
1069 /// The returned `SavePoint` handle contains the validated savepoint name.
1070 /// Use it with `rollback_to()` to partially undo transaction work.
1071 ///
1072 /// # Example
1073 ///
1074 /// ```rust,ignore
1075 /// let tx = client.begin_transaction().await?;
1076 /// tx.execute("INSERT INTO orders ...").await?;
1077 /// let sp = tx.save_point("before_items").await?;
1078 /// tx.execute("INSERT INTO items ...").await?;
1079 /// // Oops, rollback just the items
1080 /// tx.rollback_to(&sp).await?;
1081 /// tx.commit().await?;
1082 /// ```
1083 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
1084 crate::validation::validate_identifier(name)?;
1085 tracing::debug!(name = name, "creating savepoint");
1086
1087 // Execute SAVE TRANSACTION <name>
1088 // Note: name is validated by validate_identifier() to prevent SQL injection
1089 let sql = format!("SAVE TRANSACTION {name}");
1090 self.send_sql_batch(&sql).await?;
1091 self.read_execute_result().await?;
1092
1093 Ok(SavePoint::new(name.to_string()))
1094 }
1095
1096 /// Rollback to a savepoint.
1097 ///
1098 /// This rolls back all changes made after the savepoint was created,
1099 /// but keeps the transaction active. The savepoint remains valid and
1100 /// can be rolled back to again.
1101 ///
1102 /// # Example
1103 ///
1104 /// ```rust,ignore
1105 /// let sp = tx.save_point("checkpoint").await?;
1106 /// // ... do some work ...
1107 /// tx.rollback_to(&sp).await?; // Undo changes since checkpoint
1108 /// // Transaction is still active, savepoint is still valid
1109 /// ```
1110 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
1111 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
1112
1113 // Execute ROLLBACK TRANSACTION <name>
1114 // Note: savepoint name was validated during creation
1115 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
1116 self.send_sql_batch(&sql).await?;
1117 self.read_execute_result().await?;
1118
1119 Ok(())
1120 }
1121
1122 /// Release a savepoint (optional cleanup).
1123 ///
1124 /// Note: SQL Server doesn't have explicit savepoint release, but this
1125 /// method is provided for API completeness. The savepoint is automatically
1126 /// released when the transaction commits or rolls back.
1127 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
1128 tracing::debug!(name = savepoint.name(), "releasing savepoint");
1129
1130 // SQL Server doesn't require explicit savepoint release
1131 // The savepoint is implicitly released on commit/rollback
1132 // This method exists for API completeness
1133 drop(savepoint);
1134 Ok(())
1135 }
1136
1137 /// Get a handle for cancelling the current query within the transaction.
1138 ///
1139 /// See [`Client<Ready>::cancel_handle`] for usage examples.
1140 #[must_use]
1141 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
1142 let connection = self
1143 .connection
1144 .as_ref()
1145 .expect("connection should be present");
1146 match connection {
1147 #[cfg(feature = "tls")]
1148 ConnectionHandle::Tls(conn) => {
1149 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1150 }
1151 #[cfg(feature = "tls")]
1152 ConnectionHandle::TlsPrelogin(conn) => {
1153 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1154 }
1155 ConnectionHandle::Plain(conn) => {
1156 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1157 }
1158 }
1159 }
1160}
1161
1162impl<S: ConnectionState> std::fmt::Debug for Client<S> {
1163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1164 f.debug_struct("Client")
1165 .field("host", &self.config.host)
1166 .field("port", &self.config.port)
1167 .field("database", &self.config.database)
1168 .finish()
1169 }
1170}