1use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::DEFAULT_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30#[derive(Clone, Copy)]
35struct FedAuthLogin<'a> {
36 token: &'a str,
37 echo: bool,
38}
39
40impl Client<Disconnected> {
41 pub async fn connect(config: Config) -> Result<Client<Ready>> {
57 Self::validate_credential_support(&config)?;
58
59 let fed_auth_token = Self::resolve_fed_auth_token(&config).await?;
64
65 let retry = config.retry.clone();
66 let max_redirects = config.redirect.max_redirects;
67 let follow_redirects = config.redirect.follow_redirects;
68 let per_attempt = config.timeouts.connect_timeout
70 + config.timeouts.tls_timeout
71 + config.timeouts.login_timeout;
72 let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
73 let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
74 let initial_host = config.host.clone();
75 let initial_port = config.port;
76
77 let result = timeout(overall, async {
78 let mut last_error: Option<Error> = None;
79
80 for retry_attempt in 0..=retry.max_retries {
81 if retry_attempt > 0 {
82 let backoff = retry.backoff_for_attempt(retry_attempt);
83 tracing::info!(
84 retry_attempt,
85 backoff_ms = backoff.as_millis() as u64,
86 "retrying connection after transient error"
87 );
88 tokio::time::sleep(backoff).await;
89 }
90
91 let mut current_config = config.clone();
93 let mut redirect_count: u8 = 0;
94
95 let attempt_result = loop {
96 redirect_count += 1;
97 if redirect_count > max_redirects + 1 {
98 break Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config, fed_auth_token.as_deref()).await {
102 Ok(client) => break Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 break Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 redirect = redirect_count,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => break Err(e),
118 }
119 };
120
121 match attempt_result {
122 Ok(client) => return Ok(client),
123 Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
124 tracing::warn!(
125 retry_attempt,
126 max_retries = retry.max_retries,
127 error = %e,
128 "transient connection error, will retry"
129 );
130 last_error = Some(attempt_result.unwrap_err());
131 }
132 Err(e) => return Err(e),
133 }
134 }
135
136 Err(last_error.expect("at least one attempt was made"))
138 })
139 .await;
140
141 match result {
142 Ok(inner) => inner,
143 Err(_elapsed) => Err(Error::ConnectTimeout {
144 host: initial_host,
145 port: initial_port,
146 }),
147 }
148 }
149
150 fn validate_credential_support(config: &Config) -> Result<()> {
156 match &config.credentials {
157 mssql_auth::Credentials::SqlServer { .. } => Ok(()),
158 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
159 mssql_auth::Credentials::Integrated => Ok(()),
160 creds if creds.is_azure_ad() => {
161 #[cfg(not(feature = "tls"))]
165 {
166 Err(Error::Config(
167 "Azure AD / Entra (FEDAUTH) authentication requires TLS: \
168 enable the 'tls' feature."
169 .into(),
170 ))
171 }
172 #[cfg(feature = "tls")]
173 {
174 if config.no_tls {
175 return Err(Error::Config(
176 "Azure AD / Entra (FEDAUTH) authentication cannot be combined \
177 with Encrypt=no_tls: the access token would be sent in \
178 plaintext. Use Encrypt=mandatory or Encrypt=strict."
179 .into(),
180 ));
181 }
182 if matches!(&config.credentials,
183 mssql_auth::Credentials::AzureAccessToken { token } if token.is_empty())
184 {
185 return Err(Error::Config(
186 "Azure AD access token is empty (the FEDAUTH token length \
187 must not be zero)"
188 .into(),
189 ));
190 }
191 if !config.strict_mode && !config.tds_version.supports_fed_auth() {
192 return Err(Error::Config(format!(
193 "Azure AD / Entra (FEDAUTH) authentication requires TDS 7.4 \
194 or later (configured: {})",
195 config.tds_version
196 )));
197 }
198 Ok(())
199 }
200 }
201 _ => Err(Error::Config(
205 "client certificate (FEDAUTH) authentication is not yet supported \
206 (tracked in https://github.com/praxiomlabs/rust-mssql-driver/issues/155). \
207 Use SQL Server, integrated, or Azure AD / Entra authentication."
208 .into(),
209 )),
210 }
211 }
212
213 async fn resolve_fed_auth_token(config: &Config) -> Result<Option<String>> {
220 match &config.credentials {
221 mssql_auth::Credentials::AzureAccessToken { token } => Ok(Some(token.to_string())),
222 #[cfg(feature = "azure-identity")]
223 mssql_auth::Credentials::AzureManagedIdentity { client_id } => {
224 let auth = match client_id {
225 Some(id) => {
226 mssql_auth::ManagedIdentityAuth::user_assigned_client_id(id.to_string())?
227 }
228 None => mssql_auth::ManagedIdentityAuth::system_assigned()?,
229 };
230 tracing::debug!("acquiring Azure SQL access token via managed identity");
231 Ok(Some(auth.get_token().await?))
232 }
233 #[cfg(feature = "azure-identity")]
234 mssql_auth::Credentials::AzureServicePrincipal {
235 tenant_id,
236 client_id,
237 client_secret,
238 } => {
239 let auth = mssql_auth::ServicePrincipalAuth::new(
240 tenant_id.as_ref(),
241 client_id.to_string(),
242 client_secret.to_string(),
243 )?;
244 tracing::debug!(
245 client_id = %client_id,
246 "acquiring Azure SQL access token via service principal"
247 );
248 Ok(Some(auth.get_token().await?))
249 }
250 _ => Ok(None),
251 }
252 }
253
254 async fn try_connect(config: &Config, fed_auth_token: Option<&str>) -> Result<Client<Ready>> {
255 let port = if let Some(ref instance) = config.instance {
257 let resolved = crate::browser::resolve_instance(
258 &config.host,
259 instance,
260 Some(config.timeouts.connect_timeout),
261 )
262 .await?;
263 tracing::info!(
264 host = %config.host,
265 instance = %instance,
266 resolved_port = resolved,
267 database = ?config.database,
268 "connecting to named SQL Server instance"
269 );
270 resolved
271 } else {
272 tracing::info!(
273 host = %config.host,
274 port = config.port,
275 database = ?config.database,
276 "connecting to SQL Server"
277 );
278 config.port
279 };
280
281 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
284 "127.0.0.1"
285 } else {
286 &config.host
287 };
288
289 let tcp_stream = if config.multi_subnet_failover {
291 Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
292 } else {
293 let addr = format!("{host}:{port}");
294 tracing::debug!("establishing TCP connection to {}", addr);
295 let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
296 .await
297 .map_err(|_| Error::ConnectTimeout {
298 host: config.host.clone(),
299 port: config.port,
300 })?
301 .map_err(Error::from)?;
302 stream.set_nodelay(true).map_err(Error::from)?;
303 stream
304 };
305
306 #[cfg(feature = "tls")]
307 {
308 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
310
311 if tls_mode.is_tls_first() {
313 return Self::connect_tds_8(config, tcp_stream, fed_auth_token).await;
314 }
315
316 Self::connect_tds_7x(config, tcp_stream, fed_auth_token).await
318 }
319
320 #[cfg(not(feature = "tls"))]
321 {
322 let _ = fed_auth_token;
325
326 if config.strict_mode {
328 return Err(Error::Config(
329 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
330 ));
331 }
332
333 if !config.no_tls {
334 return Err(Error::Config(
335 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
336 or use Encrypt=no_tls in your connection string for unencrypted connections."
337 .into(),
338 ));
339 }
340
341 Self::connect_no_tls(config, tcp_stream).await
343 }
344 }
345
346 async fn connect_parallel(
351 host: &str,
352 port: u16,
353 connect_timeout: std::time::Duration,
354 ) -> Result<TcpStream> {
355 let addr_str = format!("{host}:{port}");
356 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
357 .await
358 .map_err(Error::from)?
359 .collect();
360
361 if addrs.is_empty() {
362 return Err(Error::from(std::io::Error::new(
363 std::io::ErrorKind::AddrNotAvailable,
364 format!("no addresses resolved for {host}:{port}"),
365 )));
366 }
367
368 if addrs.len() == 1 {
370 tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
371 let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
372 .await
373 .map_err(|_| Error::ConnectTimeout {
374 host: host.to_string(),
375 port,
376 })?
377 .map_err(Error::from)?;
378 stream.set_nodelay(true).map_err(Error::from)?;
379 return Ok(stream);
380 }
381
382 let addr_count = addrs.len();
383 tracing::debug!(
384 host = host,
385 port = port,
386 resolved_count = addr_count,
387 "MultiSubnetFailover: racing parallel connections",
388 );
389
390 let mut join_set = tokio::task::JoinSet::new();
391
392 for addr in addrs {
393 let dur = connect_timeout;
394 join_set.spawn(async move {
395 let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
396 std::io::Error::new(
397 std::io::ErrorKind::TimedOut,
398 format!("connection to {addr} timed out"),
399 )
400 })??;
401 tcp.set_nodelay(true)?;
402 Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
403 });
404 }
405
406 let mut last_error: Option<std::io::Error> = None;
407
408 while let Some(result) = join_set.join_next().await {
409 match result {
410 Ok(Ok((stream, addr))) => {
411 tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
412 join_set.abort_all();
413 return Ok(stream);
414 }
415 Ok(Err(e)) => {
416 tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
417 last_error = Some(e);
418 }
419 Err(join_err) => {
420 tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
421 last_error = Some(std::io::Error::other(join_err.to_string()));
422 }
423 }
424 }
425
426 Err(Error::from(last_error.unwrap_or_else(|| {
428 std::io::Error::new(
429 std::io::ErrorKind::ConnectionRefused,
430 format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
431 )
432 })))
433 }
434
435 #[cfg(feature = "tls")]
439 async fn connect_tds_8(
440 config: &Config,
441 tcp_stream: TcpStream,
442 fed_auth_token: Option<&str>,
443 ) -> Result<Client<Ready>> {
444 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
445
446 let tls_config = connection_tls_config(config, true);
449
450 let tls_connector = TlsConnector::new(tls_config)?;
451
452 let tls_stream = timeout(
454 config.timeouts.tls_timeout,
455 tls_connector.connect(tcp_stream, &config.host),
456 )
457 .await
458 .map_err(|_| Error::TlsTimeout {
459 host: config.host.clone(),
460 port: config.port,
461 })??;
462
463 tracing::debug!("TLS handshake completed (strict mode)");
464
465 let mut connection = Connection::new(tls_stream);
467 connection.set_max_message_size(config.max_response_size);
468
469 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
471 Self::send_prelogin(&mut connection, &prelogin).await?;
472 let prelogin_response = Self::receive_prelogin(&mut connection).await?;
473
474 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
476 let negotiator = Self::create_negotiator(config)?;
477 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
478 let sspi_token = match negotiator {
479 Some(ref neg) => Some(neg.initialize()?),
480 None => None,
481 };
482 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
483 let sspi_token: Option<Vec<u8>> = None;
484
485 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
487 token,
488 echo: prelogin_response.fed_auth_required,
489 });
490 let login = Self::build_login7(config, sspi_token, fed_auth);
491 Self::send_login7(&mut connection, &login).await?;
492
493 let (server_version, current_database, routing, server_collation) = timeout(
495 config.timeouts.login_timeout,
496 Self::process_login_response(
497 &mut connection,
498 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
499 negotiator.as_deref(),
500 ),
501 )
502 .await
503 .map_err(|_| Error::LoginTimeout {
504 host: config.host.clone(),
505 port: config.port,
506 })??;
507
508 if let Some((host, port)) = routing {
510 return Err(Error::Routing { host, port });
511 }
512
513 Ok(Client {
514 config: config.clone(),
515 _state: PhantomData,
516 connection: Some(ConnectionHandle::Tls(connection)),
517 server_version,
518 current_database: current_database.clone(),
519 server_collation,
520 statement_cache: StatementCache::with_default_size(),
521 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
525 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
526 .with_database(current_database.clone().unwrap_or_default()),
527 #[cfg(feature = "always-encrypted")]
528 encryption_context: config.column_encryption.clone().map(|cfg| {
529 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
530 }),
531 })
532 }
533
534 #[cfg(feature = "tls")]
542 async fn connect_tds_7x(
543 config: &Config,
544 mut tcp_stream: TcpStream,
545 fed_auth_token: Option<&str>,
546 ) -> Result<Client<Ready>> {
547 use bytes::BufMut;
548 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
549 use tokio::io::{AsyncReadExt, AsyncWriteExt};
550
551 tracing::debug!("using TDS 7.x flow (PreLogin first)");
552
553 let client_encryption = if config.no_tls {
556 tracing::warn!(
558 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
559 Credentials and data will be transmitted in plaintext. \
560 This should only be used for development/testing with legacy SQL Server."
561 );
562 EncryptionLevel::NotSupported
563 } else if config.encrypt {
564 EncryptionLevel::On
565 } else {
566 EncryptionLevel::Off
567 };
568 let prelogin = Self::build_prelogin(config, client_encryption);
569 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
570 let prelogin_bytes = prelogin.encode();
571
572 let header = PacketHeader::new(
574 PacketType::PreLogin,
575 PacketStatus::END_OF_MESSAGE,
576 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
577 );
578
579 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
580 header.encode(&mut packet_buf);
581 packet_buf.put_slice(&prelogin_bytes);
582
583 tcp_stream
584 .write_all(&packet_buf)
585 .await
586 .map_err(Error::from)?;
587
588 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
590 tcp_stream
591 .read_exact(&mut header_buf)
592 .await
593 .map_err(Error::from)?;
594
595 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
596 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
597
598 let mut response_buf = vec![0u8; payload_length];
599 tcp_stream
600 .read_exact(&mut response_buf)
601 .await
602 .map_err(Error::from)?;
603
604 let prelogin_response = PreLogin::decode(&response_buf[..])?;
605
606 let client_tds_version = config.tds_version;
611 if let Some(ref server_version) = prelogin_response.server_version {
612 tracing::debug!(
613 requested_tds_version = %client_tds_version,
614 server_product_version = %server_version,
615 server_product = server_version.product_name(),
616 max_tds_version = %server_version.max_tds_version(),
617 "PreLogin response received"
618 );
619
620 let server_max_tds = server_version.max_tds_version();
622 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
623 tracing::warn!(
624 requested_tds_version = %client_tds_version,
625 server_max_tds_version = %server_max_tds,
626 server_product = server_version.product_name(),
627 "Server supports lower TDS version than requested. \
628 Connection will use server's maximum: {}",
629 server_max_tds
630 );
631 }
632
633 if server_max_tds.is_legacy() {
635 tracing::warn!(
636 server_product = server_version.product_name(),
637 server_max_tds_version = %server_max_tds,
638 "Server uses legacy TDS version. Some features may not be available."
639 );
640 }
641 } else {
642 tracing::debug!(
643 requested_tds_version = %client_tds_version,
644 "PreLogin response received (no version info)"
645 );
646 }
647
648 let server_encryption = prelogin_response.encryption;
650 tracing::debug!(encryption = ?server_encryption, "server encryption level");
651
652 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
654 token,
655 echo: prelogin_response.fed_auth_required,
656 });
657
658 let negotiated_encryption = match (client_encryption, server_encryption) {
664 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
665 EncryptionLevel::NotSupported
666 }
667 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
668 (EncryptionLevel::On, EncryptionLevel::Off)
669 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
670 return Err(Error::Protocol(
671 "Server does not support requested encryption level".to_string(),
672 ));
673 }
674 _ => EncryptionLevel::On,
675 };
676
677 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
680
681 if use_tls {
682 let tls_config = connection_tls_config(config, false);
687
688 let tls_connector = TlsConnector::new(tls_config)?;
689
690 let mut tls_stream = timeout(
692 config.timeouts.tls_timeout,
693 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
694 )
695 .await
696 .map_err(|_| Error::TlsTimeout {
697 host: config.host.clone(),
698 port: config.port,
699 })??;
700
701 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
702
703 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
705
706 if login_only_encryption {
707 use tokio::io::AsyncWriteExt;
715
716 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
721 let negotiator = Self::create_negotiator(config)?;
722 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
723 let sspi_token = match negotiator {
724 Some(ref neg) => Some(neg.initialize()?),
725 None => None,
726 };
727 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
728 let sspi_token: Option<Vec<u8>> = None;
729
730 let login = Self::build_login7(config, sspi_token, fed_auth);
732 let login_payload = login.encode();
733
734 let max_packet = DEFAULT_PACKET_SIZE;
740 let max_payload = max_packet - PACKET_HEADER_SIZE;
741 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
742 let total_chunks = chunks.len();
743
744 for (i, chunk) in chunks.into_iter().enumerate() {
745 let is_last = i == total_chunks - 1;
746 let status = if is_last {
747 PacketStatus::END_OF_MESSAGE
748 } else {
749 PacketStatus::NORMAL
750 };
751
752 let header = PacketHeader::new(
753 PacketType::Tds7Login,
754 status,
755 (PACKET_HEADER_SIZE + chunk.len()) as u16,
756 );
757
758 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
759 header.encode(&mut packet_buf);
760 packet_buf.put_slice(chunk);
761
762 tls_stream
763 .write_all(&packet_buf)
764 .await
765 .map_err(Error::from)?;
766 }
767
768 tls_stream.flush().await.map_err(Error::from)?;
770
771 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
772
773 let (wrapper, _client_conn) = tls_stream.into_inner();
777 let tcp_stream = wrapper.into_inner();
778
779 let mut connection = Connection::new(tcp_stream);
781 connection.set_max_message_size(config.max_response_size);
782
783 let (server_version, current_database, routing, server_collation) = timeout(
785 config.timeouts.login_timeout,
786 Self::process_login_response(
787 &mut connection,
788 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
789 negotiator.as_deref(),
790 ),
791 )
792 .await
793 .map_err(|_| Error::LoginTimeout {
794 host: config.host.clone(),
795 port: config.port,
796 })??;
797
798 if let Some((host, port)) = routing {
800 return Err(Error::Routing { host, port });
801 }
802
803 Ok(Client {
805 config: config.clone(),
806 _state: PhantomData,
807 connection: Some(ConnectionHandle::Plain(connection)),
808 server_version,
809 current_database: current_database.clone(),
810 server_collation,
811 statement_cache: StatementCache::with_default_size(),
812 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
816 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
817 .with_database(current_database.clone().unwrap_or_default()),
818 #[cfg(feature = "always-encrypted")]
819 encryption_context: config.column_encryption.clone().map(|cfg| {
820 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
821 }),
822 })
823 } else {
824 let mut connection = Connection::new(tls_stream);
827 connection.set_max_message_size(config.max_response_size);
828
829 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
831 let negotiator = Self::create_negotiator(config)?;
832 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
833 let sspi_token = match negotiator {
834 Some(ref neg) => Some(neg.initialize()?),
835 None => None,
836 };
837 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
838 let sspi_token: Option<Vec<u8>> = None;
839
840 let login = Self::build_login7(config, sspi_token, fed_auth);
842 Self::send_login7(&mut connection, &login).await?;
843
844 let (server_version, current_database, routing, server_collation) = timeout(
846 config.timeouts.login_timeout,
847 Self::process_login_response(
848 &mut connection,
849 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
850 negotiator.as_deref(),
851 ),
852 )
853 .await
854 .map_err(|_| Error::LoginTimeout {
855 host: config.host.clone(),
856 port: config.port,
857 })??;
858
859 if let Some((host, port)) = routing {
861 return Err(Error::Routing { host, port });
862 }
863
864 Ok(Client {
865 config: config.clone(),
866 _state: PhantomData,
867 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
868 server_version,
869 current_database: current_database.clone(),
870 server_collation,
871 statement_cache: StatementCache::with_default_size(),
872 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
876 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
877 .with_database(current_database.clone().unwrap_or_default()),
878 #[cfg(feature = "always-encrypted")]
879 encryption_context: config.column_encryption.clone().map(|cfg| {
880 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
881 }),
882 })
883 }
884 } else {
885 tracing::warn!(
887 "Connecting without TLS encryption. This is insecure and should only be \
888 used for development/testing on trusted networks."
889 );
890
891 let mut connection = Connection::new(tcp_stream);
892
893 connection.set_max_message_size(config.max_response_size);
894
895 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
897 let negotiator = Self::create_negotiator(config)?;
898 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
899 let sspi_token = match negotiator {
900 Some(ref neg) => Some(neg.initialize()?),
901 None => None,
902 };
903 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
904 let sspi_token: Option<Vec<u8>> = None;
905
906 let login = Self::build_login7(config, sspi_token, fed_auth);
910 Self::send_login7(&mut connection, &login).await?;
911
912 let (server_version, current_database, routing, server_collation) = timeout(
914 config.timeouts.login_timeout,
915 Self::process_login_response(
916 &mut connection,
917 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
918 negotiator.as_deref(),
919 ),
920 )
921 .await
922 .map_err(|_| Error::LoginTimeout {
923 host: config.host.clone(),
924 port: config.port,
925 })??;
926
927 if let Some((host, port)) = routing {
929 return Err(Error::Routing { host, port });
930 }
931
932 Ok(Client {
933 config: config.clone(),
934 _state: PhantomData,
935 connection: Some(ConnectionHandle::Plain(connection)),
936 server_version,
937 current_database: current_database.clone(),
938 server_collation,
939 statement_cache: StatementCache::with_default_size(),
940 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
944 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
945 .with_database(current_database.clone().unwrap_or_default()),
946 #[cfg(feature = "always-encrypted")]
947 encryption_context: config.column_encryption.clone().map(|cfg| {
948 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
949 }),
950 })
951 }
952 }
953
954 #[cfg(not(feature = "tls"))]
965 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
966 use bytes::BufMut;
967 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
968 use tokio::io::{AsyncReadExt, AsyncWriteExt};
969
970 tracing::warn!(
971 "⚠️ Connecting without TLS (tls feature disabled). \
972 Credentials and data will be transmitted in plaintext."
973 );
974
975 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
977 let prelogin_bytes = prelogin.encode();
978
979 let header = PacketHeader::new(
981 PacketType::PreLogin,
982 PacketStatus::END_OF_MESSAGE,
983 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
984 );
985
986 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
987 header.encode(&mut packet_buf);
988 packet_buf.put_slice(&prelogin_bytes);
989
990 tcp_stream
991 .write_all(&packet_buf)
992 .await
993 .map_err(Error::from)?;
994
995 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
997 tcp_stream
998 .read_exact(&mut header_buf)
999 .await
1000 .map_err(Error::from)?;
1001
1002 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
1003 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
1004
1005 let mut response_buf = vec![0u8; payload_length];
1006 tcp_stream
1007 .read_exact(&mut response_buf)
1008 .await
1009 .map_err(Error::from)?;
1010
1011 let prelogin_response = PreLogin::decode(&response_buf[..])?;
1012
1013 let server_encryption = prelogin_response.encryption;
1015 if server_encryption != EncryptionLevel::NotSupported {
1016 return Err(Error::Config(format!(
1017 "Server requires encryption (level: {server_encryption:?}) but TLS feature is disabled. \
1018 Either enable the 'tls' feature or configure the server to allow unencrypted connections."
1019 )));
1020 }
1021
1022 tracing::debug!("Server accepted unencrypted connection");
1023
1024 let mut connection = Connection::new(tcp_stream);
1025
1026 connection.set_max_message_size(config.max_response_size);
1027
1028 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1030 let negotiator = Self::create_negotiator(config)?;
1031 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1032 let sspi_token = match negotiator {
1033 Some(ref neg) => Some(neg.initialize()?),
1034 None => None,
1035 };
1036 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
1037 let sspi_token: Option<Vec<u8>> = None;
1038
1039 let login = Self::build_login7(config, sspi_token, None);
1042 Self::send_login7(&mut connection, &login).await?;
1043
1044 let (server_version, current_database, routing, server_collation) = timeout(
1046 config.timeouts.login_timeout,
1047 Self::process_login_response(
1048 &mut connection,
1049 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1050 negotiator.as_deref(),
1051 ),
1052 )
1053 .await
1054 .map_err(|_| Error::LoginTimeout {
1055 host: config.host.clone(),
1056 port: config.port,
1057 })??;
1058
1059 if let Some((host, port)) = routing {
1061 return Err(Error::Routing { host, port });
1062 }
1063
1064 Ok(Client {
1065 config: config.clone(),
1066 _state: PhantomData,
1067 connection: Some(ConnectionHandle::Plain(connection)),
1068 server_version,
1069 current_database: current_database.clone(),
1070 server_collation,
1071 statement_cache: StatementCache::with_default_size(),
1072 transaction_descriptor: 0,
1073 needs_reset: false,
1074 in_flight: false,
1075 #[cfg(feature = "otel")]
1076 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
1077 .with_database(current_database.clone().unwrap_or_default()),
1078 #[cfg(feature = "always-encrypted")]
1079 encryption_context: config.column_encryption.clone().map(|cfg| {
1080 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
1081 }),
1082 })
1083 }
1084
1085 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
1087 let version = if config.strict_mode {
1089 tds_protocol::version::TdsVersion::V8_0
1090 } else {
1091 config.tds_version
1092 };
1093
1094 let mut prelogin = PreLogin::new()
1095 .with_version(version)
1096 .with_encryption(encryption);
1097
1098 if config.mars {
1099 prelogin = prelogin.with_mars(true);
1100 }
1101
1102 if let Some(ref instance) = config.instance {
1103 prelogin = prelogin.with_instance(instance);
1104 }
1105
1106 if config.credentials.is_azure_ad() {
1109 prelogin = prelogin.with_fed_auth_required(true);
1110 }
1111
1112 prelogin
1113 }
1114
1115 fn resolve_workstation_id(config: &Config) -> String {
1123 if let Some(ref id) = config.workstation_id {
1124 return id.clone();
1125 }
1126 std::env::var("COMPUTERNAME")
1129 .or_else(|_| std::env::var("HOSTNAME"))
1130 .unwrap_or_default()
1131 }
1132
1133 fn build_login7(
1143 config: &Config,
1144 sspi_token: Option<Vec<u8>>,
1145 fed_auth: Option<FedAuthLogin<'_>>,
1146 ) -> Login7 {
1147 let version = if config.strict_mode {
1149 tds_protocol::version::TdsVersion::V8_0
1150 } else {
1151 config.tds_version
1152 };
1153
1154 let mut login = Login7::new()
1155 .with_tds_version(version)
1156 .with_packet_size(config.packet_size as u32)
1157 .with_app_name(&config.application_name)
1158 .with_server_name(&config.host)
1159 .with_hostname(Self::resolve_workstation_id(config));
1160
1161 if let Some(ref database) = config.database {
1162 login = login.with_database(database);
1163 }
1164
1165 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
1167 login = login.with_read_only_intent(true);
1168 }
1169
1170 if let Some(ref lang) = config.language {
1172 login = login.with_language(lang);
1173 }
1174
1175 if let Some(token) = sspi_token {
1177 login = login.with_integrated_auth(token);
1179 } else if let Some(fed) = fed_auth {
1180 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1183 feature_id: tds_protocol::login7::FeatureId::FedAuth,
1184 data: mssql_auth::azure_ad::build_security_token_feature_data(fed.token, fed.echo),
1185 });
1186 tracing::debug!(
1187 fed_auth_echo = fed.echo,
1188 "Login7: adding FEDAUTH feature extension (SecurityToken workflow)"
1189 );
1190 } else if let mssql_auth::Credentials::SqlServer { username, password } =
1191 &config.credentials
1192 {
1193 login = login.with_sql_auth(username.as_ref(), password.as_ref());
1194 }
1195
1196 #[cfg(feature = "always-encrypted")]
1199 if config.column_encryption.is_some() {
1200 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1201 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1202 data: bytes::Bytes::from_static(&[0x01]), });
1204 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1205 }
1206
1207 login
1208 }
1209
1210 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1220 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1221 #[allow(clippy::match_like_matches_macro)]
1222 let is_integrated = match &config.credentials {
1223 mssql_auth::Credentials::Integrated => true,
1224 _ => false,
1225 };
1226
1227 if !is_integrated {
1228 return Ok(None);
1229 }
1230
1231 #[cfg(all(windows, feature = "sspi-auth"))]
1236 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1237 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1238
1239 #[cfg(all(not(windows), feature = "sspi-auth"))]
1241 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1242 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1243
1244 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1245 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1246 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1247
1248 Ok(Some(negotiator))
1249 }
1250
1251 #[cfg(feature = "tls")]
1253 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1254 where
1255 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1256 {
1257 let payload = prelogin.encode();
1258 let max_packet = tds_protocol::packet::MAX_PACKET_SIZE;
1262
1263 connection
1264 .send_message(PacketType::PreLogin, payload, max_packet)
1265 .await?;
1266 Ok(())
1267 }
1268
1269 #[cfg(feature = "tls")]
1271 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1272 where
1273 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1274 {
1275 let message = connection
1276 .read_message()
1277 .await?
1278 .ok_or(Error::ConnectionClosed)?;
1279
1280 Ok(PreLogin::decode(&message.payload[..])?)
1281 }
1282
1283 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1285 where
1286 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1287 {
1288 let payload = login.encode();
1289 let max_packet = DEFAULT_PACKET_SIZE;
1294
1295 connection
1296 .send_message(PacketType::Tds7Login, payload, max_packet)
1297 .await?;
1298 Ok(())
1299 }
1300
1301 #[allow(clippy::never_loop)] async fn process_login_response<T>(
1312 connection: &mut Connection<T>,
1313 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1314 &dyn mssql_auth::SspiNegotiator,
1315 >,
1316 ) -> Result<(
1317 Option<u32>,
1318 Option<String>,
1319 Option<(String, u16)>,
1320 Option<tds_protocol::token::Collation>,
1321 )>
1322 where
1323 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1324 {
1325 let mut server_version = None;
1326 let mut database = None;
1327 let mut routing = None;
1328 let mut collation = None;
1329
1330 'outer: loop {
1331 let message = connection
1332 .read_message()
1333 .await?
1334 .ok_or(Error::ConnectionClosed)?;
1335
1336 let response_bytes = message.payload;
1337 let mut parser = TokenParser::new(response_bytes);
1338
1339 while let Some(token) = parser.next_token()? {
1340 match token {
1341 Token::LoginAck(ack) => {
1342 tracing::info!(
1343 version = ack.tds_version,
1344 interface = ack.interface,
1345 prog_name = %ack.prog_name,
1346 "login acknowledged"
1347 );
1348 server_version = Some(ack.tds_version);
1349 }
1350 Token::EnvChange(env) => {
1351 Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1352 }
1353 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1354 Token::Sspi(sspi_token) => {
1355 let neg = negotiator.ok_or_else(|| {
1356 Error::Protocol(
1357 "server sent SSPI challenge but no negotiator is configured"
1358 .to_string(),
1359 )
1360 })?;
1361
1362 tracing::debug!(
1363 challenge_len = sspi_token.data.len(),
1364 "received SSPI challenge from server"
1365 );
1366
1367 if let Some(response) = neg.step(&sspi_token.data)? {
1368 tracing::debug!(response_len = response.len(), "sending SSPI response");
1369 connection
1370 .send_message(
1371 PacketType::Sspi,
1372 bytes::Bytes::from(response),
1373 tds_protocol::packet::MAX_PACKET_SIZE,
1374 )
1375 .await?;
1376 }
1377
1378 continue 'outer;
1380 }
1381 Token::Error(err) => {
1382 return Err(Error::Server {
1383 number: err.number,
1384 state: err.state,
1385 class: err.class,
1386 message: err.message.clone(),
1387 server: if err.server.is_empty() {
1388 None
1389 } else {
1390 Some(err.server.clone())
1391 },
1392 procedure: if err.procedure.is_empty() {
1393 None
1394 } else {
1395 Some(err.procedure.clone())
1396 },
1397 line: err.line as u32,
1398 });
1399 }
1400 Token::Info(info) => {
1401 tracing::info!(
1402 number = info.number,
1403 message = %info.message,
1404 "server info message"
1405 );
1406 }
1407 Token::FeatureExtAck(ack) => {
1408 for feature in &ack.features {
1409 tracing::debug!(
1410 feature_id = feature.feature_id,
1411 data_len = feature.data.len(),
1412 "server acknowledged feature extension"
1413 );
1414 }
1415 }
1416 Token::Done(done) => {
1417 if done.status.error {
1418 return Err(Error::Protocol("login failed".to_string()));
1419 }
1420 break 'outer;
1421 }
1422 _ => {}
1423 }
1424 }
1425
1426 break;
1428 }
1429
1430 Ok((server_version, database, routing, collation))
1431 }
1432
1433 fn process_env_change(
1435 env: &EnvChange,
1436 database: &mut Option<String>,
1437 routing: &mut Option<(String, u16)>,
1438 collation: &mut Option<tds_protocol::token::Collation>,
1439 ) {
1440 use tds_protocol::token::EnvChangeValue;
1441
1442 match env.env_type {
1443 EnvChangeType::Database => {
1444 if let EnvChangeValue::String(ref new_value) = env.new_value {
1445 tracing::debug!(database = %new_value, "database changed");
1446 *database = Some(new_value.clone());
1447 }
1448 }
1449 EnvChangeType::Routing => {
1450 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1451 tracing::info!(host = %host, port = port, "routing redirect received");
1452 *routing = Some((host.clone(), port));
1453 }
1454 }
1455 EnvChangeType::SqlCollation => {
1456 if let EnvChangeValue::Binary(ref data) = env.new_value {
1457 if data.len() >= 5 {
1458 let c = tds_protocol::token::Collation::from_bytes(
1459 data[..5].try_into().unwrap(),
1460 );
1461 tracing::debug!(
1462 lcid = c.lcid,
1463 sort_id = c.sort_id,
1464 "server collation received"
1465 );
1466 *collation = Some(c);
1467 }
1468 }
1469 }
1470 _ => {
1471 if let EnvChangeValue::String(ref new_value) = env.new_value {
1472 tracing::debug!(
1473 env_type = ?env.env_type,
1474 new_value = %new_value,
1475 "environment change"
1476 );
1477 }
1478 }
1479 }
1480 }
1481}
1482
1483#[cfg(feature = "tls")]
1504fn connection_tls_config(config: &Config, strict: bool) -> TlsConfig {
1505 let tls = config
1506 .tls
1507 .clone()
1508 .trust_server_certificate(config.trust_server_certificate);
1509 if strict {
1510 tls.strict_mode(true)
1511 .with_alpn_protocols(vec![b"tds/8.0".to_vec()])
1512 } else {
1513 tls
1514 }
1515}
1516
1517#[cfg(all(test, feature = "tls"))]
1518mod tls_config_tests {
1519 use super::*;
1520 use mssql_tls::CertificateDer;
1521
1522 fn config_with_root(cert: Vec<u8>) -> Config {
1523 let mut config = Config::new();
1524 config.tls = config
1525 .tls
1526 .clone()
1527 .add_root_certificate(CertificateDer::from(cert));
1528 config
1529 }
1530
1531 #[test]
1532 fn custom_root_certificate_reaches_connector_config() {
1533 let config = config_with_root(vec![0xCA; 32]);
1537
1538 for strict in [true, false] {
1539 let tls = connection_tls_config(&config, strict);
1540 assert_eq!(
1541 tls.root_certificates.len(),
1542 1,
1543 "custom root must reach the connector (strict={strict})"
1544 );
1545 assert_eq!(tls.root_certificates[0].as_ref(), &[0xCA; 32][..]);
1546 }
1547 }
1548
1549 #[test]
1550 fn trust_server_certificate_taken_from_top_level_field() {
1551 let mut config = Config::new();
1554 config.trust_server_certificate = true;
1555 assert!(!config.tls.trust_server_certificate);
1557
1558 let tls = connection_tls_config(&config, false);
1559 assert!(
1560 tls.trust_server_certificate,
1561 "top-level trust flag must win"
1562 );
1563 }
1564
1565 #[test]
1566 fn strict_mode_adds_tds8_alpn() {
1567 let config = Config::new();
1568 let strict = connection_tls_config(&config, true);
1569 assert!(strict.strict_mode);
1570 assert!(strict.alpn_protocols.iter().any(|p| p == b"tds/8.0"));
1571
1572 let non_strict = connection_tls_config(&config, false);
1573 assert!(!non_strict.strict_mode);
1574 }
1575}
1576
1577#[cfg(test)]
1578#[allow(clippy::unwrap_used)]
1579mod fed_auth_login_tests {
1580 use super::*;
1581 use tds_protocol::prelogin::EncryptionLevel;
1582
1583 fn azure_config(token: &str) -> Config {
1584 Config::new().credentials(mssql_auth::Credentials::azure_token(token.to_string()))
1585 }
1586
1587 #[test]
1594 fn login7_fed_auth_feature_block_wire_exact() {
1595 let config = azure_config("AB");
1596 let login = Client::<Disconnected>::build_login7(
1597 &config,
1598 None,
1599 Some(FedAuthLogin {
1600 token: "AB",
1601 echo: true,
1602 }),
1603 );
1604
1605 assert!(
1606 !login.option_flags2.integrated_security,
1607 "fIntSecurity MUST be 0 when FEDAUTH is present"
1608 );
1609 assert!(
1610 login.username.is_empty() && login.password.is_empty(),
1611 "FEDAUTH logins must not carry username/password"
1612 );
1613
1614 let encoded = login.encode();
1615
1616 assert_eq!(encoded[27] & 0x10, 0x10, "fExtension bit must be set");
1618
1619 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1623 let ib_extension =
1624 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1625 let feature_off =
1626 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1627 as usize;
1628
1629 assert_eq!(
1630 encoded[feature_off], 0x02,
1631 "FeatureId must be FEDAUTH (0x02)"
1632 );
1633 let data_len = u32::from_le_bytes(
1634 encoded[feature_off + 1..feature_off + 5]
1635 .try_into()
1636 .unwrap(),
1637 ) as usize;
1638 assert_eq!(data_len, 9, "FeatureDataLen must cover options + token");
1640
1641 let data = &encoded[feature_off + 5..feature_off + 5 + data_len];
1642 assert_eq!(
1643 data,
1644 &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
1645 "options must be (SecurityToken << 1) | echo, then DWORD-LE \
1646 token byte length, then UTF-16LE token"
1647 );
1648 assert_eq!(
1649 encoded[feature_off + 5 + data_len],
1650 0xFF,
1651 "FeatureExt terminator must follow"
1652 );
1653 }
1654
1655 #[test]
1658 fn login7_fed_auth_echo_clear() {
1659 let config = azure_config("AB");
1660 let login = Client::<Disconnected>::build_login7(
1661 &config,
1662 None,
1663 Some(FedAuthLogin {
1664 token: "AB",
1665 echo: false,
1666 }),
1667 );
1668 let encoded = login.encode();
1669
1670 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1671 let ib_extension =
1672 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1673 let feature_off =
1674 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1675 as usize;
1676 assert_eq!(encoded[feature_off], 0x02);
1677 assert_eq!(
1678 encoded[feature_off + 5],
1679 0x02,
1680 "options byte must have fFedAuthEcho clear"
1681 );
1682 }
1683
1684 #[test]
1687 fn prelogin_advertises_fed_auth_for_azure_credentials() {
1688 let azure = azure_config("tok");
1689 let prelogin = Client::<Disconnected>::build_prelogin(&azure, EncryptionLevel::On);
1690 assert!(prelogin.fed_auth_required);
1691
1692 let sql = Config::new().credentials(mssql_auth::Credentials::sql_server("u", "p"));
1693 let prelogin = Client::<Disconnected>::build_prelogin(&sql, EncryptionLevel::On);
1694 assert!(!prelogin.fed_auth_required);
1695 }
1696
1697 #[tokio::test]
1705 async fn login7_large_fed_auth_token_is_split_at_default_packet_size() {
1706 use tds_protocol::packet::PACKET_HEADER_SIZE;
1707 use tokio::io::AsyncReadExt;
1708
1709 let token = "A".repeat(2000);
1711 let config = azure_config(&token);
1712 let login = Client::<Disconnected>::build_login7(
1713 &config,
1714 None,
1715 Some(FedAuthLogin {
1716 token: &token,
1717 echo: true,
1718 }),
1719 );
1720 let encoded = login.encode();
1721 assert!(
1722 encoded.len() > DEFAULT_PACKET_SIZE,
1723 "precondition: LOGIN7 ({}) must exceed the default packet size to exercise splitting",
1724 encoded.len()
1725 );
1726
1727 let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1729 let mut connection = Connection::new(client_io);
1730 Client::<Disconnected>::send_login7(&mut connection, &login)
1731 .await
1732 .unwrap();
1733 drop(connection); let mut raw = Vec::new();
1735 server_io.read_to_end(&mut raw).await.unwrap();
1736
1737 let mut offset = 0;
1740 let mut packets = 0;
1741 let mut reassembled = Vec::new();
1742 let mut saw_eom = false;
1743 while offset < raw.len() {
1744 let status = raw[offset + 1];
1745 let len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
1746 assert!(
1747 len <= DEFAULT_PACKET_SIZE,
1748 "packet {packets} length {len} exceeds the 4096-byte default"
1749 );
1750 assert!(!saw_eom, "no packet may follow the END_OF_MESSAGE packet");
1751 saw_eom = status & 0x01 == 0x01;
1752 reassembled.extend_from_slice(&raw[offset + PACKET_HEADER_SIZE..offset + len]);
1753 offset += len;
1754 packets += 1;
1755 }
1756 assert!(
1757 packets >= 2,
1758 "an oversized LOGIN7 must span multiple packets, got {packets}"
1759 );
1760 assert!(saw_eom, "the final packet must carry END_OF_MESSAGE");
1761 assert_eq!(
1762 reassembled,
1763 encoded.as_ref(),
1764 "reassembled packet payloads must equal the LOGIN7 encoding"
1765 );
1766 }
1767}