Skip to main content

mssql_client/client/
connect.rs

1//! Connection establishment for SQL Server.
2//!
3//! This module contains the `impl Client<Disconnected>` block, handling
4//! TCP connection, TLS negotiation, PreLogin exchange, and Login7 authentication.
5
6use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::MAX_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30impl Client<Disconnected> {
31    /// Connect to SQL Server.
32    ///
33    /// This establishes a connection, performs TLS negotiation (if required),
34    /// and authenticates with the server.
35    ///
36    /// # Example
37    ///
38    /// ```rust,ignore
39    /// let client = Client::connect(config).await?;
40    /// ```
41    pub async fn connect(config: Config) -> Result<Client<Ready>> {
42        let retry = config.retry.clone();
43        let max_redirects = config.redirect.max_redirects;
44        let follow_redirects = config.redirect.follow_redirects;
45        // Overall timeout accounts for retries + redirects per attempt, capped at 5 min.
46        let per_attempt = config.timeouts.connect_timeout
47            + config.timeouts.tls_timeout
48            + config.timeouts.login_timeout;
49        let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
50        let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
51        let initial_host = config.host.clone();
52        let initial_port = config.port;
53
54        let result = timeout(overall, async {
55            let mut last_error: Option<Error> = None;
56
57            for retry_attempt in 0..=retry.max_retries {
58                if retry_attempt > 0 {
59                    let backoff = retry.backoff_for_attempt(retry_attempt);
60                    tracing::info!(
61                        retry_attempt,
62                        backoff_ms = backoff.as_millis() as u64,
63                        "retrying connection after transient error"
64                    );
65                    tokio::time::sleep(backoff).await;
66                }
67
68                // Each retry starts fresh with original host/port
69                let mut current_config = config.clone();
70                let mut redirect_count: u8 = 0;
71
72                let attempt_result = loop {
73                    redirect_count += 1;
74                    if redirect_count > max_redirects + 1 {
75                        break Err(Error::TooManyRedirects { max: max_redirects });
76                    }
77
78                    match Self::try_connect(&current_config).await {
79                        Ok(client) => break Ok(client),
80                        Err(Error::Routing { host, port }) => {
81                            if !follow_redirects {
82                                break Err(Error::Routing { host, port });
83                            }
84                            tracing::info!(
85                                host = %host,
86                                port = port,
87                                redirect = redirect_count,
88                                max_redirects = max_redirects,
89                                "following Azure SQL routing redirect"
90                            );
91                            current_config = current_config.with_host(&host).with_port(port);
92                            continue;
93                        }
94                        Err(e) => break Err(e),
95                    }
96                };
97
98                match attempt_result {
99                    Ok(client) => return Ok(client),
100                    Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
101                        tracing::warn!(
102                            retry_attempt,
103                            max_retries = retry.max_retries,
104                            error = %e,
105                            "transient connection error, will retry"
106                        );
107                        last_error = Some(attempt_result.unwrap_err());
108                    }
109                    Err(e) => return Err(e),
110                }
111            }
112
113            // All retries exhausted — return last error
114            Err(last_error.expect("at least one attempt was made"))
115        })
116        .await;
117
118        match result {
119            Ok(inner) => inner,
120            Err(_elapsed) => Err(Error::ConnectTimeout {
121                host: initial_host,
122                port: initial_port,
123            }),
124        }
125    }
126
127    async fn try_connect(config: &Config) -> Result<Client<Ready>> {
128        // If a named instance is specified, resolve the TCP port via SQL Browser
129        let port = if let Some(ref instance) = config.instance {
130            let resolved = crate::browser::resolve_instance(
131                &config.host,
132                instance,
133                Some(config.timeouts.connect_timeout),
134            )
135            .await?;
136            tracing::info!(
137                host = %config.host,
138                instance = %instance,
139                resolved_port = resolved,
140                database = ?config.database,
141                "connecting to named SQL Server instance"
142            );
143            resolved
144        } else {
145            tracing::info!(
146                host = %config.host,
147                port = config.port,
148                database = ?config.database,
149                "connecting to SQL Server"
150            );
151            config.port
152        };
153
154        // Normalize "." and "(local)" to localhost for TCP.
155        // These are standard ADO.NET aliases for the local machine.
156        let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
157            "127.0.0.1"
158        } else {
159            &config.host
160        };
161
162        // Step 1: Establish TCP connection
163        let tcp_stream = if config.multi_subnet_failover {
164            Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
165        } else {
166            let addr = format!("{host}:{port}");
167            tracing::debug!("establishing TCP connection to {}", addr);
168            let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
169                .await
170                .map_err(|_| Error::ConnectTimeout {
171                    host: config.host.clone(),
172                    port: config.port,
173                })?
174                .map_err(Error::from)?;
175            stream.set_nodelay(true).map_err(Error::from)?;
176            stream
177        };
178
179        #[cfg(feature = "tls")]
180        {
181            // Determine TLS negotiation mode
182            let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
183
184            // Step 2: Handle TDS 8.0 strict mode (TLS before any TDS traffic)
185            if tls_mode.is_tls_first() {
186                return Self::connect_tds_8(config, tcp_stream).await;
187            }
188
189            // Step 3: TDS 7.x flow - PreLogin first, then TLS, then Login7
190            Self::connect_tds_7x(config, tcp_stream).await
191        }
192
193        #[cfg(not(feature = "tls"))]
194        {
195            // When TLS feature is disabled, only no_tls connections are supported
196            if config.strict_mode {
197                return Err(Error::Config(
198                    "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
199                ));
200            }
201
202            if !config.no_tls {
203                return Err(Error::Config(
204                    "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
205                     or use Encrypt=no_tls in your connection string for unencrypted connections."
206                        .into(),
207                ));
208            }
209
210            // Proceed with no-TLS connection
211            Self::connect_no_tls(config, tcp_stream).await
212        }
213    }
214
215    /// Resolve hostname to all IPs and race parallel TCP connections.
216    ///
217    /// Used when `MultiSubnetFailover=True` for AlwaysOn AG listeners that
218    /// span multiple subnets. First successful TCP connection wins.
219    async fn connect_parallel(
220        host: &str,
221        port: u16,
222        connect_timeout: std::time::Duration,
223    ) -> Result<TcpStream> {
224        let addr_str = format!("{host}:{port}");
225        let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
226            .await
227            .map_err(Error::from)?
228            .collect();
229
230        if addrs.is_empty() {
231            return Err(Error::from(std::io::Error::new(
232                std::io::ErrorKind::AddrNotAvailable,
233                format!("no addresses resolved for {host}:{port}"),
234            )));
235        }
236
237        // Single address — no need to spawn tasks
238        if addrs.len() == 1 {
239            tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
240            let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
241                .await
242                .map_err(|_| Error::ConnectTimeout {
243                    host: host.to_string(),
244                    port,
245                })?
246                .map_err(Error::from)?;
247            stream.set_nodelay(true).map_err(Error::from)?;
248            return Ok(stream);
249        }
250
251        let addr_count = addrs.len();
252        tracing::debug!(
253            host = host,
254            port = port,
255            resolved_count = addr_count,
256            "MultiSubnetFailover: racing parallel connections",
257        );
258
259        let mut join_set = tokio::task::JoinSet::new();
260
261        for addr in addrs {
262            let dur = connect_timeout;
263            join_set.spawn(async move {
264                let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
265                    std::io::Error::new(
266                        std::io::ErrorKind::TimedOut,
267                        format!("connection to {addr} timed out"),
268                    )
269                })??;
270                tcp.set_nodelay(true)?;
271                Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
272            });
273        }
274
275        let mut last_error: Option<std::io::Error> = None;
276
277        while let Some(result) = join_set.join_next().await {
278            match result {
279                Ok(Ok((stream, addr))) => {
280                    tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
281                    join_set.abort_all();
282                    return Ok(stream);
283                }
284                Ok(Err(e)) => {
285                    tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
286                    last_error = Some(e);
287                }
288                Err(join_err) => {
289                    tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
290                    last_error = Some(std::io::Error::other(join_err.to_string()));
291                }
292            }
293        }
294
295        // All connections failed
296        Err(Error::from(last_error.unwrap_or_else(|| {
297            std::io::Error::new(
298                std::io::ErrorKind::ConnectionRefused,
299                format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
300            )
301        })))
302    }
303
304    /// Connect using TDS 8.0 strict mode.
305    ///
306    /// Flow: TCP -> TLS -> PreLogin (encrypted) -> Login7 (encrypted)
307    #[cfg(feature = "tls")]
308    async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
309        tracing::debug!("using TDS 8.0 strict mode (TLS first)");
310
311        // Build TLS configuration with TDS 8.0 ALPN protocol
312        let tls_config = TlsConfig::new()
313            .strict_mode(true)
314            .trust_server_certificate(config.trust_server_certificate)
315            .with_alpn_protocols(vec![b"tds/8.0".to_vec()]);
316
317        let tls_connector = TlsConnector::new(tls_config)?;
318
319        // Perform TLS handshake before any TDS traffic
320        let tls_stream = timeout(
321            config.timeouts.tls_timeout,
322            tls_connector.connect(tcp_stream, &config.host),
323        )
324        .await
325        .map_err(|_| Error::TlsTimeout {
326            host: config.host.clone(),
327            port: config.port,
328        })??;
329
330        tracing::debug!("TLS handshake completed (strict mode)");
331
332        // Create connection wrapper
333        let mut connection = Connection::new(tls_stream);
334
335        // Send PreLogin (encrypted in strict mode)
336        let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
337        Self::send_prelogin(&mut connection, &prelogin).await?;
338        let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
339
340        // Create SSPI negotiator if integrated auth
341        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
342        let negotiator = Self::create_negotiator(config)?;
343        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
344        let sspi_token = match negotiator {
345            Some(ref neg) => Some(neg.initialize()?),
346            None => None,
347        };
348        #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
349        let sspi_token: Option<Vec<u8>> = None;
350
351        // Send Login7
352        let login = Self::build_login7(config, sspi_token);
353        Self::send_login7(&mut connection, &login).await?;
354
355        // Process login response (with timeout to prevent hangs during redirect)
356        let (server_version, current_database, routing, server_collation) = timeout(
357            config.timeouts.login_timeout,
358            Self::process_login_response(
359                &mut connection,
360                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
361                negotiator.as_deref(),
362            ),
363        )
364        .await
365        .map_err(|_| Error::LoginTimeout {
366            host: config.host.clone(),
367            port: config.port,
368        })??;
369
370        // Handle routing redirect
371        if let Some((host, port)) = routing {
372            return Err(Error::Routing { host, port });
373        }
374
375        Ok(Client {
376            config: config.clone(),
377            _state: PhantomData,
378            connection: Some(ConnectionHandle::Tls(connection)),
379            server_version,
380            current_database: current_database.clone(),
381            server_collation,
382            statement_cache: StatementCache::with_default_size(),
383            transaction_descriptor: 0, // Auto-commit mode initially
384            needs_reset: false,        // Fresh connection, no reset needed
385            in_flight: false,          // No request pending
386            #[cfg(feature = "otel")]
387            instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
388                .with_database(current_database.clone().unwrap_or_default()),
389            #[cfg(feature = "always-encrypted")]
390            encryption_context: config.column_encryption.clone().map(|cfg| {
391                std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
392            }),
393        })
394    }
395
396    /// Connect using TDS 7.x flow.
397    ///
398    /// Flow: TCP -> PreLogin (clear) -> TLS -> Login7 (encrypted)
399    ///
400    /// Note: For TDS 7.x, the PreLogin exchange happens over raw TCP before
401    /// upgrading to TLS. We use low-level I/O for this initial exchange
402    /// since the Connection struct splits the stream immediately.
403    #[cfg(feature = "tls")]
404    async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
405        use bytes::BufMut;
406        use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
407        use tokio::io::{AsyncReadExt, AsyncWriteExt};
408
409        tracing::debug!("using TDS 7.x flow (PreLogin first)");
410
411        // Build PreLogin packet
412        // Determine client encryption level based on configuration
413        let client_encryption = if config.no_tls {
414            // no_tls: Completely disable TLS
415            tracing::warn!(
416                "⚠️  no_tls mode enabled. Connection will be UNENCRYPTED. \
417                 Credentials and data will be transmitted in plaintext. \
418                 This should only be used for development/testing with legacy SQL Server."
419            );
420            EncryptionLevel::NotSupported
421        } else if config.encrypt {
422            EncryptionLevel::On
423        } else {
424            EncryptionLevel::Off
425        };
426        let prelogin = Self::build_prelogin(config, client_encryption);
427        tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
428        let prelogin_bytes = prelogin.encode();
429
430        // Manually create and send the PreLogin packet over raw TCP
431        let header = PacketHeader::new(
432            PacketType::PreLogin,
433            PacketStatus::END_OF_MESSAGE,
434            (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
435        );
436
437        let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
438        header.encode(&mut packet_buf);
439        packet_buf.put_slice(&prelogin_bytes);
440
441        tcp_stream
442            .write_all(&packet_buf)
443            .await
444            .map_err(Error::from)?;
445
446        // Read PreLogin response
447        let mut header_buf = [0u8; PACKET_HEADER_SIZE];
448        tcp_stream
449            .read_exact(&mut header_buf)
450            .await
451            .map_err(Error::from)?;
452
453        let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
454        let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
455
456        let mut response_buf = vec![0u8; payload_length];
457        tcp_stream
458            .read_exact(&mut response_buf)
459            .await
460            .map_err(Error::from)?;
461
462        let prelogin_response = PreLogin::decode(&response_buf[..])?;
463
464        // Log PreLogin response
465        // Note: The server sends its SQL Server product version in PreLogin,
466        // NOT the TDS protocol version. The actual TDS version is negotiated
467        // in the LOGINACK token after login.
468        let client_tds_version = config.tds_version;
469        if let Some(ref server_version) = prelogin_response.server_version {
470            tracing::debug!(
471                requested_tds_version = %client_tds_version,
472                server_product_version = %server_version,
473                server_product = server_version.product_name(),
474                max_tds_version = %server_version.max_tds_version(),
475                "PreLogin response received"
476            );
477
478            // Warn if the server's max TDS version is lower than requested
479            let server_max_tds = server_version.max_tds_version();
480            if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
481                tracing::warn!(
482                    requested_tds_version = %client_tds_version,
483                    server_max_tds_version = %server_max_tds,
484                    server_product = server_version.product_name(),
485                    "Server supports lower TDS version than requested. \
486                     Connection will use server's maximum: {}",
487                    server_max_tds
488                );
489            }
490
491            // Warn about legacy SQL Server versions (2005 and earlier)
492            if server_max_tds.is_legacy() {
493                tracing::warn!(
494                    server_product = server_version.product_name(),
495                    server_max_tds_version = %server_max_tds,
496                    "Server uses legacy TDS version. Some features may not be available."
497                );
498            }
499        } else {
500            tracing::debug!(
501                requested_tds_version = %client_tds_version,
502                "PreLogin response received (no version info)"
503            );
504        }
505
506        // Check server encryption response
507        let server_encryption = prelogin_response.encryption;
508        tracing::debug!(encryption = ?server_encryption, "server encryption level");
509
510        // Determine negotiated encryption level (follows TDS 7.x rules)
511        // - NotSupported + NotSupported = NotSupported (no TLS at all)
512        // - Off + Off = Off (TLS for login only, then plain)
513        // - On + anything supported = On (full TLS)
514        // - Required = On with failure if not possible
515        let negotiated_encryption = match (client_encryption, server_encryption) {
516            (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
517                EncryptionLevel::NotSupported
518            }
519            (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
520            (EncryptionLevel::On, EncryptionLevel::Off)
521            | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
522                return Err(Error::Protocol(
523                    "Server does not support requested encryption level".to_string(),
524                ));
525            }
526            _ => EncryptionLevel::On,
527        };
528
529        // TLS is required unless negotiated encryption is NotSupported
530        // Even with "Off", TLS is used to protect login credentials (per TDS 7.x spec)
531        let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
532
533        if use_tls {
534            // Upgrade to TLS with PreLogin wrapping (TDS 7.x style)
535            // In TDS 7.x, the TLS handshake is wrapped inside TDS PreLogin packets
536            let tls_config =
537                TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
538
539            let tls_connector = TlsConnector::new(tls_config)?;
540
541            // Use PreLogin-wrapped TLS connection for TDS 7.x
542            let mut tls_stream = timeout(
543                config.timeouts.tls_timeout,
544                tls_connector.connect_with_prelogin(tcp_stream, &config.host),
545            )
546            .await
547            .map_err(|_| Error::TlsTimeout {
548                host: config.host.clone(),
549                port: config.port,
550            })??;
551
552            tracing::debug!("TLS handshake completed (PreLogin wrapped)");
553
554            // Check if we need full encryption or login-only encryption
555            let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
556
557            if login_only_encryption {
558                // Login-Only Encryption (ENCRYPT_OFF + ENCRYPT_OFF per MS-TDS spec):
559                // - Login7 is sent through TLS to protect credentials
560                // - Server responds in PLAINTEXT after receiving Login7
561                // - All subsequent communication is plaintext
562                //
563                // We must NOT use Connection with TLS stream because Connection splits
564                // the stream and we need to extract the underlying TCP afterward.
565                use tokio::io::AsyncWriteExt;
566
567                // Create SSPI negotiator if integrated auth
568                // Note: SSPI handshake over login-only encryption is limited —
569                // the server response comes in plaintext, so multi-step SSPI
570                // may not work. We include the initial token but don't loop.
571                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
572                let negotiator = Self::create_negotiator(config)?;
573                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
574                let sspi_token = match negotiator {
575                    Some(ref neg) => Some(neg.initialize()?),
576                    None => None,
577                };
578                #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
579                let sspi_token: Option<Vec<u8>> = None;
580
581                // Build and send Login7 directly through TLS
582                let login = Self::build_login7(config, sspi_token);
583                let login_payload = login.encode();
584
585                // Create TDS packet manually for Login7
586                let max_packet = MAX_PACKET_SIZE;
587                let max_payload = max_packet - PACKET_HEADER_SIZE;
588                let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
589                let total_chunks = chunks.len();
590
591                for (i, chunk) in chunks.into_iter().enumerate() {
592                    let is_last = i == total_chunks - 1;
593                    let status = if is_last {
594                        PacketStatus::END_OF_MESSAGE
595                    } else {
596                        PacketStatus::NORMAL
597                    };
598
599                    let header = PacketHeader::new(
600                        PacketType::Tds7Login,
601                        status,
602                        (PACKET_HEADER_SIZE + chunk.len()) as u16,
603                    );
604
605                    let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
606                    header.encode(&mut packet_buf);
607                    packet_buf.put_slice(chunk);
608
609                    tls_stream
610                        .write_all(&packet_buf)
611                        .await
612                        .map_err(Error::from)?;
613                }
614
615                // Flush TLS to ensure all data is sent
616                tls_stream.flush().await.map_err(Error::from)?;
617
618                tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
619
620                // Extract the underlying TCP stream from the TLS layer
621                // TlsStream::into_inner() returns (IO, ClientConnection)
622                // where IO is our TlsPreloginWrapper<TcpStream>
623                let (wrapper, _client_conn) = tls_stream.into_inner();
624                let tcp_stream = wrapper.into_inner();
625
626                // Create Connection from plain TCP for reading response
627                let mut connection = Connection::new(tcp_stream);
628
629                // Process login response (comes in plaintext, with timeout)
630                let (server_version, current_database, routing, server_collation) = timeout(
631                    config.timeouts.login_timeout,
632                    Self::process_login_response(
633                        &mut connection,
634                        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
635                        negotiator.as_deref(),
636                    ),
637                )
638                .await
639                .map_err(|_| Error::LoginTimeout {
640                    host: config.host.clone(),
641                    port: config.port,
642                })??;
643
644                // Handle routing redirect
645                if let Some((host, port)) = routing {
646                    return Err(Error::Routing { host, port });
647                }
648
649                // Store plain TCP connection for subsequent operations
650                Ok(Client {
651                    config: config.clone(),
652                    _state: PhantomData,
653                    connection: Some(ConnectionHandle::Plain(connection)),
654                    server_version,
655                    current_database: current_database.clone(),
656                    server_collation,
657                    statement_cache: StatementCache::with_default_size(),
658                    transaction_descriptor: 0, // Auto-commit mode initially
659                    needs_reset: false,        // Fresh connection, no reset needed
660                    in_flight: false,          // No request pending
661                    #[cfg(feature = "otel")]
662                    instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
663                        .with_database(current_database.clone().unwrap_or_default()),
664                    #[cfg(feature = "always-encrypted")]
665                    encryption_context: config.column_encryption.clone().map(|cfg| {
666                        std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
667                    }),
668                })
669            } else {
670                // Full Encryption (ENCRYPT_ON per MS-TDS spec):
671                // - All communication after TLS handshake goes through TLS
672                let mut connection = Connection::new(tls_stream);
673
674                // Create SSPI negotiator if integrated auth
675                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
676                let negotiator = Self::create_negotiator(config)?;
677                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
678                let sspi_token = match negotiator {
679                    Some(ref neg) => Some(neg.initialize()?),
680                    None => None,
681                };
682                #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
683                let sspi_token: Option<Vec<u8>> = None;
684
685                // Send Login7
686                let login = Self::build_login7(config, sspi_token);
687                Self::send_login7(&mut connection, &login).await?;
688
689                // Process login response (with timeout)
690                let (server_version, current_database, routing, server_collation) = timeout(
691                    config.timeouts.login_timeout,
692                    Self::process_login_response(
693                        &mut connection,
694                        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
695                        negotiator.as_deref(),
696                    ),
697                )
698                .await
699                .map_err(|_| Error::LoginTimeout {
700                    host: config.host.clone(),
701                    port: config.port,
702                })??;
703
704                // Handle routing redirect
705                if let Some((host, port)) = routing {
706                    return Err(Error::Routing { host, port });
707                }
708
709                Ok(Client {
710                    config: config.clone(),
711                    _state: PhantomData,
712                    connection: Some(ConnectionHandle::TlsPrelogin(connection)),
713                    server_version,
714                    current_database: current_database.clone(),
715                    server_collation,
716                    statement_cache: StatementCache::with_default_size(),
717                    transaction_descriptor: 0, // Auto-commit mode initially
718                    needs_reset: false,        // Fresh connection, no reset needed
719                    in_flight: false,          // No request pending
720                    #[cfg(feature = "otel")]
721                    instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
722                        .with_database(current_database.clone().unwrap_or_default()),
723                    #[cfg(feature = "always-encrypted")]
724                    encryption_context: config.column_encryption.clone().map(|cfg| {
725                        std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
726                    }),
727                })
728            }
729        } else {
730            // Server does not require encryption and client doesn't either
731            tracing::warn!(
732                "Connecting without TLS encryption. This is insecure and should only be \
733                 used for development/testing on trusted networks."
734            );
735
736            let mut connection = Connection::new(tcp_stream);
737
738            // Create SSPI negotiator if integrated auth
739            #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
740            let negotiator = Self::create_negotiator(config)?;
741            #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
742            let sspi_token = match negotiator {
743                Some(ref neg) => Some(neg.initialize()?),
744                None => None,
745            };
746            #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
747            let sspi_token: Option<Vec<u8>> = None;
748
749            // Build and send Login7
750            let login = Self::build_login7(config, sspi_token);
751            Self::send_login7(&mut connection, &login).await?;
752
753            // Process login response (with timeout)
754            let (server_version, current_database, routing, server_collation) = timeout(
755                config.timeouts.login_timeout,
756                Self::process_login_response(
757                    &mut connection,
758                    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
759                    negotiator.as_deref(),
760                ),
761            )
762            .await
763            .map_err(|_| Error::LoginTimeout {
764                host: config.host.clone(),
765                port: config.port,
766            })??;
767
768            // Handle routing redirect
769            if let Some((host, port)) = routing {
770                return Err(Error::Routing { host, port });
771            }
772
773            Ok(Client {
774                config: config.clone(),
775                _state: PhantomData,
776                connection: Some(ConnectionHandle::Plain(connection)),
777                server_version,
778                current_database: current_database.clone(),
779                server_collation,
780                statement_cache: StatementCache::with_default_size(),
781                transaction_descriptor: 0, // Auto-commit mode initially
782                needs_reset: false,        // Fresh connection, no reset needed
783                in_flight: false,          // No request pending
784                #[cfg(feature = "otel")]
785                instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
786                    .with_database(current_database.clone().unwrap_or_default()),
787                #[cfg(feature = "always-encrypted")]
788                encryption_context: config.column_encryption.clone().map(|cfg| {
789                    std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
790                }),
791            })
792        }
793    }
794
795    /// Connect without TLS encryption (no_tls mode).
796    ///
797    /// This method is used when the `tls` feature is disabled and only supports
798    /// unencrypted connections via `Encrypt=no_tls`.
799    ///
800    /// # Security Warning
801    ///
802    /// This transmits all data including credentials in plaintext. Only use this
803    /// for development, testing, or on trusted internal networks where TLS is not
804    /// required.
805    #[cfg(not(feature = "tls"))]
806    async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
807        use bytes::BufMut;
808        use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
809        use tokio::io::{AsyncReadExt, AsyncWriteExt};
810
811        tracing::warn!(
812            "⚠️  Connecting without TLS (tls feature disabled). \
813             Credentials and data will be transmitted in plaintext."
814        );
815
816        // Build PreLogin packet with NotSupported encryption
817        let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
818        let prelogin_bytes = prelogin.encode();
819
820        // Manually create and send the PreLogin packet over raw TCP
821        let header = PacketHeader::new(
822            PacketType::PreLogin,
823            PacketStatus::END_OF_MESSAGE,
824            (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
825        );
826
827        let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
828        header.encode(&mut packet_buf);
829        packet_buf.put_slice(&prelogin_bytes);
830
831        tcp_stream
832            .write_all(&packet_buf)
833            .await
834            .map_err(Error::from)?;
835
836        // Read PreLogin response
837        let mut header_buf = [0u8; PACKET_HEADER_SIZE];
838        tcp_stream
839            .read_exact(&mut header_buf)
840            .await
841            .map_err(Error::from)?;
842
843        let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
844        let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
845
846        let mut response_buf = vec![0u8; payload_length];
847        tcp_stream
848            .read_exact(&mut response_buf)
849            .await
850            .map_err(Error::from)?;
851
852        let prelogin_response = PreLogin::decode(&response_buf[..])?;
853
854        // Check server encryption response - must accept NotSupported
855        let server_encryption = prelogin_response.encryption;
856        if server_encryption != EncryptionLevel::NotSupported {
857            return Err(Error::Config(format!(
858                "Server requires encryption (level: {:?}) but TLS feature is disabled. \
859                     Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
860                server_encryption
861            )));
862        }
863
864        tracing::debug!("Server accepted unencrypted connection");
865
866        let mut connection = Connection::new(tcp_stream);
867
868        // Create SSPI negotiator if integrated auth
869        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
870        let negotiator = Self::create_negotiator(config)?;
871        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
872        let sspi_token = match negotiator {
873            Some(ref neg) => Some(neg.initialize()?),
874            None => None,
875        };
876        #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
877        let sspi_token: Option<Vec<u8>> = None;
878
879        // Build and send Login7
880        let login = Self::build_login7(config, sspi_token);
881        Self::send_login7(&mut connection, &login).await?;
882
883        // Process login response (with timeout)
884        let (server_version, current_database, routing, server_collation) = timeout(
885            config.timeouts.login_timeout,
886            Self::process_login_response(
887                &mut connection,
888                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
889                negotiator.as_deref(),
890            ),
891        )
892        .await
893        .map_err(|_| Error::LoginTimeout {
894            host: config.host.clone(),
895            port: config.port,
896        })??;
897
898        // Handle routing redirect
899        if let Some((host, port)) = routing {
900            return Err(Error::Routing { host, port });
901        }
902
903        Ok(Client {
904            config: config.clone(),
905            _state: PhantomData,
906            connection: Some(ConnectionHandle::Plain(connection)),
907            server_version,
908            current_database: current_database.clone(),
909            server_collation,
910            statement_cache: StatementCache::with_default_size(),
911            transaction_descriptor: 0,
912            needs_reset: false,
913            in_flight: false,
914            #[cfg(feature = "otel")]
915            instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
916                .with_database(current_database.clone().unwrap_or_default()),
917            #[cfg(feature = "always-encrypted")]
918            encryption_context: config.column_encryption.clone().map(|cfg| {
919                std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
920            }),
921        })
922    }
923
924    /// Build a PreLogin packet.
925    fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
926        // Use the configured TDS version (strict_mode overrides to V8_0)
927        let version = if config.strict_mode {
928            tds_protocol::version::TdsVersion::V8_0
929        } else {
930            config.tds_version
931        };
932
933        let mut prelogin = PreLogin::new()
934            .with_version(version)
935            .with_encryption(encryption);
936
937        if config.mars {
938            prelogin = prelogin.with_mars(true);
939        }
940
941        if let Some(ref instance) = config.instance {
942            prelogin = prelogin.with_instance(instance);
943        }
944
945        prelogin
946    }
947
948    /// Resolve the workstation ID for the LOGIN7 HostName field.
949    ///
950    /// Per MS-TDS, the LOGIN7 HostName field contains the client machine name
951    /// (not the server name). Priority:
952    /// 1. `Config::workstation_id` (explicit override)
953    /// 2. Machine hostname from environment (`COMPUTERNAME` on Windows, `HOSTNAME` on Linux)
954    /// 3. Empty string (fallback)
955    fn resolve_workstation_id(config: &Config) -> String {
956        if let Some(ref id) = config.workstation_id {
957            return id.clone();
958        }
959        // COMPUTERNAME is set on Windows; HOSTNAME is set on most Linux systems.
960        // This avoids adding a dependency for a simple lookup.
961        std::env::var("COMPUTERNAME")
962            .or_else(|_| std::env::var("HOSTNAME"))
963            .unwrap_or_default()
964    }
965
966    /// Build a Login7 packet.
967    ///
968    /// When `sspi_token` is provided (integrated auth), the Login7 packet is
969    /// built with the integrated security flag and the initial SSPI blob.
970    fn build_login7(config: &Config, sspi_token: Option<Vec<u8>>) -> Login7 {
971        // Use the configured TDS version (strict_mode overrides to V8_0)
972        let version = if config.strict_mode {
973            tds_protocol::version::TdsVersion::V8_0
974        } else {
975            config.tds_version
976        };
977
978        let mut login = Login7::new()
979            .with_tds_version(version)
980            .with_packet_size(config.packet_size as u32)
981            .with_app_name(&config.application_name)
982            .with_server_name(&config.host)
983            .with_hostname(Self::resolve_workstation_id(config));
984
985        if let Some(ref database) = config.database {
986            login = login.with_database(database);
987        }
988
989        // ApplicationIntent → LOGIN7 TypeFlags READONLY_INTENT bit
990        if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
991            login = login.with_read_only_intent(true);
992        }
993
994        // Session language → LOGIN7 Language field
995        if let Some(ref lang) = config.language {
996            login = login.with_language(lang);
997        }
998
999        // Set credentials
1000        if let Some(token) = sspi_token {
1001            // Integrated auth: set SSPI data and integrated security flag
1002            login = login.with_integrated_auth(token);
1003        } else if let mssql_auth::Credentials::SqlServer { username, password } =
1004            &config.credentials
1005        {
1006            login = login.with_sql_auth(username.as_ref(), password.as_ref());
1007        }
1008
1009        // When Always Encrypted is configured, add the ColumnEncryption feature extension.
1010        // Version 1 = client supports column encryption without enclave computations.
1011        #[cfg(feature = "always-encrypted")]
1012        if config.column_encryption.is_some() {
1013            login = login.with_feature(tds_protocol::login7::FeatureExtension {
1014                feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1015                data: bytes::Bytes::from_static(&[0x01]), // Version 1
1016            });
1017            tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1018        }
1019
1020        login
1021    }
1022
1023    /// Create an SSPI/GSSAPI negotiator if integrated auth is configured.
1024    ///
1025    /// Returns `None` for non-integrated credential types.
1026    ///
1027    /// On Windows with `sspi-auth`, uses native Windows SSPI (`secur32.dll`) which
1028    /// supports all account types including Microsoft Accounts. Falls back to sspi-rs
1029    /// on non-Windows platforms.
1030    ///
1031    /// With `integrated-auth` (Linux/macOS), uses GSSAPI/Kerberos.
1032    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1033    fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1034        #[allow(clippy::match_like_matches_macro)]
1035        let is_integrated = match &config.credentials {
1036            mssql_auth::Credentials::Integrated => true,
1037            _ => false,
1038        };
1039
1040        if !is_integrated {
1041            return Ok(None);
1042        }
1043
1044        // On Windows: prefer native SSPI (secur32.dll) for integrated auth.
1045        // This handles all Windows account types including Microsoft Accounts,
1046        // domain accounts, and local accounts — unlike sspi-rs which requires
1047        // explicit credentials.
1048        #[cfg(all(windows, feature = "sspi-auth"))]
1049        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1050            Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1051
1052        // On non-Windows: use sspi-rs (pure Rust SSPI implementation)
1053        #[cfg(all(not(windows), feature = "sspi-auth"))]
1054        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1055            Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1056
1057        #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1058        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1059            Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1060
1061        Ok(Some(negotiator))
1062    }
1063
1064    /// Send a PreLogin packet (for use with Connection).
1065    #[cfg(feature = "tls")]
1066    async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1067    where
1068        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1069    {
1070        let payload = prelogin.encode();
1071        let max_packet = MAX_PACKET_SIZE;
1072
1073        connection
1074            .send_message(PacketType::PreLogin, payload, max_packet)
1075            .await?;
1076        Ok(())
1077    }
1078
1079    /// Receive a PreLogin response (for use with Connection).
1080    #[cfg(feature = "tls")]
1081    async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1082    where
1083        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1084    {
1085        let message = connection
1086            .read_message()
1087            .await?
1088            .ok_or(Error::ConnectionClosed)?;
1089
1090        Ok(PreLogin::decode(&message.payload[..])?)
1091    }
1092
1093    /// Send a Login7 packet.
1094    async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1095    where
1096        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1097    {
1098        let payload = login.encode();
1099        let max_packet = MAX_PACKET_SIZE;
1100
1101        connection
1102            .send_message(PacketType::Tds7Login, payload, max_packet)
1103            .await?;
1104        Ok(())
1105    }
1106
1107    /// Process the login response tokens, handling SSPI challenge/response if needed.
1108    ///
1109    /// When a `negotiator` is provided and the server sends an SSPI challenge token,
1110    /// this method will automatically perform the multi-step SSPI handshake by:
1111    /// 1. Calling `negotiator.step(challenge)` to generate a response
1112    /// 2. Sending the response via an SSPI packet
1113    /// 3. Reading the next server message and continuing
1114    ///
1115    /// Returns: (server_version, database, routing_info)
1116    #[allow(clippy::never_loop)] // Loop is used when integrated-auth/sspi-auth features are enabled
1117    async fn process_login_response<T>(
1118        connection: &mut Connection<T>,
1119        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1120            &dyn mssql_auth::SspiNegotiator,
1121        >,
1122    ) -> Result<(
1123        Option<u32>,
1124        Option<String>,
1125        Option<(String, u16)>,
1126        Option<tds_protocol::token::Collation>,
1127    )>
1128    where
1129        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1130    {
1131        let mut server_version = None;
1132        let mut database = None;
1133        let mut routing = None;
1134        let mut collation = None;
1135
1136        'outer: loop {
1137            let message = connection
1138                .read_message()
1139                .await?
1140                .ok_or(Error::ConnectionClosed)?;
1141
1142            let response_bytes = message.payload;
1143            let mut parser = TokenParser::new(response_bytes);
1144
1145            while let Some(token) = parser.next_token()? {
1146                match token {
1147                    Token::LoginAck(ack) => {
1148                        tracing::info!(
1149                            version = ack.tds_version,
1150                            interface = ack.interface,
1151                            prog_name = %ack.prog_name,
1152                            "login acknowledged"
1153                        );
1154                        server_version = Some(ack.tds_version);
1155                    }
1156                    Token::EnvChange(env) => {
1157                        Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1158                    }
1159                    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1160                    Token::Sspi(sspi_token) => {
1161                        let neg = negotiator.ok_or_else(|| {
1162                            Error::Protocol(
1163                                "server sent SSPI challenge but no negotiator is configured"
1164                                    .to_string(),
1165                            )
1166                        })?;
1167
1168                        tracing::debug!(
1169                            challenge_len = sspi_token.data.len(),
1170                            "received SSPI challenge from server"
1171                        );
1172
1173                        if let Some(response) = neg.step(&sspi_token.data)? {
1174                            tracing::debug!(response_len = response.len(), "sending SSPI response");
1175                            connection
1176                                .send_message(
1177                                    PacketType::Sspi,
1178                                    bytes::Bytes::from(response),
1179                                    tds_protocol::packet::MAX_PACKET_SIZE,
1180                                )
1181                                .await?;
1182                        }
1183
1184                        // After sending the SSPI response, read the next server message
1185                        continue 'outer;
1186                    }
1187                    Token::Error(err) => {
1188                        return Err(Error::Server {
1189                            number: err.number,
1190                            state: err.state,
1191                            class: err.class,
1192                            message: err.message.clone(),
1193                            server: if err.server.is_empty() {
1194                                None
1195                            } else {
1196                                Some(err.server.clone())
1197                            },
1198                            procedure: if err.procedure.is_empty() {
1199                                None
1200                            } else {
1201                                Some(err.procedure.clone())
1202                            },
1203                            line: err.line as u32,
1204                        });
1205                    }
1206                    Token::Info(info) => {
1207                        tracing::info!(
1208                            number = info.number,
1209                            message = %info.message,
1210                            "server info message"
1211                        );
1212                    }
1213                    Token::Done(done) => {
1214                        if done.status.error {
1215                            return Err(Error::Protocol("login failed".to_string()));
1216                        }
1217                        break 'outer;
1218                    }
1219                    _ => {}
1220                }
1221            }
1222
1223            // If we consumed all tokens without a Done or SSPI, break
1224            break;
1225        }
1226
1227        Ok((server_version, database, routing, collation))
1228    }
1229
1230    /// Process an EnvChange token.
1231    fn process_env_change(
1232        env: &EnvChange,
1233        database: &mut Option<String>,
1234        routing: &mut Option<(String, u16)>,
1235        collation: &mut Option<tds_protocol::token::Collation>,
1236    ) {
1237        use tds_protocol::token::EnvChangeValue;
1238
1239        match env.env_type {
1240            EnvChangeType::Database => {
1241                if let EnvChangeValue::String(ref new_value) = env.new_value {
1242                    tracing::debug!(database = %new_value, "database changed");
1243                    *database = Some(new_value.clone());
1244                }
1245            }
1246            EnvChangeType::Routing => {
1247                if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1248                    tracing::info!(host = %host, port = port, "routing redirect received");
1249                    *routing = Some((host.clone(), port));
1250                }
1251            }
1252            EnvChangeType::SqlCollation => {
1253                if let EnvChangeValue::Binary(ref data) = env.new_value {
1254                    if data.len() >= 5 {
1255                        let c = tds_protocol::token::Collation::from_bytes(
1256                            data[..5].try_into().unwrap(),
1257                        );
1258                        tracing::debug!(
1259                            lcid = c.lcid,
1260                            sort_id = c.sort_id,
1261                            "server collation received"
1262                        );
1263                        *collation = Some(c);
1264                    }
1265                }
1266            }
1267            _ => {
1268                if let EnvChangeValue::String(ref new_value) = env.new_value {
1269                    tracing::debug!(
1270                        env_type = ?env.env_type,
1271                        new_value = %new_value,
1272                        "environment change"
1273                    );
1274                }
1275            }
1276        }
1277    }
1278}