1use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::DEFAULT_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30#[derive(Clone, Copy)]
35struct FedAuthLogin<'a> {
36 token: &'a str,
37 echo: bool,
38}
39
40impl Client<Disconnected> {
41 pub async fn connect(config: Config) -> Result<Client<Ready>> {
57 Self::validate_credential_support(&config)?;
58
59 let fed_auth_token = Self::resolve_fed_auth_token(&config).await?;
64
65 let retry = config.retry.clone();
66 let max_redirects = config.redirect.max_redirects;
67 let follow_redirects = config.redirect.follow_redirects;
68 let per_attempt = config.timeouts.connect_timeout
70 + config.timeouts.tls_timeout
71 + config.timeouts.login_timeout;
72 let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
73 let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
74 let initial_host = config.host.clone();
75 let initial_port = config.port;
76
77 let result = timeout(overall, async {
78 let mut last_error: Option<Error> = None;
79
80 for retry_attempt in 0..=retry.max_retries {
81 if retry_attempt > 0 {
82 let backoff = retry.backoff_for_attempt(retry_attempt);
83 tracing::info!(
84 retry_attempt,
85 backoff_ms = backoff.as_millis() as u64,
86 "retrying connection after transient error"
87 );
88 tokio::time::sleep(backoff).await;
89 }
90
91 let mut current_config = config.clone();
93 let mut redirect_count: u8 = 0;
94
95 let attempt_result = loop {
96 redirect_count += 1;
97 if redirect_count > max_redirects + 1 {
98 break Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config, fed_auth_token.as_deref()).await {
102 Ok(client) => break Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 break Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 redirect = redirect_count,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => break Err(e),
118 }
119 };
120
121 match attempt_result {
122 Ok(client) => return Ok(client),
123 Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
124 tracing::warn!(
125 retry_attempt,
126 max_retries = retry.max_retries,
127 error = %e,
128 "transient connection error, will retry"
129 );
130 last_error = Some(attempt_result.unwrap_err());
131 }
132 Err(e) => return Err(e),
133 }
134 }
135
136 Err(last_error.expect("at least one attempt was made"))
138 })
139 .await;
140
141 match result {
142 Ok(inner) => inner,
143 Err(_elapsed) => Err(Error::ConnectTimeout {
144 host: initial_host,
145 port: initial_port,
146 }),
147 }
148 }
149
150 fn validate_credential_support(config: &Config) -> Result<()> {
156 match &config.credentials {
157 mssql_auth::Credentials::SqlServer { .. } => Ok(()),
158 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
159 mssql_auth::Credentials::Integrated => Ok(()),
160 creds if creds.is_azure_ad() => {
161 #[cfg(not(feature = "tls"))]
165 {
166 return Err(Error::Config(
167 "Azure AD / Entra (FEDAUTH) authentication requires TLS: \
168 enable the 'tls' feature."
169 .into(),
170 ));
171 }
172 #[cfg(feature = "tls")]
173 {
174 if config.no_tls {
175 return Err(Error::Config(
176 "Azure AD / Entra (FEDAUTH) authentication cannot be combined \
177 with Encrypt=no_tls: the access token would be sent in \
178 plaintext. Use Encrypt=mandatory or Encrypt=strict."
179 .into(),
180 ));
181 }
182 if matches!(&config.credentials,
183 mssql_auth::Credentials::AzureAccessToken { token } if token.is_empty())
184 {
185 return Err(Error::Config(
186 "Azure AD access token is empty (the FEDAUTH token length \
187 must not be zero)"
188 .into(),
189 ));
190 }
191 if !config.strict_mode && !config.tds_version.supports_fed_auth() {
192 return Err(Error::Config(format!(
193 "Azure AD / Entra (FEDAUTH) authentication requires TDS 7.4 \
194 or later (configured: {})",
195 config.tds_version
196 )));
197 }
198 Ok(())
199 }
200 }
201 _ => Err(Error::Config(
205 "client certificate (FEDAUTH) authentication is not yet supported \
206 (tracked in https://github.com/praxiomlabs/rust-mssql-driver/issues/155). \
207 Use SQL Server, integrated, or Azure AD / Entra authentication."
208 .into(),
209 )),
210 }
211 }
212
213 async fn resolve_fed_auth_token(config: &Config) -> Result<Option<String>> {
220 match &config.credentials {
221 mssql_auth::Credentials::AzureAccessToken { token } => Ok(Some(token.to_string())),
222 #[cfg(feature = "azure-identity")]
223 mssql_auth::Credentials::AzureManagedIdentity { client_id } => {
224 let auth = match client_id {
225 Some(id) => {
226 mssql_auth::ManagedIdentityAuth::user_assigned_client_id(id.to_string())?
227 }
228 None => mssql_auth::ManagedIdentityAuth::system_assigned()?,
229 };
230 tracing::debug!("acquiring Azure SQL access token via managed identity");
231 Ok(Some(auth.get_token().await?))
232 }
233 #[cfg(feature = "azure-identity")]
234 mssql_auth::Credentials::AzureServicePrincipal {
235 tenant_id,
236 client_id,
237 client_secret,
238 } => {
239 let auth = mssql_auth::ServicePrincipalAuth::new(
240 tenant_id.as_ref(),
241 client_id.to_string(),
242 client_secret.to_string(),
243 )?;
244 tracing::debug!(
245 client_id = %client_id,
246 "acquiring Azure SQL access token via service principal"
247 );
248 Ok(Some(auth.get_token().await?))
249 }
250 _ => Ok(None),
251 }
252 }
253
254 async fn try_connect(config: &Config, fed_auth_token: Option<&str>) -> Result<Client<Ready>> {
255 let port = if let Some(ref instance) = config.instance {
257 let resolved = crate::browser::resolve_instance(
258 &config.host,
259 instance,
260 Some(config.timeouts.connect_timeout),
261 )
262 .await?;
263 tracing::info!(
264 host = %config.host,
265 instance = %instance,
266 resolved_port = resolved,
267 database = ?config.database,
268 "connecting to named SQL Server instance"
269 );
270 resolved
271 } else {
272 tracing::info!(
273 host = %config.host,
274 port = config.port,
275 database = ?config.database,
276 "connecting to SQL Server"
277 );
278 config.port
279 };
280
281 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
284 "127.0.0.1"
285 } else {
286 &config.host
287 };
288
289 let tcp_stream = if config.multi_subnet_failover {
291 Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
292 } else {
293 let addr = format!("{host}:{port}");
294 tracing::debug!("establishing TCP connection to {}", addr);
295 let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
296 .await
297 .map_err(|_| Error::ConnectTimeout {
298 host: config.host.clone(),
299 port: config.port,
300 })?
301 .map_err(Error::from)?;
302 stream.set_nodelay(true).map_err(Error::from)?;
303 stream
304 };
305
306 #[cfg(feature = "tls")]
307 {
308 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
310
311 if tls_mode.is_tls_first() {
313 return Self::connect_tds_8(config, tcp_stream, fed_auth_token).await;
314 }
315
316 Self::connect_tds_7x(config, tcp_stream, fed_auth_token).await
318 }
319
320 #[cfg(not(feature = "tls"))]
321 {
322 let _ = fed_auth_token;
325
326 if config.strict_mode {
328 return Err(Error::Config(
329 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
330 ));
331 }
332
333 if !config.no_tls {
334 return Err(Error::Config(
335 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
336 or use Encrypt=no_tls in your connection string for unencrypted connections."
337 .into(),
338 ));
339 }
340
341 Self::connect_no_tls(config, tcp_stream).await
343 }
344 }
345
346 async fn connect_parallel(
351 host: &str,
352 port: u16,
353 connect_timeout: std::time::Duration,
354 ) -> Result<TcpStream> {
355 let addr_str = format!("{host}:{port}");
356 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
357 .await
358 .map_err(Error::from)?
359 .collect();
360
361 if addrs.is_empty() {
362 return Err(Error::from(std::io::Error::new(
363 std::io::ErrorKind::AddrNotAvailable,
364 format!("no addresses resolved for {host}:{port}"),
365 )));
366 }
367
368 if addrs.len() == 1 {
370 tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
371 let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
372 .await
373 .map_err(|_| Error::ConnectTimeout {
374 host: host.to_string(),
375 port,
376 })?
377 .map_err(Error::from)?;
378 stream.set_nodelay(true).map_err(Error::from)?;
379 return Ok(stream);
380 }
381
382 let addr_count = addrs.len();
383 tracing::debug!(
384 host = host,
385 port = port,
386 resolved_count = addr_count,
387 "MultiSubnetFailover: racing parallel connections",
388 );
389
390 let mut join_set = tokio::task::JoinSet::new();
391
392 for addr in addrs {
393 let dur = connect_timeout;
394 join_set.spawn(async move {
395 let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
396 std::io::Error::new(
397 std::io::ErrorKind::TimedOut,
398 format!("connection to {addr} timed out"),
399 )
400 })??;
401 tcp.set_nodelay(true)?;
402 Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
403 });
404 }
405
406 let mut last_error: Option<std::io::Error> = None;
407
408 while let Some(result) = join_set.join_next().await {
409 match result {
410 Ok(Ok((stream, addr))) => {
411 tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
412 join_set.abort_all();
413 return Ok(stream);
414 }
415 Ok(Err(e)) => {
416 tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
417 last_error = Some(e);
418 }
419 Err(join_err) => {
420 tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
421 last_error = Some(std::io::Error::other(join_err.to_string()));
422 }
423 }
424 }
425
426 Err(Error::from(last_error.unwrap_or_else(|| {
428 std::io::Error::new(
429 std::io::ErrorKind::ConnectionRefused,
430 format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
431 )
432 })))
433 }
434
435 #[cfg(feature = "tls")]
439 async fn connect_tds_8(
440 config: &Config,
441 tcp_stream: TcpStream,
442 fed_auth_token: Option<&str>,
443 ) -> Result<Client<Ready>> {
444 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
445
446 let tls_config = connection_tls_config(config, true);
449
450 let tls_connector = TlsConnector::new(tls_config)?;
451
452 let tls_stream = timeout(
454 config.timeouts.tls_timeout,
455 tls_connector.connect(tcp_stream, &config.host),
456 )
457 .await
458 .map_err(|_| Error::TlsTimeout {
459 host: config.host.clone(),
460 port: config.port,
461 })??;
462
463 tracing::debug!("TLS handshake completed (strict mode)");
464
465 let mut connection = Connection::new(tls_stream);
467 connection.set_max_message_size(config.max_response_size);
468
469 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
471 Self::send_prelogin(&mut connection, &prelogin).await?;
472 let prelogin_response = Self::receive_prelogin(&mut connection).await?;
473
474 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
476 let negotiator = Self::create_negotiator(config)?;
477 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
478 let sspi_token = match negotiator {
479 Some(ref neg) => Some(neg.initialize()?),
480 None => None,
481 };
482 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
483 let sspi_token: Option<Vec<u8>> = None;
484
485 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
487 token,
488 echo: prelogin_response.fed_auth_required,
489 });
490 let login = Self::build_login7(config, sspi_token, fed_auth);
491 Self::send_login7(&mut connection, &login).await?;
492
493 let (server_version, current_database, routing, server_collation) = timeout(
495 config.timeouts.login_timeout,
496 Self::process_login_response(
497 &mut connection,
498 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
499 negotiator.as_deref(),
500 ),
501 )
502 .await
503 .map_err(|_| Error::LoginTimeout {
504 host: config.host.clone(),
505 port: config.port,
506 })??;
507
508 if let Some((host, port)) = routing {
510 return Err(Error::Routing { host, port });
511 }
512
513 Ok(Client {
514 config: config.clone(),
515 _state: PhantomData,
516 connection: Some(ConnectionHandle::Tls(connection)),
517 server_version,
518 current_database: current_database.clone(),
519 server_collation,
520 statement_cache: StatementCache::with_default_size(),
521 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
525 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
526 .with_database(current_database.clone().unwrap_or_default()),
527 #[cfg(feature = "always-encrypted")]
528 encryption_context: config.column_encryption.clone().map(|cfg| {
529 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
530 }),
531 })
532 }
533
534 #[cfg(feature = "tls")]
542 async fn connect_tds_7x(
543 config: &Config,
544 mut tcp_stream: TcpStream,
545 fed_auth_token: Option<&str>,
546 ) -> Result<Client<Ready>> {
547 use bytes::BufMut;
548 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
549 use tokio::io::{AsyncReadExt, AsyncWriteExt};
550
551 tracing::debug!("using TDS 7.x flow (PreLogin first)");
552
553 let client_encryption = if config.no_tls {
556 tracing::warn!(
558 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
559 Credentials and data will be transmitted in plaintext. \
560 This should only be used for development/testing with legacy SQL Server."
561 );
562 EncryptionLevel::NotSupported
563 } else if config.encrypt {
564 EncryptionLevel::On
565 } else {
566 EncryptionLevel::Off
567 };
568 let prelogin = Self::build_prelogin(config, client_encryption);
569 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
570 let prelogin_bytes = prelogin.encode();
571
572 let header = PacketHeader::new(
574 PacketType::PreLogin,
575 PacketStatus::END_OF_MESSAGE,
576 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
577 );
578
579 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
580 header.encode(&mut packet_buf);
581 packet_buf.put_slice(&prelogin_bytes);
582
583 tcp_stream
584 .write_all(&packet_buf)
585 .await
586 .map_err(Error::from)?;
587
588 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
590 tcp_stream
591 .read_exact(&mut header_buf)
592 .await
593 .map_err(Error::from)?;
594
595 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
596 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
597
598 let mut response_buf = vec![0u8; payload_length];
599 tcp_stream
600 .read_exact(&mut response_buf)
601 .await
602 .map_err(Error::from)?;
603
604 let prelogin_response = PreLogin::decode(&response_buf[..])?;
605
606 let client_tds_version = config.tds_version;
611 if let Some(ref server_version) = prelogin_response.server_version {
612 tracing::debug!(
613 requested_tds_version = %client_tds_version,
614 server_product_version = %server_version,
615 server_product = server_version.product_name(),
616 max_tds_version = %server_version.max_tds_version(),
617 "PreLogin response received"
618 );
619
620 let server_max_tds = server_version.max_tds_version();
622 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
623 tracing::warn!(
624 requested_tds_version = %client_tds_version,
625 server_max_tds_version = %server_max_tds,
626 server_product = server_version.product_name(),
627 "Server supports lower TDS version than requested. \
628 Connection will use server's maximum: {}",
629 server_max_tds
630 );
631 }
632
633 if server_max_tds.is_legacy() {
635 tracing::warn!(
636 server_product = server_version.product_name(),
637 server_max_tds_version = %server_max_tds,
638 "Server uses legacy TDS version. Some features may not be available."
639 );
640 }
641 } else {
642 tracing::debug!(
643 requested_tds_version = %client_tds_version,
644 "PreLogin response received (no version info)"
645 );
646 }
647
648 let server_encryption = prelogin_response.encryption;
650 tracing::debug!(encryption = ?server_encryption, "server encryption level");
651
652 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
654 token,
655 echo: prelogin_response.fed_auth_required,
656 });
657
658 let negotiated_encryption = match (client_encryption, server_encryption) {
664 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
665 EncryptionLevel::NotSupported
666 }
667 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
668 (EncryptionLevel::On, EncryptionLevel::Off)
669 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
670 return Err(Error::Protocol(
671 "Server does not support requested encryption level".to_string(),
672 ));
673 }
674 _ => EncryptionLevel::On,
675 };
676
677 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
680
681 if use_tls {
682 let tls_config = connection_tls_config(config, false);
687
688 let tls_connector = TlsConnector::new(tls_config)?;
689
690 let mut tls_stream = timeout(
692 config.timeouts.tls_timeout,
693 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
694 )
695 .await
696 .map_err(|_| Error::TlsTimeout {
697 host: config.host.clone(),
698 port: config.port,
699 })??;
700
701 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
702
703 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
705
706 if login_only_encryption {
707 use tokio::io::AsyncWriteExt;
715
716 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
721 let negotiator = Self::create_negotiator(config)?;
722 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
723 let sspi_token = match negotiator {
724 Some(ref neg) => Some(neg.initialize()?),
725 None => None,
726 };
727 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
728 let sspi_token: Option<Vec<u8>> = None;
729
730 let login = Self::build_login7(config, sspi_token, fed_auth);
732 let login_payload = login.encode();
733
734 let max_packet = DEFAULT_PACKET_SIZE;
740 let max_payload = max_packet - PACKET_HEADER_SIZE;
741 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
742 let total_chunks = chunks.len();
743
744 for (i, chunk) in chunks.into_iter().enumerate() {
745 let is_last = i == total_chunks - 1;
746 let status = if is_last {
747 PacketStatus::END_OF_MESSAGE
748 } else {
749 PacketStatus::NORMAL
750 };
751
752 let header = PacketHeader::new(
753 PacketType::Tds7Login,
754 status,
755 (PACKET_HEADER_SIZE + chunk.len()) as u16,
756 );
757
758 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
759 header.encode(&mut packet_buf);
760 packet_buf.put_slice(chunk);
761
762 tls_stream
763 .write_all(&packet_buf)
764 .await
765 .map_err(Error::from)?;
766 }
767
768 tls_stream.flush().await.map_err(Error::from)?;
770
771 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
772
773 let (wrapper, _client_conn) = tls_stream.into_inner();
777 let tcp_stream = wrapper.into_inner();
778
779 let mut connection = Connection::new(tcp_stream);
781 connection.set_max_message_size(config.max_response_size);
782
783 let (server_version, current_database, routing, server_collation) = timeout(
785 config.timeouts.login_timeout,
786 Self::process_login_response(
787 &mut connection,
788 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
789 negotiator.as_deref(),
790 ),
791 )
792 .await
793 .map_err(|_| Error::LoginTimeout {
794 host: config.host.clone(),
795 port: config.port,
796 })??;
797
798 if let Some((host, port)) = routing {
800 return Err(Error::Routing { host, port });
801 }
802
803 Ok(Client {
805 config: config.clone(),
806 _state: PhantomData,
807 connection: Some(ConnectionHandle::Plain(connection)),
808 server_version,
809 current_database: current_database.clone(),
810 server_collation,
811 statement_cache: StatementCache::with_default_size(),
812 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
816 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
817 .with_database(current_database.clone().unwrap_or_default()),
818 #[cfg(feature = "always-encrypted")]
819 encryption_context: config.column_encryption.clone().map(|cfg| {
820 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
821 }),
822 })
823 } else {
824 let mut connection = Connection::new(tls_stream);
827 connection.set_max_message_size(config.max_response_size);
828
829 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
831 let negotiator = Self::create_negotiator(config)?;
832 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
833 let sspi_token = match negotiator {
834 Some(ref neg) => Some(neg.initialize()?),
835 None => None,
836 };
837 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
838 let sspi_token: Option<Vec<u8>> = None;
839
840 let login = Self::build_login7(config, sspi_token, fed_auth);
842 Self::send_login7(&mut connection, &login).await?;
843
844 let (server_version, current_database, routing, server_collation) = timeout(
846 config.timeouts.login_timeout,
847 Self::process_login_response(
848 &mut connection,
849 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
850 negotiator.as_deref(),
851 ),
852 )
853 .await
854 .map_err(|_| Error::LoginTimeout {
855 host: config.host.clone(),
856 port: config.port,
857 })??;
858
859 if let Some((host, port)) = routing {
861 return Err(Error::Routing { host, port });
862 }
863
864 Ok(Client {
865 config: config.clone(),
866 _state: PhantomData,
867 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
868 server_version,
869 current_database: current_database.clone(),
870 server_collation,
871 statement_cache: StatementCache::with_default_size(),
872 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
876 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
877 .with_database(current_database.clone().unwrap_or_default()),
878 #[cfg(feature = "always-encrypted")]
879 encryption_context: config.column_encryption.clone().map(|cfg| {
880 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
881 }),
882 })
883 }
884 } else {
885 tracing::warn!(
887 "Connecting without TLS encryption. This is insecure and should only be \
888 used for development/testing on trusted networks."
889 );
890
891 let mut connection = Connection::new(tcp_stream);
892
893 connection.set_max_message_size(config.max_response_size);
894
895 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
897 let negotiator = Self::create_negotiator(config)?;
898 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
899 let sspi_token = match negotiator {
900 Some(ref neg) => Some(neg.initialize()?),
901 None => None,
902 };
903 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
904 let sspi_token: Option<Vec<u8>> = None;
905
906 let login = Self::build_login7(config, sspi_token, fed_auth);
910 Self::send_login7(&mut connection, &login).await?;
911
912 let (server_version, current_database, routing, server_collation) = timeout(
914 config.timeouts.login_timeout,
915 Self::process_login_response(
916 &mut connection,
917 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
918 negotiator.as_deref(),
919 ),
920 )
921 .await
922 .map_err(|_| Error::LoginTimeout {
923 host: config.host.clone(),
924 port: config.port,
925 })??;
926
927 if let Some((host, port)) = routing {
929 return Err(Error::Routing { host, port });
930 }
931
932 Ok(Client {
933 config: config.clone(),
934 _state: PhantomData,
935 connection: Some(ConnectionHandle::Plain(connection)),
936 server_version,
937 current_database: current_database.clone(),
938 server_collation,
939 statement_cache: StatementCache::with_default_size(),
940 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
944 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
945 .with_database(current_database.clone().unwrap_or_default()),
946 #[cfg(feature = "always-encrypted")]
947 encryption_context: config.column_encryption.clone().map(|cfg| {
948 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
949 }),
950 })
951 }
952 }
953
954 #[cfg(not(feature = "tls"))]
965 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
966 use bytes::BufMut;
967 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
968 use tokio::io::{AsyncReadExt, AsyncWriteExt};
969
970 tracing::warn!(
971 "⚠️ Connecting without TLS (tls feature disabled). \
972 Credentials and data will be transmitted in plaintext."
973 );
974
975 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
977 let prelogin_bytes = prelogin.encode();
978
979 let header = PacketHeader::new(
981 PacketType::PreLogin,
982 PacketStatus::END_OF_MESSAGE,
983 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
984 );
985
986 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
987 header.encode(&mut packet_buf);
988 packet_buf.put_slice(&prelogin_bytes);
989
990 tcp_stream
991 .write_all(&packet_buf)
992 .await
993 .map_err(Error::from)?;
994
995 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
997 tcp_stream
998 .read_exact(&mut header_buf)
999 .await
1000 .map_err(Error::from)?;
1001
1002 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
1003 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
1004
1005 let mut response_buf = vec![0u8; payload_length];
1006 tcp_stream
1007 .read_exact(&mut response_buf)
1008 .await
1009 .map_err(Error::from)?;
1010
1011 let prelogin_response = PreLogin::decode(&response_buf[..])?;
1012
1013 let server_encryption = prelogin_response.encryption;
1015 if server_encryption != EncryptionLevel::NotSupported {
1016 return Err(Error::Config(format!(
1017 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
1018 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
1019 server_encryption
1020 )));
1021 }
1022
1023 tracing::debug!("Server accepted unencrypted connection");
1024
1025 let mut connection = Connection::new(tcp_stream);
1026
1027 connection.set_max_message_size(config.max_response_size);
1028
1029 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1031 let negotiator = Self::create_negotiator(config)?;
1032 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1033 let sspi_token = match negotiator {
1034 Some(ref neg) => Some(neg.initialize()?),
1035 None => None,
1036 };
1037 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
1038 let sspi_token: Option<Vec<u8>> = None;
1039
1040 let login = Self::build_login7(config, sspi_token, None);
1043 Self::send_login7(&mut connection, &login).await?;
1044
1045 let (server_version, current_database, routing, server_collation) = timeout(
1047 config.timeouts.login_timeout,
1048 Self::process_login_response(
1049 &mut connection,
1050 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1051 negotiator.as_deref(),
1052 ),
1053 )
1054 .await
1055 .map_err(|_| Error::LoginTimeout {
1056 host: config.host.clone(),
1057 port: config.port,
1058 })??;
1059
1060 if let Some((host, port)) = routing {
1062 return Err(Error::Routing { host, port });
1063 }
1064
1065 Ok(Client {
1066 config: config.clone(),
1067 _state: PhantomData,
1068 connection: Some(ConnectionHandle::Plain(connection)),
1069 server_version,
1070 current_database: current_database.clone(),
1071 server_collation,
1072 statement_cache: StatementCache::with_default_size(),
1073 transaction_descriptor: 0,
1074 needs_reset: false,
1075 in_flight: false,
1076 #[cfg(feature = "otel")]
1077 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
1078 .with_database(current_database.clone().unwrap_or_default()),
1079 #[cfg(feature = "always-encrypted")]
1080 encryption_context: config.column_encryption.clone().map(|cfg| {
1081 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
1082 }),
1083 })
1084 }
1085
1086 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
1088 let version = if config.strict_mode {
1090 tds_protocol::version::TdsVersion::V8_0
1091 } else {
1092 config.tds_version
1093 };
1094
1095 let mut prelogin = PreLogin::new()
1096 .with_version(version)
1097 .with_encryption(encryption);
1098
1099 if config.mars {
1100 prelogin = prelogin.with_mars(true);
1101 }
1102
1103 if let Some(ref instance) = config.instance {
1104 prelogin = prelogin.with_instance(instance);
1105 }
1106
1107 if config.credentials.is_azure_ad() {
1110 prelogin = prelogin.with_fed_auth_required(true);
1111 }
1112
1113 prelogin
1114 }
1115
1116 fn resolve_workstation_id(config: &Config) -> String {
1124 if let Some(ref id) = config.workstation_id {
1125 return id.clone();
1126 }
1127 std::env::var("COMPUTERNAME")
1130 .or_else(|_| std::env::var("HOSTNAME"))
1131 .unwrap_or_default()
1132 }
1133
1134 fn build_login7(
1144 config: &Config,
1145 sspi_token: Option<Vec<u8>>,
1146 fed_auth: Option<FedAuthLogin<'_>>,
1147 ) -> Login7 {
1148 let version = if config.strict_mode {
1150 tds_protocol::version::TdsVersion::V8_0
1151 } else {
1152 config.tds_version
1153 };
1154
1155 let mut login = Login7::new()
1156 .with_tds_version(version)
1157 .with_packet_size(config.packet_size as u32)
1158 .with_app_name(&config.application_name)
1159 .with_server_name(&config.host)
1160 .with_hostname(Self::resolve_workstation_id(config));
1161
1162 if let Some(ref database) = config.database {
1163 login = login.with_database(database);
1164 }
1165
1166 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
1168 login = login.with_read_only_intent(true);
1169 }
1170
1171 if let Some(ref lang) = config.language {
1173 login = login.with_language(lang);
1174 }
1175
1176 if let Some(token) = sspi_token {
1178 login = login.with_integrated_auth(token);
1180 } else if let Some(fed) = fed_auth {
1181 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1184 feature_id: tds_protocol::login7::FeatureId::FedAuth,
1185 data: mssql_auth::azure_ad::build_security_token_feature_data(fed.token, fed.echo),
1186 });
1187 tracing::debug!(
1188 fed_auth_echo = fed.echo,
1189 "Login7: adding FEDAUTH feature extension (SecurityToken workflow)"
1190 );
1191 } else if let mssql_auth::Credentials::SqlServer { username, password } =
1192 &config.credentials
1193 {
1194 login = login.with_sql_auth(username.as_ref(), password.as_ref());
1195 }
1196
1197 #[cfg(feature = "always-encrypted")]
1200 if config.column_encryption.is_some() {
1201 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1202 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1203 data: bytes::Bytes::from_static(&[0x01]), });
1205 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1206 }
1207
1208 login
1209 }
1210
1211 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1221 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1222 #[allow(clippy::match_like_matches_macro)]
1223 let is_integrated = match &config.credentials {
1224 mssql_auth::Credentials::Integrated => true,
1225 _ => false,
1226 };
1227
1228 if !is_integrated {
1229 return Ok(None);
1230 }
1231
1232 #[cfg(all(windows, feature = "sspi-auth"))]
1237 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1238 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1239
1240 #[cfg(all(not(windows), feature = "sspi-auth"))]
1242 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1243 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1244
1245 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1246 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1247 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1248
1249 Ok(Some(negotiator))
1250 }
1251
1252 #[cfg(feature = "tls")]
1254 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1255 where
1256 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1257 {
1258 let payload = prelogin.encode();
1259 let max_packet = tds_protocol::packet::MAX_PACKET_SIZE;
1263
1264 connection
1265 .send_message(PacketType::PreLogin, payload, max_packet)
1266 .await?;
1267 Ok(())
1268 }
1269
1270 #[cfg(feature = "tls")]
1272 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1273 where
1274 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1275 {
1276 let message = connection
1277 .read_message()
1278 .await?
1279 .ok_or(Error::ConnectionClosed)?;
1280
1281 Ok(PreLogin::decode(&message.payload[..])?)
1282 }
1283
1284 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1286 where
1287 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1288 {
1289 let payload = login.encode();
1290 let max_packet = DEFAULT_PACKET_SIZE;
1295
1296 connection
1297 .send_message(PacketType::Tds7Login, payload, max_packet)
1298 .await?;
1299 Ok(())
1300 }
1301
1302 #[allow(clippy::never_loop)] async fn process_login_response<T>(
1313 connection: &mut Connection<T>,
1314 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1315 &dyn mssql_auth::SspiNegotiator,
1316 >,
1317 ) -> Result<(
1318 Option<u32>,
1319 Option<String>,
1320 Option<(String, u16)>,
1321 Option<tds_protocol::token::Collation>,
1322 )>
1323 where
1324 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1325 {
1326 let mut server_version = None;
1327 let mut database = None;
1328 let mut routing = None;
1329 let mut collation = None;
1330
1331 'outer: loop {
1332 let message = connection
1333 .read_message()
1334 .await?
1335 .ok_or(Error::ConnectionClosed)?;
1336
1337 let response_bytes = message.payload;
1338 let mut parser = TokenParser::new(response_bytes);
1339
1340 while let Some(token) = parser.next_token()? {
1341 match token {
1342 Token::LoginAck(ack) => {
1343 tracing::info!(
1344 version = ack.tds_version,
1345 interface = ack.interface,
1346 prog_name = %ack.prog_name,
1347 "login acknowledged"
1348 );
1349 server_version = Some(ack.tds_version);
1350 }
1351 Token::EnvChange(env) => {
1352 Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1353 }
1354 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1355 Token::Sspi(sspi_token) => {
1356 let neg = negotiator.ok_or_else(|| {
1357 Error::Protocol(
1358 "server sent SSPI challenge but no negotiator is configured"
1359 .to_string(),
1360 )
1361 })?;
1362
1363 tracing::debug!(
1364 challenge_len = sspi_token.data.len(),
1365 "received SSPI challenge from server"
1366 );
1367
1368 if let Some(response) = neg.step(&sspi_token.data)? {
1369 tracing::debug!(response_len = response.len(), "sending SSPI response");
1370 connection
1371 .send_message(
1372 PacketType::Sspi,
1373 bytes::Bytes::from(response),
1374 tds_protocol::packet::MAX_PACKET_SIZE,
1375 )
1376 .await?;
1377 }
1378
1379 continue 'outer;
1381 }
1382 Token::Error(err) => {
1383 return Err(Error::Server {
1384 number: err.number,
1385 state: err.state,
1386 class: err.class,
1387 message: err.message.clone(),
1388 server: if err.server.is_empty() {
1389 None
1390 } else {
1391 Some(err.server.clone())
1392 },
1393 procedure: if err.procedure.is_empty() {
1394 None
1395 } else {
1396 Some(err.procedure.clone())
1397 },
1398 line: err.line as u32,
1399 });
1400 }
1401 Token::Info(info) => {
1402 tracing::info!(
1403 number = info.number,
1404 message = %info.message,
1405 "server info message"
1406 );
1407 }
1408 Token::FeatureExtAck(ack) => {
1409 for feature in &ack.features {
1410 tracing::debug!(
1411 feature_id = feature.feature_id,
1412 data_len = feature.data.len(),
1413 "server acknowledged feature extension"
1414 );
1415 }
1416 }
1417 Token::Done(done) => {
1418 if done.status.error {
1419 return Err(Error::Protocol("login failed".to_string()));
1420 }
1421 break 'outer;
1422 }
1423 _ => {}
1424 }
1425 }
1426
1427 break;
1429 }
1430
1431 Ok((server_version, database, routing, collation))
1432 }
1433
1434 fn process_env_change(
1436 env: &EnvChange,
1437 database: &mut Option<String>,
1438 routing: &mut Option<(String, u16)>,
1439 collation: &mut Option<tds_protocol::token::Collation>,
1440 ) {
1441 use tds_protocol::token::EnvChangeValue;
1442
1443 match env.env_type {
1444 EnvChangeType::Database => {
1445 if let EnvChangeValue::String(ref new_value) = env.new_value {
1446 tracing::debug!(database = %new_value, "database changed");
1447 *database = Some(new_value.clone());
1448 }
1449 }
1450 EnvChangeType::Routing => {
1451 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1452 tracing::info!(host = %host, port = port, "routing redirect received");
1453 *routing = Some((host.clone(), port));
1454 }
1455 }
1456 EnvChangeType::SqlCollation => {
1457 if let EnvChangeValue::Binary(ref data) = env.new_value {
1458 if data.len() >= 5 {
1459 let c = tds_protocol::token::Collation::from_bytes(
1460 data[..5].try_into().unwrap(),
1461 );
1462 tracing::debug!(
1463 lcid = c.lcid,
1464 sort_id = c.sort_id,
1465 "server collation received"
1466 );
1467 *collation = Some(c);
1468 }
1469 }
1470 }
1471 _ => {
1472 if let EnvChangeValue::String(ref new_value) = env.new_value {
1473 tracing::debug!(
1474 env_type = ?env.env_type,
1475 new_value = %new_value,
1476 "environment change"
1477 );
1478 }
1479 }
1480 }
1481 }
1482}
1483
1484#[cfg(feature = "tls")]
1505fn connection_tls_config(config: &Config, strict: bool) -> TlsConfig {
1506 let tls = config
1507 .tls
1508 .clone()
1509 .trust_server_certificate(config.trust_server_certificate);
1510 if strict {
1511 tls.strict_mode(true)
1512 .with_alpn_protocols(vec![b"tds/8.0".to_vec()])
1513 } else {
1514 tls
1515 }
1516}
1517
1518#[cfg(all(test, feature = "tls"))]
1519mod tls_config_tests {
1520 use super::*;
1521 use mssql_tls::CertificateDer;
1522
1523 fn config_with_root(cert: Vec<u8>) -> Config {
1524 let mut config = Config::new();
1525 config.tls = config
1526 .tls
1527 .clone()
1528 .add_root_certificate(CertificateDer::from(cert));
1529 config
1530 }
1531
1532 #[test]
1533 fn custom_root_certificate_reaches_connector_config() {
1534 let config = config_with_root(vec![0xCA; 32]);
1538
1539 for strict in [true, false] {
1540 let tls = connection_tls_config(&config, strict);
1541 assert_eq!(
1542 tls.root_certificates.len(),
1543 1,
1544 "custom root must reach the connector (strict={strict})"
1545 );
1546 assert_eq!(tls.root_certificates[0].as_ref(), &[0xCA; 32][..]);
1547 }
1548 }
1549
1550 #[test]
1551 fn trust_server_certificate_taken_from_top_level_field() {
1552 let mut config = Config::new();
1555 config.trust_server_certificate = true;
1556 assert!(!config.tls.trust_server_certificate);
1558
1559 let tls = connection_tls_config(&config, false);
1560 assert!(
1561 tls.trust_server_certificate,
1562 "top-level trust flag must win"
1563 );
1564 }
1565
1566 #[test]
1567 fn strict_mode_adds_tds8_alpn() {
1568 let config = Config::new();
1569 let strict = connection_tls_config(&config, true);
1570 assert!(strict.strict_mode);
1571 assert!(strict.alpn_protocols.iter().any(|p| p == b"tds/8.0"));
1572
1573 let non_strict = connection_tls_config(&config, false);
1574 assert!(!non_strict.strict_mode);
1575 }
1576}
1577
1578#[cfg(test)]
1579#[allow(clippy::unwrap_used)]
1580mod fed_auth_login_tests {
1581 use super::*;
1582 use tds_protocol::prelogin::EncryptionLevel;
1583
1584 fn azure_config(token: &str) -> Config {
1585 Config::new().credentials(mssql_auth::Credentials::azure_token(token.to_string()))
1586 }
1587
1588 #[test]
1595 fn login7_fed_auth_feature_block_wire_exact() {
1596 let config = azure_config("AB");
1597 let login = Client::<Disconnected>::build_login7(
1598 &config,
1599 None,
1600 Some(FedAuthLogin {
1601 token: "AB",
1602 echo: true,
1603 }),
1604 );
1605
1606 assert!(
1607 !login.option_flags2.integrated_security,
1608 "fIntSecurity MUST be 0 when FEDAUTH is present"
1609 );
1610 assert!(
1611 login.username.is_empty() && login.password.is_empty(),
1612 "FEDAUTH logins must not carry username/password"
1613 );
1614
1615 let encoded = login.encode();
1616
1617 assert_eq!(encoded[27] & 0x10, 0x10, "fExtension bit must be set");
1619
1620 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1624 let ib_extension =
1625 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1626 let feature_off =
1627 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1628 as usize;
1629
1630 assert_eq!(
1631 encoded[feature_off], 0x02,
1632 "FeatureId must be FEDAUTH (0x02)"
1633 );
1634 let data_len = u32::from_le_bytes(
1635 encoded[feature_off + 1..feature_off + 5]
1636 .try_into()
1637 .unwrap(),
1638 ) as usize;
1639 assert_eq!(data_len, 9, "FeatureDataLen must cover options + token");
1641
1642 let data = &encoded[feature_off + 5..feature_off + 5 + data_len];
1643 assert_eq!(
1644 data,
1645 &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
1646 "options must be (SecurityToken << 1) | echo, then DWORD-LE \
1647 token byte length, then UTF-16LE token"
1648 );
1649 assert_eq!(
1650 encoded[feature_off + 5 + data_len],
1651 0xFF,
1652 "FeatureExt terminator must follow"
1653 );
1654 }
1655
1656 #[test]
1659 fn login7_fed_auth_echo_clear() {
1660 let config = azure_config("AB");
1661 let login = Client::<Disconnected>::build_login7(
1662 &config,
1663 None,
1664 Some(FedAuthLogin {
1665 token: "AB",
1666 echo: false,
1667 }),
1668 );
1669 let encoded = login.encode();
1670
1671 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1672 let ib_extension =
1673 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1674 let feature_off =
1675 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1676 as usize;
1677 assert_eq!(encoded[feature_off], 0x02);
1678 assert_eq!(
1679 encoded[feature_off + 5],
1680 0x02,
1681 "options byte must have fFedAuthEcho clear"
1682 );
1683 }
1684
1685 #[test]
1688 fn prelogin_advertises_fed_auth_for_azure_credentials() {
1689 let azure = azure_config("tok");
1690 let prelogin = Client::<Disconnected>::build_prelogin(&azure, EncryptionLevel::On);
1691 assert!(prelogin.fed_auth_required);
1692
1693 let sql = Config::new().credentials(mssql_auth::Credentials::sql_server("u", "p"));
1694 let prelogin = Client::<Disconnected>::build_prelogin(&sql, EncryptionLevel::On);
1695 assert!(!prelogin.fed_auth_required);
1696 }
1697
1698 #[tokio::test]
1706 async fn login7_large_fed_auth_token_is_split_at_default_packet_size() {
1707 use tds_protocol::packet::PACKET_HEADER_SIZE;
1708 use tokio::io::AsyncReadExt;
1709
1710 let token = "A".repeat(2000);
1712 let config = azure_config(&token);
1713 let login = Client::<Disconnected>::build_login7(
1714 &config,
1715 None,
1716 Some(FedAuthLogin {
1717 token: &token,
1718 echo: true,
1719 }),
1720 );
1721 let encoded = login.encode();
1722 assert!(
1723 encoded.len() > DEFAULT_PACKET_SIZE,
1724 "precondition: LOGIN7 ({}) must exceed the default packet size to exercise splitting",
1725 encoded.len()
1726 );
1727
1728 let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1730 let mut connection = Connection::new(client_io);
1731 Client::<Disconnected>::send_login7(&mut connection, &login)
1732 .await
1733 .unwrap();
1734 drop(connection); let mut raw = Vec::new();
1736 server_io.read_to_end(&mut raw).await.unwrap();
1737
1738 let mut offset = 0;
1741 let mut packets = 0;
1742 let mut reassembled = Vec::new();
1743 let mut saw_eom = false;
1744 while offset < raw.len() {
1745 let status = raw[offset + 1];
1746 let len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
1747 assert!(
1748 len <= DEFAULT_PACKET_SIZE,
1749 "packet {packets} length {len} exceeds the 4096-byte default"
1750 );
1751 assert!(!saw_eom, "no packet may follow the END_OF_MESSAGE packet");
1752 saw_eom = status & 0x01 == 0x01;
1753 reassembled.extend_from_slice(&raw[offset + PACKET_HEADER_SIZE..offset + len]);
1754 offset += len;
1755 packets += 1;
1756 }
1757 assert!(
1758 packets >= 2,
1759 "an oversized LOGIN7 must span multiple packets, got {packets}"
1760 );
1761 assert!(saw_eom, "the final packet must carry END_OF_MESSAGE");
1762 assert_eq!(
1763 reassembled,
1764 encoded.as_ref(),
1765 "reassembled packet payloads must equal the LOGIN7 encoding"
1766 );
1767 }
1768}