1use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::MAX_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30impl Client<Disconnected> {
31 pub async fn connect(config: Config) -> Result<Client<Ready>> {
47 match &config.credentials {
54 mssql_auth::Credentials::SqlServer { .. } => {}
55 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
56 mssql_auth::Credentials::Integrated => {}
57 _ => {
61 return Err(Error::Config(
62 "Azure AD / Entra and client certificate (FEDAUTH) authentication \
63 are not yet supported: the LOGIN7 FEDAUTH feature extension is not \
64 implemented (tracked in \
65 https://github.com/praxiomlabs/rust-mssql-driver/issues/155). \
66 Use SQL Server or integrated authentication."
67 .into(),
68 ));
69 }
70 }
71
72 let retry = config.retry.clone();
73 let max_redirects = config.redirect.max_redirects;
74 let follow_redirects = config.redirect.follow_redirects;
75 let per_attempt = config.timeouts.connect_timeout
77 + config.timeouts.tls_timeout
78 + config.timeouts.login_timeout;
79 let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
80 let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
81 let initial_host = config.host.clone();
82 let initial_port = config.port;
83
84 let result = timeout(overall, async {
85 let mut last_error: Option<Error> = None;
86
87 for retry_attempt in 0..=retry.max_retries {
88 if retry_attempt > 0 {
89 let backoff = retry.backoff_for_attempt(retry_attempt);
90 tracing::info!(
91 retry_attempt,
92 backoff_ms = backoff.as_millis() as u64,
93 "retrying connection after transient error"
94 );
95 tokio::time::sleep(backoff).await;
96 }
97
98 let mut current_config = config.clone();
100 let mut redirect_count: u8 = 0;
101
102 let attempt_result = loop {
103 redirect_count += 1;
104 if redirect_count > max_redirects + 1 {
105 break Err(Error::TooManyRedirects { max: max_redirects });
106 }
107
108 match Self::try_connect(¤t_config).await {
109 Ok(client) => break Ok(client),
110 Err(Error::Routing { host, port }) => {
111 if !follow_redirects {
112 break Err(Error::Routing { host, port });
113 }
114 tracing::info!(
115 host = %host,
116 port = port,
117 redirect = redirect_count,
118 max_redirects = max_redirects,
119 "following Azure SQL routing redirect"
120 );
121 current_config = current_config.with_host(&host).with_port(port);
122 continue;
123 }
124 Err(e) => break Err(e),
125 }
126 };
127
128 match attempt_result {
129 Ok(client) => return Ok(client),
130 Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
131 tracing::warn!(
132 retry_attempt,
133 max_retries = retry.max_retries,
134 error = %e,
135 "transient connection error, will retry"
136 );
137 last_error = Some(attempt_result.unwrap_err());
138 }
139 Err(e) => return Err(e),
140 }
141 }
142
143 Err(last_error.expect("at least one attempt was made"))
145 })
146 .await;
147
148 match result {
149 Ok(inner) => inner,
150 Err(_elapsed) => Err(Error::ConnectTimeout {
151 host: initial_host,
152 port: initial_port,
153 }),
154 }
155 }
156
157 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
158 let port = if let Some(ref instance) = config.instance {
160 let resolved = crate::browser::resolve_instance(
161 &config.host,
162 instance,
163 Some(config.timeouts.connect_timeout),
164 )
165 .await?;
166 tracing::info!(
167 host = %config.host,
168 instance = %instance,
169 resolved_port = resolved,
170 database = ?config.database,
171 "connecting to named SQL Server instance"
172 );
173 resolved
174 } else {
175 tracing::info!(
176 host = %config.host,
177 port = config.port,
178 database = ?config.database,
179 "connecting to SQL Server"
180 );
181 config.port
182 };
183
184 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
187 "127.0.0.1"
188 } else {
189 &config.host
190 };
191
192 let tcp_stream = if config.multi_subnet_failover {
194 Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
195 } else {
196 let addr = format!("{host}:{port}");
197 tracing::debug!("establishing TCP connection to {}", addr);
198 let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
199 .await
200 .map_err(|_| Error::ConnectTimeout {
201 host: config.host.clone(),
202 port: config.port,
203 })?
204 .map_err(Error::from)?;
205 stream.set_nodelay(true).map_err(Error::from)?;
206 stream
207 };
208
209 #[cfg(feature = "tls")]
210 {
211 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
213
214 if tls_mode.is_tls_first() {
216 return Self::connect_tds_8(config, tcp_stream).await;
217 }
218
219 Self::connect_tds_7x(config, tcp_stream).await
221 }
222
223 #[cfg(not(feature = "tls"))]
224 {
225 if config.strict_mode {
227 return Err(Error::Config(
228 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
229 ));
230 }
231
232 if !config.no_tls {
233 return Err(Error::Config(
234 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
235 or use Encrypt=no_tls in your connection string for unencrypted connections."
236 .into(),
237 ));
238 }
239
240 Self::connect_no_tls(config, tcp_stream).await
242 }
243 }
244
245 async fn connect_parallel(
250 host: &str,
251 port: u16,
252 connect_timeout: std::time::Duration,
253 ) -> Result<TcpStream> {
254 let addr_str = format!("{host}:{port}");
255 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
256 .await
257 .map_err(Error::from)?
258 .collect();
259
260 if addrs.is_empty() {
261 return Err(Error::from(std::io::Error::new(
262 std::io::ErrorKind::AddrNotAvailable,
263 format!("no addresses resolved for {host}:{port}"),
264 )));
265 }
266
267 if addrs.len() == 1 {
269 tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
270 let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
271 .await
272 .map_err(|_| Error::ConnectTimeout {
273 host: host.to_string(),
274 port,
275 })?
276 .map_err(Error::from)?;
277 stream.set_nodelay(true).map_err(Error::from)?;
278 return Ok(stream);
279 }
280
281 let addr_count = addrs.len();
282 tracing::debug!(
283 host = host,
284 port = port,
285 resolved_count = addr_count,
286 "MultiSubnetFailover: racing parallel connections",
287 );
288
289 let mut join_set = tokio::task::JoinSet::new();
290
291 for addr in addrs {
292 let dur = connect_timeout;
293 join_set.spawn(async move {
294 let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
295 std::io::Error::new(
296 std::io::ErrorKind::TimedOut,
297 format!("connection to {addr} timed out"),
298 )
299 })??;
300 tcp.set_nodelay(true)?;
301 Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
302 });
303 }
304
305 let mut last_error: Option<std::io::Error> = None;
306
307 while let Some(result) = join_set.join_next().await {
308 match result {
309 Ok(Ok((stream, addr))) => {
310 tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
311 join_set.abort_all();
312 return Ok(stream);
313 }
314 Ok(Err(e)) => {
315 tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
316 last_error = Some(e);
317 }
318 Err(join_err) => {
319 tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
320 last_error = Some(std::io::Error::other(join_err.to_string()));
321 }
322 }
323 }
324
325 Err(Error::from(last_error.unwrap_or_else(|| {
327 std::io::Error::new(
328 std::io::ErrorKind::ConnectionRefused,
329 format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
330 )
331 })))
332 }
333
334 #[cfg(feature = "tls")]
338 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
339 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
340
341 let tls_config = connection_tls_config(config, true);
344
345 let tls_connector = TlsConnector::new(tls_config)?;
346
347 let tls_stream = timeout(
349 config.timeouts.tls_timeout,
350 tls_connector.connect(tcp_stream, &config.host),
351 )
352 .await
353 .map_err(|_| Error::TlsTimeout {
354 host: config.host.clone(),
355 port: config.port,
356 })??;
357
358 tracing::debug!("TLS handshake completed (strict mode)");
359
360 let mut connection = Connection::new(tls_stream);
362
363 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
365 Self::send_prelogin(&mut connection, &prelogin).await?;
366 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
367
368 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
370 let negotiator = Self::create_negotiator(config)?;
371 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
372 let sspi_token = match negotiator {
373 Some(ref neg) => Some(neg.initialize()?),
374 None => None,
375 };
376 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
377 let sspi_token: Option<Vec<u8>> = None;
378
379 let login = Self::build_login7(config, sspi_token);
381 Self::send_login7(&mut connection, &login).await?;
382
383 let (server_version, current_database, routing, server_collation) = timeout(
385 config.timeouts.login_timeout,
386 Self::process_login_response(
387 &mut connection,
388 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
389 negotiator.as_deref(),
390 ),
391 )
392 .await
393 .map_err(|_| Error::LoginTimeout {
394 host: config.host.clone(),
395 port: config.port,
396 })??;
397
398 if let Some((host, port)) = routing {
400 return Err(Error::Routing { host, port });
401 }
402
403 Ok(Client {
404 config: config.clone(),
405 _state: PhantomData,
406 connection: Some(ConnectionHandle::Tls(connection)),
407 server_version,
408 current_database: current_database.clone(),
409 server_collation,
410 statement_cache: StatementCache::with_default_size(),
411 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
415 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
416 .with_database(current_database.clone().unwrap_or_default()),
417 #[cfg(feature = "always-encrypted")]
418 encryption_context: config.column_encryption.clone().map(|cfg| {
419 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
420 }),
421 })
422 }
423
424 #[cfg(feature = "tls")]
432 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
433 use bytes::BufMut;
434 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
435 use tokio::io::{AsyncReadExt, AsyncWriteExt};
436
437 tracing::debug!("using TDS 7.x flow (PreLogin first)");
438
439 let client_encryption = if config.no_tls {
442 tracing::warn!(
444 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
445 Credentials and data will be transmitted in plaintext. \
446 This should only be used for development/testing with legacy SQL Server."
447 );
448 EncryptionLevel::NotSupported
449 } else if config.encrypt {
450 EncryptionLevel::On
451 } else {
452 EncryptionLevel::Off
453 };
454 let prelogin = Self::build_prelogin(config, client_encryption);
455 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
456 let prelogin_bytes = prelogin.encode();
457
458 let header = PacketHeader::new(
460 PacketType::PreLogin,
461 PacketStatus::END_OF_MESSAGE,
462 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
463 );
464
465 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
466 header.encode(&mut packet_buf);
467 packet_buf.put_slice(&prelogin_bytes);
468
469 tcp_stream
470 .write_all(&packet_buf)
471 .await
472 .map_err(Error::from)?;
473
474 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
476 tcp_stream
477 .read_exact(&mut header_buf)
478 .await
479 .map_err(Error::from)?;
480
481 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
482 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
483
484 let mut response_buf = vec![0u8; payload_length];
485 tcp_stream
486 .read_exact(&mut response_buf)
487 .await
488 .map_err(Error::from)?;
489
490 let prelogin_response = PreLogin::decode(&response_buf[..])?;
491
492 let client_tds_version = config.tds_version;
497 if let Some(ref server_version) = prelogin_response.server_version {
498 tracing::debug!(
499 requested_tds_version = %client_tds_version,
500 server_product_version = %server_version,
501 server_product = server_version.product_name(),
502 max_tds_version = %server_version.max_tds_version(),
503 "PreLogin response received"
504 );
505
506 let server_max_tds = server_version.max_tds_version();
508 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
509 tracing::warn!(
510 requested_tds_version = %client_tds_version,
511 server_max_tds_version = %server_max_tds,
512 server_product = server_version.product_name(),
513 "Server supports lower TDS version than requested. \
514 Connection will use server's maximum: {}",
515 server_max_tds
516 );
517 }
518
519 if server_max_tds.is_legacy() {
521 tracing::warn!(
522 server_product = server_version.product_name(),
523 server_max_tds_version = %server_max_tds,
524 "Server uses legacy TDS version. Some features may not be available."
525 );
526 }
527 } else {
528 tracing::debug!(
529 requested_tds_version = %client_tds_version,
530 "PreLogin response received (no version info)"
531 );
532 }
533
534 let server_encryption = prelogin_response.encryption;
536 tracing::debug!(encryption = ?server_encryption, "server encryption level");
537
538 let negotiated_encryption = match (client_encryption, server_encryption) {
544 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
545 EncryptionLevel::NotSupported
546 }
547 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
548 (EncryptionLevel::On, EncryptionLevel::Off)
549 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
550 return Err(Error::Protocol(
551 "Server does not support requested encryption level".to_string(),
552 ));
553 }
554 _ => EncryptionLevel::On,
555 };
556
557 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
560
561 if use_tls {
562 let tls_config = connection_tls_config(config, false);
567
568 let tls_connector = TlsConnector::new(tls_config)?;
569
570 let mut tls_stream = timeout(
572 config.timeouts.tls_timeout,
573 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
574 )
575 .await
576 .map_err(|_| Error::TlsTimeout {
577 host: config.host.clone(),
578 port: config.port,
579 })??;
580
581 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
582
583 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
585
586 if login_only_encryption {
587 use tokio::io::AsyncWriteExt;
595
596 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
601 let negotiator = Self::create_negotiator(config)?;
602 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
603 let sspi_token = match negotiator {
604 Some(ref neg) => Some(neg.initialize()?),
605 None => None,
606 };
607 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
608 let sspi_token: Option<Vec<u8>> = None;
609
610 let login = Self::build_login7(config, sspi_token);
612 let login_payload = login.encode();
613
614 let max_packet = MAX_PACKET_SIZE;
616 let max_payload = max_packet - PACKET_HEADER_SIZE;
617 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
618 let total_chunks = chunks.len();
619
620 for (i, chunk) in chunks.into_iter().enumerate() {
621 let is_last = i == total_chunks - 1;
622 let status = if is_last {
623 PacketStatus::END_OF_MESSAGE
624 } else {
625 PacketStatus::NORMAL
626 };
627
628 let header = PacketHeader::new(
629 PacketType::Tds7Login,
630 status,
631 (PACKET_HEADER_SIZE + chunk.len()) as u16,
632 );
633
634 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
635 header.encode(&mut packet_buf);
636 packet_buf.put_slice(chunk);
637
638 tls_stream
639 .write_all(&packet_buf)
640 .await
641 .map_err(Error::from)?;
642 }
643
644 tls_stream.flush().await.map_err(Error::from)?;
646
647 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
648
649 let (wrapper, _client_conn) = tls_stream.into_inner();
653 let tcp_stream = wrapper.into_inner();
654
655 let mut connection = Connection::new(tcp_stream);
657
658 let (server_version, current_database, routing, server_collation) = timeout(
660 config.timeouts.login_timeout,
661 Self::process_login_response(
662 &mut connection,
663 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
664 negotiator.as_deref(),
665 ),
666 )
667 .await
668 .map_err(|_| Error::LoginTimeout {
669 host: config.host.clone(),
670 port: config.port,
671 })??;
672
673 if let Some((host, port)) = routing {
675 return Err(Error::Routing { host, port });
676 }
677
678 Ok(Client {
680 config: config.clone(),
681 _state: PhantomData,
682 connection: Some(ConnectionHandle::Plain(connection)),
683 server_version,
684 current_database: current_database.clone(),
685 server_collation,
686 statement_cache: StatementCache::with_default_size(),
687 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
691 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
692 .with_database(current_database.clone().unwrap_or_default()),
693 #[cfg(feature = "always-encrypted")]
694 encryption_context: config.column_encryption.clone().map(|cfg| {
695 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
696 }),
697 })
698 } else {
699 let mut connection = Connection::new(tls_stream);
702
703 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
705 let negotiator = Self::create_negotiator(config)?;
706 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
707 let sspi_token = match negotiator {
708 Some(ref neg) => Some(neg.initialize()?),
709 None => None,
710 };
711 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
712 let sspi_token: Option<Vec<u8>> = None;
713
714 let login = Self::build_login7(config, sspi_token);
716 Self::send_login7(&mut connection, &login).await?;
717
718 let (server_version, current_database, routing, server_collation) = timeout(
720 config.timeouts.login_timeout,
721 Self::process_login_response(
722 &mut connection,
723 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
724 negotiator.as_deref(),
725 ),
726 )
727 .await
728 .map_err(|_| Error::LoginTimeout {
729 host: config.host.clone(),
730 port: config.port,
731 })??;
732
733 if let Some((host, port)) = routing {
735 return Err(Error::Routing { host, port });
736 }
737
738 Ok(Client {
739 config: config.clone(),
740 _state: PhantomData,
741 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
742 server_version,
743 current_database: current_database.clone(),
744 server_collation,
745 statement_cache: StatementCache::with_default_size(),
746 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
750 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
751 .with_database(current_database.clone().unwrap_or_default()),
752 #[cfg(feature = "always-encrypted")]
753 encryption_context: config.column_encryption.clone().map(|cfg| {
754 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
755 }),
756 })
757 }
758 } else {
759 tracing::warn!(
761 "Connecting without TLS encryption. This is insecure and should only be \
762 used for development/testing on trusted networks."
763 );
764
765 let mut connection = Connection::new(tcp_stream);
766
767 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
769 let negotiator = Self::create_negotiator(config)?;
770 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
771 let sspi_token = match negotiator {
772 Some(ref neg) => Some(neg.initialize()?),
773 None => None,
774 };
775 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
776 let sspi_token: Option<Vec<u8>> = None;
777
778 let login = Self::build_login7(config, sspi_token);
780 Self::send_login7(&mut connection, &login).await?;
781
782 let (server_version, current_database, routing, server_collation) = timeout(
784 config.timeouts.login_timeout,
785 Self::process_login_response(
786 &mut connection,
787 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
788 negotiator.as_deref(),
789 ),
790 )
791 .await
792 .map_err(|_| Error::LoginTimeout {
793 host: config.host.clone(),
794 port: config.port,
795 })??;
796
797 if let Some((host, port)) = routing {
799 return Err(Error::Routing { host, port });
800 }
801
802 Ok(Client {
803 config: config.clone(),
804 _state: PhantomData,
805 connection: Some(ConnectionHandle::Plain(connection)),
806 server_version,
807 current_database: current_database.clone(),
808 server_collation,
809 statement_cache: StatementCache::with_default_size(),
810 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
814 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
815 .with_database(current_database.clone().unwrap_or_default()),
816 #[cfg(feature = "always-encrypted")]
817 encryption_context: config.column_encryption.clone().map(|cfg| {
818 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
819 }),
820 })
821 }
822 }
823
824 #[cfg(not(feature = "tls"))]
835 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
836 use bytes::BufMut;
837 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
838 use tokio::io::{AsyncReadExt, AsyncWriteExt};
839
840 tracing::warn!(
841 "⚠️ Connecting without TLS (tls feature disabled). \
842 Credentials and data will be transmitted in plaintext."
843 );
844
845 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
847 let prelogin_bytes = prelogin.encode();
848
849 let header = PacketHeader::new(
851 PacketType::PreLogin,
852 PacketStatus::END_OF_MESSAGE,
853 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
854 );
855
856 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
857 header.encode(&mut packet_buf);
858 packet_buf.put_slice(&prelogin_bytes);
859
860 tcp_stream
861 .write_all(&packet_buf)
862 .await
863 .map_err(Error::from)?;
864
865 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
867 tcp_stream
868 .read_exact(&mut header_buf)
869 .await
870 .map_err(Error::from)?;
871
872 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
873 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
874
875 let mut response_buf = vec![0u8; payload_length];
876 tcp_stream
877 .read_exact(&mut response_buf)
878 .await
879 .map_err(Error::from)?;
880
881 let prelogin_response = PreLogin::decode(&response_buf[..])?;
882
883 let server_encryption = prelogin_response.encryption;
885 if server_encryption != EncryptionLevel::NotSupported {
886 return Err(Error::Config(format!(
887 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
888 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
889 server_encryption
890 )));
891 }
892
893 tracing::debug!("Server accepted unencrypted connection");
894
895 let mut connection = Connection::new(tcp_stream);
896
897 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
899 let negotiator = Self::create_negotiator(config)?;
900 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
901 let sspi_token = match negotiator {
902 Some(ref neg) => Some(neg.initialize()?),
903 None => None,
904 };
905 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
906 let sspi_token: Option<Vec<u8>> = None;
907
908 let login = Self::build_login7(config, sspi_token);
910 Self::send_login7(&mut connection, &login).await?;
911
912 let (server_version, current_database, routing, server_collation) = timeout(
914 config.timeouts.login_timeout,
915 Self::process_login_response(
916 &mut connection,
917 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
918 negotiator.as_deref(),
919 ),
920 )
921 .await
922 .map_err(|_| Error::LoginTimeout {
923 host: config.host.clone(),
924 port: config.port,
925 })??;
926
927 if let Some((host, port)) = routing {
929 return Err(Error::Routing { host, port });
930 }
931
932 Ok(Client {
933 config: config.clone(),
934 _state: PhantomData,
935 connection: Some(ConnectionHandle::Plain(connection)),
936 server_version,
937 current_database: current_database.clone(),
938 server_collation,
939 statement_cache: StatementCache::with_default_size(),
940 transaction_descriptor: 0,
941 needs_reset: false,
942 in_flight: false,
943 #[cfg(feature = "otel")]
944 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
945 .with_database(current_database.clone().unwrap_or_default()),
946 #[cfg(feature = "always-encrypted")]
947 encryption_context: config.column_encryption.clone().map(|cfg| {
948 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
949 }),
950 })
951 }
952
953 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
955 let version = if config.strict_mode {
957 tds_protocol::version::TdsVersion::V8_0
958 } else {
959 config.tds_version
960 };
961
962 let mut prelogin = PreLogin::new()
963 .with_version(version)
964 .with_encryption(encryption);
965
966 if config.mars {
967 prelogin = prelogin.with_mars(true);
968 }
969
970 if let Some(ref instance) = config.instance {
971 prelogin = prelogin.with_instance(instance);
972 }
973
974 prelogin
975 }
976
977 fn resolve_workstation_id(config: &Config) -> String {
985 if let Some(ref id) = config.workstation_id {
986 return id.clone();
987 }
988 std::env::var("COMPUTERNAME")
991 .or_else(|_| std::env::var("HOSTNAME"))
992 .unwrap_or_default()
993 }
994
995 fn build_login7(config: &Config, sspi_token: Option<Vec<u8>>) -> Login7 {
1000 let version = if config.strict_mode {
1002 tds_protocol::version::TdsVersion::V8_0
1003 } else {
1004 config.tds_version
1005 };
1006
1007 let mut login = Login7::new()
1008 .with_tds_version(version)
1009 .with_packet_size(config.packet_size as u32)
1010 .with_app_name(&config.application_name)
1011 .with_server_name(&config.host)
1012 .with_hostname(Self::resolve_workstation_id(config));
1013
1014 if let Some(ref database) = config.database {
1015 login = login.with_database(database);
1016 }
1017
1018 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
1020 login = login.with_read_only_intent(true);
1021 }
1022
1023 if let Some(ref lang) = config.language {
1025 login = login.with_language(lang);
1026 }
1027
1028 if let Some(token) = sspi_token {
1030 login = login.with_integrated_auth(token);
1032 } else if let mssql_auth::Credentials::SqlServer { username, password } =
1033 &config.credentials
1034 {
1035 login = login.with_sql_auth(username.as_ref(), password.as_ref());
1036 }
1037
1038 #[cfg(feature = "always-encrypted")]
1041 if config.column_encryption.is_some() {
1042 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1043 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1044 data: bytes::Bytes::from_static(&[0x01]), });
1046 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1047 }
1048
1049 login
1050 }
1051
1052 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1062 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1063 #[allow(clippy::match_like_matches_macro)]
1064 let is_integrated = match &config.credentials {
1065 mssql_auth::Credentials::Integrated => true,
1066 _ => false,
1067 };
1068
1069 if !is_integrated {
1070 return Ok(None);
1071 }
1072
1073 #[cfg(all(windows, feature = "sspi-auth"))]
1078 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1079 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1080
1081 #[cfg(all(not(windows), feature = "sspi-auth"))]
1083 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1084 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1085
1086 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1087 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1088 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1089
1090 Ok(Some(negotiator))
1091 }
1092
1093 #[cfg(feature = "tls")]
1095 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1096 where
1097 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1098 {
1099 let payload = prelogin.encode();
1100 let max_packet = MAX_PACKET_SIZE;
1101
1102 connection
1103 .send_message(PacketType::PreLogin, payload, max_packet)
1104 .await?;
1105 Ok(())
1106 }
1107
1108 #[cfg(feature = "tls")]
1110 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1111 where
1112 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1113 {
1114 let message = connection
1115 .read_message()
1116 .await?
1117 .ok_or(Error::ConnectionClosed)?;
1118
1119 Ok(PreLogin::decode(&message.payload[..])?)
1120 }
1121
1122 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1124 where
1125 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1126 {
1127 let payload = login.encode();
1128 let max_packet = MAX_PACKET_SIZE;
1129
1130 connection
1131 .send_message(PacketType::Tds7Login, payload, max_packet)
1132 .await?;
1133 Ok(())
1134 }
1135
1136 #[allow(clippy::never_loop)] async fn process_login_response<T>(
1147 connection: &mut Connection<T>,
1148 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1149 &dyn mssql_auth::SspiNegotiator,
1150 >,
1151 ) -> Result<(
1152 Option<u32>,
1153 Option<String>,
1154 Option<(String, u16)>,
1155 Option<tds_protocol::token::Collation>,
1156 )>
1157 where
1158 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1159 {
1160 let mut server_version = None;
1161 let mut database = None;
1162 let mut routing = None;
1163 let mut collation = None;
1164
1165 'outer: loop {
1166 let message = connection
1167 .read_message()
1168 .await?
1169 .ok_or(Error::ConnectionClosed)?;
1170
1171 let response_bytes = message.payload;
1172 let mut parser = TokenParser::new(response_bytes);
1173
1174 while let Some(token) = parser.next_token()? {
1175 match token {
1176 Token::LoginAck(ack) => {
1177 tracing::info!(
1178 version = ack.tds_version,
1179 interface = ack.interface,
1180 prog_name = %ack.prog_name,
1181 "login acknowledged"
1182 );
1183 server_version = Some(ack.tds_version);
1184 }
1185 Token::EnvChange(env) => {
1186 Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1187 }
1188 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1189 Token::Sspi(sspi_token) => {
1190 let neg = negotiator.ok_or_else(|| {
1191 Error::Protocol(
1192 "server sent SSPI challenge but no negotiator is configured"
1193 .to_string(),
1194 )
1195 })?;
1196
1197 tracing::debug!(
1198 challenge_len = sspi_token.data.len(),
1199 "received SSPI challenge from server"
1200 );
1201
1202 if let Some(response) = neg.step(&sspi_token.data)? {
1203 tracing::debug!(response_len = response.len(), "sending SSPI response");
1204 connection
1205 .send_message(
1206 PacketType::Sspi,
1207 bytes::Bytes::from(response),
1208 tds_protocol::packet::MAX_PACKET_SIZE,
1209 )
1210 .await?;
1211 }
1212
1213 continue 'outer;
1215 }
1216 Token::Error(err) => {
1217 return Err(Error::Server {
1218 number: err.number,
1219 state: err.state,
1220 class: err.class,
1221 message: err.message.clone(),
1222 server: if err.server.is_empty() {
1223 None
1224 } else {
1225 Some(err.server.clone())
1226 },
1227 procedure: if err.procedure.is_empty() {
1228 None
1229 } else {
1230 Some(err.procedure.clone())
1231 },
1232 line: err.line as u32,
1233 });
1234 }
1235 Token::Info(info) => {
1236 tracing::info!(
1237 number = info.number,
1238 message = %info.message,
1239 "server info message"
1240 );
1241 }
1242 Token::Done(done) => {
1243 if done.status.error {
1244 return Err(Error::Protocol("login failed".to_string()));
1245 }
1246 break 'outer;
1247 }
1248 _ => {}
1249 }
1250 }
1251
1252 break;
1254 }
1255
1256 Ok((server_version, database, routing, collation))
1257 }
1258
1259 fn process_env_change(
1261 env: &EnvChange,
1262 database: &mut Option<String>,
1263 routing: &mut Option<(String, u16)>,
1264 collation: &mut Option<tds_protocol::token::Collation>,
1265 ) {
1266 use tds_protocol::token::EnvChangeValue;
1267
1268 match env.env_type {
1269 EnvChangeType::Database => {
1270 if let EnvChangeValue::String(ref new_value) = env.new_value {
1271 tracing::debug!(database = %new_value, "database changed");
1272 *database = Some(new_value.clone());
1273 }
1274 }
1275 EnvChangeType::Routing => {
1276 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1277 tracing::info!(host = %host, port = port, "routing redirect received");
1278 *routing = Some((host.clone(), port));
1279 }
1280 }
1281 EnvChangeType::SqlCollation => {
1282 if let EnvChangeValue::Binary(ref data) = env.new_value {
1283 if data.len() >= 5 {
1284 let c = tds_protocol::token::Collation::from_bytes(
1285 data[..5].try_into().unwrap(),
1286 );
1287 tracing::debug!(
1288 lcid = c.lcid,
1289 sort_id = c.sort_id,
1290 "server collation received"
1291 );
1292 *collation = Some(c);
1293 }
1294 }
1295 }
1296 _ => {
1297 if let EnvChangeValue::String(ref new_value) = env.new_value {
1298 tracing::debug!(
1299 env_type = ?env.env_type,
1300 new_value = %new_value,
1301 "environment change"
1302 );
1303 }
1304 }
1305 }
1306 }
1307}
1308
1309#[cfg(feature = "tls")]
1330fn connection_tls_config(config: &Config, strict: bool) -> TlsConfig {
1331 let tls = config
1332 .tls
1333 .clone()
1334 .trust_server_certificate(config.trust_server_certificate);
1335 if strict {
1336 tls.strict_mode(true)
1337 .with_alpn_protocols(vec![b"tds/8.0".to_vec()])
1338 } else {
1339 tls
1340 }
1341}
1342
1343#[cfg(all(test, feature = "tls"))]
1344mod tls_config_tests {
1345 use super::*;
1346 use mssql_tls::CertificateDer;
1347
1348 fn config_with_root(cert: Vec<u8>) -> Config {
1349 let mut config = Config::new();
1350 config.tls = config
1351 .tls
1352 .clone()
1353 .add_root_certificate(CertificateDer::from(cert));
1354 config
1355 }
1356
1357 #[test]
1358 fn custom_root_certificate_reaches_connector_config() {
1359 let config = config_with_root(vec![0xCA; 32]);
1363
1364 for strict in [true, false] {
1365 let tls = connection_tls_config(&config, strict);
1366 assert_eq!(
1367 tls.root_certificates.len(),
1368 1,
1369 "custom root must reach the connector (strict={strict})"
1370 );
1371 assert_eq!(tls.root_certificates[0].as_ref(), &[0xCA; 32][..]);
1372 }
1373 }
1374
1375 #[test]
1376 fn trust_server_certificate_taken_from_top_level_field() {
1377 let mut config = Config::new();
1380 config.trust_server_certificate = true;
1381 assert!(!config.tls.trust_server_certificate);
1383
1384 let tls = connection_tls_config(&config, false);
1385 assert!(
1386 tls.trust_server_certificate,
1387 "top-level trust flag must win"
1388 );
1389 }
1390
1391 #[test]
1392 fn strict_mode_adds_tds8_alpn() {
1393 let config = Config::new();
1394 let strict = connection_tls_config(&config, true);
1395 assert!(strict.strict_mode);
1396 assert!(strict.alpn_protocols.iter().any(|p| p == b"tds/8.0"));
1397
1398 let non_strict = connection_tls_config(&config, false);
1399 assert!(!non_strict.strict_mode);
1400 }
1401}