mssql_client/client/mod.rs
1//! SQL Server client implementation.
2
3// Allow unwrap/expect for chrono date construction with known-valid constant dates
4// and for regex patterns that are compile-time constants
5#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7mod connect;
8mod params;
9mod response;
10
11use std::marker::PhantomData;
12
13use mssql_codec::connection::Connection;
14#[cfg(feature = "tls")]
15use mssql_tls::TlsStream;
16use tds_protocol::packet::PacketType;
17use tds_protocol::rpc::RpcRequest;
18use tds_protocol::token::{EnvChange, EnvChangeType};
19use tokio::net::TcpStream;
20use tokio::time::timeout;
21
22use crate::config::Config;
23use crate::error::{Error, Result};
24#[cfg(feature = "otel")]
25use crate::instrumentation::InstrumentationContext;
26use crate::state::{ConnectionState, InTransaction, Ready};
27use crate::statement_cache::StatementCache;
28use crate::stream::{MultiResultStream, QueryStream};
29use crate::transaction::SavePoint;
30
31/// SQL Server client with type-state connection management.
32///
33/// The generic parameter `S` represents the current connection state,
34/// ensuring at compile time that certain operations are only available
35/// in appropriate states.
36pub struct Client<S: ConnectionState> {
37 config: Config,
38 _state: PhantomData<S>,
39 /// The underlying connection (present only when connected)
40 connection: Option<ConnectionHandle>,
41 /// Server version from LoginAck (raw u32 TDS version)
42 server_version: Option<u32>,
43 /// Current database from EnvChange
44 current_database: Option<String>,
45 /// Prepared statement cache for query optimization
46 statement_cache: StatementCache,
47 /// Transaction descriptor from BeginTransaction EnvChange.
48 /// Per MS-TDS spec, this value must be included in ALL_HEADERS for subsequent
49 /// requests within an explicit transaction. 0 indicates auto-commit mode.
50 transaction_descriptor: u64,
51 /// Whether this connection needs a reset on next use.
52 /// Set by connection pool on checkin, cleared after first query/execute.
53 /// When true, the RESETCONNECTION flag is set on the first TDS packet.
54 needs_reset: bool,
55 /// OpenTelemetry instrumentation context (when otel feature is enabled)
56 #[cfg(feature = "otel")]
57 instrumentation: InstrumentationContext,
58}
59
60/// Internal connection handle wrapping the actual connection.
61///
62/// This is an enum to support different connection types:
63/// - TLS (TDS 8.0 strict mode) - requires `tls` feature
64/// - TLS with PreLogin wrapping (TDS 7.x style) - requires `tls` feature
65/// - Plain TCP (for internal networks or when `tls` feature is disabled)
66#[allow(dead_code)] // Connection will be used once query execution is implemented
67enum ConnectionHandle {
68 /// TLS connection (TDS 8.0 strict mode - TLS before any TDS traffic)
69 #[cfg(feature = "tls")]
70 Tls(Connection<TlsStream<TcpStream>>),
71 /// TLS connection with PreLogin wrapping (TDS 7.x style)
72 #[cfg(feature = "tls")]
73 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
74 /// Plain TCP connection (for internal networks or when `tls` feature is disabled)
75 Plain(Connection<TcpStream>),
76}
77
78// Private helper methods available to all connection states
79impl<S: ConnectionState> Client<S> {
80 /// Process transaction-related EnvChange tokens.
81 ///
82 /// This handles BeginTransaction, CommitTransaction, and RollbackTransaction
83 /// EnvChange tokens, updating the transaction descriptor accordingly.
84 ///
85 /// This enables executing BEGIN TRANSACTION, COMMIT, and ROLLBACK via raw SQL
86 /// while still having the transaction descriptor tracked correctly.
87 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
88 use tds_protocol::token::EnvChangeValue;
89
90 match env.env_type {
91 EnvChangeType::BeginTransaction => {
92 if let EnvChangeValue::Binary(ref data) = env.new_value {
93 if data.len() >= 8 {
94 let descriptor = u64::from_le_bytes([
95 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
96 ]);
97 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
98 *transaction_descriptor = descriptor;
99 }
100 }
101 }
102 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
103 tracing::debug!(
104 env_type = ?env.env_type,
105 "transaction ended via raw SQL"
106 );
107 *transaction_descriptor = 0;
108 }
109 _ => {}
110 }
111 }
112
113 /// Send a SQL batch to the server.
114 ///
115 /// Uses the client's current transaction descriptor in ALL_HEADERS.
116 /// Per MS-TDS spec, when in an explicit transaction, the descriptor
117 /// returned by BeginTransaction must be included.
118 ///
119 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
120 /// is included in the first packet to reset connection state.
121 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
122 let payload =
123 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
124 let max_packet = self.config.packet_size as usize;
125
126 // Check if we need to reset the connection on this request
127 let reset = self.needs_reset;
128 if reset {
129 self.needs_reset = false; // Clear flag before sending
130 tracing::debug!("sending SQL batch with RESETCONNECTION flag");
131 }
132
133 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
134
135 match connection {
136 #[cfg(feature = "tls")]
137 ConnectionHandle::Tls(conn) => {
138 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
139 .await?;
140 }
141 #[cfg(feature = "tls")]
142 ConnectionHandle::TlsPrelogin(conn) => {
143 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
144 .await?;
145 }
146 ConnectionHandle::Plain(conn) => {
147 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
148 .await?;
149 }
150 }
151
152 Ok(())
153 }
154
155 /// Send an RPC request to the server.
156 ///
157 /// Uses the client's current transaction descriptor in ALL_HEADERS.
158 ///
159 /// If `needs_reset` is set (from pool return), the RESETCONNECTION flag
160 /// is included in the first packet to reset connection state.
161 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
162 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
163 let max_packet = self.config.packet_size as usize;
164
165 // Check if we need to reset the connection on this request
166 let reset = self.needs_reset;
167 if reset {
168 self.needs_reset = false; // Clear flag before sending
169 tracing::debug!("sending RPC with RESETCONNECTION flag");
170 }
171
172 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
173
174 match connection {
175 #[cfg(feature = "tls")]
176 ConnectionHandle::Tls(conn) => {
177 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
178 .await?;
179 }
180 #[cfg(feature = "tls")]
181 ConnectionHandle::TlsPrelogin(conn) => {
182 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
183 .await?;
184 }
185 ConnectionHandle::Plain(conn) => {
186 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
187 .await?;
188 }
189 }
190
191 Ok(())
192 }
193}
194
195impl Client<Ready> {
196 /// Mark this connection as needing a reset on next use.
197 ///
198 /// Called by the connection pool when a connection is returned.
199 /// The next SQL batch or RPC will include the RESETCONNECTION flag
200 /// in the TDS packet header, causing SQL Server to reset connection
201 /// state (temp tables, SET options, transaction isolation level, etc.)
202 /// before executing the command.
203 ///
204 /// This is more efficient than calling `sp_reset_connection` as a
205 /// separate command because it's handled at the TDS protocol level.
206 pub fn mark_needs_reset(&mut self) {
207 self.needs_reset = true;
208 }
209
210 /// Check if this connection needs a reset.
211 ///
212 /// Returns true if `mark_needs_reset()` was called and the reset
213 /// hasn't been performed yet.
214 #[must_use]
215 pub fn needs_reset(&self) -> bool {
216 self.needs_reset
217 }
218
219 /// Execute a query and return a streaming result set.
220 ///
221 /// Per ADR-007, results are streamed by default for memory efficiency.
222 /// Use `.collect_all()` on the stream if you need all rows in memory.
223 ///
224 /// # Example
225 ///
226 /// ```rust,ignore
227 /// use futures::StreamExt;
228 ///
229 /// // Streaming (memory-efficient)
230 /// let mut stream = client.query("SELECT * FROM users WHERE id = @p1", &[&1]).await?;
231 /// while let Some(row) = stream.next().await {
232 /// let row = row?;
233 /// process(&row);
234 /// }
235 ///
236 /// // Buffered (loads all into memory)
237 /// let rows: Vec<Row> = client
238 /// .query("SELECT * FROM small_table", &[])
239 /// .await?
240 /// .collect_all()
241 /// .await?;
242 /// ```
243 pub async fn query<'a>(
244 &'a mut self,
245 sql: &str,
246 params: &[&(dyn crate::ToSql + Sync)],
247 ) -> Result<QueryStream<'a>> {
248 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
249
250 #[cfg(feature = "otel")]
251 let instrumentation = self.instrumentation.clone();
252 #[cfg(feature = "otel")]
253 let mut span = instrumentation.query_span(sql);
254
255 let result = async {
256 if params.is_empty() {
257 // Simple query without parameters - use SQL batch
258 self.send_sql_batch(sql).await?;
259 } else {
260 // Parameterized query - use sp_executesql via RPC
261 let rpc_params = Self::convert_params(params)?;
262 let rpc = RpcRequest::execute_sql(sql, rpc_params);
263 self.send_rpc(&rpc).await?;
264 }
265
266 // Read complete response including columns and rows
267 self.read_query_response().await
268 }
269 .await;
270
271 #[cfg(feature = "otel")]
272 match &result {
273 Ok(_) => InstrumentationContext::record_success(&mut span, None),
274 Err(e) => InstrumentationContext::record_error(&mut span, e),
275 }
276
277 // Drop the span before returning
278 #[cfg(feature = "otel")]
279 drop(span);
280
281 let (columns, rows) = result?;
282 Ok(QueryStream::new(columns, rows))
283 }
284
285 /// Execute a query with a specific timeout.
286 ///
287 /// This overrides the default `command_timeout` from the connection configuration
288 /// for this specific query. If the query does not complete within the specified
289 /// duration, an error is returned.
290 ///
291 /// # Arguments
292 ///
293 /// * `sql` - The SQL query to execute
294 /// * `params` - Query parameters
295 /// * `timeout_duration` - Maximum time to wait for the query to complete
296 ///
297 /// # Example
298 ///
299 /// ```rust,ignore
300 /// use std::time::Duration;
301 ///
302 /// // Execute with a 5-second timeout
303 /// let rows = client
304 /// .query_with_timeout(
305 /// "SELECT * FROM large_table",
306 /// &[],
307 /// Duration::from_secs(5),
308 /// )
309 /// .await?;
310 /// ```
311 pub async fn query_with_timeout<'a>(
312 &'a mut self,
313 sql: &str,
314 params: &[&(dyn crate::ToSql + Sync)],
315 timeout_duration: std::time::Duration,
316 ) -> Result<QueryStream<'a>> {
317 timeout(timeout_duration, self.query(sql, params))
318 .await
319 .map_err(|_| Error::CommandTimeout)?
320 }
321
322 /// Execute a batch that may return multiple result sets.
323 ///
324 /// This is useful for stored procedures or SQL batches that contain
325 /// multiple SELECT statements.
326 ///
327 /// # Example
328 ///
329 /// ```rust,ignore
330 /// // Execute a batch with multiple SELECTs
331 /// let mut results = client.query_multiple(
332 /// "SELECT 1 AS a; SELECT 2 AS b, 3 AS c;",
333 /// &[]
334 /// ).await?;
335 ///
336 /// // Process first result set
337 /// while let Some(row) = results.next_row().await? {
338 /// println!("Result 1: {:?}", row);
339 /// }
340 ///
341 /// // Move to second result set
342 /// if results.next_result().await? {
343 /// while let Some(row) = results.next_row().await? {
344 /// println!("Result 2: {:?}", row);
345 /// }
346 /// }
347 /// ```
348 pub async fn query_multiple<'a>(
349 &'a mut self,
350 sql: &str,
351 params: &[&(dyn crate::ToSql + Sync)],
352 ) -> Result<MultiResultStream<'a>> {
353 tracing::debug!(
354 sql = sql,
355 params_count = params.len(),
356 "executing multi-result query"
357 );
358
359 if params.is_empty() {
360 // Simple batch without parameters - use SQL batch
361 self.send_sql_batch(sql).await?;
362 } else {
363 // Parameterized query - use sp_executesql via RPC
364 let rpc_params = Self::convert_params(params)?;
365 let rpc = RpcRequest::execute_sql(sql, rpc_params);
366 self.send_rpc(&rpc).await?;
367 }
368
369 // Read all result sets
370 let result_sets = self.read_multi_result_response().await?;
371 Ok(MultiResultStream::new(result_sets))
372 }
373
374 /// Execute a query that doesn't return rows.
375 ///
376 /// Returns the number of affected rows.
377 pub async fn execute(
378 &mut self,
379 sql: &str,
380 params: &[&(dyn crate::ToSql + Sync)],
381 ) -> Result<u64> {
382 tracing::debug!(
383 sql = sql,
384 params_count = params.len(),
385 "executing statement"
386 );
387
388 #[cfg(feature = "otel")]
389 let instrumentation = self.instrumentation.clone();
390 #[cfg(feature = "otel")]
391 let mut span = instrumentation.query_span(sql);
392
393 let result = async {
394 if params.is_empty() {
395 // Simple statement without parameters - use SQL batch
396 self.send_sql_batch(sql).await?;
397 } else {
398 // Parameterized statement - use sp_executesql via RPC
399 let rpc_params = Self::convert_params(params)?;
400 let rpc = RpcRequest::execute_sql(sql, rpc_params);
401 self.send_rpc(&rpc).await?;
402 }
403
404 // Read response and get row count
405 self.read_execute_result().await
406 }
407 .await;
408
409 #[cfg(feature = "otel")]
410 match &result {
411 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
412 Err(e) => InstrumentationContext::record_error(&mut span, e),
413 }
414
415 // Drop the span before returning
416 #[cfg(feature = "otel")]
417 drop(span);
418
419 result
420 }
421
422 /// Execute a statement with a specific timeout.
423 ///
424 /// This overrides the default `command_timeout` from the connection configuration
425 /// for this specific statement. If the statement does not complete within the
426 /// specified duration, an error is returned.
427 ///
428 /// # Arguments
429 ///
430 /// * `sql` - The SQL statement to execute
431 /// * `params` - Statement parameters
432 /// * `timeout_duration` - Maximum time to wait for the statement to complete
433 ///
434 /// # Example
435 ///
436 /// ```rust,ignore
437 /// use std::time::Duration;
438 ///
439 /// // Execute with a 10-second timeout
440 /// let rows_affected = client
441 /// .execute_with_timeout(
442 /// "UPDATE large_table SET status = @p1",
443 /// &[&"processed"],
444 /// Duration::from_secs(10),
445 /// )
446 /// .await?;
447 /// ```
448 pub async fn execute_with_timeout(
449 &mut self,
450 sql: &str,
451 params: &[&(dyn crate::ToSql + Sync)],
452 timeout_duration: std::time::Duration,
453 ) -> Result<u64> {
454 timeout(timeout_duration, self.execute(sql, params))
455 .await
456 .map_err(|_| Error::CommandTimeout)?
457 }
458
459 /// Begin a transaction.
460 ///
461 /// This transitions the client from `Ready` to `InTransaction` state.
462 /// Per MS-TDS spec, the server returns a transaction descriptor in the
463 /// BeginTransaction EnvChange token that must be included in subsequent
464 /// ALL_HEADERS sections.
465 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
466 tracing::debug!("beginning transaction");
467
468 #[cfg(feature = "otel")]
469 let instrumentation = self.instrumentation.clone();
470 #[cfg(feature = "otel")]
471 let mut span = instrumentation.transaction_span("BEGIN");
472
473 // Execute BEGIN TRANSACTION and extract the transaction descriptor
474 let result = async {
475 self.send_sql_batch("BEGIN TRANSACTION").await?;
476 self.read_transaction_begin_result().await
477 }
478 .await;
479
480 #[cfg(feature = "otel")]
481 match &result {
482 Ok(_) => InstrumentationContext::record_success(&mut span, None),
483 Err(e) => InstrumentationContext::record_error(&mut span, e),
484 }
485
486 // Drop the span before moving instrumentation
487 #[cfg(feature = "otel")]
488 drop(span);
489
490 let transaction_descriptor = result?;
491
492 Ok(Client {
493 config: self.config,
494 _state: PhantomData,
495 connection: self.connection,
496 server_version: self.server_version,
497 current_database: self.current_database,
498 statement_cache: self.statement_cache,
499 transaction_descriptor, // Store the descriptor from server
500 needs_reset: self.needs_reset,
501 #[cfg(feature = "otel")]
502 instrumentation: self.instrumentation,
503 })
504 }
505
506 /// Begin a transaction with a specific isolation level.
507 ///
508 /// This transitions the client from `Ready` to `InTransaction` state
509 /// with the specified isolation level.
510 ///
511 /// # Example
512 ///
513 /// ```rust,ignore
514 /// use mssql_client::IsolationLevel;
515 ///
516 /// let tx = client.begin_transaction_with_isolation(IsolationLevel::Serializable).await?;
517 /// // All operations in this transaction use SERIALIZABLE isolation
518 /// tx.commit().await?;
519 /// ```
520 pub async fn begin_transaction_with_isolation(
521 mut self,
522 isolation_level: crate::transaction::IsolationLevel,
523 ) -> Result<Client<InTransaction>> {
524 tracing::debug!(
525 isolation_level = %isolation_level.name(),
526 "beginning transaction with isolation level"
527 );
528
529 #[cfg(feature = "otel")]
530 let instrumentation = self.instrumentation.clone();
531 #[cfg(feature = "otel")]
532 let mut span = instrumentation.transaction_span("BEGIN");
533
534 // First set the isolation level
535 let result = async {
536 self.send_sql_batch(isolation_level.as_sql()).await?;
537 self.read_execute_result().await?;
538
539 // Then begin the transaction
540 self.send_sql_batch("BEGIN TRANSACTION").await?;
541 self.read_transaction_begin_result().await
542 }
543 .await;
544
545 #[cfg(feature = "otel")]
546 match &result {
547 Ok(_) => InstrumentationContext::record_success(&mut span, None),
548 Err(e) => InstrumentationContext::record_error(&mut span, e),
549 }
550
551 #[cfg(feature = "otel")]
552 drop(span);
553
554 let transaction_descriptor = result?;
555
556 Ok(Client {
557 config: self.config,
558 _state: PhantomData,
559 connection: self.connection,
560 server_version: self.server_version,
561 current_database: self.current_database,
562 statement_cache: self.statement_cache,
563 transaction_descriptor,
564 needs_reset: self.needs_reset,
565 #[cfg(feature = "otel")]
566 instrumentation: self.instrumentation,
567 })
568 }
569
570 /// Execute a simple query without parameters.
571 ///
572 /// This is useful for DDL statements and simple queries where you
573 /// don't need to retrieve the affected row count.
574 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
575 tracing::debug!(sql = sql, "executing simple query");
576
577 // Send SQL batch
578 self.send_sql_batch(sql).await?;
579
580 // Read and discard response
581 let _ = self.read_execute_result().await?;
582
583 Ok(())
584 }
585
586 /// Close the connection gracefully.
587 pub async fn close(self) -> Result<()> {
588 tracing::debug!("closing connection");
589 Ok(())
590 }
591
592 /// Get the current database name.
593 #[must_use]
594 pub fn database(&self) -> Option<&str> {
595 self.config.database.as_deref()
596 }
597
598 /// Get the server host.
599 #[must_use]
600 pub fn host(&self) -> &str {
601 &self.config.host
602 }
603
604 /// Get the server port.
605 #[must_use]
606 pub fn port(&self) -> u16 {
607 self.config.port
608 }
609
610 /// Check if the connection is currently in a transaction.
611 ///
612 /// This returns `true` if a transaction was started via raw SQL
613 /// (`BEGIN TRANSACTION`) and has not yet been committed or rolled back.
614 ///
615 /// Note: This only tracks transactions started via raw SQL. Transactions
616 /// started via the type-state API (`begin_transaction()`) result in a
617 /// `Client<InTransaction>` which is a different type.
618 ///
619 /// # Example
620 ///
621 /// ```rust,ignore
622 /// client.execute("BEGIN TRANSACTION", &[]).await?;
623 /// assert!(client.is_in_transaction());
624 ///
625 /// client.execute("COMMIT", &[]).await?;
626 /// assert!(!client.is_in_transaction());
627 /// ```
628 #[must_use]
629 pub fn is_in_transaction(&self) -> bool {
630 self.transaction_descriptor != 0
631 }
632
633 /// Get a handle for cancelling the current query.
634 ///
635 /// The cancel handle can be cloned and sent to other tasks, enabling
636 /// cancellation of long-running queries from a separate async context.
637 ///
638 /// # Example
639 ///
640 /// ```rust,ignore
641 /// use std::time::Duration;
642 ///
643 /// let cancel_handle = client.cancel_handle();
644 ///
645 /// // Spawn a task to cancel after 10 seconds
646 /// let handle = tokio::spawn(async move {
647 /// tokio::time::sleep(Duration::from_secs(10)).await;
648 /// let _ = cancel_handle.cancel().await;
649 /// });
650 ///
651 /// // This query will be cancelled if it runs longer than 10 seconds
652 /// let result = client.query("SELECT * FROM very_large_table", &[]).await;
653 /// ```
654 #[must_use]
655 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
656 let connection = self
657 .connection
658 .as_ref()
659 .expect("connection should be present");
660 match connection {
661 #[cfg(feature = "tls")]
662 ConnectionHandle::Tls(conn) => {
663 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
664 }
665 #[cfg(feature = "tls")]
666 ConnectionHandle::TlsPrelogin(conn) => {
667 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
668 }
669 ConnectionHandle::Plain(conn) => {
670 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
671 }
672 }
673 }
674}
675
676/// # Drop Behavior
677///
678/// **`Client<InTransaction>` has no automatic rollback on drop.** If the client is
679/// dropped without calling [`commit()`](Client::commit) or [`rollback()`](Client::rollback),
680/// the transaction remains open on the server until the TCP connection closes
681/// (at which point SQL Server automatically rolls back).
682///
683/// This is because `Drop` is synchronous and cannot perform the async I/O needed
684/// to send a `ROLLBACK TRANSACTION` command.
685///
686/// ## Consequences of dropping without commit/rollback
687///
688/// - **Direct connections:** The transaction leaks until the OS TCP timeout
689/// (potentially 30+ minutes), holding locks on any modified rows.
690/// - **Pooled connections:** The pool detects the active transaction descriptor
691/// and discards the connection rather than returning it to the idle pool
692/// (see `PooledConnection::drop` in `mssql-driver-pool`).
693///
694/// ## Best practice
695///
696/// Always ensure `commit()` or `rollback()` is called. Use helper patterns
697/// for error paths:
698///
699/// ```rust,ignore
700/// let tx = client.begin_transaction().await?;
701/// match do_work(&tx).await {
702/// Ok(_) => { tx.commit().await?; }
703/// Err(e) => { tx.rollback().await?; return Err(e); }
704/// }
705/// ```
706impl Client<InTransaction> {
707 /// Execute a query within the transaction and return a streaming result set.
708 ///
709 /// See [`Client<Ready>::query`] for usage examples.
710 pub async fn query<'a>(
711 &'a mut self,
712 sql: &str,
713 params: &[&(dyn crate::ToSql + Sync)],
714 ) -> Result<QueryStream<'a>> {
715 tracing::debug!(
716 sql = sql,
717 params_count = params.len(),
718 "executing query in transaction"
719 );
720
721 #[cfg(feature = "otel")]
722 let instrumentation = self.instrumentation.clone();
723 #[cfg(feature = "otel")]
724 let mut span = instrumentation.query_span(sql);
725
726 let result = async {
727 if params.is_empty() {
728 // Simple query without parameters - use SQL batch
729 self.send_sql_batch(sql).await?;
730 } else {
731 // Parameterized query - use sp_executesql via RPC
732 let rpc_params = Self::convert_params(params)?;
733 let rpc = RpcRequest::execute_sql(sql, rpc_params);
734 self.send_rpc(&rpc).await?;
735 }
736
737 // Read complete response including columns and rows
738 self.read_query_response().await
739 }
740 .await;
741
742 #[cfg(feature = "otel")]
743 match &result {
744 Ok(_) => InstrumentationContext::record_success(&mut span, None),
745 Err(e) => InstrumentationContext::record_error(&mut span, e),
746 }
747
748 // Drop the span before returning
749 #[cfg(feature = "otel")]
750 drop(span);
751
752 let (columns, rows) = result?;
753 Ok(QueryStream::new(columns, rows))
754 }
755
756 /// Execute a statement within the transaction.
757 ///
758 /// Returns the number of affected rows.
759 pub async fn execute(
760 &mut self,
761 sql: &str,
762 params: &[&(dyn crate::ToSql + Sync)],
763 ) -> Result<u64> {
764 tracing::debug!(
765 sql = sql,
766 params_count = params.len(),
767 "executing statement in transaction"
768 );
769
770 #[cfg(feature = "otel")]
771 let instrumentation = self.instrumentation.clone();
772 #[cfg(feature = "otel")]
773 let mut span = instrumentation.query_span(sql);
774
775 let result = async {
776 if params.is_empty() {
777 // Simple statement without parameters - use SQL batch
778 self.send_sql_batch(sql).await?;
779 } else {
780 // Parameterized statement - use sp_executesql via RPC
781 let rpc_params = Self::convert_params(params)?;
782 let rpc = RpcRequest::execute_sql(sql, rpc_params);
783 self.send_rpc(&rpc).await?;
784 }
785
786 // Read response and get row count
787 self.read_execute_result().await
788 }
789 .await;
790
791 #[cfg(feature = "otel")]
792 match &result {
793 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
794 Err(e) => InstrumentationContext::record_error(&mut span, e),
795 }
796
797 // Drop the span before returning
798 #[cfg(feature = "otel")]
799 drop(span);
800
801 result
802 }
803
804 /// Execute a query within the transaction with a specific timeout.
805 ///
806 /// See [`Client<Ready>::query_with_timeout`] for details.
807 pub async fn query_with_timeout<'a>(
808 &'a mut self,
809 sql: &str,
810 params: &[&(dyn crate::ToSql + Sync)],
811 timeout_duration: std::time::Duration,
812 ) -> Result<QueryStream<'a>> {
813 timeout(timeout_duration, self.query(sql, params))
814 .await
815 .map_err(|_| Error::CommandTimeout)?
816 }
817
818 /// Execute a statement within the transaction with a specific timeout.
819 ///
820 /// See [`Client<Ready>::execute_with_timeout`] for details.
821 pub async fn execute_with_timeout(
822 &mut self,
823 sql: &str,
824 params: &[&(dyn crate::ToSql + Sync)],
825 timeout_duration: std::time::Duration,
826 ) -> Result<u64> {
827 timeout(timeout_duration, self.execute(sql, params))
828 .await
829 .map_err(|_| Error::CommandTimeout)?
830 }
831
832 /// Commit the transaction.
833 ///
834 /// This transitions the client back to `Ready` state.
835 pub async fn commit(mut self) -> Result<Client<Ready>> {
836 tracing::debug!("committing transaction");
837
838 #[cfg(feature = "otel")]
839 let instrumentation = self.instrumentation.clone();
840 #[cfg(feature = "otel")]
841 let mut span = instrumentation.transaction_span("COMMIT");
842
843 // Execute COMMIT TRANSACTION
844 let result = async {
845 self.send_sql_batch("COMMIT TRANSACTION").await?;
846 self.read_execute_result().await
847 }
848 .await;
849
850 #[cfg(feature = "otel")]
851 match &result {
852 Ok(_) => InstrumentationContext::record_success(&mut span, None),
853 Err(e) => InstrumentationContext::record_error(&mut span, e),
854 }
855
856 // Drop the span before moving instrumentation
857 #[cfg(feature = "otel")]
858 drop(span);
859
860 result?;
861
862 Ok(Client {
863 config: self.config,
864 _state: PhantomData,
865 connection: self.connection,
866 server_version: self.server_version,
867 current_database: self.current_database,
868 statement_cache: self.statement_cache,
869 transaction_descriptor: 0, // Reset to auto-commit mode
870 needs_reset: self.needs_reset,
871 #[cfg(feature = "otel")]
872 instrumentation: self.instrumentation,
873 })
874 }
875
876 /// Rollback the transaction.
877 ///
878 /// This transitions the client back to `Ready` state.
879 pub async fn rollback(mut self) -> Result<Client<Ready>> {
880 tracing::debug!("rolling back transaction");
881
882 #[cfg(feature = "otel")]
883 let instrumentation = self.instrumentation.clone();
884 #[cfg(feature = "otel")]
885 let mut span = instrumentation.transaction_span("ROLLBACK");
886
887 // Execute ROLLBACK TRANSACTION
888 let result = async {
889 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
890 self.read_execute_result().await
891 }
892 .await;
893
894 #[cfg(feature = "otel")]
895 match &result {
896 Ok(_) => InstrumentationContext::record_success(&mut span, None),
897 Err(e) => InstrumentationContext::record_error(&mut span, e),
898 }
899
900 // Drop the span before moving instrumentation
901 #[cfg(feature = "otel")]
902 drop(span);
903
904 result?;
905
906 Ok(Client {
907 config: self.config,
908 _state: PhantomData,
909 connection: self.connection,
910 server_version: self.server_version,
911 current_database: self.current_database,
912 statement_cache: self.statement_cache,
913 transaction_descriptor: 0, // Reset to auto-commit mode
914 needs_reset: self.needs_reset,
915 #[cfg(feature = "otel")]
916 instrumentation: self.instrumentation,
917 })
918 }
919
920 /// Create a savepoint and return a handle for later rollback.
921 ///
922 /// The returned `SavePoint` handle contains the validated savepoint name.
923 /// Use it with `rollback_to()` to partially undo transaction work.
924 ///
925 /// # Example
926 ///
927 /// ```rust,ignore
928 /// let tx = client.begin_transaction().await?;
929 /// tx.execute("INSERT INTO orders ...").await?;
930 /// let sp = tx.save_point("before_items").await?;
931 /// tx.execute("INSERT INTO items ...").await?;
932 /// // Oops, rollback just the items
933 /// tx.rollback_to(&sp).await?;
934 /// tx.commit().await?;
935 /// ```
936 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
937 validate_identifier(name)?;
938 tracing::debug!(name = name, "creating savepoint");
939
940 // Execute SAVE TRANSACTION <name>
941 // Note: name is validated by validate_identifier() to prevent SQL injection
942 let sql = format!("SAVE TRANSACTION {name}");
943 self.send_sql_batch(&sql).await?;
944 self.read_execute_result().await?;
945
946 Ok(SavePoint::new(name.to_string()))
947 }
948
949 /// Rollback to a savepoint.
950 ///
951 /// This rolls back all changes made after the savepoint was created,
952 /// but keeps the transaction active. The savepoint remains valid and
953 /// can be rolled back to again.
954 ///
955 /// # Example
956 ///
957 /// ```rust,ignore
958 /// let sp = tx.save_point("checkpoint").await?;
959 /// // ... do some work ...
960 /// tx.rollback_to(&sp).await?; // Undo changes since checkpoint
961 /// // Transaction is still active, savepoint is still valid
962 /// ```
963 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
964 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
965
966 // Execute ROLLBACK TRANSACTION <name>
967 // Note: savepoint name was validated during creation
968 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
969 self.send_sql_batch(&sql).await?;
970 self.read_execute_result().await?;
971
972 Ok(())
973 }
974
975 /// Release a savepoint (optional cleanup).
976 ///
977 /// Note: SQL Server doesn't have explicit savepoint release, but this
978 /// method is provided for API completeness. The savepoint is automatically
979 /// released when the transaction commits or rolls back.
980 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
981 tracing::debug!(name = savepoint.name(), "releasing savepoint");
982
983 // SQL Server doesn't require explicit savepoint release
984 // The savepoint is implicitly released on commit/rollback
985 // This method exists for API completeness
986 drop(savepoint);
987 Ok(())
988 }
989
990 /// Get a handle for cancelling the current query within the transaction.
991 ///
992 /// See [`Client<Ready>::cancel_handle`] for usage examples.
993 #[must_use]
994 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
995 let connection = self
996 .connection
997 .as_ref()
998 .expect("connection should be present");
999 match connection {
1000 #[cfg(feature = "tls")]
1001 ConnectionHandle::Tls(conn) => {
1002 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
1003 }
1004 #[cfg(feature = "tls")]
1005 ConnectionHandle::TlsPrelogin(conn) => {
1006 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
1007 }
1008 ConnectionHandle::Plain(conn) => {
1009 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
1010 }
1011 }
1012 }
1013}
1014
1015/// Validate an identifier (table name, savepoint name, etc.) to prevent SQL injection.
1016fn validate_identifier(name: &str) -> Result<()> {
1017 use once_cell::sync::Lazy;
1018 use regex::Regex;
1019
1020 static IDENTIFIER_RE: Lazy<Regex> =
1021 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
1022
1023 if name.is_empty() {
1024 return Err(Error::InvalidIdentifier(
1025 "identifier cannot be empty".into(),
1026 ));
1027 }
1028
1029 if !IDENTIFIER_RE.is_match(name) {
1030 return Err(Error::InvalidIdentifier(format!(
1031 "invalid identifier '{name}': must start with letter/underscore, \
1032 contain only alphanumerics/_/@/#/$, and be 1-128 characters"
1033 )));
1034 }
1035
1036 Ok(())
1037}
1038
1039impl<S: ConnectionState> std::fmt::Debug for Client<S> {
1040 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041 f.debug_struct("Client")
1042 .field("host", &self.config.host)
1043 .field("port", &self.config.port)
1044 .field("database", &self.config.database)
1045 .finish()
1046 }
1047}
1048
1049#[cfg(test)]
1050#[allow(clippy::unwrap_used, clippy::panic)]
1051mod tests {
1052 use super::*;
1053
1054 #[test]
1055 fn test_validate_identifier_valid() {
1056 assert!(validate_identifier("my_table").is_ok());
1057 assert!(validate_identifier("Table123").is_ok());
1058 assert!(validate_identifier("_private").is_ok());
1059 assert!(validate_identifier("sp_test").is_ok());
1060 }
1061
1062 #[test]
1063 fn test_validate_identifier_invalid() {
1064 assert!(validate_identifier("").is_err());
1065 assert!(validate_identifier("123abc").is_err());
1066 assert!(validate_identifier("table-name").is_err());
1067 assert!(validate_identifier("table name").is_err());
1068 assert!(validate_identifier("table;DROP TABLE users").is_err());
1069 }
1070}