Skip to main content

mssql_client/client/
mod.rs

1//! SQL Server client implementation.
2
3// Allow unwrap/expect for chrono date construction with known-valid constant dates
4// and for regex patterns that are compile-time constants
5#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7mod connect;
8mod params;
9mod response;
10
11use std::marker::PhantomData;
12
13use mssql_codec::connection::Connection;
14#[cfg(feature = "tls")]
15use mssql_tls::TlsStream;
16use tds_protocol::packet::PacketType;
17use tds_protocol::rpc::RpcRequest;
18use tds_protocol::token::{EnvChange, EnvChangeType};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21
22use crate::config::Config;
23use crate::error::{Error, Result};
24#[cfg(feature = "otel")]
25use crate::instrumentation::InstrumentationContext;
26use crate::state::{ConnectionState, InTransaction, Ready};
27use crate::statement_cache::StatementCache;
28use crate::stream::{MultiResultStream, QueryStream};
29use crate::transaction::SavePoint;
30
31/// SQL Server client with type-state connection management.
32///
33/// The generic parameter `S` represents the current connection state,
34/// ensuring at compile time that certain operations are only available
35/// in appropriate states.
36pub struct Client<S: ConnectionState> {
37    config: Config,
38    _state: PhantomData<S>,
39    /// The underlying connection (present only when connected)
40    connection: Option<ConnectionHandle>,
41    /// Server version from LoginAck (raw u32 TDS version)
42    server_version: Option<u32>,
43    /// Current database from EnvChange
44    current_database: Option<String>,
45    /// Prepared statement cache for query optimization
46    statement_cache: StatementCache,
47    /// Transaction descriptor from BeginTransaction EnvChange.
48    /// Per MS-TDS spec, this value must be included in ALL_HEADERS for subsequent
49    /// requests within an explicit transaction. 0 indicates auto-commit mode.
50    transaction_descriptor: u64,
51    /// Whether this connection needs a reset on next use.
52    /// Set by connection pool on checkin, cleared after first query/execute.
53    /// When true, the RESETCONNECTION flag is set on the first TDS packet.
54    needs_reset: bool,
55    /// OpenTelemetry instrumentation context (when otel feature is enabled)
56    #[cfg(feature = "otel")]
57    instrumentation: InstrumentationContext,
58    /// Always Encrypted context for column decryption (when always-encrypted feature is enabled)
59    #[cfg(feature = "always-encrypted")]
60    pub(crate) encryption_context: Option<std::sync::Arc<crate::encryption::EncryptionContext>>,
61}
62
63/// Internal connection handle wrapping the actual connection.
64///
65/// This is an enum to support different connection types:
66/// - TLS (TDS 8.0 strict mode) - requires `tls` feature
67/// - TLS with PreLogin wrapping (TDS 7.x style) - requires `tls` feature
68/// - Plain TCP (for internal networks or when `tls` feature is disabled)
69#[allow(dead_code)] // Connection will be used once query execution is implemented
70enum ConnectionHandle {
71    /// TLS connection (TDS 8.0 strict mode - TLS before any TDS traffic)
72    #[cfg(feature = "tls")]
73    Tls(Connection<TlsStream<TcpStream>>),
74    /// TLS connection with PreLogin wrapping (TDS 7.x style)
75    #[cfg(feature = "tls")]
76    TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
77    /// Plain TCP connection (for internal networks or when `tls` feature is disabled)
78    Plain(Connection<TcpStream>),
79}
80
81// Private helper methods available to all connection states
82impl<S: ConnectionState> Client<S> {
83    /// Process transaction-related EnvChange tokens.
84    ///
85    /// This handles BeginTransaction, CommitTransaction, and RollbackTransaction
86    /// EnvChange tokens, updating the transaction descriptor accordingly.
87    ///
88    /// This enables executing BEGIN TRANSACTION, COMMIT, and ROLLBACK via raw SQL
89    /// while still having the transaction descriptor tracked correctly.
90    fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
91        use tds_protocol::token::EnvChangeValue;
92
93        match env.env_type {
94            EnvChangeType::BeginTransaction => {
95                if let EnvChangeValue::Binary(ref data) = env.new_value {
96                    if data.len() >= 8 {
97                        let descriptor = u64::from_le_bytes([
98                            data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
99                        ]);
100                        tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
101                        *transaction_descriptor = descriptor;
102                    }
103                }
104            }
105            EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
106                tracing::debug!(
107                    env_type = ?env.env_type,
108                    "transaction ended via raw SQL"
109                );
110                *transaction_descriptor = 0;
111            }
112            _ => {}
113        }
114    }
115
116    /// Send a SQL batch to the server.
117    ///
118    /// Uses the client's current transaction descriptor in ALL_HEADERS.
119    /// Per MS-TDS spec, when in an explicit transaction, the descriptor
120    /// returned by BeginTransaction must be included.
121    ///
122    /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
123    /// is included in the first packet to reset connection state.
124    async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
125        let payload =
126            tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
127        let max_packet = self.config.packet_size as usize;
128
129        // Check if we need to reset the connection on this request
130        let reset = self.needs_reset;
131        if reset {
132            self.needs_reset = false; // Clear flag before sending
133            tracing::debug!("sending SQL batch with RESETCONNECTION flag");
134        }
135
136        let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
137
138        match connection {
139            #[cfg(feature = "tls")]
140            ConnectionHandle::Tls(conn) => {
141                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
142                    .await?;
143            }
144            #[cfg(feature = "tls")]
145            ConnectionHandle::TlsPrelogin(conn) => {
146                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
147                    .await?;
148            }
149            ConnectionHandle::Plain(conn) => {
150                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
151                    .await?;
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Send an RPC request to the server.
159    ///
160    /// Uses the client's current transaction descriptor in ALL_HEADERS.
161    ///
162    /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
163    /// is included in the first packet to reset connection state.
164    pub(crate) async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
165        let payload = rpc.encode_with_transaction(self.transaction_descriptor);
166        let max_packet = self.config.packet_size as usize;
167
168        // Check if we need to reset the connection on this request
169        let reset = self.needs_reset;
170        if reset {
171            self.needs_reset = false; // Clear flag before sending
172            tracing::debug!("sending RPC with RESETCONNECTION flag");
173        }
174
175        let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
176
177        match connection {
178            #[cfg(feature = "tls")]
179            ConnectionHandle::Tls(conn) => {
180                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
181                    .await?;
182            }
183            #[cfg(feature = "tls")]
184            ConnectionHandle::TlsPrelogin(conn) => {
185                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
186                    .await?;
187            }
188            ConnectionHandle::Plain(conn) => {
189                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
190                    .await?;
191            }
192        }
193
194        Ok(())
195    }
196
197    /// Start building a stored procedure call with full control over parameters.
198    ///
199    /// Returns a [`crate::procedure::ProcedureBuilder`] that allows adding named input and output
200    /// parameters before executing the call.
201    ///
202    /// The procedure name is validated to prevent SQL injection. It may be
203    /// schema-qualified (e.g., `"dbo.MyProc"`).
204    ///
205    /// # Example
206    ///
207    /// ```rust,ignore
208    /// let result = client.procedure("dbo.CalculateSum")?
209    ///     .input("@a", &10i32)
210    ///     .input("@b", &20i32)
211    ///     .output_int("@result")
212    ///     .execute().await?;
213    ///
214    /// let sum = result.get_output("@result").unwrap();
215    /// ```
216    pub fn procedure(
217        &mut self,
218        proc_name: &str,
219    ) -> Result<crate::procedure::ProcedureBuilder<'_, S>> {
220        crate::validation::validate_qualified_identifier(proc_name)?;
221        Ok(crate::procedure::ProcedureBuilder::new(self, proc_name))
222    }
223
224    /// Execute a stored procedure with positional input parameters.
225    ///
226    /// This is a convenience method for the common case of calling a procedure
227    /// with input-only parameters. For output parameters or named parameters,
228    /// use [`procedure()`](Client::procedure) instead.
229    ///
230    /// # Example
231    ///
232    /// ```rust,ignore
233    /// let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
234    /// assert_eq!(result.return_value, 0);
235    ///
236    /// if let Some(rs) = result.first_result_set() {
237    ///     println!("columns: {:?}", rs.columns());
238    /// }
239    /// ```
240    pub async fn call_procedure(
241        &mut self,
242        proc_name: &str,
243        params: &[&(dyn crate::ToSql + Sync)],
244    ) -> Result<crate::stream::ProcedureResult> {
245        crate::validation::validate_qualified_identifier(proc_name)?;
246
247        tracing::debug!(
248            proc_name = proc_name,
249            params_count = params.len(),
250            "executing stored procedure"
251        );
252
253        let rpc_params = Self::convert_params(params)?;
254        let mut rpc = RpcRequest::named(proc_name);
255        for param in rpc_params {
256            rpc = rpc.param(param);
257        }
258
259        self.send_rpc(&rpc).await?;
260        self.read_procedure_result().await
261    }
262}
263
264impl Client<Ready> {
265    /// Mark this connection as needing a reset on next use.
266    ///
267    /// Called by the connection pool when a connection is returned.
268    /// The next SQL batch or RPC will include the RESETCONNECTION flag
269    /// in the TDS packet header, causing SQL Server to reset connection
270    /// state (temp tables, SET options, transaction isolation level, etc.)
271    /// before executing the command.
272    ///
273    /// This is more efficient than calling `sp_reset_connection` as a
274    /// separate command because it's handled at the TDS protocol level.
275    pub fn mark_needs_reset(&mut self) {
276        self.needs_reset = true;
277    }
278
279    /// Check if this connection needs a reset.
280    ///
281    /// Returns true if `mark_needs_reset()` was called and the reset
282    /// hasn't been performed yet.
283    #[must_use]
284    pub fn needs_reset(&self) -> bool {
285        self.needs_reset
286    }
287
288    /// Execute a query and return a streaming result set.
289    ///
290    /// Per ADR-007, results are streamed by default for memory efficiency.
291    /// Use `.collect_all()` on the stream if you need all rows in memory.
292    ///
293    /// # Example
294    ///
295    /// ```rust,ignore
296    /// use futures::StreamExt;
297    ///
298    /// // Streaming (memory-efficient)
299    /// let mut stream = client.query("SELECT * FROM users WHERE id = @p1", &[&1]).await?;
300    /// while let Some(row) = stream.next().await {
301    ///     let row = row?;
302    ///     process(&row);
303    /// }
304    ///
305    /// // Buffered (loads all into memory)
306    /// let rows: Vec<Row> = client
307    ///     .query("SELECT * FROM small_table", &[])
308    ///     .await?
309    ///     .collect_all()
310    ///     .await?;
311    /// ```
312    pub async fn query<'a>(
313        &'a mut self,
314        sql: &str,
315        params: &[&(dyn crate::ToSql + Sync)],
316    ) -> Result<QueryStream<'a>> {
317        tracing::debug!(sql = sql, params_count = params.len(), "executing query");
318
319        #[cfg(feature = "otel")]
320        let instrumentation = self.instrumentation.clone();
321        #[cfg(feature = "otel")]
322        let mut span = instrumentation.query_span(sql);
323
324        let result = async {
325            if params.is_empty() {
326                // Simple query without parameters - use SQL batch
327                self.send_sql_batch(sql).await?;
328            } else {
329                // Parameterized query - use sp_executesql via RPC
330                let rpc_params = Self::convert_params(params)?;
331                let rpc = RpcRequest::execute_sql(sql, rpc_params);
332                self.send_rpc(&rpc).await?;
333            }
334
335            // Read complete response including columns and rows
336            self.read_query_response().await
337        }
338        .await;
339
340        #[cfg(feature = "otel")]
341        match &result {
342            Ok(_) => InstrumentationContext::record_success(&mut span, None),
343            Err(e) => InstrumentationContext::record_error(&mut span, e),
344        }
345
346        // Drop the span before returning
347        #[cfg(feature = "otel")]
348        drop(span);
349
350        let (columns, rows) = result?;
351        Ok(QueryStream::new(columns, rows))
352    }
353
354    /// Execute a query with a specific timeout.
355    ///
356    /// This overrides the default `command_timeout` from the connection configuration
357    /// for this specific query. If the query does not complete within the specified
358    /// duration, an error is returned.
359    ///
360    /// # Arguments
361    ///
362    /// * `sql` - The SQL query to execute
363    /// * `params` - Query parameters
364    /// * `timeout_duration` - Maximum time to wait for the query to complete
365    ///
366    /// # Example
367    ///
368    /// ```rust,ignore
369    /// use std::time::Duration;
370    ///
371    /// // Execute with a 5-second timeout
372    /// let rows = client
373    ///     .query_with_timeout(
374    ///         "SELECT * FROM large_table",
375    ///         &[],
376    ///         Duration::from_secs(5),
377    ///     )
378    ///     .await?;
379    /// ```
380    pub async fn query_with_timeout<'a>(
381        &'a mut self,
382        sql: &str,
383        params: &[&(dyn crate::ToSql + Sync)],
384        timeout_duration: std::time::Duration,
385    ) -> Result<QueryStream<'a>> {
386        timeout(timeout_duration, self.query(sql, params))
387            .await
388            .map_err(|_| Error::CommandTimeout)?
389    }
390
391    /// Execute a batch that may return multiple result sets.
392    ///
393    /// This is useful for stored procedures or SQL batches that contain
394    /// multiple SELECT statements.
395    ///
396    /// # Example
397    ///
398    /// ```rust,ignore
399    /// // Execute a batch with multiple SELECTs
400    /// let mut results = client.query_multiple(
401    ///     "SELECT 1 AS a; SELECT 2 AS b, 3 AS c;",
402    ///     &[]
403    /// ).await?;
404    ///
405    /// // Process first result set
406    /// while let Some(row) = results.next_row().await? {
407    ///     println!("Result 1: {:?}", row);
408    /// }
409    ///
410    /// // Move to second result set
411    /// if results.next_result().await? {
412    ///     while let Some(row) = results.next_row().await? {
413    ///         println!("Result 2: {:?}", row);
414    ///     }
415    /// }
416    /// ```
417    pub async fn query_multiple<'a>(
418        &'a mut self,
419        sql: &str,
420        params: &[&(dyn crate::ToSql + Sync)],
421    ) -> Result<MultiResultStream<'a>> {
422        tracing::debug!(
423            sql = sql,
424            params_count = params.len(),
425            "executing multi-result query"
426        );
427
428        if params.is_empty() {
429            // Simple batch without parameters - use SQL batch
430            self.send_sql_batch(sql).await?;
431        } else {
432            // Parameterized query - use sp_executesql via RPC
433            let rpc_params = Self::convert_params(params)?;
434            let rpc = RpcRequest::execute_sql(sql, rpc_params);
435            self.send_rpc(&rpc).await?;
436        }
437
438        // Read all result sets
439        let result_sets = self.read_multi_result_response().await?;
440        Ok(MultiResultStream::new(result_sets))
441    }
442
443    /// Execute a query that doesn't return rows.
444    ///
445    /// Returns the number of affected rows.
446    pub async fn execute(
447        &mut self,
448        sql: &str,
449        params: &[&(dyn crate::ToSql + Sync)],
450    ) -> Result<u64> {
451        tracing::debug!(
452            sql = sql,
453            params_count = params.len(),
454            "executing statement"
455        );
456
457        #[cfg(feature = "otel")]
458        let instrumentation = self.instrumentation.clone();
459        #[cfg(feature = "otel")]
460        let mut span = instrumentation.query_span(sql);
461
462        let result = async {
463            if params.is_empty() {
464                // Simple statement without parameters - use SQL batch
465                self.send_sql_batch(sql).await?;
466            } else {
467                // Parameterized statement - use sp_executesql via RPC
468                let rpc_params = Self::convert_params(params)?;
469                let rpc = RpcRequest::execute_sql(sql, rpc_params);
470                self.send_rpc(&rpc).await?;
471            }
472
473            // Read response and get row count
474            self.read_execute_result().await
475        }
476        .await;
477
478        #[cfg(feature = "otel")]
479        match &result {
480            Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
481            Err(e) => InstrumentationContext::record_error(&mut span, e),
482        }
483
484        // Drop the span before returning
485        #[cfg(feature = "otel")]
486        drop(span);
487
488        result
489    }
490
491    /// Execute a statement with a specific timeout.
492    ///
493    /// This overrides the default `command_timeout` from the connection configuration
494    /// for this specific statement. If the statement does not complete within the
495    /// specified duration, an error is returned.
496    ///
497    /// # Arguments
498    ///
499    /// * `sql` - The SQL statement to execute
500    /// * `params` - Statement parameters
501    /// * `timeout_duration` - Maximum time to wait for the statement to complete
502    ///
503    /// # Example
504    ///
505    /// ```rust,ignore
506    /// use std::time::Duration;
507    ///
508    /// // Execute with a 10-second timeout
509    /// let rows_affected = client
510    ///     .execute_with_timeout(
511    ///         "UPDATE large_table SET status = @p1",
512    ///         &[&"processed"],
513    ///         Duration::from_secs(10),
514    ///     )
515    ///     .await?;
516    /// ```
517    pub async fn execute_with_timeout(
518        &mut self,
519        sql: &str,
520        params: &[&(dyn crate::ToSql + Sync)],
521        timeout_duration: std::time::Duration,
522    ) -> Result<u64> {
523        timeout(timeout_duration, self.execute(sql, params))
524            .await
525            .map_err(|_| Error::CommandTimeout)?
526    }
527
528    /// Begin a transaction.
529    ///
530    /// This transitions the client from `Ready` to `InTransaction` state.
531    /// Per MS-TDS spec, the server returns a transaction descriptor in the
532    /// BeginTransaction EnvChange token that must be included in subsequent
533    /// ALL_HEADERS sections.
534    pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
535        tracing::debug!("beginning transaction");
536
537        #[cfg(feature = "otel")]
538        let instrumentation = self.instrumentation.clone();
539        #[cfg(feature = "otel")]
540        let mut span = instrumentation.transaction_span("BEGIN");
541
542        // Execute BEGIN TRANSACTION and extract the transaction descriptor
543        let result = async {
544            self.send_sql_batch("BEGIN TRANSACTION").await?;
545            self.read_transaction_begin_result().await
546        }
547        .await;
548
549        #[cfg(feature = "otel")]
550        match &result {
551            Ok(_) => InstrumentationContext::record_success(&mut span, None),
552            Err(e) => InstrumentationContext::record_error(&mut span, e),
553        }
554
555        // Drop the span before moving instrumentation
556        #[cfg(feature = "otel")]
557        drop(span);
558
559        let transaction_descriptor = result?;
560
561        Ok(Client {
562            config: self.config,
563            _state: PhantomData,
564            connection: self.connection,
565            server_version: self.server_version,
566            current_database: self.current_database,
567            statement_cache: self.statement_cache,
568            transaction_descriptor, // Store the descriptor from server
569            needs_reset: self.needs_reset,
570            #[cfg(feature = "otel")]
571            instrumentation: self.instrumentation,
572            #[cfg(feature = "always-encrypted")]
573            encryption_context: self.encryption_context,
574        })
575    }
576
577    /// Begin a transaction with a specific isolation level.
578    ///
579    /// This transitions the client from `Ready` to `InTransaction` state
580    /// with the specified isolation level.
581    ///
582    /// # Example
583    ///
584    /// ```rust,ignore
585    /// use mssql_client::IsolationLevel;
586    ///
587    /// let tx = client.begin_transaction_with_isolation(IsolationLevel::Serializable).await?;
588    /// // All operations in this transaction use SERIALIZABLE isolation
589    /// tx.commit().await?;
590    /// ```
591    pub async fn begin_transaction_with_isolation(
592        mut self,
593        isolation_level: crate::transaction::IsolationLevel,
594    ) -> Result<Client<InTransaction>> {
595        tracing::debug!(
596            isolation_level = %isolation_level.name(),
597            "beginning transaction with isolation level"
598        );
599
600        #[cfg(feature = "otel")]
601        let instrumentation = self.instrumentation.clone();
602        #[cfg(feature = "otel")]
603        let mut span = instrumentation.transaction_span("BEGIN");
604
605        // First set the isolation level
606        let result = async {
607            self.send_sql_batch(isolation_level.as_sql()).await?;
608            self.read_execute_result().await?;
609
610            // Then begin the transaction
611            self.send_sql_batch("BEGIN TRANSACTION").await?;
612            self.read_transaction_begin_result().await
613        }
614        .await;
615
616        #[cfg(feature = "otel")]
617        match &result {
618            Ok(_) => InstrumentationContext::record_success(&mut span, None),
619            Err(e) => InstrumentationContext::record_error(&mut span, e),
620        }
621
622        #[cfg(feature = "otel")]
623        drop(span);
624
625        let transaction_descriptor = result?;
626
627        Ok(Client {
628            config: self.config,
629            _state: PhantomData,
630            connection: self.connection,
631            server_version: self.server_version,
632            current_database: self.current_database,
633            statement_cache: self.statement_cache,
634            transaction_descriptor,
635            needs_reset: self.needs_reset,
636            #[cfg(feature = "otel")]
637            instrumentation: self.instrumentation,
638            #[cfg(feature = "always-encrypted")]
639            encryption_context: self.encryption_context,
640        })
641    }
642
643    /// Execute a simple query without parameters.
644    ///
645    /// This is useful for DDL statements and simple queries where you
646    /// don't need to retrieve the affected row count.
647    pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
648        tracing::debug!(sql = sql, "executing simple query");
649
650        // Send SQL batch
651        self.send_sql_batch(sql).await?;
652
653        // Read and discard response
654        let _ = self.read_execute_result().await?;
655
656        Ok(())
657    }
658
659    /// Close the connection gracefully.
660    pub async fn close(self) -> Result<()> {
661        tracing::debug!("closing connection");
662        Ok(())
663    }
664
665    /// Get the current database name.
666    #[must_use]
667    pub fn database(&self) -> Option<&str> {
668        self.config.database.as_deref()
669    }
670
671    /// Get the server host.
672    #[must_use]
673    pub fn host(&self) -> &str {
674        &self.config.host
675    }
676
677    /// Get the server port.
678    #[must_use]
679    pub fn port(&self) -> u16 {
680        self.config.port
681    }
682
683    /// Check if the connection is currently in a transaction.
684    ///
685    /// This returns `true` if a transaction was started via raw SQL
686    /// (`BEGIN TRANSACTION`) and has not yet been committed or rolled back.
687    ///
688    /// Note: This only tracks transactions started via raw SQL. Transactions
689    /// started via the type-state API (`begin_transaction()`) result in a
690    /// `Client<InTransaction>` which is a different type.
691    ///
692    /// # Example
693    ///
694    /// ```rust,ignore
695    /// client.execute("BEGIN TRANSACTION", &[]).await?;
696    /// assert!(client.is_in_transaction());
697    ///
698    /// client.execute("COMMIT", &[]).await?;
699    /// assert!(!client.is_in_transaction());
700    /// ```
701    #[must_use]
702    pub fn is_in_transaction(&self) -> bool {
703        self.transaction_descriptor != 0
704    }
705
706    /// Get a handle for cancelling the current query.
707    ///
708    /// The cancel handle can be cloned and sent to other tasks, enabling
709    /// cancellation of long-running queries from a separate async context.
710    ///
711    /// # Example
712    ///
713    /// ```rust,ignore
714    /// use std::time::Duration;
715    ///
716    /// let cancel_handle = client.cancel_handle();
717    ///
718    /// // Spawn a task to cancel after 10 seconds
719    /// let handle = tokio::spawn(async move {
720    ///     tokio::time::sleep(Duration::from_secs(10)).await;
721    ///     let _ = cancel_handle.cancel().await;
722    /// });
723    ///
724    /// // This query will be cancelled if it runs longer than 10 seconds
725    /// let result = client.query("SELECT * FROM very_large_table", &[]).await;
726    /// ```
727    #[must_use]
728    pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
729        let connection = self
730            .connection
731            .as_ref()
732            .expect("connection should be present");
733        match connection {
734            #[cfg(feature = "tls")]
735            ConnectionHandle::Tls(conn) => {
736                crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
737            }
738            #[cfg(feature = "tls")]
739            ConnectionHandle::TlsPrelogin(conn) => {
740                crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
741            }
742            ConnectionHandle::Plain(conn) => {
743                crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
744            }
745        }
746    }
747}
748
749/// # Drop Behavior
750///
751/// **`Client<InTransaction>` has no automatic rollback on drop.** If the client is
752/// dropped without calling [`commit()`](Client::commit) or [`rollback()`](Client::rollback),
753/// the transaction remains open on the server until the TCP connection closes
754/// (at which point SQL Server automatically rolls back).
755///
756/// This is because `Drop` is synchronous and cannot perform the async I/O needed
757/// to send a `ROLLBACK TRANSACTION` command.
758///
759/// ## Consequences of dropping without commit/rollback
760///
761/// - **Direct connections:** The transaction leaks until the OS TCP timeout
762///   (potentially 30+ minutes), holding locks on any modified rows.
763/// - **Pooled connections:** The pool detects the active transaction descriptor
764///   and discards the connection rather than returning it to the idle pool
765///   (see `PooledConnection::drop` in `mssql-driver-pool`).
766///
767/// ## Best practice
768///
769/// Always ensure `commit()` or `rollback()` is called. Use helper patterns
770/// for error paths:
771///
772/// ```rust,ignore
773/// let tx = client.begin_transaction().await?;
774/// match do_work(&tx).await {
775///     Ok(_) => { tx.commit().await?; }
776///     Err(e) => { tx.rollback().await?; return Err(e); }
777/// }
778/// ```
779impl Client<InTransaction> {
780    /// Execute a query within the transaction and return a streaming result set.
781    ///
782    /// See [`Client<Ready>::query`] for usage examples.
783    pub async fn query<'a>(
784        &'a mut self,
785        sql: &str,
786        params: &[&(dyn crate::ToSql + Sync)],
787    ) -> Result<QueryStream<'a>> {
788        tracing::debug!(
789            sql = sql,
790            params_count = params.len(),
791            "executing query in transaction"
792        );
793
794        #[cfg(feature = "otel")]
795        let instrumentation = self.instrumentation.clone();
796        #[cfg(feature = "otel")]
797        let mut span = instrumentation.query_span(sql);
798
799        let result = async {
800            if params.is_empty() {
801                // Simple query without parameters - use SQL batch
802                self.send_sql_batch(sql).await?;
803            } else {
804                // Parameterized query - use sp_executesql via RPC
805                let rpc_params = Self::convert_params(params)?;
806                let rpc = RpcRequest::execute_sql(sql, rpc_params);
807                self.send_rpc(&rpc).await?;
808            }
809
810            // Read complete response including columns and rows
811            self.read_query_response().await
812        }
813        .await;
814
815        #[cfg(feature = "otel")]
816        match &result {
817            Ok(_) => InstrumentationContext::record_success(&mut span, None),
818            Err(e) => InstrumentationContext::record_error(&mut span, e),
819        }
820
821        // Drop the span before returning
822        #[cfg(feature = "otel")]
823        drop(span);
824
825        let (columns, rows) = result?;
826        Ok(QueryStream::new(columns, rows))
827    }
828
829    /// Execute a statement within the transaction.
830    ///
831    /// Returns the number of affected rows.
832    pub async fn execute(
833        &mut self,
834        sql: &str,
835        params: &[&(dyn crate::ToSql + Sync)],
836    ) -> Result<u64> {
837        tracing::debug!(
838            sql = sql,
839            params_count = params.len(),
840            "executing statement in transaction"
841        );
842
843        #[cfg(feature = "otel")]
844        let instrumentation = self.instrumentation.clone();
845        #[cfg(feature = "otel")]
846        let mut span = instrumentation.query_span(sql);
847
848        let result = async {
849            if params.is_empty() {
850                // Simple statement without parameters - use SQL batch
851                self.send_sql_batch(sql).await?;
852            } else {
853                // Parameterized statement - use sp_executesql via RPC
854                let rpc_params = Self::convert_params(params)?;
855                let rpc = RpcRequest::execute_sql(sql, rpc_params);
856                self.send_rpc(&rpc).await?;
857            }
858
859            // Read response and get row count
860            self.read_execute_result().await
861        }
862        .await;
863
864        #[cfg(feature = "otel")]
865        match &result {
866            Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
867            Err(e) => InstrumentationContext::record_error(&mut span, e),
868        }
869
870        // Drop the span before returning
871        #[cfg(feature = "otel")]
872        drop(span);
873
874        result
875    }
876
877    /// Execute a query within the transaction with a specific timeout.
878    ///
879    /// See [`Client<Ready>::query_with_timeout`] for details.
880    pub async fn query_with_timeout<'a>(
881        &'a mut self,
882        sql: &str,
883        params: &[&(dyn crate::ToSql + Sync)],
884        timeout_duration: std::time::Duration,
885    ) -> Result<QueryStream<'a>> {
886        timeout(timeout_duration, self.query(sql, params))
887            .await
888            .map_err(|_| Error::CommandTimeout)?
889    }
890
891    /// Execute a statement within the transaction with a specific timeout.
892    ///
893    /// See [`Client<Ready>::execute_with_timeout`] for details.
894    pub async fn execute_with_timeout(
895        &mut self,
896        sql: &str,
897        params: &[&(dyn crate::ToSql + Sync)],
898        timeout_duration: std::time::Duration,
899    ) -> Result<u64> {
900        timeout(timeout_duration, self.execute(sql, params))
901            .await
902            .map_err(|_| Error::CommandTimeout)?
903    }
904
905    /// Open a FILESTREAM BLOB for async reading and/or writing.
906    ///
907    /// This method queries the server for the transaction context, then opens
908    /// the FILESTREAM handle using the native Win32 `OpenSqlFilestream` API.
909    ///
910    /// # Arguments
911    ///
912    /// * `path` — The UNC path obtained from the T-SQL `column.PathName()` function.
913    ///   Query this yourself before calling `open_filestream`:
914    ///   ```sql
915    ///   SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1
916    ///   ```
917    /// * `access` — Read, write, or read/write access mode.
918    ///
919    /// # Requirements
920    ///
921    /// - SQL Server must have FILESTREAM enabled (`sp_configure 'filestream access level', 2`)
922    /// - The Microsoft OLE DB Driver for SQL Server must be installed on the client
923    /// - The `FileStream` must be dropped before calling [`commit`] or [`rollback`]
924    ///
925    /// # Example
926    ///
927    /// ```rust,ignore
928    /// use mssql_client::FileStreamAccess;
929    /// use tokio::io::AsyncReadExt;
930    ///
931    /// let mut tx = client.begin_transaction().await?;
932    ///
933    /// // Get the FILESTREAM path
934    /// let rows = tx.query(
935    ///     "SELECT Content.PathName() FROM dbo.Documents WHERE Id = @p1",
936    ///     &[&doc_id],
937    /// ).await?;
938    /// let path: String = rows.into_iter().next().unwrap()?.get(0)?;
939    ///
940    /// // Open and read the BLOB
941    /// let mut stream = tx.open_filestream(&path, FileStreamAccess::Read).await?;
942    /// let mut data = Vec::new();
943    /// stream.read_to_end(&mut data).await?;
944    /// drop(stream);
945    ///
946    /// tx.commit().await?;
947    /// ```
948    #[cfg(all(windows, feature = "filestream"))]
949    pub async fn open_filestream(
950        &mut self,
951        path: &str,
952        access: crate::filestream::FileStreamAccess,
953    ) -> Result<crate::filestream::FileStream> {
954        tracing::debug!(path = path, ?access, "opening FILESTREAM BLOB");
955
956        // Get the transaction context from SQL Server.
957        // This binds the file access to the current SQL transaction.
958        let txn_context: Vec<u8> = {
959            let rows = self
960                .query("SELECT GET_FILESTREAM_TRANSACTION_CONTEXT()", &[])
961                .await?;
962            let mut ctx = None;
963            for result in rows {
964                let row = result?;
965                ctx = Some(row.get::<Vec<u8>>(0)?);
966            }
967            ctx.ok_or_else(|| {
968                Error::FileStream("GET_FILESTREAM_TRANSACTION_CONTEXT() returned no rows".into())
969            })?
970        };
971
972        crate::filestream::FileStream::open(path, access, &txn_context)
973    }
974
975    /// Commit the transaction.
976    ///
977    /// This transitions the client back to `Ready` state.
978    pub async fn commit(mut self) -> Result<Client<Ready>> {
979        tracing::debug!("committing transaction");
980
981        #[cfg(feature = "otel")]
982        let instrumentation = self.instrumentation.clone();
983        #[cfg(feature = "otel")]
984        let mut span = instrumentation.transaction_span("COMMIT");
985
986        // Execute COMMIT TRANSACTION
987        let result = async {
988            self.send_sql_batch("COMMIT TRANSACTION").await?;
989            self.read_execute_result().await
990        }
991        .await;
992
993        #[cfg(feature = "otel")]
994        match &result {
995            Ok(_) => InstrumentationContext::record_success(&mut span, None),
996            Err(e) => InstrumentationContext::record_error(&mut span, e),
997        }
998
999        // Drop the span before moving instrumentation
1000        #[cfg(feature = "otel")]
1001        drop(span);
1002
1003        result?;
1004
1005        Ok(Client {
1006            config: self.config,
1007            _state: PhantomData,
1008            connection: self.connection,
1009            server_version: self.server_version,
1010            current_database: self.current_database,
1011            statement_cache: self.statement_cache,
1012            transaction_descriptor: 0, // Reset to auto-commit mode
1013            needs_reset: self.needs_reset,
1014            #[cfg(feature = "otel")]
1015            instrumentation: self.instrumentation,
1016            #[cfg(feature = "always-encrypted")]
1017            encryption_context: self.encryption_context,
1018        })
1019    }
1020
1021    /// Rollback the transaction.
1022    ///
1023    /// This transitions the client back to `Ready` state.
1024    pub async fn rollback(mut self) -> Result<Client<Ready>> {
1025        tracing::debug!("rolling back transaction");
1026
1027        #[cfg(feature = "otel")]
1028        let instrumentation = self.instrumentation.clone();
1029        #[cfg(feature = "otel")]
1030        let mut span = instrumentation.transaction_span("ROLLBACK");
1031
1032        // Execute ROLLBACK TRANSACTION
1033        let result = async {
1034            self.send_sql_batch("ROLLBACK TRANSACTION").await?;
1035            self.read_execute_result().await
1036        }
1037        .await;
1038
1039        #[cfg(feature = "otel")]
1040        match &result {
1041            Ok(_) => InstrumentationContext::record_success(&mut span, None),
1042            Err(e) => InstrumentationContext::record_error(&mut span, e),
1043        }
1044
1045        // Drop the span before moving instrumentation
1046        #[cfg(feature = "otel")]
1047        drop(span);
1048
1049        result?;
1050
1051        Ok(Client {
1052            config: self.config,
1053            _state: PhantomData,
1054            connection: self.connection,
1055            server_version: self.server_version,
1056            current_database: self.current_database,
1057            statement_cache: self.statement_cache,
1058            transaction_descriptor: 0, // Reset to auto-commit mode
1059            needs_reset: self.needs_reset,
1060            #[cfg(feature = "otel")]
1061            instrumentation: self.instrumentation,
1062            #[cfg(feature = "always-encrypted")]
1063            encryption_context: self.encryption_context,
1064        })
1065    }
1066
1067    /// Create a savepoint and return a handle for later rollback.
1068    ///
1069    /// The returned `SavePoint` handle contains the validated savepoint name.
1070    /// Use it with `rollback_to()` to partially undo transaction work.
1071    ///
1072    /// # Example
1073    ///
1074    /// ```rust,ignore
1075    /// let tx = client.begin_transaction().await?;
1076    /// tx.execute("INSERT INTO orders ...").await?;
1077    /// let sp = tx.save_point("before_items").await?;
1078    /// tx.execute("INSERT INTO items ...").await?;
1079    /// // Oops, rollback just the items
1080    /// tx.rollback_to(&sp).await?;
1081    /// tx.commit().await?;
1082    /// ```
1083    pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
1084        crate::validation::validate_identifier(name)?;
1085        tracing::debug!(name = name, "creating savepoint");
1086
1087        // Execute SAVE TRANSACTION <name>
1088        // Note: name is validated by validate_identifier() to prevent SQL injection
1089        let sql = format!("SAVE TRANSACTION {name}");
1090        self.send_sql_batch(&sql).await?;
1091        self.read_execute_result().await?;
1092
1093        Ok(SavePoint::new(name.to_string()))
1094    }
1095
1096    /// Rollback to a savepoint.
1097    ///
1098    /// This rolls back all changes made after the savepoint was created,
1099    /// but keeps the transaction active. The savepoint remains valid and
1100    /// can be rolled back to again.
1101    ///
1102    /// # Example
1103    ///
1104    /// ```rust,ignore
1105    /// let sp = tx.save_point("checkpoint").await?;
1106    /// // ... do some work ...
1107    /// tx.rollback_to(&sp).await?;  // Undo changes since checkpoint
1108    /// // Transaction is still active, savepoint is still valid
1109    /// ```
1110    pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
1111        tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
1112
1113        // Execute ROLLBACK TRANSACTION <name>
1114        // Note: savepoint name was validated during creation
1115        let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
1116        self.send_sql_batch(&sql).await?;
1117        self.read_execute_result().await?;
1118
1119        Ok(())
1120    }
1121
1122    /// Release a savepoint (optional cleanup).
1123    ///
1124    /// Note: SQL Server doesn't have explicit savepoint release, but this
1125    /// method is provided for API completeness. The savepoint is automatically
1126    /// released when the transaction commits or rolls back.
1127    pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
1128        tracing::debug!(name = savepoint.name(), "releasing savepoint");
1129
1130        // SQL Server doesn't require explicit savepoint release
1131        // The savepoint is implicitly released on commit/rollback
1132        // This method exists for API completeness
1133        drop(savepoint);
1134        Ok(())
1135    }
1136
1137    /// Get a handle for cancelling the current query within the transaction.
1138    ///
1139    /// See [`Client<Ready>::cancel_handle`] for usage examples.
1140    #[must_use]
1141    pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
1142        let connection = self
1143            .connection
1144            .as_ref()
1145            .expect("connection should be present");
1146        match connection {
1147            #[cfg(feature = "tls")]
1148            ConnectionHandle::Tls(conn) => {
1149                crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1150            }
1151            #[cfg(feature = "tls")]
1152            ConnectionHandle::TlsPrelogin(conn) => {
1153                crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1154            }
1155            ConnectionHandle::Plain(conn) => {
1156                crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1157            }
1158        }
1159    }
1160}
1161
1162impl<S: ConnectionState> std::fmt::Debug for Client<S> {
1163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1164        f.debug_struct("Client")
1165            .field("host", &self.config.host)
1166            .field("port", &self.config.port)
1167            .field("database", &self.config.database)
1168            .finish()
1169    }
1170}