Skip to main content

mssql_client/client/
mod.rs

1//! SQL Server client implementation.
2
3// Allow unwrap/expect for chrono date construction with known-valid constant dates
4// and for regex patterns that are compile-time constants
5#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7mod connect;
8mod params;
9mod response;
10
11use std::marker::PhantomData;
12
13use mssql_codec::connection::Connection;
14#[cfg(feature = "tls")]
15use mssql_tls::TlsStream;
16use tds_protocol::packet::PacketType;
17use tds_protocol::rpc::RpcRequest;
18use tds_protocol::token::{EnvChange, EnvChangeType};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21
22use crate::config::Config;
23use crate::error::{Error, Result};
24#[cfg(feature = "otel")]
25use crate::instrumentation::InstrumentationContext;
26use crate::state::{ConnectionState, InTransaction, Ready};
27use crate::statement_cache::StatementCache;
28use crate::stream::{MultiResultStream, QueryStream};
29use crate::transaction::SavePoint;
30
31/// SQL Server client with type-state connection management.
32///
33/// The generic parameter `S` represents the current connection state,
34/// ensuring at compile time that certain operations are only available
35/// in appropriate states.
36pub struct Client<S: ConnectionState> {
37    config: Config,
38    _state: PhantomData<S>,
39    /// The underlying connection (present only when connected)
40    connection: Option<ConnectionHandle>,
41    /// Server version from LoginAck (raw u32 TDS version)
42    server_version: Option<u32>,
43    /// Current database from EnvChange
44    current_database: Option<String>,
45    /// Prepared statement cache for query optimization
46    statement_cache: StatementCache,
47    /// Transaction descriptor from BeginTransaction EnvChange.
48    /// Per MS-TDS spec, this value must be included in ALL_HEADERS for subsequent
49    /// requests within an explicit transaction. 0 indicates auto-commit mode.
50    transaction_descriptor: u64,
51    /// Whether this connection needs a reset on next use.
52    /// Set by connection pool on checkin, cleared after first query/execute.
53    /// When true, the RESETCONNECTION flag is set on the first TDS packet.
54    needs_reset: bool,
55    /// OpenTelemetry instrumentation context (when otel feature is enabled)
56    #[cfg(feature = "otel")]
57    instrumentation: InstrumentationContext,
58}
59
60/// Internal connection handle wrapping the actual connection.
61///
62/// This is an enum to support different connection types:
63/// - TLS (TDS 8.0 strict mode) - requires `tls` feature
64/// - TLS with PreLogin wrapping (TDS 7.x style) - requires `tls` feature
65/// - Plain TCP (for internal networks or when `tls` feature is disabled)
66#[allow(dead_code)] // Connection will be used once query execution is implemented
67enum ConnectionHandle {
68    /// TLS connection (TDS 8.0 strict mode - TLS before any TDS traffic)
69    #[cfg(feature = "tls")]
70    Tls(Connection<TlsStream<TcpStream>>),
71    /// TLS connection with PreLogin wrapping (TDS 7.x style)
72    #[cfg(feature = "tls")]
73    TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
74    /// Plain TCP connection (for internal networks or when `tls` feature is disabled)
75    Plain(Connection<TcpStream>),
76}
77
78// Private helper methods available to all connection states
79impl<S: ConnectionState> Client<S> {
80    /// Process transaction-related EnvChange tokens.
81    ///
82    /// This handles BeginTransaction, CommitTransaction, and RollbackTransaction
83    /// EnvChange tokens, updating the transaction descriptor accordingly.
84    ///
85    /// This enables executing BEGIN TRANSACTION, COMMIT, and ROLLBACK via raw SQL
86    /// while still having the transaction descriptor tracked correctly.
87    fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
88        use tds_protocol::token::EnvChangeValue;
89
90        match env.env_type {
91            EnvChangeType::BeginTransaction => {
92                if let EnvChangeValue::Binary(ref data) = env.new_value {
93                    if data.len() >= 8 {
94                        let descriptor = u64::from_le_bytes([
95                            data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
96                        ]);
97                        tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
98                        *transaction_descriptor = descriptor;
99                    }
100                }
101            }
102            EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
103                tracing::debug!(
104                    env_type = ?env.env_type,
105                    "transaction ended via raw SQL"
106                );
107                *transaction_descriptor = 0;
108            }
109            _ => {}
110        }
111    }
112
113    /// Send a SQL batch to the server.
114    ///
115    /// Uses the client's current transaction descriptor in ALL_HEADERS.
116    /// Per MS-TDS spec, when in an explicit transaction, the descriptor
117    /// returned by BeginTransaction must be included.
118    ///
119    /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
120    /// is included in the first packet to reset connection state.
121    async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
122        let payload =
123            tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
124        let max_packet = self.config.packet_size as usize;
125
126        // Check if we need to reset the connection on this request
127        let reset = self.needs_reset;
128        if reset {
129            self.needs_reset = false; // Clear flag before sending
130            tracing::debug!("sending SQL batch with RESETCONNECTION flag");
131        }
132
133        let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
134
135        match connection {
136            #[cfg(feature = "tls")]
137            ConnectionHandle::Tls(conn) => {
138                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
139                    .await?;
140            }
141            #[cfg(feature = "tls")]
142            ConnectionHandle::TlsPrelogin(conn) => {
143                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
144                    .await?;
145            }
146            ConnectionHandle::Plain(conn) => {
147                conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
148                    .await?;
149            }
150        }
151
152        Ok(())
153    }
154
155    /// Send an RPC request to the server.
156    ///
157    /// Uses the client's current transaction descriptor in ALL_HEADERS.
158    ///
159    /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
160    /// is included in the first packet to reset connection state.
161    async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
162        let payload = rpc.encode_with_transaction(self.transaction_descriptor);
163        let max_packet = self.config.packet_size as usize;
164
165        // Check if we need to reset the connection on this request
166        let reset = self.needs_reset;
167        if reset {
168            self.needs_reset = false; // Clear flag before sending
169            tracing::debug!("sending RPC with RESETCONNECTION flag");
170        }
171
172        let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
173
174        match connection {
175            #[cfg(feature = "tls")]
176            ConnectionHandle::Tls(conn) => {
177                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
178                    .await?;
179            }
180            #[cfg(feature = "tls")]
181            ConnectionHandle::TlsPrelogin(conn) => {
182                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
183                    .await?;
184            }
185            ConnectionHandle::Plain(conn) => {
186                conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
187                    .await?;
188            }
189        }
190
191        Ok(())
192    }
193}
194
195impl Client<Ready> {
196    /// Mark this connection as needing a reset on next use.
197    ///
198    /// Called by the connection pool when a connection is returned.
199    /// The next SQL batch or RPC will include the RESETCONNECTION flag
200    /// in the TDS packet header, causing SQL Server to reset connection
201    /// state (temp tables, SET options, transaction isolation level, etc.)
202    /// before executing the command.
203    ///
204    /// This is more efficient than calling `sp_reset_connection` as a
205    /// separate command because it's handled at the TDS protocol level.
206    pub fn mark_needs_reset(&mut self) {
207        self.needs_reset = true;
208    }
209
210    /// Check if this connection needs a reset.
211    ///
212    /// Returns true if `mark_needs_reset()` was called and the reset
213    /// hasn't been performed yet.
214    #[must_use]
215    pub fn needs_reset(&self) -> bool {
216        self.needs_reset
217    }
218
219    /// Execute a query and return a streaming result set.
220    ///
221    /// Per ADR-007, results are streamed by default for memory efficiency.
222    /// Use `.collect_all()` on the stream if you need all rows in memory.
223    ///
224    /// # Example
225    ///
226    /// ```rust,ignore
227    /// use futures::StreamExt;
228    ///
229    /// // Streaming (memory-efficient)
230    /// let mut stream = client.query("SELECT * FROM users WHERE id = @p1", &[&1]).await?;
231    /// while let Some(row) = stream.next().await {
232    ///     let row = row?;
233    ///     process(&row);
234    /// }
235    ///
236    /// // Buffered (loads all into memory)
237    /// let rows: Vec<Row> = client
238    ///     .query("SELECT * FROM small_table", &[])
239    ///     .await?
240    ///     .collect_all()
241    ///     .await?;
242    /// ```
243    pub async fn query<'a>(
244        &'a mut self,
245        sql: &str,
246        params: &[&(dyn crate::ToSql + Sync)],
247    ) -> Result<QueryStream<'a>> {
248        tracing::debug!(sql = sql, params_count = params.len(), "executing query");
249
250        #[cfg(feature = "otel")]
251        let instrumentation = self.instrumentation.clone();
252        #[cfg(feature = "otel")]
253        let mut span = instrumentation.query_span(sql);
254
255        let result = async {
256            if params.is_empty() {
257                // Simple query without parameters - use SQL batch
258                self.send_sql_batch(sql).await?;
259            } else {
260                // Parameterized query - use sp_executesql via RPC
261                let rpc_params = Self::convert_params(params)?;
262                let rpc = RpcRequest::execute_sql(sql, rpc_params);
263                self.send_rpc(&rpc).await?;
264            }
265
266            // Read complete response including columns and rows
267            self.read_query_response().await
268        }
269        .await;
270
271        #[cfg(feature = "otel")]
272        match &result {
273            Ok(_) => InstrumentationContext::record_success(&mut span, None),
274            Err(e) => InstrumentationContext::record_error(&mut span, e),
275        }
276
277        // Drop the span before returning
278        #[cfg(feature = "otel")]
279        drop(span);
280
281        let (columns, rows) = result?;
282        Ok(QueryStream::new(columns, rows))
283    }
284
285    /// Execute a query with a specific timeout.
286    ///
287    /// This overrides the default `command_timeout` from the connection configuration
288    /// for this specific query. If the query does not complete within the specified
289    /// duration, an error is returned.
290    ///
291    /// # Arguments
292    ///
293    /// * `sql` - The SQL query to execute
294    /// * `params` - Query parameters
295    /// * `timeout_duration` - Maximum time to wait for the query to complete
296    ///
297    /// # Example
298    ///
299    /// ```rust,ignore
300    /// use std::time::Duration;
301    ///
302    /// // Execute with a 5-second timeout
303    /// let rows = client
304    ///     .query_with_timeout(
305    ///         "SELECT * FROM large_table",
306    ///         &[],
307    ///         Duration::from_secs(5),
308    ///     )
309    ///     .await?;
310    /// ```
311    pub async fn query_with_timeout<'a>(
312        &'a mut self,
313        sql: &str,
314        params: &[&(dyn crate::ToSql + Sync)],
315        timeout_duration: std::time::Duration,
316    ) -> Result<QueryStream<'a>> {
317        timeout(timeout_duration, self.query(sql, params))
318            .await
319            .map_err(|_| Error::CommandTimeout)?
320    }
321
322    /// Execute a batch that may return multiple result sets.
323    ///
324    /// This is useful for stored procedures or SQL batches that contain
325    /// multiple SELECT statements.
326    ///
327    /// # Example
328    ///
329    /// ```rust,ignore
330    /// // Execute a batch with multiple SELECTs
331    /// let mut results = client.query_multiple(
332    ///     "SELECT 1 AS a; SELECT 2 AS b, 3 AS c;",
333    ///     &[]
334    /// ).await?;
335    ///
336    /// // Process first result set
337    /// while let Some(row) = results.next_row().await? {
338    ///     println!("Result 1: {:?}", row);
339    /// }
340    ///
341    /// // Move to second result set
342    /// if results.next_result().await? {
343    ///     while let Some(row) = results.next_row().await? {
344    ///         println!("Result 2: {:?}", row);
345    ///     }
346    /// }
347    /// ```
348    pub async fn query_multiple<'a>(
349        &'a mut self,
350        sql: &str,
351        params: &[&(dyn crate::ToSql + Sync)],
352    ) -> Result<MultiResultStream<'a>> {
353        tracing::debug!(
354            sql = sql,
355            params_count = params.len(),
356            "executing multi-result query"
357        );
358
359        if params.is_empty() {
360            // Simple batch without parameters - use SQL batch
361            self.send_sql_batch(sql).await?;
362        } else {
363            // Parameterized query - use sp_executesql via RPC
364            let rpc_params = Self::convert_params(params)?;
365            let rpc = RpcRequest::execute_sql(sql, rpc_params);
366            self.send_rpc(&rpc).await?;
367        }
368
369        // Read all result sets
370        let result_sets = self.read_multi_result_response().await?;
371        Ok(MultiResultStream::new(result_sets))
372    }
373
374    /// Execute a query that doesn't return rows.
375    ///
376    /// Returns the number of affected rows.
377    pub async fn execute(
378        &mut self,
379        sql: &str,
380        params: &[&(dyn crate::ToSql + Sync)],
381    ) -> Result<u64> {
382        tracing::debug!(
383            sql = sql,
384            params_count = params.len(),
385            "executing statement"
386        );
387
388        #[cfg(feature = "otel")]
389        let instrumentation = self.instrumentation.clone();
390        #[cfg(feature = "otel")]
391        let mut span = instrumentation.query_span(sql);
392
393        let result = async {
394            if params.is_empty() {
395                // Simple statement without parameters - use SQL batch
396                self.send_sql_batch(sql).await?;
397            } else {
398                // Parameterized statement - use sp_executesql via RPC
399                let rpc_params = Self::convert_params(params)?;
400                let rpc = RpcRequest::execute_sql(sql, rpc_params);
401                self.send_rpc(&rpc).await?;
402            }
403
404            // Read response and get row count
405            self.read_execute_result().await
406        }
407        .await;
408
409        #[cfg(feature = "otel")]
410        match &result {
411            Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
412            Err(e) => InstrumentationContext::record_error(&mut span, e),
413        }
414
415        // Drop the span before returning
416        #[cfg(feature = "otel")]
417        drop(span);
418
419        result
420    }
421
422    /// Execute a statement with a specific timeout.
423    ///
424    /// This overrides the default `command_timeout` from the connection configuration
425    /// for this specific statement. If the statement does not complete within the
426    /// specified duration, an error is returned.
427    ///
428    /// # Arguments
429    ///
430    /// * `sql` - The SQL statement to execute
431    /// * `params` - Statement parameters
432    /// * `timeout_duration` - Maximum time to wait for the statement to complete
433    ///
434    /// # Example
435    ///
436    /// ```rust,ignore
437    /// use std::time::Duration;
438    ///
439    /// // Execute with a 10-second timeout
440    /// let rows_affected = client
441    ///     .execute_with_timeout(
442    ///         "UPDATE large_table SET status = @p1",
443    ///         &[&"processed"],
444    ///         Duration::from_secs(10),
445    ///     )
446    ///     .await?;
447    /// ```
448    pub async fn execute_with_timeout(
449        &mut self,
450        sql: &str,
451        params: &[&(dyn crate::ToSql + Sync)],
452        timeout_duration: std::time::Duration,
453    ) -> Result<u64> {
454        timeout(timeout_duration, self.execute(sql, params))
455            .await
456            .map_err(|_| Error::CommandTimeout)?
457    }
458
459    /// Begin a transaction.
460    ///
461    /// This transitions the client from `Ready` to `InTransaction` state.
462    /// Per MS-TDS spec, the server returns a transaction descriptor in the
463    /// BeginTransaction EnvChange token that must be included in subsequent
464    /// ALL_HEADERS sections.
465    pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
466        tracing::debug!("beginning transaction");
467
468        #[cfg(feature = "otel")]
469        let instrumentation = self.instrumentation.clone();
470        #[cfg(feature = "otel")]
471        let mut span = instrumentation.transaction_span("BEGIN");
472
473        // Execute BEGIN TRANSACTION and extract the transaction descriptor
474        let result = async {
475            self.send_sql_batch("BEGIN TRANSACTION").await?;
476            self.read_transaction_begin_result().await
477        }
478        .await;
479
480        #[cfg(feature = "otel")]
481        match &result {
482            Ok(_) => InstrumentationContext::record_success(&mut span, None),
483            Err(e) => InstrumentationContext::record_error(&mut span, e),
484        }
485
486        // Drop the span before moving instrumentation
487        #[cfg(feature = "otel")]
488        drop(span);
489
490        let transaction_descriptor = result?;
491
492        Ok(Client {
493            config: self.config,
494            _state: PhantomData,
495            connection: self.connection,
496            server_version: self.server_version,
497            current_database: self.current_database,
498            statement_cache: self.statement_cache,
499            transaction_descriptor, // Store the descriptor from server
500            needs_reset: self.needs_reset,
501            #[cfg(feature = "otel")]
502            instrumentation: self.instrumentation,
503        })
504    }
505
506    /// Begin a transaction with a specific isolation level.
507    ///
508    /// This transitions the client from `Ready` to `InTransaction` state
509    /// with the specified isolation level.
510    ///
511    /// # Example
512    ///
513    /// ```rust,ignore
514    /// use mssql_client::IsolationLevel;
515    ///
516    /// let tx = client.begin_transaction_with_isolation(IsolationLevel::Serializable).await?;
517    /// // All operations in this transaction use SERIALIZABLE isolation
518    /// tx.commit().await?;
519    /// ```
520    pub async fn begin_transaction_with_isolation(
521        mut self,
522        isolation_level: crate::transaction::IsolationLevel,
523    ) -> Result<Client<InTransaction>> {
524        tracing::debug!(
525            isolation_level = %isolation_level.name(),
526            "beginning transaction with isolation level"
527        );
528
529        #[cfg(feature = "otel")]
530        let instrumentation = self.instrumentation.clone();
531        #[cfg(feature = "otel")]
532        let mut span = instrumentation.transaction_span("BEGIN");
533
534        // First set the isolation level
535        let result = async {
536            self.send_sql_batch(isolation_level.as_sql()).await?;
537            self.read_execute_result().await?;
538
539            // Then begin the transaction
540            self.send_sql_batch("BEGIN TRANSACTION").await?;
541            self.read_transaction_begin_result().await
542        }
543        .await;
544
545        #[cfg(feature = "otel")]
546        match &result {
547            Ok(_) => InstrumentationContext::record_success(&mut span, None),
548            Err(e) => InstrumentationContext::record_error(&mut span, e),
549        }
550
551        #[cfg(feature = "otel")]
552        drop(span);
553
554        let transaction_descriptor = result?;
555
556        Ok(Client {
557            config: self.config,
558            _state: PhantomData,
559            connection: self.connection,
560            server_version: self.server_version,
561            current_database: self.current_database,
562            statement_cache: self.statement_cache,
563            transaction_descriptor,
564            needs_reset: self.needs_reset,
565            #[cfg(feature = "otel")]
566            instrumentation: self.instrumentation,
567        })
568    }
569
570    /// Execute a simple query without parameters.
571    ///
572    /// This is useful for DDL statements and simple queries where you
573    /// don't need to retrieve the affected row count.
574    pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
575        tracing::debug!(sql = sql, "executing simple query");
576
577        // Send SQL batch
578        self.send_sql_batch(sql).await?;
579
580        // Read and discard response
581        let _ = self.read_execute_result().await?;
582
583        Ok(())
584    }
585
586    /// Close the connection gracefully.
587    pub async fn close(self) -> Result<()> {
588        tracing::debug!("closing connection");
589        Ok(())
590    }
591
592    /// Get the current database name.
593    #[must_use]
594    pub fn database(&self) -> Option<&str> {
595        self.config.database.as_deref()
596    }
597
598    /// Get the server host.
599    #[must_use]
600    pub fn host(&self) -> &str {
601        &self.config.host
602    }
603
604    /// Get the server port.
605    #[must_use]
606    pub fn port(&self) -> u16 {
607        self.config.port
608    }
609
610    /// Check if the connection is currently in a transaction.
611    ///
612    /// This returns `true` if a transaction was started via raw SQL
613    /// (`BEGIN TRANSACTION`) and has not yet been committed or rolled back.
614    ///
615    /// Note: This only tracks transactions started via raw SQL. Transactions
616    /// started via the type-state API (`begin_transaction()`) result in a
617    /// `Client<InTransaction>` which is a different type.
618    ///
619    /// # Example
620    ///
621    /// ```rust,ignore
622    /// client.execute("BEGIN TRANSACTION", &[]).await?;
623    /// assert!(client.is_in_transaction());
624    ///
625    /// client.execute("COMMIT", &[]).await?;
626    /// assert!(!client.is_in_transaction());
627    /// ```
628    #[must_use]
629    pub fn is_in_transaction(&self) -> bool {
630        self.transaction_descriptor != 0
631    }
632
633    /// Get a handle for cancelling the current query.
634    ///
635    /// The cancel handle can be cloned and sent to other tasks, enabling
636    /// cancellation of long-running queries from a separate async context.
637    ///
638    /// # Example
639    ///
640    /// ```rust,ignore
641    /// use std::time::Duration;
642    ///
643    /// let cancel_handle = client.cancel_handle();
644    ///
645    /// // Spawn a task to cancel after 10 seconds
646    /// let handle = tokio::spawn(async move {
647    ///     tokio::time::sleep(Duration::from_secs(10)).await;
648    ///     let _ = cancel_handle.cancel().await;
649    /// });
650    ///
651    /// // This query will be cancelled if it runs longer than 10 seconds
652    /// let result = client.query("SELECT * FROM very_large_table", &[]).await;
653    /// ```
654    #[must_use]
655    pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
656        let connection = self
657            .connection
658            .as_ref()
659            .expect("connection should be present");
660        match connection {
661            #[cfg(feature = "tls")]
662            ConnectionHandle::Tls(conn) => {
663                crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
664            }
665            #[cfg(feature = "tls")]
666            ConnectionHandle::TlsPrelogin(conn) => {
667                crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
668            }
669            ConnectionHandle::Plain(conn) => {
670                crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
671            }
672        }
673    }
674}
675
676/// # Drop Behavior
677///
678/// **`Client<InTransaction>` has no automatic rollback on drop.** If the client is
679/// dropped without calling [`commit()`](Client::commit) or [`rollback()`](Client::rollback),
680/// the transaction remains open on the server until the TCP connection closes
681/// (at which point SQL Server automatically rolls back).
682///
683/// This is because `Drop` is synchronous and cannot perform the async I/O needed
684/// to send a `ROLLBACK TRANSACTION` command.
685///
686/// ## Consequences of dropping without commit/rollback
687///
688/// - **Direct connections:** The transaction leaks until the OS TCP timeout
689///   (potentially 30+ minutes), holding locks on any modified rows.
690/// - **Pooled connections:** The pool detects the active transaction descriptor
691///   and discards the connection rather than returning it to the idle pool
692///   (see `PooledConnection::drop` in `mssql-driver-pool`).
693///
694/// ## Best practice
695///
696/// Always ensure `commit()` or `rollback()` is called. Use helper patterns
697/// for error paths:
698///
699/// ```rust,ignore
700/// let tx = client.begin_transaction().await?;
701/// match do_work(&tx).await {
702///     Ok(_) => { tx.commit().await?; }
703///     Err(e) => { tx.rollback().await?; return Err(e); }
704/// }
705/// ```
706impl Client<InTransaction> {
707    /// Execute a query within the transaction and return a streaming result set.
708    ///
709    /// See [`Client<Ready>::query`] for usage examples.
710    pub async fn query<'a>(
711        &'a mut self,
712        sql: &str,
713        params: &[&(dyn crate::ToSql + Sync)],
714    ) -> Result<QueryStream<'a>> {
715        tracing::debug!(
716            sql = sql,
717            params_count = params.len(),
718            "executing query in transaction"
719        );
720
721        #[cfg(feature = "otel")]
722        let instrumentation = self.instrumentation.clone();
723        #[cfg(feature = "otel")]
724        let mut span = instrumentation.query_span(sql);
725
726        let result = async {
727            if params.is_empty() {
728                // Simple query without parameters - use SQL batch
729                self.send_sql_batch(sql).await?;
730            } else {
731                // Parameterized query - use sp_executesql via RPC
732                let rpc_params = Self::convert_params(params)?;
733                let rpc = RpcRequest::execute_sql(sql, rpc_params);
734                self.send_rpc(&rpc).await?;
735            }
736
737            // Read complete response including columns and rows
738            self.read_query_response().await
739        }
740        .await;
741
742        #[cfg(feature = "otel")]
743        match &result {
744            Ok(_) => InstrumentationContext::record_success(&mut span, None),
745            Err(e) => InstrumentationContext::record_error(&mut span, e),
746        }
747
748        // Drop the span before returning
749        #[cfg(feature = "otel")]
750        drop(span);
751
752        let (columns, rows) = result?;
753        Ok(QueryStream::new(columns, rows))
754    }
755
756    /// Execute a statement within the transaction.
757    ///
758    /// Returns the number of affected rows.
759    pub async fn execute(
760        &mut self,
761        sql: &str,
762        params: &[&(dyn crate::ToSql + Sync)],
763    ) -> Result<u64> {
764        tracing::debug!(
765            sql = sql,
766            params_count = params.len(),
767            "executing statement in transaction"
768        );
769
770        #[cfg(feature = "otel")]
771        let instrumentation = self.instrumentation.clone();
772        #[cfg(feature = "otel")]
773        let mut span = instrumentation.query_span(sql);
774
775        let result = async {
776            if params.is_empty() {
777                // Simple statement without parameters - use SQL batch
778                self.send_sql_batch(sql).await?;
779            } else {
780                // Parameterized statement - use sp_executesql via RPC
781                let rpc_params = Self::convert_params(params)?;
782                let rpc = RpcRequest::execute_sql(sql, rpc_params);
783                self.send_rpc(&rpc).await?;
784            }
785
786            // Read response and get row count
787            self.read_execute_result().await
788        }
789        .await;
790
791        #[cfg(feature = "otel")]
792        match &result {
793            Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
794            Err(e) => InstrumentationContext::record_error(&mut span, e),
795        }
796
797        // Drop the span before returning
798        #[cfg(feature = "otel")]
799        drop(span);
800
801        result
802    }
803
804    /// Execute a query within the transaction with a specific timeout.
805    ///
806    /// See [`Client<Ready>::query_with_timeout`] for details.
807    pub async fn query_with_timeout<'a>(
808        &'a mut self,
809        sql: &str,
810        params: &[&(dyn crate::ToSql + Sync)],
811        timeout_duration: std::time::Duration,
812    ) -> Result<QueryStream<'a>> {
813        timeout(timeout_duration, self.query(sql, params))
814            .await
815            .map_err(|_| Error::CommandTimeout)?
816    }
817
818    /// Execute a statement within the transaction with a specific timeout.
819    ///
820    /// See [`Client<Ready>::execute_with_timeout`] for details.
821    pub async fn execute_with_timeout(
822        &mut self,
823        sql: &str,
824        params: &[&(dyn crate::ToSql + Sync)],
825        timeout_duration: std::time::Duration,
826    ) -> Result<u64> {
827        timeout(timeout_duration, self.execute(sql, params))
828            .await
829            .map_err(|_| Error::CommandTimeout)?
830    }
831
832    /// Commit the transaction.
833    ///
834    /// This transitions the client back to `Ready` state.
835    pub async fn commit(mut self) -> Result<Client<Ready>> {
836        tracing::debug!("committing transaction");
837
838        #[cfg(feature = "otel")]
839        let instrumentation = self.instrumentation.clone();
840        #[cfg(feature = "otel")]
841        let mut span = instrumentation.transaction_span("COMMIT");
842
843        // Execute COMMIT TRANSACTION
844        let result = async {
845            self.send_sql_batch("COMMIT TRANSACTION").await?;
846            self.read_execute_result().await
847        }
848        .await;
849
850        #[cfg(feature = "otel")]
851        match &result {
852            Ok(_) => InstrumentationContext::record_success(&mut span, None),
853            Err(e) => InstrumentationContext::record_error(&mut span, e),
854        }
855
856        // Drop the span before moving instrumentation
857        #[cfg(feature = "otel")]
858        drop(span);
859
860        result?;
861
862        Ok(Client {
863            config: self.config,
864            _state: PhantomData,
865            connection: self.connection,
866            server_version: self.server_version,
867            current_database: self.current_database,
868            statement_cache: self.statement_cache,
869            transaction_descriptor: 0, // Reset to auto-commit mode
870            needs_reset: self.needs_reset,
871            #[cfg(feature = "otel")]
872            instrumentation: self.instrumentation,
873        })
874    }
875
876    /// Rollback the transaction.
877    ///
878    /// This transitions the client back to `Ready` state.
879    pub async fn rollback(mut self) -> Result<Client<Ready>> {
880        tracing::debug!("rolling back transaction");
881
882        #[cfg(feature = "otel")]
883        let instrumentation = self.instrumentation.clone();
884        #[cfg(feature = "otel")]
885        let mut span = instrumentation.transaction_span("ROLLBACK");
886
887        // Execute ROLLBACK TRANSACTION
888        let result = async {
889            self.send_sql_batch("ROLLBACK TRANSACTION").await?;
890            self.read_execute_result().await
891        }
892        .await;
893
894        #[cfg(feature = "otel")]
895        match &result {
896            Ok(_) => InstrumentationContext::record_success(&mut span, None),
897            Err(e) => InstrumentationContext::record_error(&mut span, e),
898        }
899
900        // Drop the span before moving instrumentation
901        #[cfg(feature = "otel")]
902        drop(span);
903
904        result?;
905
906        Ok(Client {
907            config: self.config,
908            _state: PhantomData,
909            connection: self.connection,
910            server_version: self.server_version,
911            current_database: self.current_database,
912            statement_cache: self.statement_cache,
913            transaction_descriptor: 0, // Reset to auto-commit mode
914            needs_reset: self.needs_reset,
915            #[cfg(feature = "otel")]
916            instrumentation: self.instrumentation,
917        })
918    }
919
920    /// Create a savepoint and return a handle for later rollback.
921    ///
922    /// The returned `SavePoint` handle contains the validated savepoint name.
923    /// Use it with `rollback_to()` to partially undo transaction work.
924    ///
925    /// # Example
926    ///
927    /// ```rust,ignore
928    /// let tx = client.begin_transaction().await?;
929    /// tx.execute("INSERT INTO orders ...").await?;
930    /// let sp = tx.save_point("before_items").await?;
931    /// tx.execute("INSERT INTO items ...").await?;
932    /// // Oops, rollback just the items
933    /// tx.rollback_to(&sp).await?;
934    /// tx.commit().await?;
935    /// ```
936    pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
937        validate_identifier(name)?;
938        tracing::debug!(name = name, "creating savepoint");
939
940        // Execute SAVE TRANSACTION <name>
941        // Note: name is validated by validate_identifier() to prevent SQL injection
942        let sql = format!("SAVE TRANSACTION {name}");
943        self.send_sql_batch(&sql).await?;
944        self.read_execute_result().await?;
945
946        Ok(SavePoint::new(name.to_string()))
947    }
948
949    /// Rollback to a savepoint.
950    ///
951    /// This rolls back all changes made after the savepoint was created,
952    /// but keeps the transaction active. The savepoint remains valid and
953    /// can be rolled back to again.
954    ///
955    /// # Example
956    ///
957    /// ```rust,ignore
958    /// let sp = tx.save_point("checkpoint").await?;
959    /// // ... do some work ...
960    /// tx.rollback_to(&sp).await?;  // Undo changes since checkpoint
961    /// // Transaction is still active, savepoint is still valid
962    /// ```
963    pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
964        tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
965
966        // Execute ROLLBACK TRANSACTION <name>
967        // Note: savepoint name was validated during creation
968        let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
969        self.send_sql_batch(&sql).await?;
970        self.read_execute_result().await?;
971
972        Ok(())
973    }
974
975    /// Release a savepoint (optional cleanup).
976    ///
977    /// Note: SQL Server doesn't have explicit savepoint release, but this
978    /// method is provided for API completeness. The savepoint is automatically
979    /// released when the transaction commits or rolls back.
980    pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
981        tracing::debug!(name = savepoint.name(), "releasing savepoint");
982
983        // SQL Server doesn't require explicit savepoint release
984        // The savepoint is implicitly released on commit/rollback
985        // This method exists for API completeness
986        drop(savepoint);
987        Ok(())
988    }
989
990    /// Get a handle for cancelling the current query within the transaction.
991    ///
992    /// See [`Client<Ready>::cancel_handle`] for usage examples.
993    #[must_use]
994    pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
995        let connection = self
996            .connection
997            .as_ref()
998            .expect("connection should be present");
999        match connection {
1000            #[cfg(feature = "tls")]
1001            ConnectionHandle::Tls(conn) => {
1002                crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1003            }
1004            #[cfg(feature = "tls")]
1005            ConnectionHandle::TlsPrelogin(conn) => {
1006                crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1007            }
1008            ConnectionHandle::Plain(conn) => {
1009                crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1010            }
1011        }
1012    }
1013}
1014
1015/// Validate an identifier (table name, savepoint name, etc.) to prevent SQL injection.
1016fn validate_identifier(name: &str) -> Result<()> {
1017    use once_cell::sync::Lazy;
1018    use regex::Regex;
1019
1020    static IDENTIFIER_RE: Lazy<Regex> =
1021        Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
1022
1023    if name.is_empty() {
1024        return Err(Error::InvalidIdentifier(
1025            "identifier cannot be empty".into(),
1026        ));
1027    }
1028
1029    if !IDENTIFIER_RE.is_match(name) {
1030        return Err(Error::InvalidIdentifier(format!(
1031            "invalid identifier '{name}': must start with letter/underscore, \
1032             contain only alphanumerics/_/@/#/$, and be 1-128 characters"
1033        )));
1034    }
1035
1036    Ok(())
1037}
1038
1039impl<S: ConnectionState> std::fmt::Debug for Client<S> {
1040    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041        f.debug_struct("Client")
1042            .field("host", &self.config.host)
1043            .field("port", &self.config.port)
1044            .field("database", &self.config.database)
1045            .finish()
1046    }
1047}
1048
1049#[cfg(test)]
1050#[allow(clippy::unwrap_used, clippy::panic)]
1051mod tests {
1052    use super::*;
1053
1054    #[test]
1055    fn test_validate_identifier_valid() {
1056        assert!(validate_identifier("my_table").is_ok());
1057        assert!(validate_identifier("Table123").is_ok());
1058        assert!(validate_identifier("_private").is_ok());
1059        assert!(validate_identifier("sp_test").is_ok());
1060    }
1061
1062    #[test]
1063    fn test_validate_identifier_invalid() {
1064        assert!(validate_identifier("").is_err());
1065        assert!(validate_identifier("123abc").is_err());
1066        assert!(validate_identifier("table-name").is_err());
1067        assert!(validate_identifier("table name").is_err());
1068        assert!(validate_identifier("table;DROP TABLE users").is_err());
1069    }
1070}