1use std::marker::PhantomData;
7use std::net::SocketAddr;
8
9use bytes::BytesMut;
10use mssql_codec::connection::Connection;
11#[cfg(feature = "tls")]
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::MAX_PACKET_SIZE;
15use tds_protocol::packet::PacketType;
16use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
17use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20
21use crate::config::Config;
22use crate::error::{Error, Result};
23#[cfg(feature = "otel")]
24use crate::instrumentation::InstrumentationContext;
25use crate::state::{Disconnected, Ready};
26use crate::statement_cache::StatementCache;
27
28use super::{Client, ConnectionHandle};
29
30impl Client<Disconnected> {
31 pub async fn connect(config: Config) -> Result<Client<Ready>> {
42 let retry = config.retry.clone();
43 let max_redirects = config.redirect.max_redirects;
44 let follow_redirects = config.redirect.follow_redirects;
45 let per_attempt = config.timeouts.connect_timeout
47 + config.timeouts.tls_timeout
48 + config.timeouts.login_timeout;
49 let total_attempts = (retry.max_retries + 1) * (max_redirects as u32 + 1);
50 let overall = (per_attempt * total_attempts).min(std::time::Duration::from_secs(300));
51 let initial_host = config.host.clone();
52 let initial_port = config.port;
53
54 let result = timeout(overall, async {
55 let mut last_error: Option<Error> = None;
56
57 for retry_attempt in 0..=retry.max_retries {
58 if retry_attempt > 0 {
59 let backoff = retry.backoff_for_attempt(retry_attempt);
60 tracing::info!(
61 retry_attempt,
62 backoff_ms = backoff.as_millis() as u64,
63 "retrying connection after transient error"
64 );
65 tokio::time::sleep(backoff).await;
66 }
67
68 let mut current_config = config.clone();
70 let mut redirect_count: u8 = 0;
71
72 let attempt_result = loop {
73 redirect_count += 1;
74 if redirect_count > max_redirects + 1 {
75 break Err(Error::TooManyRedirects { max: max_redirects });
76 }
77
78 match Self::try_connect(¤t_config).await {
79 Ok(client) => break Ok(client),
80 Err(Error::Routing { host, port }) => {
81 if !follow_redirects {
82 break Err(Error::Routing { host, port });
83 }
84 tracing::info!(
85 host = %host,
86 port = port,
87 redirect = redirect_count,
88 max_redirects = max_redirects,
89 "following Azure SQL routing redirect"
90 );
91 current_config = current_config.with_host(&host).with_port(port);
92 continue;
93 }
94 Err(e) => break Err(e),
95 }
96 };
97
98 match attempt_result {
99 Ok(client) => return Ok(client),
100 Err(ref e) if e.is_transient() && retry.should_retry(retry_attempt) => {
101 tracing::warn!(
102 retry_attempt,
103 max_retries = retry.max_retries,
104 error = %e,
105 "transient connection error, will retry"
106 );
107 last_error = Some(attempt_result.unwrap_err());
108 }
109 Err(e) => return Err(e),
110 }
111 }
112
113 Err(last_error.expect("at least one attempt was made"))
115 })
116 .await;
117
118 match result {
119 Ok(inner) => inner,
120 Err(_elapsed) => Err(Error::ConnectTimeout {
121 host: initial_host,
122 port: initial_port,
123 }),
124 }
125 }
126
127 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
128 let port = if let Some(ref instance) = config.instance {
130 let resolved = crate::browser::resolve_instance(
131 &config.host,
132 instance,
133 Some(config.timeouts.connect_timeout),
134 )
135 .await?;
136 tracing::info!(
137 host = %config.host,
138 instance = %instance,
139 resolved_port = resolved,
140 database = ?config.database,
141 "connecting to named SQL Server instance"
142 );
143 resolved
144 } else {
145 tracing::info!(
146 host = %config.host,
147 port = config.port,
148 database = ?config.database,
149 "connecting to SQL Server"
150 );
151 config.port
152 };
153
154 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
157 "127.0.0.1"
158 } else {
159 &config.host
160 };
161
162 let tcp_stream = if config.multi_subnet_failover {
164 Self::connect_parallel(host, port, config.timeouts.connect_timeout).await?
165 } else {
166 let addr = format!("{host}:{port}");
167 tracing::debug!("establishing TCP connection to {}", addr);
168 let stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
169 .await
170 .map_err(|_| Error::ConnectTimeout {
171 host: config.host.clone(),
172 port: config.port,
173 })?
174 .map_err(Error::from)?;
175 stream.set_nodelay(true).map_err(Error::from)?;
176 stream
177 };
178
179 #[cfg(feature = "tls")]
180 {
181 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
183
184 if tls_mode.is_tls_first() {
186 return Self::connect_tds_8(config, tcp_stream).await;
187 }
188
189 Self::connect_tds_7x(config, tcp_stream).await
191 }
192
193 #[cfg(not(feature = "tls"))]
194 {
195 if config.strict_mode {
197 return Err(Error::Config(
198 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
199 ));
200 }
201
202 if !config.no_tls {
203 return Err(Error::Config(
204 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
205 or use Encrypt=no_tls in your connection string for unencrypted connections."
206 .into(),
207 ));
208 }
209
210 Self::connect_no_tls(config, tcp_stream).await
212 }
213 }
214
215 async fn connect_parallel(
220 host: &str,
221 port: u16,
222 connect_timeout: std::time::Duration,
223 ) -> Result<TcpStream> {
224 let addr_str = format!("{host}:{port}");
225 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
226 .await
227 .map_err(Error::from)?
228 .collect();
229
230 if addrs.is_empty() {
231 return Err(Error::from(std::io::Error::new(
232 std::io::ErrorKind::AddrNotAvailable,
233 format!("no addresses resolved for {host}:{port}"),
234 )));
235 }
236
237 if addrs.len() == 1 {
239 tracing::debug!(addr = %addrs[0], "MultiSubnetFailover: single address resolved");
240 let stream = timeout(connect_timeout, TcpStream::connect(addrs[0]))
241 .await
242 .map_err(|_| Error::ConnectTimeout {
243 host: host.to_string(),
244 port,
245 })?
246 .map_err(Error::from)?;
247 stream.set_nodelay(true).map_err(Error::from)?;
248 return Ok(stream);
249 }
250
251 let addr_count = addrs.len();
252 tracing::debug!(
253 host = host,
254 port = port,
255 resolved_count = addr_count,
256 "MultiSubnetFailover: racing parallel connections",
257 );
258
259 let mut join_set = tokio::task::JoinSet::new();
260
261 for addr in addrs {
262 let dur = connect_timeout;
263 join_set.spawn(async move {
264 let tcp = timeout(dur, TcpStream::connect(addr)).await.map_err(|_| {
265 std::io::Error::new(
266 std::io::ErrorKind::TimedOut,
267 format!("connection to {addr} timed out"),
268 )
269 })??;
270 tcp.set_nodelay(true)?;
271 Ok::<(TcpStream, SocketAddr), std::io::Error>((tcp, addr))
272 });
273 }
274
275 let mut last_error: Option<std::io::Error> = None;
276
277 while let Some(result) = join_set.join_next().await {
278 match result {
279 Ok(Ok((stream, addr))) => {
280 tracing::debug!(addr = %addr, "MultiSubnetFailover: connected");
281 join_set.abort_all();
282 return Ok(stream);
283 }
284 Ok(Err(e)) => {
285 tracing::debug!(error = %e, "MultiSubnetFailover: attempt failed");
286 last_error = Some(e);
287 }
288 Err(join_err) => {
289 tracing::debug!(error = %join_err, "MultiSubnetFailover: task failed");
290 last_error = Some(std::io::Error::other(join_err.to_string()));
291 }
292 }
293 }
294
295 Err(Error::from(last_error.unwrap_or_else(|| {
297 std::io::Error::new(
298 std::io::ErrorKind::ConnectionRefused,
299 format!("all {addr_count} parallel connection attempts failed for {host}:{port}"),
300 )
301 })))
302 }
303
304 #[cfg(feature = "tls")]
308 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
309 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
310
311 let tls_config = TlsConfig::new()
313 .strict_mode(true)
314 .trust_server_certificate(config.trust_server_certificate)
315 .with_alpn_protocols(vec![b"tds/8.0".to_vec()]);
316
317 let tls_connector = TlsConnector::new(tls_config)?;
318
319 let tls_stream = timeout(
321 config.timeouts.tls_timeout,
322 tls_connector.connect(tcp_stream, &config.host),
323 )
324 .await
325 .map_err(|_| Error::TlsTimeout {
326 host: config.host.clone(),
327 port: config.port,
328 })??;
329
330 tracing::debug!("TLS handshake completed (strict mode)");
331
332 let mut connection = Connection::new(tls_stream);
334
335 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
337 Self::send_prelogin(&mut connection, &prelogin).await?;
338 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
339
340 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
342 let negotiator = Self::create_negotiator(config)?;
343 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
344 let sspi_token = match negotiator {
345 Some(ref neg) => Some(neg.initialize()?),
346 None => None,
347 };
348 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
349 let sspi_token: Option<Vec<u8>> = None;
350
351 let login = Self::build_login7(config, sspi_token);
353 Self::send_login7(&mut connection, &login).await?;
354
355 let (server_version, current_database, routing, server_collation) = timeout(
357 config.timeouts.login_timeout,
358 Self::process_login_response(
359 &mut connection,
360 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
361 negotiator.as_deref(),
362 ),
363 )
364 .await
365 .map_err(|_| Error::LoginTimeout {
366 host: config.host.clone(),
367 port: config.port,
368 })??;
369
370 if let Some((host, port)) = routing {
372 return Err(Error::Routing { host, port });
373 }
374
375 Ok(Client {
376 config: config.clone(),
377 _state: PhantomData,
378 connection: Some(ConnectionHandle::Tls(connection)),
379 server_version,
380 current_database: current_database.clone(),
381 server_collation,
382 statement_cache: StatementCache::with_default_size(),
383 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
387 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
388 .with_database(current_database.clone().unwrap_or_default()),
389 #[cfg(feature = "always-encrypted")]
390 encryption_context: config.column_encryption.clone().map(|cfg| {
391 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
392 }),
393 })
394 }
395
396 #[cfg(feature = "tls")]
404 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
405 use bytes::BufMut;
406 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
407 use tokio::io::{AsyncReadExt, AsyncWriteExt};
408
409 tracing::debug!("using TDS 7.x flow (PreLogin first)");
410
411 let client_encryption = if config.no_tls {
414 tracing::warn!(
416 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
417 Credentials and data will be transmitted in plaintext. \
418 This should only be used for development/testing with legacy SQL Server."
419 );
420 EncryptionLevel::NotSupported
421 } else if config.encrypt {
422 EncryptionLevel::On
423 } else {
424 EncryptionLevel::Off
425 };
426 let prelogin = Self::build_prelogin(config, client_encryption);
427 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
428 let prelogin_bytes = prelogin.encode();
429
430 let header = PacketHeader::new(
432 PacketType::PreLogin,
433 PacketStatus::END_OF_MESSAGE,
434 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
435 );
436
437 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
438 header.encode(&mut packet_buf);
439 packet_buf.put_slice(&prelogin_bytes);
440
441 tcp_stream
442 .write_all(&packet_buf)
443 .await
444 .map_err(Error::from)?;
445
446 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
448 tcp_stream
449 .read_exact(&mut header_buf)
450 .await
451 .map_err(Error::from)?;
452
453 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
454 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
455
456 let mut response_buf = vec![0u8; payload_length];
457 tcp_stream
458 .read_exact(&mut response_buf)
459 .await
460 .map_err(Error::from)?;
461
462 let prelogin_response = PreLogin::decode(&response_buf[..])?;
463
464 let client_tds_version = config.tds_version;
469 if let Some(ref server_version) = prelogin_response.server_version {
470 tracing::debug!(
471 requested_tds_version = %client_tds_version,
472 server_product_version = %server_version,
473 server_product = server_version.product_name(),
474 max_tds_version = %server_version.max_tds_version(),
475 "PreLogin response received"
476 );
477
478 let server_max_tds = server_version.max_tds_version();
480 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
481 tracing::warn!(
482 requested_tds_version = %client_tds_version,
483 server_max_tds_version = %server_max_tds,
484 server_product = server_version.product_name(),
485 "Server supports lower TDS version than requested. \
486 Connection will use server's maximum: {}",
487 server_max_tds
488 );
489 }
490
491 if server_max_tds.is_legacy() {
493 tracing::warn!(
494 server_product = server_version.product_name(),
495 server_max_tds_version = %server_max_tds,
496 "Server uses legacy TDS version. Some features may not be available."
497 );
498 }
499 } else {
500 tracing::debug!(
501 requested_tds_version = %client_tds_version,
502 "PreLogin response received (no version info)"
503 );
504 }
505
506 let server_encryption = prelogin_response.encryption;
508 tracing::debug!(encryption = ?server_encryption, "server encryption level");
509
510 let negotiated_encryption = match (client_encryption, server_encryption) {
516 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
517 EncryptionLevel::NotSupported
518 }
519 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
520 (EncryptionLevel::On, EncryptionLevel::Off)
521 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
522 return Err(Error::Protocol(
523 "Server does not support requested encryption level".to_string(),
524 ));
525 }
526 _ => EncryptionLevel::On,
527 };
528
529 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
532
533 if use_tls {
534 let tls_config =
537 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
538
539 let tls_connector = TlsConnector::new(tls_config)?;
540
541 let mut tls_stream = timeout(
543 config.timeouts.tls_timeout,
544 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
545 )
546 .await
547 .map_err(|_| Error::TlsTimeout {
548 host: config.host.clone(),
549 port: config.port,
550 })??;
551
552 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
553
554 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
556
557 if login_only_encryption {
558 use tokio::io::AsyncWriteExt;
566
567 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
572 let negotiator = Self::create_negotiator(config)?;
573 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
574 let sspi_token = match negotiator {
575 Some(ref neg) => Some(neg.initialize()?),
576 None => None,
577 };
578 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
579 let sspi_token: Option<Vec<u8>> = None;
580
581 let login = Self::build_login7(config, sspi_token);
583 let login_payload = login.encode();
584
585 let max_packet = MAX_PACKET_SIZE;
587 let max_payload = max_packet - PACKET_HEADER_SIZE;
588 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
589 let total_chunks = chunks.len();
590
591 for (i, chunk) in chunks.into_iter().enumerate() {
592 let is_last = i == total_chunks - 1;
593 let status = if is_last {
594 PacketStatus::END_OF_MESSAGE
595 } else {
596 PacketStatus::NORMAL
597 };
598
599 let header = PacketHeader::new(
600 PacketType::Tds7Login,
601 status,
602 (PACKET_HEADER_SIZE + chunk.len()) as u16,
603 );
604
605 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
606 header.encode(&mut packet_buf);
607 packet_buf.put_slice(chunk);
608
609 tls_stream
610 .write_all(&packet_buf)
611 .await
612 .map_err(Error::from)?;
613 }
614
615 tls_stream.flush().await.map_err(Error::from)?;
617
618 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
619
620 let (wrapper, _client_conn) = tls_stream.into_inner();
624 let tcp_stream = wrapper.into_inner();
625
626 let mut connection = Connection::new(tcp_stream);
628
629 let (server_version, current_database, routing, server_collation) = timeout(
631 config.timeouts.login_timeout,
632 Self::process_login_response(
633 &mut connection,
634 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
635 negotiator.as_deref(),
636 ),
637 )
638 .await
639 .map_err(|_| Error::LoginTimeout {
640 host: config.host.clone(),
641 port: config.port,
642 })??;
643
644 if let Some((host, port)) = routing {
646 return Err(Error::Routing { host, port });
647 }
648
649 Ok(Client {
651 config: config.clone(),
652 _state: PhantomData,
653 connection: Some(ConnectionHandle::Plain(connection)),
654 server_version,
655 current_database: current_database.clone(),
656 server_collation,
657 statement_cache: StatementCache::with_default_size(),
658 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
662 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
663 .with_database(current_database.clone().unwrap_or_default()),
664 #[cfg(feature = "always-encrypted")]
665 encryption_context: config.column_encryption.clone().map(|cfg| {
666 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
667 }),
668 })
669 } else {
670 let mut connection = Connection::new(tls_stream);
673
674 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
676 let negotiator = Self::create_negotiator(config)?;
677 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
678 let sspi_token = match negotiator {
679 Some(ref neg) => Some(neg.initialize()?),
680 None => None,
681 };
682 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
683 let sspi_token: Option<Vec<u8>> = None;
684
685 let login = Self::build_login7(config, sspi_token);
687 Self::send_login7(&mut connection, &login).await?;
688
689 let (server_version, current_database, routing, server_collation) = timeout(
691 config.timeouts.login_timeout,
692 Self::process_login_response(
693 &mut connection,
694 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
695 negotiator.as_deref(),
696 ),
697 )
698 .await
699 .map_err(|_| Error::LoginTimeout {
700 host: config.host.clone(),
701 port: config.port,
702 })??;
703
704 if let Some((host, port)) = routing {
706 return Err(Error::Routing { host, port });
707 }
708
709 Ok(Client {
710 config: config.clone(),
711 _state: PhantomData,
712 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
713 server_version,
714 current_database: current_database.clone(),
715 server_collation,
716 statement_cache: StatementCache::with_default_size(),
717 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
721 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
722 .with_database(current_database.clone().unwrap_or_default()),
723 #[cfg(feature = "always-encrypted")]
724 encryption_context: config.column_encryption.clone().map(|cfg| {
725 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
726 }),
727 })
728 }
729 } else {
730 tracing::warn!(
732 "Connecting without TLS encryption. This is insecure and should only be \
733 used for development/testing on trusted networks."
734 );
735
736 let mut connection = Connection::new(tcp_stream);
737
738 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
740 let negotiator = Self::create_negotiator(config)?;
741 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
742 let sspi_token = match negotiator {
743 Some(ref neg) => Some(neg.initialize()?),
744 None => None,
745 };
746 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
747 let sspi_token: Option<Vec<u8>> = None;
748
749 let login = Self::build_login7(config, sspi_token);
751 Self::send_login7(&mut connection, &login).await?;
752
753 let (server_version, current_database, routing, server_collation) = timeout(
755 config.timeouts.login_timeout,
756 Self::process_login_response(
757 &mut connection,
758 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
759 negotiator.as_deref(),
760 ),
761 )
762 .await
763 .map_err(|_| Error::LoginTimeout {
764 host: config.host.clone(),
765 port: config.port,
766 })??;
767
768 if let Some((host, port)) = routing {
770 return Err(Error::Routing { host, port });
771 }
772
773 Ok(Client {
774 config: config.clone(),
775 _state: PhantomData,
776 connection: Some(ConnectionHandle::Plain(connection)),
777 server_version,
778 current_database: current_database.clone(),
779 server_collation,
780 statement_cache: StatementCache::with_default_size(),
781 transaction_descriptor: 0, needs_reset: false, in_flight: false, #[cfg(feature = "otel")]
785 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
786 .with_database(current_database.clone().unwrap_or_default()),
787 #[cfg(feature = "always-encrypted")]
788 encryption_context: config.column_encryption.clone().map(|cfg| {
789 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
790 }),
791 })
792 }
793 }
794
795 #[cfg(not(feature = "tls"))]
806 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
807 use bytes::BufMut;
808 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
809 use tokio::io::{AsyncReadExt, AsyncWriteExt};
810
811 tracing::warn!(
812 "⚠️ Connecting without TLS (tls feature disabled). \
813 Credentials and data will be transmitted in plaintext."
814 );
815
816 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
818 let prelogin_bytes = prelogin.encode();
819
820 let header = PacketHeader::new(
822 PacketType::PreLogin,
823 PacketStatus::END_OF_MESSAGE,
824 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
825 );
826
827 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
828 header.encode(&mut packet_buf);
829 packet_buf.put_slice(&prelogin_bytes);
830
831 tcp_stream
832 .write_all(&packet_buf)
833 .await
834 .map_err(Error::from)?;
835
836 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
838 tcp_stream
839 .read_exact(&mut header_buf)
840 .await
841 .map_err(Error::from)?;
842
843 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
844 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
845
846 let mut response_buf = vec![0u8; payload_length];
847 tcp_stream
848 .read_exact(&mut response_buf)
849 .await
850 .map_err(Error::from)?;
851
852 let prelogin_response = PreLogin::decode(&response_buf[..])?;
853
854 let server_encryption = prelogin_response.encryption;
856 if server_encryption != EncryptionLevel::NotSupported {
857 return Err(Error::Config(format!(
858 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
859 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
860 server_encryption
861 )));
862 }
863
864 tracing::debug!("Server accepted unencrypted connection");
865
866 let mut connection = Connection::new(tcp_stream);
867
868 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
870 let negotiator = Self::create_negotiator(config)?;
871 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
872 let sspi_token = match negotiator {
873 Some(ref neg) => Some(neg.initialize()?),
874 None => None,
875 };
876 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
877 let sspi_token: Option<Vec<u8>> = None;
878
879 let login = Self::build_login7(config, sspi_token);
881 Self::send_login7(&mut connection, &login).await?;
882
883 let (server_version, current_database, routing, server_collation) = timeout(
885 config.timeouts.login_timeout,
886 Self::process_login_response(
887 &mut connection,
888 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
889 negotiator.as_deref(),
890 ),
891 )
892 .await
893 .map_err(|_| Error::LoginTimeout {
894 host: config.host.clone(),
895 port: config.port,
896 })??;
897
898 if let Some((host, port)) = routing {
900 return Err(Error::Routing { host, port });
901 }
902
903 Ok(Client {
904 config: config.clone(),
905 _state: PhantomData,
906 connection: Some(ConnectionHandle::Plain(connection)),
907 server_version,
908 current_database: current_database.clone(),
909 server_collation,
910 statement_cache: StatementCache::with_default_size(),
911 transaction_descriptor: 0,
912 needs_reset: false,
913 in_flight: false,
914 #[cfg(feature = "otel")]
915 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
916 .with_database(current_database.clone().unwrap_or_default()),
917 #[cfg(feature = "always-encrypted")]
918 encryption_context: config.column_encryption.clone().map(|cfg| {
919 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
920 }),
921 })
922 }
923
924 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
926 let version = if config.strict_mode {
928 tds_protocol::version::TdsVersion::V8_0
929 } else {
930 config.tds_version
931 };
932
933 let mut prelogin = PreLogin::new()
934 .with_version(version)
935 .with_encryption(encryption);
936
937 if config.mars {
938 prelogin = prelogin.with_mars(true);
939 }
940
941 if let Some(ref instance) = config.instance {
942 prelogin = prelogin.with_instance(instance);
943 }
944
945 prelogin
946 }
947
948 fn resolve_workstation_id(config: &Config) -> String {
956 if let Some(ref id) = config.workstation_id {
957 return id.clone();
958 }
959 std::env::var("COMPUTERNAME")
962 .or_else(|_| std::env::var("HOSTNAME"))
963 .unwrap_or_default()
964 }
965
966 fn build_login7(config: &Config, sspi_token: Option<Vec<u8>>) -> Login7 {
971 let version = if config.strict_mode {
973 tds_protocol::version::TdsVersion::V8_0
974 } else {
975 config.tds_version
976 };
977
978 let mut login = Login7::new()
979 .with_tds_version(version)
980 .with_packet_size(config.packet_size as u32)
981 .with_app_name(&config.application_name)
982 .with_server_name(&config.host)
983 .with_hostname(Self::resolve_workstation_id(config));
984
985 if let Some(ref database) = config.database {
986 login = login.with_database(database);
987 }
988
989 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
991 login = login.with_read_only_intent(true);
992 }
993
994 if let Some(ref lang) = config.language {
996 login = login.with_language(lang);
997 }
998
999 if let Some(token) = sspi_token {
1001 login = login.with_integrated_auth(token);
1003 } else if let mssql_auth::Credentials::SqlServer { username, password } =
1004 &config.credentials
1005 {
1006 login = login.with_sql_auth(username.as_ref(), password.as_ref());
1007 }
1008
1009 #[cfg(feature = "always-encrypted")]
1012 if config.column_encryption.is_some() {
1013 login = login.with_feature(tds_protocol::login7::FeatureExtension {
1014 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
1015 data: bytes::Bytes::from_static(&[0x01]), });
1017 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
1018 }
1019
1020 login
1021 }
1022
1023 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1033 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
1034 #[allow(clippy::match_like_matches_macro)]
1035 let is_integrated = match &config.credentials {
1036 mssql_auth::Credentials::Integrated => true,
1037 _ => false,
1038 };
1039
1040 if !is_integrated {
1041 return Ok(None);
1042 }
1043
1044 #[cfg(all(windows, feature = "sspi-auth"))]
1049 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1050 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
1051
1052 #[cfg(all(not(windows), feature = "sspi-auth"))]
1054 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1055 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
1056
1057 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
1058 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
1059 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
1060
1061 Ok(Some(negotiator))
1062 }
1063
1064 #[cfg(feature = "tls")]
1066 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
1067 where
1068 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1069 {
1070 let payload = prelogin.encode();
1071 let max_packet = MAX_PACKET_SIZE;
1072
1073 connection
1074 .send_message(PacketType::PreLogin, payload, max_packet)
1075 .await?;
1076 Ok(())
1077 }
1078
1079 #[cfg(feature = "tls")]
1081 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
1082 where
1083 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1084 {
1085 let message = connection
1086 .read_message()
1087 .await?
1088 .ok_or(Error::ConnectionClosed)?;
1089
1090 Ok(PreLogin::decode(&message.payload[..])?)
1091 }
1092
1093 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
1095 where
1096 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1097 {
1098 let payload = login.encode();
1099 let max_packet = MAX_PACKET_SIZE;
1100
1101 connection
1102 .send_message(PacketType::Tds7Login, payload, max_packet)
1103 .await?;
1104 Ok(())
1105 }
1106
1107 #[allow(clippy::never_loop)] async fn process_login_response<T>(
1118 connection: &mut Connection<T>,
1119 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
1120 &dyn mssql_auth::SspiNegotiator,
1121 >,
1122 ) -> Result<(
1123 Option<u32>,
1124 Option<String>,
1125 Option<(String, u16)>,
1126 Option<tds_protocol::token::Collation>,
1127 )>
1128 where
1129 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
1130 {
1131 let mut server_version = None;
1132 let mut database = None;
1133 let mut routing = None;
1134 let mut collation = None;
1135
1136 'outer: loop {
1137 let message = connection
1138 .read_message()
1139 .await?
1140 .ok_or(Error::ConnectionClosed)?;
1141
1142 let response_bytes = message.payload;
1143 let mut parser = TokenParser::new(response_bytes);
1144
1145 while let Some(token) = parser.next_token()? {
1146 match token {
1147 Token::LoginAck(ack) => {
1148 tracing::info!(
1149 version = ack.tds_version,
1150 interface = ack.interface,
1151 prog_name = %ack.prog_name,
1152 "login acknowledged"
1153 );
1154 server_version = Some(ack.tds_version);
1155 }
1156 Token::EnvChange(env) => {
1157 Self::process_env_change(&env, &mut database, &mut routing, &mut collation);
1158 }
1159 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1160 Token::Sspi(sspi_token) => {
1161 let neg = negotiator.ok_or_else(|| {
1162 Error::Protocol(
1163 "server sent SSPI challenge but no negotiator is configured"
1164 .to_string(),
1165 )
1166 })?;
1167
1168 tracing::debug!(
1169 challenge_len = sspi_token.data.len(),
1170 "received SSPI challenge from server"
1171 );
1172
1173 if let Some(response) = neg.step(&sspi_token.data)? {
1174 tracing::debug!(response_len = response.len(), "sending SSPI response");
1175 connection
1176 .send_message(
1177 PacketType::Sspi,
1178 bytes::Bytes::from(response),
1179 tds_protocol::packet::MAX_PACKET_SIZE,
1180 )
1181 .await?;
1182 }
1183
1184 continue 'outer;
1186 }
1187 Token::Error(err) => {
1188 return Err(Error::Server {
1189 number: err.number,
1190 state: err.state,
1191 class: err.class,
1192 message: err.message.clone(),
1193 server: if err.server.is_empty() {
1194 None
1195 } else {
1196 Some(err.server.clone())
1197 },
1198 procedure: if err.procedure.is_empty() {
1199 None
1200 } else {
1201 Some(err.procedure.clone())
1202 },
1203 line: err.line as u32,
1204 });
1205 }
1206 Token::Info(info) => {
1207 tracing::info!(
1208 number = info.number,
1209 message = %info.message,
1210 "server info message"
1211 );
1212 }
1213 Token::Done(done) => {
1214 if done.status.error {
1215 return Err(Error::Protocol("login failed".to_string()));
1216 }
1217 break 'outer;
1218 }
1219 _ => {}
1220 }
1221 }
1222
1223 break;
1225 }
1226
1227 Ok((server_version, database, routing, collation))
1228 }
1229
1230 fn process_env_change(
1232 env: &EnvChange,
1233 database: &mut Option<String>,
1234 routing: &mut Option<(String, u16)>,
1235 collation: &mut Option<tds_protocol::token::Collation>,
1236 ) {
1237 use tds_protocol::token::EnvChangeValue;
1238
1239 match env.env_type {
1240 EnvChangeType::Database => {
1241 if let EnvChangeValue::String(ref new_value) = env.new_value {
1242 tracing::debug!(database = %new_value, "database changed");
1243 *database = Some(new_value.clone());
1244 }
1245 }
1246 EnvChangeType::Routing => {
1247 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1248 tracing::info!(host = %host, port = port, "routing redirect received");
1249 *routing = Some((host.clone(), port));
1250 }
1251 }
1252 EnvChangeType::SqlCollation => {
1253 if let EnvChangeValue::Binary(ref data) = env.new_value {
1254 if data.len() >= 5 {
1255 let c = tds_protocol::token::Collation::from_bytes(
1256 data[..5].try_into().unwrap(),
1257 );
1258 tracing::debug!(
1259 lcid = c.lcid,
1260 sort_id = c.sort_id,
1261 "server collation received"
1262 );
1263 *collation = Some(c);
1264 }
1265 }
1266 }
1267 _ => {
1268 if let EnvChangeValue::String(ref new_value) = env.new_value {
1269 tracing::debug!(
1270 env_type = ?env.env_type,
1271 new_value = %new_value,
1272 "environment change"
1273 );
1274 }
1275 }
1276 }
1277 }
1278}