Skip to main content

mssql_client/client/
connect.rs

1//! Connection establishment for SQL Server.
2//!
3//! This module contains the `impl Client<Disconnected>` block, handling
4//! TCP connection, TLS negotiation, PreLogin exchange, and Login7 authentication.
5
6use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::DEFAULT_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30/// Federated authentication parameters for a single LOGIN7 attempt.
31///
32/// `echo` mirrors the server's PRELOGIN FEDAUTHREQUIRED response, as required
33/// for the `fFedAuthEcho` bit (MS-TDS §2.2.6.4).
34#[derive(Clone, Copy)]
35struct FedAuthLogin<'a> {
36    token: &'a str,
37    echo: bool,
38}
39
40impl Client<Disconnected> {
41    /// Connect to SQL Server.
42    ///
43    /// This establishes a connection, performs TLS negotiation (if required),
44    /// and authenticates with the server.
45    ///
46    /// # Example
47    ///
48    /// ```rust,no_run
49    /// # use mssql_client::Client;
50    /// # async fn ex(config: mssql_client::Config) -> Result<(), mssql_client::Error> {
51    /// let client = Client::connect(config).await?;
52    /// # let _ = client;
53    /// # Ok(())
54    /// # }
55    /// ```
56    pub async fn connect(config: Config) -> Result<Client<Ready>> {
57        Self::validate_credential_support(&config)?;
58
59        // Azure AD / Entra credentials use the FEDAUTH SecurityToken workflow
60        // (MS-TDS §2.2.6.4): the access token is acquired client-side before
61        // any TDS traffic and sent in the LOGIN7 FEDAUTH feature extension.
62        // Acquired once here so retries and Azure gateway redirects reuse it.
63        let fed_auth_token = Self::resolve_fed_auth_token(&config).await?;
64
65        let retry = config.retry.clone();
66        let max_redirects = config.redirect.max_redirects;
67        let follow_redirects = config.redirect.follow_redirects;
68        // Overall timeout accounts for retries + redirects per attempt, capped at 5 min.
69        let per_attempt = config.timeouts.connect_timeout
70            + config.timeouts.tls_timeout
71            + config.timeouts.login_timeout;
72        let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
73        let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
74        let initial_host = config.host.clone();
75        let initial_port = config.port;
76
77        let result = timeout(overall, async {
78            let mut last_error: Option<Error> = None;
79
80            for retry_attempt in 0..=retry.max_retries {
81                if retry_attempt > 0 {
82                    let backoff = retry.backoff_for_attempt(retry_attempt);
83                    tracing::info!(
84                        retry_attempt,
85                        backoff_ms = backoff.as_millis() as u64,
86                        "retrying connection after transient error"
87                    );
88                    tokio::time::sleep(backoff).await;
89                }
90
91                // Each retry starts fresh with original host/port
92                let mut current_config = config.clone();
93                let mut redirect_count: u8 = 0;
94
95                let attempt_result = loop {
96                    redirect_count += 1;
97                    if redirect_count > max_redirects + 1 {
98                        break Err(Error::TooManyRedirects { max: max_redirects });
99                    }
100
101                    match Self::try_connect(&current_config, fed_auth_token.as_deref()).await {
102                        Ok(client) => break Ok(client),
103                        Err(Error::Routing { host, port }) => {
104                            if !follow_redirects {
105                                break Err(Error::Routing { host, port });
106                            }
107                            tracing::info!(
108                                host = %host,
109                                port = port,
110                                redirect = redirect_count,
111                                max_redirects = max_redirects,
112                                "following Azure SQL routing redirect"
113                            );
114                            current_config = current_config.with_host(&host).with_port(port);
115                            continue;
116                        }
117                        Err(e) => break Err(e),
118                    }
119                };
120
121                match attempt_result {
122                    Ok(client) => return Ok(client),
123                    Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
124                        tracing::warn!(
125                            retry_attempt,
126                            max_retries = retry.max_retries,
127                            error = %e,
128                            "transient connection error, will retry"
129                        );
130                        last_error = Some(attempt_result.unwrap_err());
131                    }
132                    Err(e) => return Err(e),
133                }
134            }
135
136            // All retries exhausted — return last error
137            Err(last_error.expect("at least one attempt was made"))
138        })
139        .await;
140
141        match result {
142            Ok(inner) => inner,
143            Err(_elapsed) => Err(Error::ConnectTimeout {
144                host: initial_host,
145                port: initial_port,
146            }),
147        }
148    }
149
150    /// Validate that the configured credentials can complete a login.
151    ///
152    /// Fails fast with an actionable error instead of sending a login the
153    /// server would reject with an opaque error 18456 (or worse, leaking a
154    /// bearer token over plaintext).
155    fn validate_credential_support(config: &Config) -> Result<()> {
156        match &config.credentials {
157            mssql_auth::Credentials::SqlServer { .. } => Ok(()),
158            #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
159            mssql_auth::Credentials::Integrated => Ok(()),
160            creds if creds.is_azure_ad() => {
161                // A FEDAUTH login carries a bearer token; sending it over a
162                // plaintext connection would hand the token to any on-path
163                // observer. Azure SQL always requires TLS anyway.
164                #[cfg(not(feature = "tls"))]
165                {
166                    return Err(Error::Config(
167                        "Azure AD / Entra (FEDAUTH) authentication requires TLS: \
168                         enable the 'tls' feature."
169                            .into(),
170                    ));
171                }
172                #[cfg(feature = "tls")]
173                {
174                    if config.no_tls {
175                        return Err(Error::Config(
176                            "Azure AD / Entra (FEDAUTH) authentication cannot be combined \
177                             with Encrypt=no_tls: the access token would be sent in \
178                             plaintext. Use Encrypt=mandatory or Encrypt=strict."
179                                .into(),
180                        ));
181                    }
182                    if matches!(&config.credentials,
183                        mssql_auth::Credentials::AzureAccessToken { token } if token.is_empty())
184                    {
185                        return Err(Error::Config(
186                            "Azure AD access token is empty (the FEDAUTH token length \
187                             must not be zero)"
188                                .into(),
189                        ));
190                    }
191                    if !config.strict_mode && !config.tds_version.supports_fed_auth() {
192                        return Err(Error::Config(format!(
193                            "Azure AD / Entra (FEDAUTH) authentication requires TDS 7.4 \
194                             or later (configured: {})",
195                            config.tds_version
196                        )));
197                    }
198                    Ok(())
199                }
200            }
201            // Remaining credential types (client certificate) cannot complete
202            // a login yet: certificate-acquired tokens are not wired into the
203            // login sequence. Tracked in issue #155.
204            _ => Err(Error::Config(
205                "client certificate (FEDAUTH) authentication is not yet supported \
206                 (tracked in https://github.com/praxiomlabs/rust-mssql-driver/issues/155). \
207                 Use SQL Server, integrated, or Azure AD / Entra authentication."
208                    .into(),
209            )),
210        }
211    }
212
213    /// Resolve the federated authentication access token, if these
214    /// credentials use FEDAUTH.
215    ///
216    /// Pre-acquired tokens are passed through; managed identity and service
217    /// principal credentials acquire a token from Entra ID (network I/O).
218    /// Returns `None` for non-FEDAUTH credentials.
219    async fn resolve_fed_auth_token(config: &Config) -> Result<Option<String>> {
220        match &config.credentials {
221            mssql_auth::Credentials::AzureAccessToken { token } => Ok(Some(token.to_string())),
222            #[cfg(feature = "azure-identity")]
223            mssql_auth::Credentials::AzureManagedIdentity { client_id } => {
224                let auth = match client_id {
225                    Some(id) => {
226                        mssql_auth::ManagedIdentityAuth::user_assigned_client_id(id.to_string())?
227                    }
228                    None => mssql_auth::ManagedIdentityAuth::system_assigned()?,
229                };
230                tracing::debug!("acquiring Azure SQL access token via managed identity");
231                Ok(Some(auth.get_token().await?))
232            }
233            #[cfg(feature = "azure-identity")]
234            mssql_auth::Credentials::AzureServicePrincipal {
235                tenant_id,
236                client_id,
237                client_secret,
238            } => {
239                let auth = mssql_auth::ServicePrincipalAuth::new(
240                    tenant_id.as_ref(),
241                    client_id.to_string(),
242                    client_secret.to_string(),
243                )?;
244                tracing::debug!(
245                    client_id = %client_id,
246                    "acquiring Azure SQL access token via service principal"
247                );
248                Ok(Some(auth.get_token().await?))
249            }
250            _ => Ok(None),
251        }
252    }
253
254    async fn try_connect(config: &Config, fed_auth_token: Option<&str>) -> Result<Client<Ready>> {
255        // If a named instance is specified, resolve the TCP port via SQL Browser
256        let port = if let Some(ref instance) = config.instance {
257            let resolved = crate::browser::resolve_instance(
258                &config.host,
259                instance,
260                Some(config.timeouts.connect_timeout),
261            )
262            .await?;
263            tracing::info!(
264                host = %config.host,
265                instance = %instance,
266                resolved_port = resolved,
267                database = ?config.database,
268                "connecting to named SQL Server instance"
269            );
270            resolved
271        } else {
272            tracing::info!(
273                host = %config.host,
274                port = config.port,
275                database = ?config.database,
276                "connecting to SQL Server"
277            );
278            config.port
279        };
280
281        // Normalize "." and "(local)" to localhost for TCP.
282        // These are standard ADO.NET aliases for the local machine.
283        let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
284            "127.0.0.1"
285        } else {
286            &config.host
287        };
288
289        // Step 1: Establish TCP connection
290        let tcp_stream = if config.multi_subnet_failover {
291            Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
292        } else {
293            let addr = format!("{host}:{port}");
294            tracing::debug!("establishing TCP connection to {}", addr);
295            let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
296                .await
297                .map_err(|_| Error::ConnectTimeout {
298                    host: config.host.clone(),
299                    port: config.port,
300                })?
301                .map_err(Error::from)?;
302            stream.set_nodelay(true).map_err(Error::from)?;
303            stream
304        };
305
306        #[cfg(feature = "tls")]
307        {
308            // Determine TLS negotiation mode
309            let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
310
311            // Step 2: Handle TDS 8.0 strict mode (TLS before any TDS traffic)
312            if tls_mode.is_tls_first() {
313                return Self::connect_tds_8(config, tcp_stream, fed_auth_token).await;
314            }
315
316            // Step 3: TDS 7.x flow - PreLogin first, then TLS, then Login7
317            Self::connect_tds_7x(config, tcp_stream, fed_auth_token).await
318        }
319
320        #[cfg(not(feature = "tls"))]
321        {
322            // FEDAUTH credentials were rejected by validate_credential_support
323            // (no TLS feature means no way to protect the bearer token).
324            let _ = fed_auth_token;
325
326            // When TLS feature is disabled, only no_tls connections are supported
327            if config.strict_mode {
328                return Err(Error::Config(
329                    "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
330                ));
331            }
332
333            if !config.no_tls {
334                return Err(Error::Config(
335                    "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
336                     or use Encrypt=no_tls in your connection string for unencrypted connections."
337                        .into(),
338                ));
339            }
340
341            // Proceed with no-TLS connection
342            Self::connect_no_tls(config, tcp_stream).await
343        }
344    }
345
346    /// Resolve hostname to all IPs and race parallel TCP connections.
347    ///
348    /// Used when `MultiSubnetFailover=True` for AlwaysOn AG listeners that
349    /// span multiple subnets. First successful TCP connection wins.
350    async fn connect_parallel(
351        host: &str,
352        port: u16,
353        connect_timeout: std::time::Duration,
354    ) -> Result<TcpStream> {
355        let addr_str = format!("{host}:{port}");
356        let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
357            .await
358            .map_err(Error::from)?
359            .collect();
360
361        if addrs.is_empty() {
362            return Err(Error::from(std::io::Error::new(
363                std::io::ErrorKind::AddrNotAvailable,
364                format!("no addresses resolved for {host}:{port}"),
365            )));
366        }
367
368        // Single address — no need to spawn tasks
369        if addrs.len() == 1 {
370            tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
371            let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
372                .await
373                .map_err(|_| Error::ConnectTimeout {
374                    host: host.to_string(),
375                    port,
376                })?
377                .map_err(Error::from)?;
378            stream.set_nodelay(true).map_err(Error::from)?;
379            return Ok(stream);
380        }
381
382        let addr_count = addrs.len();
383        tracing::debug!(
384            host = host,
385            port = port,
386            resolved_count = addr_count,
387            "MultiSubnetFailover: racing parallel connections",
388        );
389
390        let mut join_set = tokio::task::JoinSet::new();
391
392        for addr in addrs {
393            let dur = connect_timeout;
394            join_set.spawn(async move {
395                let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
396                    std::io::Error::new(
397                        std::io::ErrorKind::TimedOut,
398                        format!("connection to {addr} timed out"),
399                    )
400                })??;
401                tcp.set_nodelay(true)?;
402                Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
403            });
404        }
405
406        let mut last_error: Option<std::io::Error> = None;
407
408        while let Some(result) = join_set.join_next().await {
409            match result {
410                Ok(Ok((stream, addr))) => {
411                    tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
412                    join_set.abort_all();
413                    return Ok(stream);
414                }
415                Ok(Err(e)) => {
416                    tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
417                    last_error = Some(e);
418                }
419                Err(join_err) => {
420                    tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
421                    last_error = Some(std::io::Error::other(join_err.to_string()));
422                }
423            }
424        }
425
426        // All connections failed
427        Err(Error::from(last_error.unwrap_or_else(|| {
428            std::io::Error::new(
429                std::io::ErrorKind::ConnectionRefused,
430                format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
431            )
432        })))
433    }
434
435    /// Connect using TDS 8.0 strict mode.
436    ///
437    /// Flow: TCP -> TLS -> PreLogin (encrypted) -> Login7 (encrypted)
438    #[cfg(feature = "tls")]
439    async fn connect_tds_8(
440        config: &Config,
441        tcp_stream: TcpStream,
442        fed_auth_token: Option<&str>,
443    ) -> Result<Client<Ready>> {
444        tracing::debug!("using TDS 8.0 strict mode (TLS first)");
445
446        // Build TLS configuration from the user's `config.tls` plus the
447        // TDS 8.0 strict-mode requirements (see `connection_tls_config`).
448        let tls_config = connection_tls_config(config, true);
449
450        let tls_connector = TlsConnector::new(tls_config)?;
451
452        // Perform TLS handshake before any TDS traffic
453        let tls_stream = timeout(
454            config.timeouts.tls_timeout,
455            tls_connector.connect(tcp_stream, &config.host),
456        )
457        .await
458        .map_err(|_| Error::TlsTimeout {
459            host: config.host.clone(),
460            port: config.port,
461        })??;
462
463        tracing::debug!("TLS handshake completed (strict mode)");
464
465        // Create connection wrapper
466        let mut connection = Connection::new(tls_stream);
467        connection.set_max_message_size(config.max_response_size);
468
469        // Send PreLogin (encrypted in strict mode)
470        let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
471        Self::send_prelogin(&mut connection, &prelogin).await?;
472        let prelogin_response = Self::receive_prelogin(&mut connection).await?;
473
474        // Create SSPI negotiator if integrated auth
475        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
476        let negotiator = Self::create_negotiator(config)?;
477        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
478        let sspi_token = match negotiator {
479            Some(ref neg) => Some(neg.initialize()?),
480            None => None,
481        };
482        #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
483        let sspi_token: Option<Vec<u8>> = None;
484
485        // Send Login7
486        let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
487            token,
488            echo: prelogin_response.fed_auth_required,
489        });
490        let login = Self::build_login7(config, sspi_token, fed_auth);
491        Self::send_login7(&mut connection, &login).await?;
492
493        // Process login response (with timeout to prevent hangs during redirect)
494        let (server_version, current_database, routing, server_collation) = timeout(
495            config.timeouts.login_timeout,
496            Self::process_login_response(
497                &mut connection,
498                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
499                negotiator.as_deref(),
500            ),
501        )
502        .await
503        .map_err(|_| Error::LoginTimeout {
504            host: config.host.clone(),
505            port: config.port,
506        })??;
507
508        // Handle routing redirect
509        if let Some((host, port)) = routing {
510            return Err(Error::Routing { host, port });
511        }
512
513        Ok(Client {
514            config: config.clone(),
515            _state: PhantomData,
516            connection: Some(ConnectionHandle::Tls(connection)),
517            server_version,
518            current_database: current_database.clone(),
519            server_collation,
520            statement_cache: StatementCache::with_default_size(),
521            transaction_descriptor: 0, // Auto-commit mode initially
522            needs_reset: false,        // Fresh connection, no reset needed
523            in_flight: false,          // No request pending
524            #[cfg(feature = "otel")]
525            instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
526                .with_database(current_database.clone().unwrap_or_default()),
527            #[cfg(feature = "always-encrypted")]
528            encryption_context: config.column_encryption.clone().map(|cfg| {
529                std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
530            }),
531        })
532    }
533
534    /// Connect using TDS 7.x flow.
535    ///
536    /// Flow: TCP -> PreLogin (clear) -> TLS -> Login7 (encrypted)
537    ///
538    /// Note: For TDS 7.x, the PreLogin exchange happens over raw TCP before
539    /// upgrading to TLS. We use low-level I/O for this initial exchange
540    /// since the Connection struct splits the stream immediately.
541    #[cfg(feature = "tls")]
542    async fn connect_tds_7x(
543        config: &Config,
544        mut tcp_stream: TcpStream,
545        fed_auth_token: Option<&str>,
546    ) -> Result<Client<Ready>> {
547        use bytes::BufMut;
548        use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
549        use tokio::io::{AsyncReadExt, AsyncWriteExt};
550
551        tracing::debug!("using TDS 7.x flow (PreLogin first)");
552
553        // Build PreLogin packet
554        // Determine client encryption level based on configuration
555        let client_encryption = if config.no_tls {
556            // no_tls: Completely disable TLS
557            tracing::warn!(
558                "⚠️  no_tls mode enabled. Connection will be UNENCRYPTED. \
559                 Credentials and data will be transmitted in plaintext. \
560                 This should only be used for development/testing with legacy SQL Server."
561            );
562            EncryptionLevel::NotSupported
563        } else if config.encrypt {
564            EncryptionLevel::On
565        } else {
566            EncryptionLevel::Off
567        };
568        let prelogin = Self::build_prelogin(config, client_encryption);
569        tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
570        let prelogin_bytes = prelogin.encode();
571
572        // Manually create and send the PreLogin packet over raw TCP
573        let header = PacketHeader::new(
574            PacketType::PreLogin,
575            PacketStatus::END_OF_MESSAGE,
576            (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
577        );
578
579        let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
580        header.encode(&mut packet_buf);
581        packet_buf.put_slice(&prelogin_bytes);
582
583        tcp_stream
584            .write_all(&packet_buf)
585            .await
586            .map_err(Error::from)?;
587
588        // Read PreLogin response
589        let mut header_buf = [0u8; PACKET_HEADER_SIZE];
590        tcp_stream
591            .read_exact(&mut header_buf)
592            .await
593            .map_err(Error::from)?;
594
595        let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
596        let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
597
598        let mut response_buf = vec![0u8; payload_length];
599        tcp_stream
600            .read_exact(&mut response_buf)
601            .await
602            .map_err(Error::from)?;
603
604        let prelogin_response = PreLogin::decode(&response_buf[..])?;
605
606        // Log PreLogin response
607        // Note: The server sends its SQL Server product version in PreLogin,
608        // NOT the TDS protocol version. The actual TDS version is negotiated
609        // in the LOGINACK token after login.
610        let client_tds_version = config.tds_version;
611        if let Some(ref server_version) = prelogin_response.server_version {
612            tracing::debug!(
613                requested_tds_version = %client_tds_version,
614                server_product_version = %server_version,
615                server_product = server_version.product_name(),
616                max_tds_version = %server_version.max_tds_version(),
617                "PreLogin response received"
618            );
619
620            // Warn if the server's max TDS version is lower than requested
621            let server_max_tds = server_version.max_tds_version();
622            if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
623                tracing::warn!(
624                    requested_tds_version = %client_tds_version,
625                    server_max_tds_version = %server_max_tds,
626                    server_product = server_version.product_name(),
627                    "Server supports lower TDS version than requested. \
628                     Connection will use server's maximum: {}",
629                    server_max_tds
630                );
631            }
632
633            // Warn about legacy SQL Server versions (2005 and earlier)
634            if server_max_tds.is_legacy() {
635                tracing::warn!(
636                    server_product = server_version.product_name(),
637                    server_max_tds_version = %server_max_tds,
638                    "Server uses legacy TDS version. Some features may not be available."
639                );
640            }
641        } else {
642            tracing::debug!(
643                requested_tds_version = %client_tds_version,
644                "PreLogin response received (no version info)"
645            );
646        }
647
648        // Check server encryption response
649        let server_encryption = prelogin_response.encryption;
650        tracing::debug!(encryption = ?server_encryption, "server encryption level");
651
652        // FEDAUTH: echo the server's FEDAUTHREQUIRED response (fFedAuthEcho).
653        let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
654            token,
655            echo: prelogin_response.fed_auth_required,
656        });
657
658        // Determine negotiated encryption level (follows TDS 7.x rules)
659        // - NotSupported + NotSupported = NotSupported (no TLS at all)
660        // - Off + Off = Off (TLS for login only, then plain)
661        // - On + anything supported = On (full TLS)
662        // - Required = On with failure if not possible
663        let negotiated_encryption = match (client_encryption, server_encryption) {
664            (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
665                EncryptionLevel::NotSupported
666            }
667            (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
668            (EncryptionLevel::On, EncryptionLevel::Off)
669            | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
670                return Err(Error::Protocol(
671                    "Server does not support requested encryption level".to_string(),
672                ));
673            }
674            _ => EncryptionLevel::On,
675        };
676
677        // TLS is required unless negotiated encryption is NotSupported
678        // Even with "Off", TLS is used to protect login credentials (per TDS 7.x spec)
679        let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
680
681        if use_tls {
682            // Upgrade to TLS with PreLogin wrapping (TDS 7.x style).
683            // In TDS 7.x, the TLS handshake is wrapped inside TDS PreLogin
684            // packets. Honor the user's `config.tls` (custom root certs,
685            // client auth) without the TDS 8.0 strict ALPN.
686            let tls_config = connection_tls_config(config, false);
687
688            let tls_connector = TlsConnector::new(tls_config)?;
689
690            // Use PreLogin-wrapped TLS connection for TDS 7.x
691            let mut tls_stream = timeout(
692                config.timeouts.tls_timeout,
693                tls_connector.connect_with_prelogin(tcp_stream, &config.host),
694            )
695            .await
696            .map_err(|_| Error::TlsTimeout {
697                host: config.host.clone(),
698                port: config.port,
699            })??;
700
701            tracing::debug!("TLS handshake completed (PreLogin wrapped)");
702
703            // Check if we need full encryption or login-only encryption
704            let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
705
706            if login_only_encryption {
707                // Login-Only Encryption (ENCRYPT_OFF + ENCRYPT_OFF per MS-TDS spec):
708                // - Login7 is sent through TLS to protect credentials
709                // - Server responds in PLAINTEXT after receiving Login7
710                // - All subsequent communication is plaintext
711                //
712                // We must NOT use Connection with TLS stream because Connection splits
713                // the stream and we need to extract the underlying TCP afterward.
714                use tokio::io::AsyncWriteExt;
715
716                // Create SSPI negotiator if integrated auth
717                // Note: SSPI handshake over login-only encryption is limited —
718                // the server response comes in plaintext, so multi-step SSPI
719                // may not work. We include the initial token but don't loop.
720                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
721                let negotiator = Self::create_negotiator(config)?;
722                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
723                let sspi_token = match negotiator {
724                    Some(ref neg) => Some(neg.initialize()?),
725                    None => None,
726                };
727                #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
728                let sspi_token: Option<Vec<u8>> = None;
729
730                // Build and send Login7 directly through TLS
731                let login = Self::build_login7(config, sspi_token, fed_auth);
732                let login_payload = login.encode();
733
734                // Create TDS packet manually for Login7. LOGIN7 is sent before
735                // packet-size negotiation completes, so it MUST be split at the
736                // 4096-byte TDS default — large FEDAUTH tokens (managed identity,
737                // AAD tokens with many claims) push LOGIN7 over 4096 and the
738                // server resets a single oversized packet.
739                let max_packet = DEFAULT_PACKET_SIZE;
740                let max_payload = max_packet - PACKET_HEADER_SIZE;
741                let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
742                let total_chunks = chunks.len();
743
744                for (i, chunk) in chunks.into_iter().enumerate() {
745                    let is_last = i == total_chunks - 1;
746                    let status = if is_last {
747                        PacketStatus::END_OF_MESSAGE
748                    } else {
749                        PacketStatus::NORMAL
750                    };
751
752                    let header = PacketHeader::new(
753                        PacketType::Tds7Login,
754                        status,
755                        (PACKET_HEADER_SIZE + chunk.len()) as u16,
756                    );
757
758                    let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
759                    header.encode(&mut packet_buf);
760                    packet_buf.put_slice(chunk);
761
762                    tls_stream
763                        .write_all(&packet_buf)
764                        .await
765                        .map_err(Error::from)?;
766                }
767
768                // Flush TLS to ensure all data is sent
769                tls_stream.flush().await.map_err(Error::from)?;
770
771                tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
772
773                // Extract the underlying TCP stream from the TLS layer
774                // TlsStream::into_inner() returns (IO, ClientConnection)
775                // where IO is our TlsPreloginWrapper<TcpStream>
776                let (wrapper, _client_conn) = tls_stream.into_inner();
777                let tcp_stream = wrapper.into_inner();
778
779                // Create Connection from plain TCP for reading response
780                let mut connection = Connection::new(tcp_stream);
781                connection.set_max_message_size(config.max_response_size);
782
783                // Process login response (comes in plaintext, with timeout)
784                let (server_version, current_database, routing, server_collation) = timeout(
785                    config.timeouts.login_timeout,
786                    Self::process_login_response(
787                        &mut connection,
788                        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
789                        negotiator.as_deref(),
790                    ),
791                )
792                .await
793                .map_err(|_| Error::LoginTimeout {
794                    host: config.host.clone(),
795                    port: config.port,
796                })??;
797
798                // Handle routing redirect
799                if let Some((host, port)) = routing {
800                    return Err(Error::Routing { host, port });
801                }
802
803                // Store plain TCP connection for subsequent operations
804                Ok(Client {
805                    config: config.clone(),
806                    _state: PhantomData,
807                    connection: Some(ConnectionHandle::Plain(connection)),
808                    server_version,
809                    current_database: current_database.clone(),
810                    server_collation,
811                    statement_cache: StatementCache::with_default_size(),
812                    transaction_descriptor: 0, // Auto-commit mode initially
813                    needs_reset: false,        // Fresh connection, no reset needed
814                    in_flight: false,          // No request pending
815                    #[cfg(feature = "otel")]
816                    instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
817                        .with_database(current_database.clone().unwrap_or_default()),
818                    #[cfg(feature = "always-encrypted")]
819                    encryption_context: config.column_encryption.clone().map(|cfg| {
820                        std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
821                    }),
822                })
823            } else {
824                // Full Encryption (ENCRYPT_ON per MS-TDS spec):
825                // - All communication after TLS handshake goes through TLS
826                let mut connection = Connection::new(tls_stream);
827                connection.set_max_message_size(config.max_response_size);
828
829                // Create SSPI negotiator if integrated auth
830                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
831                let negotiator = Self::create_negotiator(config)?;
832                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
833                let sspi_token = match negotiator {
834                    Some(ref neg) => Some(neg.initialize()?),
835                    None => None,
836                };
837                #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
838                let sspi_token: Option<Vec<u8>> = None;
839
840                // Send Login7
841                let login = Self::build_login7(config, sspi_token, fed_auth);
842                Self::send_login7(&mut connection, &login).await?;
843
844                // Process login response (with timeout)
845                let (server_version, current_database, routing, server_collation) = timeout(
846                    config.timeouts.login_timeout,
847                    Self::process_login_response(
848                        &mut connection,
849                        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
850                        negotiator.as_deref(),
851                    ),
852                )
853                .await
854                .map_err(|_| Error::LoginTimeout {
855                    host: config.host.clone(),
856                    port: config.port,
857                })??;
858
859                // Handle routing redirect
860                if let Some((host, port)) = routing {
861                    return Err(Error::Routing { host, port });
862                }
863
864                Ok(Client {
865                    config: config.clone(),
866                    _state: PhantomData,
867                    connection: Some(ConnectionHandle::TlsPrelogin(connection)),
868                    server_version,
869                    current_database: current_database.clone(),
870                    server_collation,
871                    statement_cache: StatementCache::with_default_size(),
872                    transaction_descriptor: 0, // Auto-commit mode initially
873                    needs_reset: false,        // Fresh connection, no reset needed
874                    in_flight: false,          // No request pending
875                    #[cfg(feature = "otel")]
876                    instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
877                        .with_database(current_database.clone().unwrap_or_default()),
878                    #[cfg(feature = "always-encrypted")]
879                    encryption_context: config.column_encryption.clone().map(|cfg| {
880                        std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
881                    }),
882                })
883            }
884        } else {
885            // Server does not require encryption and client doesn't either
886            tracing::warn!(
887                "Connecting without TLS encryption. This is insecure and should only be \
888                 used for development/testing on trusted networks."
889            );
890
891            let mut connection = Connection::new(tcp_stream);
892
893            connection.set_max_message_size(config.max_response_size);
894
895            // Create SSPI negotiator if integrated auth
896            #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
897            let negotiator = Self::create_negotiator(config)?;
898            #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
899            let sspi_token = match negotiator {
900                Some(ref neg) => Some(neg.initialize()?),
901                None => None,
902            };
903            #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
904            let sspi_token: Option<Vec<u8>> = None;
905
906            // Build and send Login7. `fed_auth` is provably None here: a
907            // plaintext connection requires Encrypt=no_tls, which
908            // validate_credential_support rejects for FEDAUTH credentials.
909            let login = Self::build_login7(config, sspi_token, fed_auth);
910            Self::send_login7(&mut connection, &login).await?;
911
912            // Process login response (with timeout)
913            let (server_version, current_database, routing, server_collation) = timeout(
914                config.timeouts.login_timeout,
915                Self::process_login_response(
916                    &mut connection,
917                    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
918                    negotiator.as_deref(),
919                ),
920            )
921            .await
922            .map_err(|_| Error::LoginTimeout {
923                host: config.host.clone(),
924                port: config.port,
925            })??;
926
927            // Handle routing redirect
928            if let Some((host, port)) = routing {
929                return Err(Error::Routing { host, port });
930            }
931
932            Ok(Client {
933                config: config.clone(),
934                _state: PhantomData,
935                connection: Some(ConnectionHandle::Plain(connection)),
936                server_version,
937                current_database: current_database.clone(),
938                server_collation,
939                statement_cache: StatementCache::with_default_size(),
940                transaction_descriptor: 0, // Auto-commit mode initially
941                needs_reset: false,        // Fresh connection, no reset needed
942                in_flight: false,          // No request pending
943                #[cfg(feature = "otel")]
944                instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
945                    .with_database(current_database.clone().unwrap_or_default()),
946                #[cfg(feature = "always-encrypted")]
947                encryption_context: config.column_encryption.clone().map(|cfg| {
948                    std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
949                }),
950            })
951        }
952    }
953
954    /// Connect without TLS encryption (no_tls mode).
955    ///
956    /// This method is used when the `tls` feature is disabled and only supports
957    /// unencrypted connections via `Encrypt=no_tls`.
958    ///
959    /// # Security Warning
960    ///
961    /// This transmits all data including credentials in plaintext. Only use this
962    /// for development, testing, or on trusted internal networks where TLS is not
963    /// required.
964    #[cfg(not(feature = "tls"))]
965    async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
966        use bytes::BufMut;
967        use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
968        use tokio::io::{AsyncReadExt, AsyncWriteExt};
969
970        tracing::warn!(
971            "⚠️  Connecting without TLS (tls feature disabled). \
972             Credentials and data will be transmitted in plaintext."
973        );
974
975        // Build PreLogin packet with NotSupported encryption
976        let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
977        let prelogin_bytes = prelogin.encode();
978
979        // Manually create and send the PreLogin packet over raw TCP
980        let header = PacketHeader::new(
981            PacketType::PreLogin,
982            PacketStatus::END_OF_MESSAGE,
983            (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
984        );
985
986        let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
987        header.encode(&mut packet_buf);
988        packet_buf.put_slice(&prelogin_bytes);
989
990        tcp_stream
991            .write_all(&packet_buf)
992            .await
993            .map_err(Error::from)?;
994
995        // Read PreLogin response
996        let mut header_buf = [0u8; PACKET_HEADER_SIZE];
997        tcp_stream
998            .read_exact(&mut header_buf)
999            .await
1000            .map_err(Error::from)?;
1001
1002        let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
1003        let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
1004
1005        let mut response_buf = vec![0u8; payload_length];
1006        tcp_stream
1007            .read_exact(&mut response_buf)
1008            .await
1009            .map_err(Error::from)?;
1010
1011        let prelogin_response = PreLogin::decode(&response_buf[..])?;
1012
1013        // Check server encryption response - must accept NotSupported
1014        let server_encryption = prelogin_response.encryption;
1015        if server_encryption != EncryptionLevel::NotSupported {
1016            return Err(Error::Config(format!(
1017                "Server requires encryption (level: {:?}) but TLS feature is disabled. \
1018                     Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
1019                server_encryption
1020            )));
1021        }
1022
1023        tracing::debug!("Server accepted unencrypted connection");
1024
1025        let mut connection = Connection::new(tcp_stream);
1026
1027        connection.set_max_message_size(config.max_response_size);
1028
1029        // Create SSPI negotiator if integrated auth
1030        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1031        let negotiator = Self::create_negotiator(config)?;
1032        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1033        let sspi_token = match negotiator {
1034            Some(ref neg) => Some(neg.initialize()?),
1035            None => None,
1036        };
1037        #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
1038        let sspi_token: Option<Vec<u8>> = None;
1039
1040        // Build and send Login7 (FEDAUTH credentials were rejected by
1041        // validate_credential_support: no TLS feature, no token protection).
1042        let login = Self::build_login7(config, sspi_token, None);
1043        Self::send_login7(&mut connection, &login).await?;
1044
1045        // Process login response (with timeout)
1046        let (server_version, current_database, routing, server_collation) = timeout(
1047            config.timeouts.login_timeout,
1048            Self::process_login_response(
1049                &mut connection,
1050                #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1051                negotiator.as_deref(),
1052            ),
1053        )
1054        .await
1055        .map_err(|_| Error::LoginTimeout {
1056            host: config.host.clone(),
1057            port: config.port,
1058        })??;
1059
1060        // Handle routing redirect
1061        if let Some((host, port)) = routing {
1062            return Err(Error::Routing { host, port });
1063        }
1064
1065        Ok(Client {
1066            config: config.clone(),
1067            _state: PhantomData,
1068            connection: Some(ConnectionHandle::Plain(connection)),
1069            server_version,
1070            current_database: current_database.clone(),
1071            server_collation,
1072            statement_cache: StatementCache::with_default_size(),
1073            transaction_descriptor: 0,
1074            needs_reset: false,
1075            in_flight: false,
1076            #[cfg(feature = "otel")]
1077            instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
1078                .with_database(current_database.clone().unwrap_or_default()),
1079            #[cfg(feature = "always-encrypted")]
1080            encryption_context: config.column_encryption.clone().map(|cfg| {
1081                std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
1082            }),
1083        })
1084    }
1085
1086    /// Build a PreLogin packet.
1087    fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
1088        // Use the configured TDS version (strict_mode overrides to V8_0)
1089        let version = if config.strict_mode {
1090            tds_protocol::version::TdsVersion::V8_0
1091        } else {
1092            config.tds_version
1093        };
1094
1095        let mut prelogin = PreLogin::new()
1096            .with_version(version)
1097            .with_encryption(encryption);
1098
1099        if config.mars {
1100            prelogin = prelogin.with_mars(true);
1101        }
1102
1103        if let Some(ref instance) = config.instance {
1104            prelogin = prelogin.with_instance(instance);
1105        }
1106
1107        // Advertise federated authentication so the server's response carries
1108        // the FEDAUTHREQUIRED value we must echo in LOGIN7 (fFedAuthEcho).
1109        if config.credentials.is_azure_ad() {
1110            prelogin = prelogin.with_fed_auth_required(true);
1111        }
1112
1113        prelogin
1114    }
1115
1116    /// Resolve the workstation ID for the LOGIN7 HostName field.
1117    ///
1118    /// Per MS-TDS, the LOGIN7 HostName field contains the client machine name
1119    /// (not the server name). Priority:
1120    /// 1. `Config::workstation_id` (explicit override)
1121    /// 2. Machine hostname from environment (`COMPUTERNAME` on Windows, `HOSTNAME` on Linux)
1122    /// 3. Empty string (fallback)
1123    fn resolve_workstation_id(config: &Config) -> String {
1124        if let Some(ref id) = config.workstation_id {
1125            return id.clone();
1126        }
1127        // COMPUTERNAME is set on Windows; HOSTNAME is set on most Linux systems.
1128        // This avoids adding a dependency for a simple lookup.
1129        std::env::var("COMPUTERNAME")
1130            .or_else(|_| std::env::var("HOSTNAME"))
1131            .unwrap_or_default()
1132    }
1133
1134    /// Build a Login7 packet.
1135    ///
1136    /// When `sspi_token` is provided (integrated auth), the Login7 packet is
1137    /// built with the integrated security flag and the initial SSPI blob.
1138    ///
1139    /// When `fed_auth` is provided (Azure AD / Entra), the packet carries the
1140    /// FEDAUTH feature extension (SecurityToken workflow) and no username or
1141    /// password — per MS-TDS §2.2.6.4, `fIntSecurity` must be 0 and the
1142    /// credential fields stay empty.
1143    fn build_login7(
1144        config: &Config,
1145        sspi_token: Option<Vec<u8>>,
1146        fed_auth: Option<FedAuthLogin<'_>>,
1147    ) -> Login7 {
1148        // Use the configured TDS version (strict_mode overrides to V8_0)
1149        let version = if config.strict_mode {
1150            tds_protocol::version::TdsVersion::V8_0
1151        } else {
1152            config.tds_version
1153        };
1154
1155        let mut login = Login7::new()
1156            .with_tds_version(version)
1157            .with_packet_size(config.packet_size as u32)
1158            .with_app_name(&config.application_name)
1159            .with_server_name(&config.host)
1160            .with_hostname(Self::resolve_workstation_id(config));
1161
1162        if let Some(ref database) = config.database {
1163            login = login.with_database(database);
1164        }
1165
1166        // ApplicationIntent → LOGIN7 TypeFlags READONLY_INTENT bit
1167        if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
1168            login = login.with_read_only_intent(true);
1169        }
1170
1171        // Session language → LOGIN7 Language field
1172        if let Some(ref lang) = config.language {
1173            login = login.with_language(lang);
1174        }
1175
1176        // Set credentials
1177        if let Some(token) = sspi_token {
1178            // Integrated auth: set SSPI data and integrated security flag
1179            login = login.with_integrated_auth(token);
1180        } else if let Some(fed) = fed_auth {
1181            // Azure AD / Entra: FEDAUTH feature extension, SecurityToken
1182            // workflow. Username/password stay empty.
1183            login = login.with_feature(tds_protocol::login7::FeatureExtension {
1184                feature_id: tds_protocol::login7::FeatureId::FedAuth,
1185                data: mssql_auth::azure_ad::build_security_token_feature_data(fed.token, fed.echo),
1186            });
1187            tracing::debug!(
1188                fed_auth_echo = fed.echo,
1189                "Login7: adding FEDAUTH feature extension (SecurityToken workflow)"
1190            );
1191        } else if let mssql_auth::Credentials::SqlServer { username, password } =
1192            &config.credentials
1193        {
1194            login = login.with_sql_auth(username.as_ref(), password.as_ref());
1195        }
1196
1197        // When Always Encrypted is configured, add the ColumnEncryption feature extension.
1198        // Version 1 = client supports column encryption without enclave computations.
1199        #[cfg(feature = "always-encrypted")]
1200        if config.column_encryption.is_some() {
1201            login = login.with_feature(tds_protocol::login7::FeatureExtension {
1202                feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1203                data: bytes::Bytes::from_static(&[0x01]), // Version 1
1204            });
1205            tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1206        }
1207
1208        login
1209    }
1210
1211    /// Create an SSPI/GSSAPI negotiator if integrated auth is configured.
1212    ///
1213    /// Returns `None` for non-integrated credential types.
1214    ///
1215    /// On Windows with `sspi-auth`, uses native Windows SSPI (`secur32.dll`) which
1216    /// supports all account types including Microsoft Accounts. Falls back to sspi-rs
1217    /// on non-Windows platforms.
1218    ///
1219    /// With `integrated-auth` (Linux/macOS), uses GSSAPI/Kerberos.
1220    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1221    fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1222        #[allow(clippy::match_like_matches_macro)]
1223        let is_integrated = match &config.credentials {
1224            mssql_auth::Credentials::Integrated => true,
1225            _ => false,
1226        };
1227
1228        if !is_integrated {
1229            return Ok(None);
1230        }
1231
1232        // On Windows: prefer native SSPI (secur32.dll) for integrated auth.
1233        // This handles all Windows account types including Microsoft Accounts,
1234        // domain accounts, and local accounts — unlike sspi-rs which requires
1235        // explicit credentials.
1236        #[cfg(all(windows, feature = "sspi-auth"))]
1237        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1238            Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1239
1240        // On non-Windows: use sspi-rs (pure Rust SSPI implementation)
1241        #[cfg(all(not(windows), feature = "sspi-auth"))]
1242        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1243            Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1244
1245        #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1246        let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1247            Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1248
1249        Ok(Some(negotiator))
1250    }
1251
1252    /// Send a PreLogin packet (for use with Connection).
1253    #[cfg(feature = "tls")]
1254    async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1255    where
1256        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1257    {
1258        let payload = prelogin.encode();
1259        // PRELOGIN is tiny and never approaches the packet limit; keep the
1260        // pre-fix behavior here (fully-qualified so the import stays lean for
1261        // no-default-features builds, matching the SSPI send site below).
1262        let max_packet = tds_protocol::packet::MAX_PACKET_SIZE;
1263
1264        connection
1265            .send_message(PacketType::PreLogin, payload, max_packet)
1266            .await?;
1267        Ok(())
1268    }
1269
1270    /// Receive a PreLogin response (for use with Connection).
1271    #[cfg(feature = "tls")]
1272    async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1273    where
1274        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1275    {
1276        let message = connection
1277            .read_message()
1278            .await?
1279            .ok_or(Error::ConnectionClosed)?;
1280
1281        Ok(PreLogin::decode(&message.payload[..])?)
1282    }
1283
1284    /// Send a Login7 packet.
1285    async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1286    where
1287        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1288    {
1289        let payload = login.encode();
1290        // LOGIN7 precedes packet-size negotiation, so it must be split at the
1291        // 4096-byte TDS default, not MAX_PACKET_SIZE: a large FEDAUTH token
1292        // makes LOGIN7 exceed 4096 and a single oversized packet is reset by
1293        // the server (a managed-identity token is ~1900 chars → ~4100 bytes).
1294        let max_packet = DEFAULT_PACKET_SIZE;
1295
1296        connection
1297            .send_message(PacketType::Tds7Login, payload, max_packet)
1298            .await?;
1299        Ok(())
1300    }
1301
1302    /// Process the login response tokens, handling SSPI challenge/response if needed.
1303    ///
1304    /// When a `negotiator` is provided and the server sends an SSPI challenge token,
1305    /// this method will automatically perform the multi-step SSPI handshake by:
1306    /// 1. Calling `negotiator.step(challenge)` to generate a response
1307    /// 2. Sending the response via an SSPI packet
1308    /// 3. Reading the next server message and continuing
1309    ///
1310    /// Returns: (server_version, database, routing_info)
1311    #[allow(clippy::never_loop)] // Loop is used when integrated-auth/sspi-auth features are enabled
1312    async fn process_login_response<T>(
1313        connection: &mut Connection<T>,
1314        #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1315            &dyn mssql_auth::SspiNegotiator,
1316        >,
1317    ) -> Result<(
1318        Option<u32>,
1319        Option<String>,
1320        Option<(String, u16)>,
1321        Option<tds_protocol::token::Collation>,
1322    )>
1323    where
1324        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1325    {
1326        let mut server_version = None;
1327        let mut database = None;
1328        let mut routing = None;
1329        let mut collation = None;
1330
1331        'outer: loop {
1332            let message = connection
1333                .read_message()
1334                .await?
1335                .ok_or(Error::ConnectionClosed)?;
1336
1337            let response_bytes = message.payload;
1338            let mut parser = TokenParser::new(response_bytes);
1339
1340            while let Some(token) = parser.next_token()? {
1341                match token {
1342                    Token::LoginAck(ack) => {
1343                        tracing::info!(
1344                            version = ack.tds_version,
1345                            interface = ack.interface,
1346                            prog_name = %ack.prog_name,
1347                            "login acknowledged"
1348                        );
1349                        server_version = Some(ack.tds_version);
1350                    }
1351                    Token::EnvChange(env) => {
1352                        Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1353                    }
1354                    #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1355                    Token::Sspi(sspi_token) => {
1356                        let neg = negotiator.ok_or_else(|| {
1357                            Error::Protocol(
1358                                "server sent SSPI challenge but no negotiator is configured"
1359                                    .to_string(),
1360                            )
1361                        })?;
1362
1363                        tracing::debug!(
1364                            challenge_len = sspi_token.data.len(),
1365                            "received SSPI challenge from server"
1366                        );
1367
1368                        if let Some(response) = neg.step(&sspi_token.data)? {
1369                            tracing::debug!(response_len = response.len(), "sending SSPI response");
1370                            connection
1371                                .send_message(
1372                                    PacketType::Sspi,
1373                                    bytes::Bytes::from(response),
1374                                    tds_protocol::packet::MAX_PACKET_SIZE,
1375                                )
1376                                .await?;
1377                        }
1378
1379                        // After sending the SSPI response, read the next server message
1380                        continue 'outer;
1381                    }
1382                    Token::Error(err) => {
1383                        return Err(Error::Server {
1384                            number: err.number,
1385                            state: err.state,
1386                            class: err.class,
1387                            message: err.message.clone(),
1388                            server: if err.server.is_empty() {
1389                                None
1390                            } else {
1391                                Some(err.server.clone())
1392                            },
1393                            procedure: if err.procedure.is_empty() {
1394                                None
1395                            } else {
1396                                Some(err.procedure.clone())
1397                            },
1398                            line: err.line as u32,
1399                        });
1400                    }
1401                    Token::Info(info) => {
1402                        tracing::info!(
1403                            number = info.number,
1404                            message = %info.message,
1405                            "server info message"
1406                        );
1407                    }
1408                    Token::FeatureExtAck(ack) => {
1409                        for feature in &ack.features {
1410                            tracing::debug!(
1411                                feature_id = feature.feature_id,
1412                                data_len = feature.data.len(),
1413                                "server acknowledged feature extension"
1414                            );
1415                        }
1416                    }
1417                    Token::Done(done) => {
1418                        if done.status.error {
1419                            return Err(Error::Protocol("login failed".to_string()));
1420                        }
1421                        break 'outer;
1422                    }
1423                    _ => {}
1424                }
1425            }
1426
1427            // If we consumed all tokens without a Done or SSPI, break
1428            break;
1429        }
1430
1431        Ok((server_version, database, routing, collation))
1432    }
1433
1434    /// Process an EnvChange token.
1435    fn process_env_change(
1436        env: &EnvChange,
1437        database: &mut Option<String>,
1438        routing: &mut Option<(String, u16)>,
1439        collation: &mut Option<tds_protocol::token::Collation>,
1440    ) {
1441        use tds_protocol::token::EnvChangeValue;
1442
1443        match env.env_type {
1444            EnvChangeType::Database => {
1445                if let EnvChangeValue::String(ref new_value) = env.new_value {
1446                    tracing::debug!(database = %new_value, "database changed");
1447                    *database = Some(new_value.clone());
1448                }
1449            }
1450            EnvChangeType::Routing => {
1451                if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1452                    tracing::info!(host = %host, port = port, "routing redirect received");
1453                    *routing = Some((host.clone(), port));
1454                }
1455            }
1456            EnvChangeType::SqlCollation => {
1457                if let EnvChangeValue::Binary(ref data) = env.new_value {
1458                    if data.len() >= 5 {
1459                        let c = tds_protocol::token::Collation::from_bytes(
1460                            data[..5].try_into().unwrap(),
1461                        );
1462                        tracing::debug!(
1463                            lcid = c.lcid,
1464                            sort_id = c.sort_id,
1465                            "server collation received"
1466                        );
1467                        *collation = Some(c);
1468                    }
1469                }
1470            }
1471            _ => {
1472                if let EnvChangeValue::String(ref new_value) = env.new_value {
1473                    tracing::debug!(
1474                        env_type = ?env.env_type,
1475                        new_value = %new_value,
1476                        "environment change"
1477                    );
1478                }
1479            }
1480        }
1481    }
1482}
1483
1484/// Build the TLS configuration for an outbound connection.
1485///
1486/// Starts from the user's [`Config::tls`] so custom root certificates, client
1487/// auth, and protocol-version bounds are honored, then layers the
1488/// connection-specific requirements. `trust_server_certificate` is taken from
1489/// the authoritative top-level [`Config`] field: both the builder and the
1490/// connection-string parser set it, but the parser does not mirror it into
1491/// `config.tls`, so reading it here is what keeps `TrustServerCertificate=...`
1492/// connection strings working.
1493///
1494/// `strict` selects TDS 8.0 strict mode (TLS-first) and adds the `tds/8.0`
1495/// ALPN protocol; TDS 7.x leaves both off (its TLS is wrapped in PreLogin).
1496///
1497/// Note the asymmetry: root certificates and client auth come from
1498/// `config.tls`, but `trust_server_certificate` is taken from the top-level
1499/// field and overrides whatever `config.tls` holds. So setting *only*
1500/// `config.tls = TlsConfig::new().trust_server_certificate(true)` while
1501/// leaving the top-level field at its `false` default does not trust the
1502/// server — set it via the connection string (`TrustServerCertificate=true`)
1503/// or `Config::trust_server_certificate(true)`, which is the supported path.
1504#[cfg(feature = "tls")]
1505fn connection_tls_config(config: &Config, strict: bool) -> TlsConfig {
1506    let tls = config
1507        .tls
1508        .clone()
1509        .trust_server_certificate(config.trust_server_certificate);
1510    if strict {
1511        tls.strict_mode(true)
1512            .with_alpn_protocols(vec![b"tds/8.0".to_vec()])
1513    } else {
1514        tls
1515    }
1516}
1517
1518#[cfg(all(test, feature = "tls"))]
1519mod tls_config_tests {
1520    use super::*;
1521    use mssql_tls::CertificateDer;
1522
1523    fn config_with_root(cert: Vec<u8>) -> Config {
1524        let mut config = Config::new();
1525        config.tls = config
1526            .tls
1527            .clone()
1528            .add_root_certificate(CertificateDer::from(cert));
1529        config
1530    }
1531
1532    #[test]
1533    fn custom_root_certificate_reaches_connector_config() {
1534        // The bug: connect built a fresh TlsConfig and dropped config.tls,
1535        // so a custom CA was unreachable. Assert it survives into the
1536        // connection's TLS config, in both strict and non-strict paths.
1537        let config = config_with_root(vec![0xCA; 32]);
1538
1539        for strict in [true, false] {
1540            let tls = connection_tls_config(&config, strict);
1541            assert_eq!(
1542                tls.root_certificates.len(),
1543                1,
1544                "custom root must reach the connector (strict={strict})"
1545            );
1546            assert_eq!(tls.root_certificates[0].as_ref(), &[0xCA; 32][..]);
1547        }
1548    }
1549
1550    #[test]
1551    fn trust_server_certificate_taken_from_top_level_field() {
1552        // Mirrors the connection-string path, which sets the top-level field
1553        // without updating config.tls.
1554        let mut config = Config::new();
1555        config.trust_server_certificate = true;
1556        // config.tls still has the default (false) trust flag.
1557        assert!(!config.tls.trust_server_certificate);
1558
1559        let tls = connection_tls_config(&config, false);
1560        assert!(
1561            tls.trust_server_certificate,
1562            "top-level trust flag must win"
1563        );
1564    }
1565
1566    #[test]
1567    fn strict_mode_adds_tds8_alpn() {
1568        let config = Config::new();
1569        let strict = connection_tls_config(&config, true);
1570        assert!(strict.strict_mode);
1571        assert!(strict.alpn_protocols.iter().any(|p| p == b"tds/8.0"));
1572
1573        let non_strict = connection_tls_config(&config, false);
1574        assert!(!non_strict.strict_mode);
1575    }
1576}
1577
1578#[cfg(test)]
1579#[allow(clippy::unwrap_used)]
1580mod fed_auth_login_tests {
1581    use super::*;
1582    use tds_protocol::prelogin::EncryptionLevel;
1583
1584    fn azure_config(token: &str) -> Config {
1585        Config::new().credentials(mssql_auth::Credentials::azure_token(token.to_string()))
1586    }
1587
1588    /// Wire-exact assembly of the FEDAUTH feature extension inside the
1589    /// encoded LOGIN7, located through the ibExtension pointer indirection
1590    /// (MS-TDS §2.2.6.4): FeatureId 0x02, DWORD-LE data length, options byte
1591    /// `(SecurityToken << 1) | echo`, DWORD-LE token byte length, UTF-16LE
1592    /// token, 0xFF terminator. Username/password must stay empty and
1593    /// fIntSecurity clear.
1594    #[test]
1595    fn login7_fed_auth_feature_block_wire_exact() {
1596        let config = azure_config("AB");
1597        let login = Client::<Disconnected>::build_login7(
1598            &config,
1599            None,
1600            Some(FedAuthLogin {
1601                token: "AB",
1602                echo: true,
1603            }),
1604        );
1605
1606        assert!(
1607            !login.option_flags2.integrated_security,
1608            "fIntSecurity MUST be 0 when FEDAUTH is present"
1609        );
1610        assert!(
1611            login.username.is_empty() && login.password.is_empty(),
1612            "FEDAUTH logins must not carry username/password"
1613        );
1614
1615        let encoded = login.encode();
1616
1617        // OptionFlags3 (byte 27) must have fExtension (0x10) set.
1618        assert_eq!(encoded[27] & 0x10, 0x10, "fExtension bit must be set");
1619
1620        // ibExtension/cbExtension are the 6th (offset, length) pair in the
1621        // offset table starting at byte 36. The u32 it points to holds the
1622        // absolute offset of the FeatureExt block.
1623        const EXTENSION_SLOT: usize = 36 + 5 * 4;
1624        let ib_extension =
1625            u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1626        let feature_off =
1627            u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1628                as usize;
1629
1630        assert_eq!(
1631            encoded[feature_off], 0x02,
1632            "FeatureId must be FEDAUTH (0x02)"
1633        );
1634        let data_len = u32::from_le_bytes(
1635            encoded[feature_off + 1..feature_off + 5]
1636                .try_into()
1637                .unwrap(),
1638        ) as usize;
1639        // options(1) + token length DWORD(4) + "AB" as UTF-16LE(4)
1640        assert_eq!(data_len, 9, "FeatureDataLen must cover options + token");
1641
1642        let data = &encoded[feature_off + 5..feature_off + 5 + data_len];
1643        assert_eq!(
1644            data,
1645            &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
1646            "options must be (SecurityToken << 1) | echo, then DWORD-LE \
1647             token byte length, then UTF-16LE token"
1648        );
1649        assert_eq!(
1650            encoded[feature_off + 5 + data_len],
1651            0xFF,
1652            "FeatureExt terminator must follow"
1653        );
1654    }
1655
1656    /// The echo bit mirrors the server's PRELOGIN FEDAUTHREQUIRED response;
1657    /// when the server sent 0x00 the options byte must be 0x02 (echo clear).
1658    #[test]
1659    fn login7_fed_auth_echo_clear() {
1660        let config = azure_config("AB");
1661        let login = Client::<Disconnected>::build_login7(
1662            &config,
1663            None,
1664            Some(FedAuthLogin {
1665                token: "AB",
1666                echo: false,
1667            }),
1668        );
1669        let encoded = login.encode();
1670
1671        const EXTENSION_SLOT: usize = 36 + 5 * 4;
1672        let ib_extension =
1673            u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1674        let feature_off =
1675            u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1676                as usize;
1677        assert_eq!(encoded[feature_off], 0x02);
1678        assert_eq!(
1679            encoded[feature_off + 5],
1680            0x02,
1681            "options byte must have fFedAuthEcho clear"
1682        );
1683    }
1684
1685    /// PRELOGIN must advertise FEDAUTHREQUIRED for Azure AD credentials and
1686    /// must not for SQL authentication.
1687    #[test]
1688    fn prelogin_advertises_fed_auth_for_azure_credentials() {
1689        let azure = azure_config("tok");
1690        let prelogin = Client::<Disconnected>::build_prelogin(&azure, EncryptionLevel::On);
1691        assert!(prelogin.fed_auth_required);
1692
1693        let sql = Config::new().credentials(mssql_auth::Credentials::sql_server("u", "p"));
1694        let prelogin = Client::<Disconnected>::build_prelogin(&sql, EncryptionLevel::On);
1695        assert!(!prelogin.fed_auth_required);
1696    }
1697
1698    /// Regression: a LOGIN7 carrying a large FEDAUTH token exceeds the 4096-byte
1699    /// TDS default packet size and MUST be split across multiple packets, each
1700    /// within 4096 bytes. Before the fix, `send_login7` passed MAX_PACKET_SIZE
1701    /// (65535) to `send_message` and emitted a single oversized packet, which
1702    /// Azure SQL reset — a managed-identity token (~1900 chars → ~4100-byte
1703    /// LOGIN7) tripped this every time, while smaller service-principal tokens
1704    /// stayed under 4096 and masked the bug. Verified live against Azure SQL.
1705    #[tokio::test]
1706    async fn login7_large_fed_auth_token_is_split_at_default_packet_size() {
1707        use tds_protocol::packet::PACKET_HEADER_SIZE;
1708        use tokio::io::AsyncReadExt;
1709
1710        // ~2000-char token -> LOGIN7 comfortably over the 4096 default.
1711        let token = "A".repeat(2000);
1712        let config = azure_config(&token);
1713        let login = Client::<Disconnected>::build_login7(
1714            &config,
1715            None,
1716            Some(FedAuthLogin {
1717                token: &token,
1718                echo: true,
1719            }),
1720        );
1721        let encoded = login.encode();
1722        assert!(
1723            encoded.len() > DEFAULT_PACKET_SIZE,
1724            "precondition: LOGIN7 ({}) must exceed the default packet size to exercise splitting",
1725            encoded.len()
1726        );
1727
1728        // Capture exactly what send_login7 writes to the transport.
1729        let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1730        let mut connection = Connection::new(client_io);
1731        Client::<Disconnected>::send_login7(&mut connection, &login)
1732            .await
1733            .unwrap();
1734        drop(connection); // close the write half so read_to_end observes EOF
1735        let mut raw = Vec::new();
1736        server_io.read_to_end(&mut raw).await.unwrap();
1737
1738        // Walk the TDS packets: 8-byte header, status at [1] (EOM = 0x01),
1739        // total length (incl. header) at [2..4] big-endian.
1740        let mut offset = 0;
1741        let mut packets = 0;
1742        let mut reassembled = Vec::new();
1743        let mut saw_eom = false;
1744        while offset < raw.len() {
1745            let status = raw[offset + 1];
1746            let len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
1747            assert!(
1748                len <= DEFAULT_PACKET_SIZE,
1749                "packet {packets} length {len} exceeds the 4096-byte default"
1750            );
1751            assert!(!saw_eom, "no packet may follow the END_OF_MESSAGE packet");
1752            saw_eom = status & 0x01 == 0x01;
1753            reassembled.extend_from_slice(&raw[offset + PACKET_HEADER_SIZE..offset + len]);
1754            offset += len;
1755            packets += 1;
1756        }
1757        assert!(
1758            packets >= 2,
1759            "an oversized LOGIN7 must span multiple packets, got {packets}"
1760        );
1761        assert!(saw_eom, "the final packet must carry END_OF_MESSAGE");
1762        assert_eq!(
1763            reassembled,
1764            encoded.as_ref(),
1765            "reassembled packet payloads must equal the LOGIN7 encoding"
1766        );
1767    }
1768}