1use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::DEFAULT_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30#[derive(Clone, Copy)]
35struct FedAuthLogin<'a> {
36 token: &'a str,
37 echo: bool,
38}
39
40impl Client<Disconnected> {
41 pub async fn connect(config: Config) -> Result<Client<Ready>> {
57 Self::validate_credential_support(&config)?;
58
59 let fed_auth_token = Self::resolve_fed_auth_token(&config).await?;
64
65 let retry = config.retry.clone();
66 let max_redirects = config.redirect.max_redirects;
67 let follow_redirects = config.redirect.follow_redirects;
68 let per_attempt = config.timeouts.connect_timeout
70 + config.timeouts.tls_timeout
71 + config.timeouts.login_timeout;
72 let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
73 let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
74 let initial_host = config.host.clone();
75 let initial_port = config.port;
76
77 let result = timeout(overall, async {
78 let mut last_error: Option<Error> = None;
79
80 for retry_attempt in 0..=retry.max_retries {
81 if retry_attempt > 0 {
82 let backoff = retry.backoff_for_attempt(retry_attempt);
83 tracing::info!(
84 retry_attempt,
85 backoff_ms = backoff.as_millis() as u64,
86 "retrying connection after transient error"
87 );
88 tokio::time::sleep(backoff).await;
89 }
90
91 let mut current_config = config.clone();
93 let mut redirect_count: u8 = 0;
94
95 let attempt_result = loop {
96 redirect_count += 1;
97 if redirect_count > max_redirects + 1 {
98 break Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config, fed_auth_token.as_deref()).await {
102 Ok(client) => break Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 break Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 redirect = redirect_count,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => break Err(e),
118 }
119 };
120
121 match attempt_result {
122 Ok(client) => return Ok(client),
123 Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
124 tracing::warn!(
125 retry_attempt,
126 max_retries = retry.max_retries,
127 error = %e,
128 "transient connection error, will retry"
129 );
130 last_error = Some(attempt_result.unwrap_err());
131 }
132 Err(e) => return Err(e),
133 }
134 }
135
136 Err(last_error.expect("at least one attempt was made"))
138 })
139 .await;
140
141 match result {
142 Ok(inner) => inner,
143 Err(_elapsed) => Err(Error::ConnectTimeout {
144 host: initial_host,
145 port: initial_port,
146 }),
147 }
148 }
149
150 fn validate_credential_support(config: &Config) -> Result<()> {
156 match &config.credentials {
157 mssql_auth::Credentials::SqlServer { .. } => Ok(()),
158 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
159 mssql_auth::Credentials::Integrated => Ok(()),
160 creds if creds.is_azure_ad() => {
161 #[cfg(not(feature = "tls"))]
165 {
166 Err(Error::Config(
167 "Azure AD / Entra (FEDAUTH) authentication requires TLS: \
168 enable the 'tls' feature."
169 .into(),
170 ))
171 }
172 #[cfg(feature = "tls")]
173 {
174 if config.no_tls {
175 return Err(Error::Config(
176 "Azure AD / Entra (FEDAUTH) authentication cannot be combined \
177 with Encrypt=no_tls: the access token would be sent in \
178 plaintext. Use Encrypt=mandatory or Encrypt=strict."
179 .into(),
180 ));
181 }
182 if matches!(&config.credentials,
183 mssql_auth::Credentials::AzureAccessToken { token } if token.is_empty())
184 {
185 return Err(Error::Config(
186 "Azure AD access token is empty (the FEDAUTH token length \
187 must not be zero)"
188 .into(),
189 ));
190 }
191 if !config.strict_mode && !config.tds_version.supports_fed_auth() {
192 return Err(Error::Config(format!(
193 "Azure AD / Entra (FEDAUTH) authentication requires TDS 7.4 \
194 or later (configured: {})",
195 config.tds_version
196 )));
197 }
198 Ok(())
199 }
200 }
201 #[cfg(feature = "cert-auth")]
205 mssql_auth::Credentials::Certificate { .. } => {
206 #[cfg(not(feature = "tls"))]
207 {
208 Err(Error::Config(
209 "client certificate (FEDAUTH) authentication requires TLS: \
210 enable the 'tls' feature."
211 .into(),
212 ))
213 }
214 #[cfg(feature = "tls")]
215 {
216 if config.no_tls {
217 return Err(Error::Config(
218 "client certificate (FEDAUTH) authentication cannot be combined \
219 with Encrypt=no_tls: the access token would be sent in \
220 plaintext. Use Encrypt=mandatory or Encrypt=strict."
221 .into(),
222 ));
223 }
224 if !config.strict_mode && !config.tds_version.supports_fed_auth() {
225 return Err(Error::Config(format!(
226 "client certificate (FEDAUTH) authentication requires TDS 7.4 \
227 or later (configured: {})",
228 config.tds_version
229 )));
230 }
231 Ok(())
232 }
233 }
234 _ => Err(Error::Config(
236 "this credential type is not supported by Client::connect. \
237 Use SQL Server, integrated, or Azure AD / Entra authentication."
238 .into(),
239 )),
240 }
241 }
242
243 async fn resolve_fed_auth_token(config: &Config) -> Result<Option<String>> {
250 match &config.credentials {
251 mssql_auth::Credentials::AzureAccessToken { token } => Ok(Some(token.to_string())),
252 #[cfg(feature = "azure-identity")]
253 mssql_auth::Credentials::AzureDefault => {
254 let auth = mssql_auth::DefaultAzureAuth::new()?;
255 tracing::debug!(
256 "acquiring Azure SQL access token via the default credential chain"
257 );
258 Ok(Some(auth.get_token().await?))
259 }
260 #[cfg(feature = "azure-identity")]
261 mssql_auth::Credentials::AzureManagedIdentity { client_id } => {
262 let auth = match client_id {
263 Some(id) => {
264 mssql_auth::ManagedIdentityAuth::user_assigned_client_id(id.to_string())?
265 }
266 None => mssql_auth::ManagedIdentityAuth::system_assigned()?,
267 };
268 tracing::debug!("acquiring Azure SQL access token via managed identity");
269 Ok(Some(auth.get_token().await?))
270 }
271 #[cfg(feature = "azure-identity")]
272 mssql_auth::Credentials::AzureServicePrincipal {
273 tenant_id,
274 client_id,
275 client_secret,
276 } => {
277 let auth = mssql_auth::ServicePrincipalAuth::new(
278 tenant_id.as_ref(),
279 client_id.to_string(),
280 client_secret.to_string(),
281 )?;
282 tracing::debug!(
283 client_id = %client_id,
284 "acquiring Azure SQL access token via service principal"
285 );
286 Ok(Some(auth.get_token().await?))
287 }
288 #[cfg(feature = "cert-auth")]
289 mssql_auth::Credentials::Certificate {
290 tenant_id,
291 client_id,
292 cert_path,
293 password,
294 } => {
295 let cert_bytes = std::fs::read(cert_path.as_ref()).map_err(|e| {
296 Error::Config(format!(
297 "client certificate authentication: failed to read certificate \
298 file '{cert_path}': {e}"
299 ))
300 })?;
301 let password = password.as_deref();
302 let auth = if cert_bytes.windows(10).any(|w| w == b"-----BEGIN") {
306 mssql_auth::CertificateAuth::from_pem(
307 tenant_id.as_ref(),
308 client_id.to_string(),
309 &cert_bytes,
310 &cert_bytes,
311 password,
312 )?
313 } else {
314 mssql_auth::CertificateAuth::new(
315 tenant_id.as_ref(),
316 client_id.to_string(),
317 &cert_bytes,
318 password,
319 )?
320 };
321 tracing::debug!(
322 client_id = %client_id,
323 "acquiring Azure SQL access token via client certificate"
324 );
325 Ok(Some(auth.get_token().await?))
326 }
327 _ => Ok(None),
328 }
329 }
330
331 async fn try_connect(config: &Config, fed_auth_token: Option<&str>) -> Result<Client<Ready>> {
332 let port = if let Some(ref instance) = config.instance {
334 let resolved = crate::browser::resolve_instance(
335 &config.host,
336 instance,
337 Some(config.timeouts.connect_timeout),
338 )
339 .await?;
340 tracing::info!(
341 host = %config.host,
342 instance = %instance,
343 resolved_port = resolved,
344 database = ?config.database,
345 "connecting to named SQL Server instance"
346 );
347 resolved
348 } else {
349 tracing::info!(
350 host = %config.host,
351 port = config.port,
352 database = ?config.database,
353 "connecting to SQL Server"
354 );
355 config.port
356 };
357
358 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
361 "127.0.0.1"
362 } else {
363 &config.host
364 };
365
366 let tcp_stream = if config.multi_subnet_failover {
368 Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
369 } else {
370 let addr = format!("{host}:{port}");
371 tracing::debug!("establishing TCP connection to {}", addr);
372 let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
373 .await
374 .map_err(|_| Error::ConnectTimeout {
375 host: config.host.clone(),
376 port: config.port,
377 })?
378 .map_err(Error::from)?;
379 stream.set_nodelay(true).map_err(Error::from)?;
380 stream
381 };
382
383 #[cfg(feature = "tls")]
384 {
385 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
387
388 if tls_mode.is_tls_first() {
390 return Self::connect_tds_8(config, tcp_stream, fed_auth_token).await;
391 }
392
393 Self::connect_tds_7x(config, tcp_stream, fed_auth_token).await
395 }
396
397 #[cfg(not(feature = "tls"))]
398 {
399 let _ = fed_auth_token;
402
403 if config.strict_mode {
405 return Err(Error::Config(
406 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
407 ));
408 }
409
410 if !config.no_tls {
411 return Err(Error::Config(
412 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
413 or use Encrypt=no_tls in your connection string for unencrypted connections."
414 .into(),
415 ));
416 }
417
418 Self::connect_no_tls(config, tcp_stream).await
420 }
421 }
422
423 async fn connect_parallel(
428 host: &str,
429 port: u16,
430 connect_timeout: std::time::Duration,
431 ) -> Result<TcpStream> {
432 let addr_str = format!("{host}:{port}");
433 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
434 .await
435 .map_err(Error::from)?
436 .collect();
437
438 if addrs.is_empty() {
439 return Err(Error::from(std::io::Error::new(
440 std::io::ErrorKind::AddrNotAvailable,
441 format!("no addresses resolved for {host}:{port}"),
442 )));
443 }
444
445 if addrs.len() == 1 {
447 tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
448 let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
449 .await
450 .map_err(|_| Error::ConnectTimeout {
451 host: host.to_string(),
452 port,
453 })?
454 .map_err(Error::from)?;
455 stream.set_nodelay(true).map_err(Error::from)?;
456 return Ok(stream);
457 }
458
459 let addr_count = addrs.len();
460 tracing::debug!(
461 host = host,
462 port = port,
463 resolved_count = addr_count,
464 "MultiSubnetFailover: racing parallel connections",
465 );
466
467 let mut join_set = tokio::task::JoinSet::new();
468
469 for addr in addrs {
470 let dur = connect_timeout;
471 join_set.spawn(async move {
472 let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
473 std::io::Error::new(
474 std::io::ErrorKind::TimedOut,
475 format!("connection to {addr} timed out"),
476 )
477 })??;
478 tcp.set_nodelay(true)?;
479 Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
480 });
481 }
482
483 let mut last_error: Option<std::io::Error> = None;
484
485 while let Some(result) = join_set.join_next().await {
486 match result {
487 Ok(Ok((stream, addr))) => {
488 tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
489 join_set.abort_all();
490 return Ok(stream);
491 }
492 Ok(Err(e)) => {
493 tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
494 last_error = Some(e);
495 }
496 Err(join_err) => {
497 tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
498 last_error = Some(std::io::Error::other(join_err.to_string()));
499 }
500 }
501 }
502
503 Err(Error::from(last_error.unwrap_or_else(|| {
505 std::io::Error::new(
506 std::io::ErrorKind::ConnectionRefused,
507 format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
508 )
509 })))
510 }
511
512 #[cfg(feature = "tls")]
516 async fn connect_tds_8(
517 config: &Config,
518 tcp_stream: TcpStream,
519 fed_auth_token: Option<&str>,
520 ) -> Result<Client<Ready>> {
521 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
522
523 let tls_config = connection_tls_config(config, true);
526
527 let tls_connector = TlsConnector::new(tls_config)?;
528
529 let tls_stream = timeout(
531 config.timeouts.tls_timeout,
532 tls_connector.connect(tcp_stream, &config.host),
533 )
534 .await
535 .map_err(|_| Error::TlsTimeout {
536 host: config.host.clone(),
537 port: config.port,
538 })??;
539
540 tracing::debug!("TLS handshake completed (strict mode)");
541
542 let mut connection = Connection::new(tls_stream);
544 connection.set_max_message_size(config.max_response_size);
545
546 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
548 Self::send_prelogin(&mut connection, &prelogin).await?;
549 let prelogin_response = Self::receive_prelogin(&mut connection).await?;
550
551 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
553 let negotiator = Self::create_negotiator(config)?;
554 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
555 let sspi_token = match negotiator {
556 Some(ref neg) => Some(neg.initialize()?),
557 None => None,
558 };
559 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
560 let sspi_token: Option<Vec<u8>> = None;
561
562 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
564 token,
565 echo: prelogin_response.fed_auth_required,
566 });
567 let login = Self::build_login7(config, sspi_token, fed_auth);
568 Self::send_login7(&mut connection, &login).await?;
569
570 let (server_version, current_database, routing, server_collation) = timeout(
572 config.timeouts.login_timeout,
573 Self::process_login_response(
574 &mut connection,
575 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
576 negotiator.as_deref(),
577 ),
578 )
579 .await
580 .map_err(|_| Error::LoginTimeout {
581 host: config.host.clone(),
582 port: config.port,
583 })??;
584
585 if let Some((host, port)) = routing {
587 return Err(Error::Routing { host, port });
588 }
589
590 Ok(Client {
591 config: config.clone(),
592 _state: PhantomData,
593 connection: Some(ConnectionHandle::Tls(connection)),
594 server_version,
595 current_database: current_database.clone(),
596 server_collation,
597 statement_cache: StatementCache::with_default_size(),
598 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
602 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
603 .with_database(current_database.clone().unwrap_or_default()),
604 #[cfg(feature = "always-encrypted")]
605 encryption_context: config.column_encryption.clone().map(|cfg| {
606 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
607 }),
608 })
609 }
610
611 #[cfg(feature = "tls")]
619 async fn connect_tds_7x(
620 config: &Config,
621 mut tcp_stream: TcpStream,
622 fed_auth_token: Option<&str>,
623 ) -> Result<Client<Ready>> {
624 use bytes::BufMut;
625 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
626 use tokio::io::{AsyncReadExt, AsyncWriteExt};
627
628 tracing::debug!("using TDS 7.x flow (PreLogin first)");
629
630 let client_encryption = if config.no_tls {
633 tracing::warn!(
635 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
636 Credentials and data will be transmitted in plaintext. \
637 This should only be used for development/testing with legacy SQL Server."
638 );
639 EncryptionLevel::NotSupported
640 } else if config.encrypt {
641 EncryptionLevel::On
642 } else {
643 EncryptionLevel::Off
644 };
645 let prelogin = Self::build_prelogin(config, client_encryption);
646 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
647 let prelogin_bytes = prelogin.encode();
648
649 let header = PacketHeader::new(
651 PacketType::PreLogin,
652 PacketStatus::END_OF_MESSAGE,
653 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
654 );
655
656 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
657 header.encode(&mut packet_buf);
658 packet_buf.put_slice(&prelogin_bytes);
659
660 tcp_stream
661 .write_all(&packet_buf)
662 .await
663 .map_err(Error::from)?;
664
665 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
667 tcp_stream
668 .read_exact(&mut header_buf)
669 .await
670 .map_err(Error::from)?;
671
672 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
673 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
674
675 let mut response_buf = vec![0u8; payload_length];
676 tcp_stream
677 .read_exact(&mut response_buf)
678 .await
679 .map_err(Error::from)?;
680
681 let prelogin_response = PreLogin::decode(&response_buf[..])?;
682
683 let client_tds_version = config.tds_version;
688 if let Some(ref server_version) = prelogin_response.server_version {
689 tracing::debug!(
690 requested_tds_version = %client_tds_version,
691 server_product_version = %server_version,
692 server_product = server_version.product_name(),
693 max_tds_version = %server_version.max_tds_version(),
694 "PreLogin response received"
695 );
696
697 let server_max_tds = server_version.max_tds_version();
699 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
700 tracing::warn!(
701 requested_tds_version = %client_tds_version,
702 server_max_tds_version = %server_max_tds,
703 server_product = server_version.product_name(),
704 "Server supports lower TDS version than requested. \
705 Connection will use server's maximum: {}",
706 server_max_tds
707 );
708 }
709
710 if server_max_tds.is_legacy() {
712 tracing::warn!(
713 server_product = server_version.product_name(),
714 server_max_tds_version = %server_max_tds,
715 "Server uses legacy TDS version. Some features may not be available."
716 );
717 }
718 } else {
719 tracing::debug!(
720 requested_tds_version = %client_tds_version,
721 "PreLogin response received (no version info)"
722 );
723 }
724
725 let server_encryption = prelogin_response.encryption;
727 tracing::debug!(encryption = ?server_encryption, "server encryption level");
728
729 let fed_auth = fed_auth_token.map(|token| FedAuthLogin {
731 token,
732 echo: prelogin_response.fed_auth_required,
733 });
734
735 let negotiated_encryption = match (client_encryption, server_encryption) {
741 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
742 EncryptionLevel::NotSupported
743 }
744 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
745 (EncryptionLevel::On, EncryptionLevel::Off)
746 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
747 return Err(Error::Protocol(
748 "Server does not support requested encryption level".to_string(),
749 ));
750 }
751 _ => EncryptionLevel::On,
752 };
753
754 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
757
758 if use_tls {
759 let tls_config = connection_tls_config(config, false);
764
765 let tls_connector = TlsConnector::new(tls_config)?;
766
767 let mut tls_stream = timeout(
769 config.timeouts.tls_timeout,
770 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
771 )
772 .await
773 .map_err(|_| Error::TlsTimeout {
774 host: config.host.clone(),
775 port: config.port,
776 })??;
777
778 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
779
780 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
782
783 if login_only_encryption {
784 use tokio::io::AsyncWriteExt;
792
793 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
798 let negotiator = Self::create_negotiator(config)?;
799 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
800 let sspi_token = match negotiator {
801 Some(ref neg) => Some(neg.initialize()?),
802 None => None,
803 };
804 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
805 let sspi_token: Option<Vec<u8>> = None;
806
807 let login = Self::build_login7(config, sspi_token, fed_auth);
809 let login_payload = login.encode();
810
811 let max_packet = DEFAULT_PACKET_SIZE;
817 let max_payload = max_packet - PACKET_HEADER_SIZE;
818 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
819 let total_chunks = chunks.len();
820
821 for (i, chunk) in chunks.into_iter().enumerate() {
822 let is_last = i == total_chunks - 1;
823 let status = if is_last {
824 PacketStatus::END_OF_MESSAGE
825 } else {
826 PacketStatus::NORMAL
827 };
828
829 let header = PacketHeader::new(
830 PacketType::Tds7Login,
831 status,
832 (PACKET_HEADER_SIZE + chunk.len()) as u16,
833 );
834
835 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
836 header.encode(&mut packet_buf);
837 packet_buf.put_slice(chunk);
838
839 tls_stream
840 .write_all(&packet_buf)
841 .await
842 .map_err(Error::from)?;
843 }
844
845 tls_stream.flush().await.map_err(Error::from)?;
847
848 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
849
850 let (wrapper, _client_conn) = tls_stream.into_inner();
854 let tcp_stream = wrapper.into_inner();
855
856 let mut connection = Connection::new(tcp_stream);
858 connection.set_max_message_size(config.max_response_size);
859
860 let (server_version, current_database, routing, server_collation) = timeout(
862 config.timeouts.login_timeout,
863 Self::process_login_response(
864 &mut connection,
865 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
866 negotiator.as_deref(),
867 ),
868 )
869 .await
870 .map_err(|_| Error::LoginTimeout {
871 host: config.host.clone(),
872 port: config.port,
873 })??;
874
875 if let Some((host, port)) = routing {
877 return Err(Error::Routing { host, port });
878 }
879
880 Ok(Client {
882 config: config.clone(),
883 _state: PhantomData,
884 connection: Some(ConnectionHandle::Plain(connection)),
885 server_version,
886 current_database: current_database.clone(),
887 server_collation,
888 statement_cache: StatementCache::with_default_size(),
889 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
893 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
894 .with_database(current_database.clone().unwrap_or_default()),
895 #[cfg(feature = "always-encrypted")]
896 encryption_context: config.column_encryption.clone().map(|cfg| {
897 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
898 }),
899 })
900 } else {
901 let mut connection = Connection::new(tls_stream);
904 connection.set_max_message_size(config.max_response_size);
905
906 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
908 let negotiator = Self::create_negotiator(config)?;
909 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
910 let sspi_token = match negotiator {
911 Some(ref neg) => Some(neg.initialize()?),
912 None => None,
913 };
914 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
915 let sspi_token: Option<Vec<u8>> = None;
916
917 let login = Self::build_login7(config, sspi_token, fed_auth);
919 Self::send_login7(&mut connection, &login).await?;
920
921 let (server_version, current_database, routing, server_collation) = timeout(
923 config.timeouts.login_timeout,
924 Self::process_login_response(
925 &mut connection,
926 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
927 negotiator.as_deref(),
928 ),
929 )
930 .await
931 .map_err(|_| Error::LoginTimeout {
932 host: config.host.clone(),
933 port: config.port,
934 })??;
935
936 if let Some((host, port)) = routing {
938 return Err(Error::Routing { host, port });
939 }
940
941 Ok(Client {
942 config: config.clone(),
943 _state: PhantomData,
944 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
945 server_version,
946 current_database: current_database.clone(),
947 server_collation,
948 statement_cache: StatementCache::with_default_size(),
949 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
953 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
954 .with_database(current_database.clone().unwrap_or_default()),
955 #[cfg(feature = "always-encrypted")]
956 encryption_context: config.column_encryption.clone().map(|cfg| {
957 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
958 }),
959 })
960 }
961 } else {
962 tracing::warn!(
964 "Connecting without TLS encryption. This is insecure and should only be \
965 used for development/testing on trusted networks."
966 );
967
968 let mut connection = Connection::new(tcp_stream);
969
970 connection.set_max_message_size(config.max_response_size);
971
972 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
974 let negotiator = Self::create_negotiator(config)?;
975 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
976 let sspi_token = match negotiator {
977 Some(ref neg) => Some(neg.initialize()?),
978 None => None,
979 };
980 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
981 let sspi_token: Option<Vec<u8>> = None;
982
983 let login = Self::build_login7(config, sspi_token, fed_auth);
987 Self::send_login7(&mut connection, &login).await?;
988
989 let (server_version, current_database, routing, server_collation) = timeout(
991 config.timeouts.login_timeout,
992 Self::process_login_response(
993 &mut connection,
994 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
995 negotiator.as_deref(),
996 ),
997 )
998 .await
999 .map_err(|_| Error::LoginTimeout {
1000 host: config.host.clone(),
1001 port: config.port,
1002 })??;
1003
1004 if let Some((host, port)) = routing {
1006 return Err(Error::Routing { host, port });
1007 }
1008
1009 Ok(Client {
1010 config: config.clone(),
1011 _state: PhantomData,
1012 connection: Some(ConnectionHandle::Plain(connection)),
1013 server_version,
1014 current_database: current_database.clone(),
1015 server_collation,
1016 statement_cache: StatementCache::with_default_size(),
1017 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
1021 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
1022 .with_database(current_database.clone().unwrap_or_default()),
1023 #[cfg(feature = "always-encrypted")]
1024 encryption_context: config.column_encryption.clone().map(|cfg| {
1025 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
1026 }),
1027 })
1028 }
1029 }
1030
1031 #[cfg(not(feature = "tls"))]
1042 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
1043 use bytes::BufMut;
1044 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
1045 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1046
1047 tracing::warn!(
1048 "⚠️ Connecting without TLS (tls feature disabled). \
1049 Credentials and data will be transmitted in plaintext."
1050 );
1051
1052 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
1054 let prelogin_bytes = prelogin.encode();
1055
1056 let header = PacketHeader::new(
1058 PacketType::PreLogin,
1059 PacketStatus::END_OF_MESSAGE,
1060 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
1061 );
1062
1063 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
1064 header.encode(&mut packet_buf);
1065 packet_buf.put_slice(&prelogin_bytes);
1066
1067 tcp_stream
1068 .write_all(&packet_buf)
1069 .await
1070 .map_err(Error::from)?;
1071
1072 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
1074 tcp_stream
1075 .read_exact(&mut header_buf)
1076 .await
1077 .map_err(Error::from)?;
1078
1079 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
1080 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
1081
1082 let mut response_buf = vec![0u8; payload_length];
1083 tcp_stream
1084 .read_exact(&mut response_buf)
1085 .await
1086 .map_err(Error::from)?;
1087
1088 let prelogin_response = PreLogin::decode(&response_buf[..])?;
1089
1090 let server_encryption = prelogin_response.encryption;
1092 if server_encryption != EncryptionLevel::NotSupported {
1093 return Err(Error::Config(format!(
1094 "Server requires encryption (level: {server_encryption:?}) but TLS feature is disabled. \
1095 Either enable the 'tls' feature or configure the server to allow unencrypted connections."
1096 )));
1097 }
1098
1099 tracing::debug!("Server accepted unencrypted connection");
1100
1101 let mut connection = Connection::new(tcp_stream);
1102
1103 connection.set_max_message_size(config.max_response_size);
1104
1105 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1107 let negotiator = Self::create_negotiator(config)?;
1108 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1109 let sspi_token = match negotiator {
1110 Some(ref neg) => Some(neg.initialize()?),
1111 None => None,
1112 };
1113 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
1114 let sspi_token: Option<Vec<u8>> = None;
1115
1116 let login = Self::build_login7(config, sspi_token, None);
1119 Self::send_login7(&mut connection, &login).await?;
1120
1121 let (server_version, current_database, routing, server_collation) = timeout(
1123 config.timeouts.login_timeout,
1124 Self::process_login_response(
1125 &mut connection,
1126 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1127 negotiator.as_deref(),
1128 ),
1129 )
1130 .await
1131 .map_err(|_| Error::LoginTimeout {
1132 host: config.host.clone(),
1133 port: config.port,
1134 })??;
1135
1136 if let Some((host, port)) = routing {
1138 return Err(Error::Routing { host, port });
1139 }
1140
1141 Ok(Client {
1142 config: config.clone(),
1143 _state: PhantomData,
1144 connection: Some(ConnectionHandle::Plain(connection)),
1145 server_version,
1146 current_database: current_database.clone(),
1147 server_collation,
1148 statement_cache: StatementCache::with_default_size(),
1149 transaction_descriptor: 0,
1150 needs_reset: false,
1151 in_flight: false,
1152 #[cfg(feature = "otel")]
1153 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
1154 .with_database(current_database.clone().unwrap_or_default()),
1155 #[cfg(feature = "always-encrypted")]
1156 encryption_context: config.column_encryption.clone().map(|cfg| {
1157 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
1158 }),
1159 })
1160 }
1161
1162 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
1164 let version = if config.strict_mode {
1166 tds_protocol::version::TdsVersion::V8_0
1167 } else {
1168 config.tds_version
1169 };
1170
1171 let mut prelogin = PreLogin::new()
1172 .with_version(version)
1173 .with_encryption(encryption);
1174
1175 if config.mars {
1176 prelogin = prelogin.with_mars(true);
1177 }
1178
1179 if let Some(ref instance) = config.instance {
1180 prelogin = prelogin.with_instance(instance);
1181 }
1182
1183 if config.credentials.is_azure_ad() {
1186 prelogin = prelogin.with_fed_auth_required(true);
1187 }
1188
1189 prelogin
1190 }
1191
1192 fn resolve_workstation_id(config: &Config) -> String {
1200 if let Some(ref id) = config.workstation_id {
1201 return id.clone();
1202 }
1203 std::env::var("COMPUTERNAME")
1206 .or_else(|_| std::env::var("HOSTNAME"))
1207 .unwrap_or_default()
1208 }
1209
1210 fn build_login7(
1220 config: &Config,
1221 sspi_token: Option<Vec<u8>>,
1222 fed_auth: Option<FedAuthLogin<'_>>,
1223 ) -> Login7 {
1224 let version = if config.strict_mode {
1226 tds_protocol::version::TdsVersion::V8_0
1227 } else {
1228 config.tds_version
1229 };
1230
1231 let mut login = Login7::new()
1232 .with_tds_version(version)
1233 .with_packet_size(config.packet_size as u32)
1234 .with_app_name(&config.application_name)
1235 .with_server_name(&config.host)
1236 .with_hostname(Self::resolve_workstation_id(config));
1237
1238 if let Some(ref database) = config.database {
1239 login = login.with_database(database);
1240 }
1241
1242 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
1244 login = login.with_read_only_intent(true);
1245 }
1246
1247 if let Some(ref lang) = config.language {
1249 login = login.with_language(lang);
1250 }
1251
1252 if let Some(token) = sspi_token {
1254 login = login.with_integrated_auth(token);
1256 } else if let Some(fed) = fed_auth {
1257 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1260 feature_id: tds_protocol::login7::FeatureId::FedAuth,
1261 data: mssql_auth::azure_ad::build_security_token_feature_data(fed.token, fed.echo),
1262 });
1263 tracing::debug!(
1264 fed_auth_echo = fed.echo,
1265 "Login7: adding FEDAUTH feature extension (SecurityToken workflow)"
1266 );
1267 } else if let mssql_auth::Credentials::SqlServer { username, password } =
1268 &config.credentials
1269 {
1270 login = login.with_sql_auth(username.as_ref(), password.as_ref());
1271 }
1272
1273 #[cfg(feature = "always-encrypted")]
1276 if config.column_encryption.is_some() {
1277 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1278 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1279 data: bytes::Bytes::from_static(&[0x01]), });
1281 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1282 }
1283
1284 login
1285 }
1286
1287 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1297 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1298 #[allow(clippy::match_like_matches_macro)]
1299 let is_integrated = match &config.credentials {
1300 mssql_auth::Credentials::Integrated => true,
1301 _ => false,
1302 };
1303
1304 if !is_integrated {
1305 return Ok(None);
1306 }
1307
1308 #[cfg(all(windows, feature = "sspi-auth"))]
1313 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1314 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1315
1316 #[cfg(all(not(windows), feature = "sspi-auth"))]
1318 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1319 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1320
1321 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1322 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1323 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1324
1325 Ok(Some(negotiator))
1326 }
1327
1328 #[cfg(feature = "tls")]
1330 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1331 where
1332 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1333 {
1334 let payload = prelogin.encode();
1335 let max_packet = tds_protocol::packet::MAX_PACKET_SIZE;
1339
1340 connection
1341 .send_message(PacketType::PreLogin, payload, max_packet)
1342 .await?;
1343 Ok(())
1344 }
1345
1346 #[cfg(feature = "tls")]
1348 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1349 where
1350 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1351 {
1352 let message = connection
1353 .read_message()
1354 .await?
1355 .ok_or(Error::ConnectionClosed)?;
1356
1357 Ok(PreLogin::decode(&message.payload[..])?)
1358 }
1359
1360 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1362 where
1363 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1364 {
1365 let payload = login.encode();
1366 let max_packet = DEFAULT_PACKET_SIZE;
1371
1372 connection
1373 .send_message(PacketType::Tds7Login, payload, max_packet)
1374 .await?;
1375 Ok(())
1376 }
1377
1378 #[allow(clippy::never_loop)] async fn process_login_response<T>(
1389 connection: &mut Connection<T>,
1390 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1391 &dyn mssql_auth::SspiNegotiator,
1392 >,
1393 ) -> Result<(
1394 Option<u32>,
1395 Option<String>,
1396 Option<(String, u16)>,
1397 Option<tds_protocol::token::Collation>,
1398 )>
1399 where
1400 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1401 {
1402 let mut server_version = None;
1403 let mut database = None;
1404 let mut routing = None;
1405 let mut collation = None;
1406
1407 'outer: loop {
1408 let message = connection
1409 .read_message()
1410 .await?
1411 .ok_or(Error::ConnectionClosed)?;
1412
1413 let response_bytes = message.payload;
1414 let mut parser = TokenParser::new(response_bytes);
1415
1416 while let Some(token) = parser.next_token()? {
1417 match token {
1418 Token::LoginAck(ack) => {
1419 tracing::info!(
1420 version = ack.tds_version,
1421 interface = ack.interface,
1422 prog_name = %ack.prog_name,
1423 "login acknowledged"
1424 );
1425 server_version = Some(ack.tds_version);
1426 }
1427 Token::EnvChange(env) => {
1428 Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1429 }
1430 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1431 Token::Sspi(sspi_token) => {
1432 let neg = negotiator.ok_or_else(|| {
1433 Error::Protocol(
1434 "server sent SSPI challenge but no negotiator is configured"
1435 .to_string(),
1436 )
1437 })?;
1438
1439 tracing::debug!(
1440 challenge_len = sspi_token.data.len(),
1441 "received SSPI challenge from server"
1442 );
1443
1444 if let Some(response) = neg.step(&sspi_token.data)? {
1445 tracing::debug!(response_len = response.len(), "sending SSPI response");
1446 connection
1447 .send_message(
1448 PacketType::Sspi,
1449 bytes::Bytes::from(response),
1450 tds_protocol::packet::MAX_PACKET_SIZE,
1451 )
1452 .await?;
1453 }
1454
1455 continue 'outer;
1457 }
1458 Token::Error(err) => {
1459 return Err(Error::Server {
1460 number: err.number,
1461 state: err.state,
1462 class: err.class,
1463 message: err.message.clone(),
1464 server: if err.server.is_empty() {
1465 None
1466 } else {
1467 Some(err.server.clone())
1468 },
1469 procedure: if err.procedure.is_empty() {
1470 None
1471 } else {
1472 Some(err.procedure.clone())
1473 },
1474 line: err.line as u32,
1475 });
1476 }
1477 Token::Info(info) => {
1478 tracing::info!(
1479 number = info.number,
1480 message = %info.message,
1481 "server info message"
1482 );
1483 }
1484 Token::FeatureExtAck(ack) => {
1485 for feature in &ack.features {
1486 tracing::debug!(
1487 feature_id = feature.feature_id,
1488 data_len = feature.data.len(),
1489 "server acknowledged feature extension"
1490 );
1491 }
1492 }
1493 Token::Done(done) => {
1494 if done.status.error {
1495 return Err(Error::Protocol("login failed".to_string()));
1496 }
1497 break 'outer;
1498 }
1499 _ => {}
1500 }
1501 }
1502
1503 break;
1505 }
1506
1507 Ok((server_version, database, routing, collation))
1508 }
1509
1510 fn process_env_change(
1512 env: &EnvChange,
1513 database: &mut Option<String>,
1514 routing: &mut Option<(String, u16)>,
1515 collation: &mut Option<tds_protocol::token::Collation>,
1516 ) {
1517 use tds_protocol::token::EnvChangeValue;
1518
1519 match env.env_type {
1520 EnvChangeType::Database => {
1521 if let EnvChangeValue::String(ref new_value) = env.new_value {
1522 tracing::debug!(database = %new_value, "database changed");
1523 *database = Some(new_value.clone());
1524 }
1525 }
1526 EnvChangeType::Routing => {
1527 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1528 tracing::info!(host = %host, port = port, "routing redirect received");
1529 *routing = Some((host.clone(), port));
1530 }
1531 }
1532 EnvChangeType::SqlCollation => {
1533 if let EnvChangeValue::Binary(ref data) = env.new_value {
1534 if data.len() >= 5 {
1535 let c = tds_protocol::token::Collation::from_bytes(
1536 data[..5].try_into().unwrap(),
1537 );
1538 tracing::debug!(
1539 lcid = c.lcid,
1540 sort_id = c.sort_id,
1541 "server collation received"
1542 );
1543 *collation = Some(c);
1544 }
1545 }
1546 }
1547 _ => {
1548 if let EnvChangeValue::String(ref new_value) = env.new_value {
1549 tracing::debug!(
1550 env_type = ?env.env_type,
1551 new_value = %new_value,
1552 "environment change"
1553 );
1554 }
1555 }
1556 }
1557 }
1558}
1559
1560#[cfg(feature = "tls")]
1581fn connection_tls_config(config: &Config, strict: bool) -> TlsConfig {
1582 let tls = config
1583 .tls
1584 .clone()
1585 .trust_server_certificate(config.trust_server_certificate);
1586 if strict {
1587 tls.strict_mode(true)
1588 .with_alpn_protocols(vec![b"tds/8.0".to_vec()])
1589 } else {
1590 tls
1591 }
1592}
1593
1594#[cfg(all(test, feature = "tls"))]
1595mod tls_config_tests {
1596 use super::*;
1597 use mssql_tls::CertificateDer;
1598
1599 fn config_with_root(cert: Vec<u8>) -> Config {
1600 let mut config = Config::new();
1601 config.tls = config
1602 .tls
1603 .clone()
1604 .add_root_certificate(CertificateDer::from(cert));
1605 config
1606 }
1607
1608 #[test]
1609 fn custom_root_certificate_reaches_connector_config() {
1610 let config = config_with_root(vec![0xCA; 32]);
1614
1615 for strict in [true, false] {
1616 let tls = connection_tls_config(&config, strict);
1617 assert_eq!(
1618 tls.root_certificates.len(),
1619 1,
1620 "custom root must reach the connector (strict={strict})"
1621 );
1622 assert_eq!(tls.root_certificates[0].as_ref(), &[0xCA; 32][..]);
1623 }
1624 }
1625
1626 #[test]
1627 fn trust_server_certificate_taken_from_top_level_field() {
1628 let mut config = Config::new();
1631 config.trust_server_certificate = true;
1632 assert!(!config.tls.trust_server_certificate);
1634
1635 let tls = connection_tls_config(&config, false);
1636 assert!(
1637 tls.trust_server_certificate,
1638 "top-level trust flag must win"
1639 );
1640 }
1641
1642 #[test]
1643 fn strict_mode_adds_tds8_alpn() {
1644 let config = Config::new();
1645 let strict = connection_tls_config(&config, true);
1646 assert!(strict.strict_mode);
1647 assert!(strict.alpn_protocols.iter().any(|p| p == b"tds/8.0"));
1648
1649 let non_strict = connection_tls_config(&config, false);
1650 assert!(!non_strict.strict_mode);
1651 }
1652}
1653
1654#[cfg(test)]
1655#[allow(clippy::unwrap_used)]
1656mod fed_auth_login_tests {
1657 use super::*;
1658 use tds_protocol::prelogin::EncryptionLevel;
1659
1660 fn azure_config(token: &str) -> Config {
1661 Config::new().credentials(mssql_auth::Credentials::azure_token(token.to_string()))
1662 }
1663
1664 #[test]
1671 fn login7_fed_auth_feature_block_wire_exact() {
1672 let config = azure_config("AB");
1673 let login = Client::<Disconnected>::build_login7(
1674 &config,
1675 None,
1676 Some(FedAuthLogin {
1677 token: "AB",
1678 echo: true,
1679 }),
1680 );
1681
1682 assert!(
1683 !login.option_flags2.integrated_security,
1684 "fIntSecurity MUST be 0 when FEDAUTH is present"
1685 );
1686 assert!(
1687 login.username.is_empty() && login.password.is_empty(),
1688 "FEDAUTH logins must not carry username/password"
1689 );
1690
1691 let encoded = login.encode();
1692
1693 assert_eq!(encoded[27] & 0x10, 0x10, "fExtension bit must be set");
1695
1696 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1700 let ib_extension =
1701 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1702 let feature_off =
1703 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1704 as usize;
1705
1706 assert_eq!(
1707 encoded[feature_off], 0x02,
1708 "FeatureId must be FEDAUTH (0x02)"
1709 );
1710 let data_len = u32::from_le_bytes(
1711 encoded[feature_off + 1..feature_off + 5]
1712 .try_into()
1713 .unwrap(),
1714 ) as usize;
1715 assert_eq!(data_len, 9, "FeatureDataLen must cover options + token");
1717
1718 let data = &encoded[feature_off + 5..feature_off + 5 + data_len];
1719 assert_eq!(
1720 data,
1721 &[0x03, 0x04, 0x00, 0x00, 0x00, 0x41, 0x00, 0x42, 0x00],
1722 "options must be (SecurityToken << 1) | echo, then DWORD-LE \
1723 token byte length, then UTF-16LE token"
1724 );
1725 assert_eq!(
1726 encoded[feature_off + 5 + data_len],
1727 0xFF,
1728 "FeatureExt terminator must follow"
1729 );
1730 }
1731
1732 #[test]
1735 fn login7_fed_auth_echo_clear() {
1736 let config = azure_config("AB");
1737 let login = Client::<Disconnected>::build_login7(
1738 &config,
1739 None,
1740 Some(FedAuthLogin {
1741 token: "AB",
1742 echo: false,
1743 }),
1744 );
1745 let encoded = login.encode();
1746
1747 const EXTENSION_SLOT: usize = 36 + 5 * 4;
1748 let ib_extension =
1749 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
1750 let feature_off =
1751 u32::from_le_bytes(encoded[ib_extension..ib_extension + 4].try_into().unwrap())
1752 as usize;
1753 assert_eq!(encoded[feature_off], 0x02);
1754 assert_eq!(
1755 encoded[feature_off + 5],
1756 0x02,
1757 "options byte must have fFedAuthEcho clear"
1758 );
1759 }
1760
1761 #[test]
1764 fn prelogin_advertises_fed_auth_for_azure_credentials() {
1765 let azure = azure_config("tok");
1766 let prelogin = Client::<Disconnected>::build_prelogin(&azure, EncryptionLevel::On);
1767 assert!(prelogin.fed_auth_required);
1768
1769 let sql = Config::new().credentials(mssql_auth::Credentials::sql_server("u", "p"));
1770 let prelogin = Client::<Disconnected>::build_prelogin(&sql, EncryptionLevel::On);
1771 assert!(!prelogin.fed_auth_required);
1772 }
1773
1774 #[tokio::test]
1782 async fn login7_large_fed_auth_token_is_split_at_default_packet_size() {
1783 use tds_protocol::packet::PACKET_HEADER_SIZE;
1784 use tokio::io::AsyncReadExt;
1785
1786 let token = "A".repeat(2000);
1788 let config = azure_config(&token);
1789 let login = Client::<Disconnected>::build_login7(
1790 &config,
1791 None,
1792 Some(FedAuthLogin {
1793 token: &token,
1794 echo: true,
1795 }),
1796 );
1797 let encoded = login.encode();
1798 assert!(
1799 encoded.len() > DEFAULT_PACKET_SIZE,
1800 "precondition: LOGIN7 ({}) must exceed the default packet size to exercise splitting",
1801 encoded.len()
1802 );
1803
1804 let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1806 let mut connection = Connection::new(client_io);
1807 Client::<Disconnected>::send_login7(&mut connection, &login)
1808 .await
1809 .unwrap();
1810 drop(connection); let mut raw = Vec::new();
1812 server_io.read_to_end(&mut raw).await.unwrap();
1813
1814 let mut offset = 0;
1817 let mut packets = 0;
1818 let mut reassembled = Vec::new();
1819 let mut saw_eom = false;
1820 while offset < raw.len() {
1821 let status = raw[offset + 1];
1822 let len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
1823 assert!(
1824 len <= DEFAULT_PACKET_SIZE,
1825 "packet {packets} length {len} exceeds the 4096-byte default"
1826 );
1827 assert!(!saw_eom, "no packet may follow the END_OF_MESSAGE packet");
1828 saw_eom = status & 0x01 == 0x01;
1829 reassembled.extend_from_slice(&raw[offset + PACKET_HEADER_SIZE..offset + len]);
1830 offset += len;
1831 packets += 1;
1832 }
1833 assert!(
1834 packets >= 2,
1835 "an oversized LOGIN7 must span multiple packets, got {packets}"
1836 );
1837 assert!(saw_eom, "the final packet must carry END_OF_MESSAGE");
1838 assert_eq!(
1839 reassembled,
1840 encoded.as_ref(),
1841 "reassembled packet payloads must equal the LOGIN7 encoding"
1842 );
1843 }
1844
1845 #[cfg(feature = "cert-auth")]
1846 fn cert_config() -> Config {
1847 Config::new().credentials(mssql_auth::Credentials::certificate(
1848 "tenant-1",
1849 "client-1",
1850 "/nonexistent/app.pfx",
1851 None,
1852 ))
1853 }
1854
1855 #[cfg(feature = "cert-auth")]
1856 #[test]
1857 fn cert_auth_is_accepted_by_credential_validation_over_tls() {
1858 let config = cert_config();
1860 assert!(Client::<Disconnected>::validate_credential_support(&config).is_ok());
1861 }
1862
1863 #[cfg(all(feature = "cert-auth", feature = "tls"))]
1864 #[test]
1865 fn cert_auth_is_rejected_over_plaintext() {
1866 let config = cert_config().no_tls(true);
1868 let err = Client::<Disconnected>::validate_credential_support(&config)
1869 .expect_err("certificate FEDAUTH over no_tls must be rejected");
1870 assert!(
1871 err.to_string().contains("no_tls"),
1872 "error should explain the plaintext rejection, got: {err}"
1873 );
1874 }
1875
1876 #[cfg(feature = "cert-auth")]
1877 #[tokio::test]
1878 async fn cert_auth_token_resolution_reports_missing_cert_file() {
1879 let config = cert_config();
1882 let err = Client::<Disconnected>::resolve_fed_auth_token(&config)
1883 .await
1884 .expect_err("a missing certificate file must error");
1885 let msg = err.to_string();
1886 assert!(
1887 msg.contains("failed to read certificate") && msg.contains("app.pfx"),
1888 "error should name the unreadable certificate file, got: {msg}"
1889 );
1890 }
1891
1892 #[cfg(all(feature = "azure-identity", feature = "tls"))]
1893 #[test]
1894 fn azure_default_credential_validation() {
1895 let ok = Config::new().credentials(mssql_auth::Credentials::azure_default());
1898 assert!(Client::<Disconnected>::validate_credential_support(&ok).is_ok());
1899
1900 let bad = Config::new()
1901 .credentials(mssql_auth::Credentials::azure_default())
1902 .no_tls(true);
1903 let err = Client::<Disconnected>::validate_credential_support(&bad)
1904 .expect_err("default-chain FEDAUTH over no_tls must be rejected");
1905 assert!(
1906 err.to_string().contains("no_tls"),
1907 "error should explain the plaintext rejection, got: {err}"
1908 );
1909 }
1910}