1use std::marker::PhantomData;
7
8use bytes::BytesMut;
9use mssql_codec::connection::Connection;
10#[cfg(feature = "tls")]
11use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
12use tds_protocol::login7::Login7;
13use tds_protocol::packet::MAX_PACKET_SIZE;
14use tds_protocol::packet::PacketType;
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
17use tokio::net::TcpStream;
18use tokio::time::timeout;
19
20use crate::config::Config;
21use crate::error::{Error, Result};
22#[cfg(feature = "otel")]
23use crate::instrumentation::InstrumentationContext;
24use crate::state::{Disconnected, Ready};
25use crate::statement_cache::StatementCache;
26
27use super::{Client, ConnectionHandle};
28
29impl Client<Disconnected> {
30 pub async fn connect(config: Config) -> Result<Client<Ready>> {
41 let max_redirects = config.redirect.max_redirects;
42 let follow_redirects = config.redirect.follow_redirects;
43 let per_attempt = config.timeouts.connect_timeout
46 + config.timeouts.tls_timeout
47 + config.timeouts.login_timeout;
48 let overall = per_attempt * (max_redirects as u32 + 1);
49 let overall = overall.min(std::time::Duration::from_secs(300));
50 let mut attempts = 0;
51 let initial_host = config.host.clone();
52 let initial_port = config.port;
53 let mut current_config = config;
54
55 let result = timeout(overall, async {
56 loop {
57 attempts += 1;
58 if attempts > max_redirects + 1 {
59 return Err(Error::TooManyRedirects { max: max_redirects });
60 }
61
62 match Self::try_connect(¤t_config).await {
63 Ok(client) => return Ok(client),
64 Err(Error::Routing { host, port }) => {
65 if !follow_redirects {
66 return Err(Error::Routing { host, port });
67 }
68 tracing::info!(
69 host = %host,
70 port = port,
71 attempt = attempts,
72 max_redirects = max_redirects,
73 "following Azure SQL routing redirect"
74 );
75 current_config = current_config.with_host(&host).with_port(port);
76 continue;
77 }
78 Err(e) => return Err(e),
79 }
80 }
81 })
82 .await;
83
84 match result {
85 Ok(inner) => inner,
86 Err(_elapsed) => Err(Error::ConnectTimeout {
87 host: initial_host,
88 port: initial_port,
89 }),
90 }
91 }
92
93 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
94 let port = if let Some(ref instance) = config.instance {
96 let resolved = crate::browser::resolve_instance(
97 &config.host,
98 instance,
99 Some(config.timeouts.connect_timeout),
100 )
101 .await?;
102 tracing::info!(
103 host = %config.host,
104 instance = %instance,
105 resolved_port = resolved,
106 database = ?config.database,
107 "connecting to named SQL Server instance"
108 );
109 resolved
110 } else {
111 tracing::info!(
112 host = %config.host,
113 port = config.port,
114 database = ?config.database,
115 "connecting to SQL Server"
116 );
117 config.port
118 };
119
120 let host = if config.host == "." || config.host.eq_ignore_ascii_case("(local)") {
123 "127.0.0.1"
124 } else {
125 &config.host
126 };
127
128 let addr = format!("{host}:{port}");
129
130 tracing::debug!("establishing TCP connection to {}", addr);
132 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
133 .await
134 .map_err(|_| Error::ConnectTimeout {
135 host: config.host.clone(),
136 port: config.port,
137 })?
138 .map_err(Error::from)?;
139
140 tcp_stream.set_nodelay(true).map_err(Error::from)?;
142
143 #[cfg(feature = "tls")]
144 {
145 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
147
148 if tls_mode.is_tls_first() {
150 return Self::connect_tds_8(config, tcp_stream).await;
151 }
152
153 Self::connect_tds_7x(config, tcp_stream).await
155 }
156
157 #[cfg(not(feature = "tls"))]
158 {
159 if config.strict_mode {
161 return Err(Error::Config(
162 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
163 ));
164 }
165
166 if !config.no_tls {
167 return Err(Error::Config(
168 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
169 or use Encrypt=no_tls in your connection string for unencrypted connections."
170 .into(),
171 ));
172 }
173
174 Self::connect_no_tls(config, tcp_stream).await
176 }
177 }
178
179 #[cfg(feature = "tls")]
183 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
184 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
185
186 let tls_config = TlsConfig::new()
188 .strict_mode(true)
189 .trust_server_certificate(config.trust_server_certificate)
190 .with_alpn_protocols(vec![b"tds/8.0".to_vec()]);
191
192 let tls_connector = TlsConnector::new(tls_config)?;
193
194 let tls_stream = timeout(
196 config.timeouts.tls_timeout,
197 tls_connector.connect(tcp_stream, &config.host),
198 )
199 .await
200 .map_err(|_| Error::TlsTimeout {
201 host: config.host.clone(),
202 port: config.port,
203 })??;
204
205 tracing::debug!("TLS handshake completed (strict mode)");
206
207 let mut connection = Connection::new(tls_stream);
209
210 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
212 Self::send_prelogin(&mut connection, &prelogin).await?;
213 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
214
215 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
217 let negotiator = Self::create_negotiator(config)?;
218 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
219 let sspi_token = match negotiator {
220 Some(ref neg) => Some(neg.initialize()?),
221 None => None,
222 };
223 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
224 let sspi_token: Option<Vec<u8>> = None;
225
226 let login = Self::build_login7(config, sspi_token);
228 Self::send_login7(&mut connection, &login).await?;
229
230 let (server_version, current_database, routing) = timeout(
232 config.timeouts.login_timeout,
233 Self::process_login_response(
234 &mut connection,
235 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
236 negotiator.as_deref(),
237 ),
238 )
239 .await
240 .map_err(|_| Error::LoginTimeout {
241 host: config.host.clone(),
242 port: config.port,
243 })??;
244
245 if let Some((host, port)) = routing {
247 return Err(Error::Routing { host, port });
248 }
249
250 Ok(Client {
251 config: config.clone(),
252 _state: PhantomData,
253 connection: Some(ConnectionHandle::Tls(connection)),
254 server_version,
255 current_database: current_database.clone(),
256 statement_cache: StatementCache::with_default_size(),
257 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
260 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
261 .with_database(current_database.clone().unwrap_or_default()),
262 #[cfg(feature = "always-encrypted")]
263 encryption_context: config.column_encryption.clone().map(|cfg| {
264 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
265 }),
266 })
267 }
268
269 #[cfg(feature = "tls")]
277 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
278 use bytes::BufMut;
279 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
280 use tokio::io::{AsyncReadExt, AsyncWriteExt};
281
282 tracing::debug!("using TDS 7.x flow (PreLogin first)");
283
284 let client_encryption = if config.no_tls {
287 tracing::warn!(
289 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
290 Credentials and data will be transmitted in plaintext. \
291 This should only be used for development/testing with legacy SQL Server."
292 );
293 EncryptionLevel::NotSupported
294 } else if config.encrypt {
295 EncryptionLevel::On
296 } else {
297 EncryptionLevel::Off
298 };
299 let prelogin = Self::build_prelogin(config, client_encryption);
300 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
301 let prelogin_bytes = prelogin.encode();
302
303 let header = PacketHeader::new(
305 PacketType::PreLogin,
306 PacketStatus::END_OF_MESSAGE,
307 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
308 );
309
310 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
311 header.encode(&mut packet_buf);
312 packet_buf.put_slice(&prelogin_bytes);
313
314 tcp_stream
315 .write_all(&packet_buf)
316 .await
317 .map_err(Error::from)?;
318
319 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
321 tcp_stream
322 .read_exact(&mut header_buf)
323 .await
324 .map_err(Error::from)?;
325
326 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
327 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
328
329 let mut response_buf = vec![0u8; payload_length];
330 tcp_stream
331 .read_exact(&mut response_buf)
332 .await
333 .map_err(Error::from)?;
334
335 let prelogin_response = PreLogin::decode(&response_buf[..])?;
336
337 let client_tds_version = config.tds_version;
342 if let Some(ref server_version) = prelogin_response.server_version {
343 tracing::debug!(
344 requested_tds_version = %client_tds_version,
345 server_product_version = %server_version,
346 server_product = server_version.product_name(),
347 max_tds_version = %server_version.max_tds_version(),
348 "PreLogin response received"
349 );
350
351 let server_max_tds = server_version.max_tds_version();
353 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
354 tracing::warn!(
355 requested_tds_version = %client_tds_version,
356 server_max_tds_version = %server_max_tds,
357 server_product = server_version.product_name(),
358 "Server supports lower TDS version than requested. \
359 Connection will use server's maximum: {}",
360 server_max_tds
361 );
362 }
363
364 if server_max_tds.is_legacy() {
366 tracing::warn!(
367 server_product = server_version.product_name(),
368 server_max_tds_version = %server_max_tds,
369 "Server uses legacy TDS version. Some features may not be available."
370 );
371 }
372 } else {
373 tracing::debug!(
374 requested_tds_version = %client_tds_version,
375 "PreLogin response received (no version info)"
376 );
377 }
378
379 let server_encryption = prelogin_response.encryption;
381 tracing::debug!(encryption = ?server_encryption, "server encryption level");
382
383 let negotiated_encryption = match (client_encryption, server_encryption) {
389 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
390 EncryptionLevel::NotSupported
391 }
392 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
393 (EncryptionLevel::On, EncryptionLevel::Off)
394 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
395 return Err(Error::Protocol(
396 "Server does not support requested encryption level".to_string(),
397 ));
398 }
399 _ => EncryptionLevel::On,
400 };
401
402 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
405
406 if use_tls {
407 let tls_config =
410 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
411
412 let tls_connector = TlsConnector::new(tls_config)?;
413
414 let mut tls_stream = timeout(
416 config.timeouts.tls_timeout,
417 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
418 )
419 .await
420 .map_err(|_| Error::TlsTimeout {
421 host: config.host.clone(),
422 port: config.port,
423 })??;
424
425 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
426
427 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
429
430 if login_only_encryption {
431 use tokio::io::AsyncWriteExt;
439
440 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
445 let negotiator = Self::create_negotiator(config)?;
446 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
447 let sspi_token = match negotiator {
448 Some(ref neg) => Some(neg.initialize()?),
449 None => None,
450 };
451 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
452 let sspi_token: Option<Vec<u8>> = None;
453
454 let login = Self::build_login7(config, sspi_token);
456 let login_payload = login.encode();
457
458 let max_packet = MAX_PACKET_SIZE;
460 let max_payload = max_packet - PACKET_HEADER_SIZE;
461 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
462 let total_chunks = chunks.len();
463
464 for (i, chunk) in chunks.into_iter().enumerate() {
465 let is_last = i == total_chunks - 1;
466 let status = if is_last {
467 PacketStatus::END_OF_MESSAGE
468 } else {
469 PacketStatus::NORMAL
470 };
471
472 let header = PacketHeader::new(
473 PacketType::Tds7Login,
474 status,
475 (PACKET_HEADER_SIZE + chunk.len()) as u16,
476 );
477
478 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
479 header.encode(&mut packet_buf);
480 packet_buf.put_slice(chunk);
481
482 tls_stream
483 .write_all(&packet_buf)
484 .await
485 .map_err(Error::from)?;
486 }
487
488 tls_stream.flush().await.map_err(Error::from)?;
490
491 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
492
493 let (wrapper, _client_conn) = tls_stream.into_inner();
497 let tcp_stream = wrapper.into_inner();
498
499 let mut connection = Connection::new(tcp_stream);
501
502 let (server_version, current_database, routing) = timeout(
504 config.timeouts.login_timeout,
505 Self::process_login_response(
506 &mut connection,
507 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
508 negotiator.as_deref(),
509 ),
510 )
511 .await
512 .map_err(|_| Error::LoginTimeout {
513 host: config.host.clone(),
514 port: config.port,
515 })??;
516
517 if let Some((host, port)) = routing {
519 return Err(Error::Routing { host, port });
520 }
521
522 Ok(Client {
524 config: config.clone(),
525 _state: PhantomData,
526 connection: Some(ConnectionHandle::Plain(connection)),
527 server_version,
528 current_database: current_database.clone(),
529 statement_cache: StatementCache::with_default_size(),
530 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
533 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
534 .with_database(current_database.clone().unwrap_or_default()),
535 #[cfg(feature = "always-encrypted")]
536 encryption_context: config.column_encryption.clone().map(|cfg| {
537 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
538 }),
539 })
540 } else {
541 let mut connection = Connection::new(tls_stream);
544
545 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
547 let negotiator = Self::create_negotiator(config)?;
548 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
549 let sspi_token = match negotiator {
550 Some(ref neg) => Some(neg.initialize()?),
551 None => None,
552 };
553 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
554 let sspi_token: Option<Vec<u8>> = None;
555
556 let login = Self::build_login7(config, sspi_token);
558 Self::send_login7(&mut connection, &login).await?;
559
560 let (server_version, current_database, routing) = timeout(
562 config.timeouts.login_timeout,
563 Self::process_login_response(
564 &mut connection,
565 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
566 negotiator.as_deref(),
567 ),
568 )
569 .await
570 .map_err(|_| Error::LoginTimeout {
571 host: config.host.clone(),
572 port: config.port,
573 })??;
574
575 if let Some((host, port)) = routing {
577 return Err(Error::Routing { host, port });
578 }
579
580 Ok(Client {
581 config: config.clone(),
582 _state: PhantomData,
583 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
584 server_version,
585 current_database: current_database.clone(),
586 statement_cache: StatementCache::with_default_size(),
587 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
590 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
591 .with_database(current_database.clone().unwrap_or_default()),
592 #[cfg(feature = "always-encrypted")]
593 encryption_context: config.column_encryption.clone().map(|cfg| {
594 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
595 }),
596 })
597 }
598 } else {
599 tracing::warn!(
601 "Connecting without TLS encryption. This is insecure and should only be \
602 used for development/testing on trusted networks."
603 );
604
605 let mut connection = Connection::new(tcp_stream);
606
607 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
609 let negotiator = Self::create_negotiator(config)?;
610 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
611 let sspi_token = match negotiator {
612 Some(ref neg) => Some(neg.initialize()?),
613 None => None,
614 };
615 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
616 let sspi_token: Option<Vec<u8>> = None;
617
618 let login = Self::build_login7(config, sspi_token);
620 Self::send_login7(&mut connection, &login).await?;
621
622 let (server_version, current_database, routing) = timeout(
624 config.timeouts.login_timeout,
625 Self::process_login_response(
626 &mut connection,
627 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
628 negotiator.as_deref(),
629 ),
630 )
631 .await
632 .map_err(|_| Error::LoginTimeout {
633 host: config.host.clone(),
634 port: config.port,
635 })??;
636
637 if let Some((host, port)) = routing {
639 return Err(Error::Routing { host, port });
640 }
641
642 Ok(Client {
643 config: config.clone(),
644 _state: PhantomData,
645 connection: Some(ConnectionHandle::Plain(connection)),
646 server_version,
647 current_database: current_database.clone(),
648 statement_cache: StatementCache::with_default_size(),
649 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
652 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
653 .with_database(current_database.clone().unwrap_or_default()),
654 #[cfg(feature = "always-encrypted")]
655 encryption_context: config.column_encryption.clone().map(|cfg| {
656 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
657 }),
658 })
659 }
660 }
661
662 #[cfg(not(feature = "tls"))]
673 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
674 use bytes::BufMut;
675 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
676 use tokio::io::{AsyncReadExt, AsyncWriteExt};
677
678 tracing::warn!(
679 "⚠️ Connecting without TLS (tls feature disabled). \
680 Credentials and data will be transmitted in plaintext."
681 );
682
683 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
685 let prelogin_bytes = prelogin.encode();
686
687 let header = PacketHeader::new(
689 PacketType::PreLogin,
690 PacketStatus::END_OF_MESSAGE,
691 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
692 );
693
694 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
695 header.encode(&mut packet_buf);
696 packet_buf.put_slice(&prelogin_bytes);
697
698 tcp_stream
699 .write_all(&packet_buf)
700 .await
701 .map_err(Error::from)?;
702
703 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
705 tcp_stream
706 .read_exact(&mut header_buf)
707 .await
708 .map_err(Error::from)?;
709
710 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
711 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
712
713 let mut response_buf = vec![0u8; payload_length];
714 tcp_stream
715 .read_exact(&mut response_buf)
716 .await
717 .map_err(Error::from)?;
718
719 let prelogin_response = PreLogin::decode(&response_buf[..])?;
720
721 let server_encryption = prelogin_response.encryption;
723 if server_encryption != EncryptionLevel::NotSupported {
724 return Err(Error::Config(format!(
725 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
726 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
727 server_encryption
728 )));
729 }
730
731 tracing::debug!("Server accepted unencrypted connection");
732
733 let mut connection = Connection::new(tcp_stream);
734
735 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
737 let negotiator = Self::create_negotiator(config)?;
738 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
739 let sspi_token = match negotiator {
740 Some(ref neg) => Some(neg.initialize()?),
741 None => None,
742 };
743 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
744 let sspi_token: Option<Vec<u8>> = None;
745
746 let login = Self::build_login7(config, sspi_token);
748 Self::send_login7(&mut connection, &login).await?;
749
750 let (server_version, current_database, routing) = timeout(
752 config.timeouts.login_timeout,
753 Self::process_login_response(
754 &mut connection,
755 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
756 negotiator.as_deref(),
757 ),
758 )
759 .await
760 .map_err(|_| Error::LoginTimeout {
761 host: config.host.clone(),
762 port: config.port,
763 })??;
764
765 if let Some((host, port)) = routing {
767 return Err(Error::Routing { host, port });
768 }
769
770 Ok(Client {
771 config: config.clone(),
772 _state: PhantomData,
773 connection: Some(ConnectionHandle::Plain(connection)),
774 server_version,
775 current_database: current_database.clone(),
776 statement_cache: StatementCache::with_default_size(),
777 transaction_descriptor: 0,
778 needs_reset: false,
779 #[cfg(feature = "otel")]
780 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
781 .with_database(current_database.clone().unwrap_or_default()),
782 #[cfg(feature = "always-encrypted")]
783 encryption_context: config.column_encryption.clone().map(|cfg| {
784 std::sync::Arc::new(crate::encryption::EncryptionContext::from_arc(cfg))
785 }),
786 })
787 }
788
789 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
791 let version = if config.strict_mode {
793 tds_protocol::version::TdsVersion::V8_0
794 } else {
795 config.tds_version
796 };
797
798 let mut prelogin = PreLogin::new()
799 .with_version(version)
800 .with_encryption(encryption);
801
802 if config.mars {
803 prelogin = prelogin.with_mars(true);
804 }
805
806 if let Some(ref instance) = config.instance {
807 prelogin = prelogin.with_instance(instance);
808 }
809
810 prelogin
811 }
812
813 fn resolve_workstation_id(config: &Config) -> String {
821 if let Some(ref id) = config.workstation_id {
822 return id.clone();
823 }
824 std::env::var("COMPUTERNAME")
827 .or_else(|_| std::env::var("HOSTNAME"))
828 .unwrap_or_default()
829 }
830
831 fn build_login7(config: &Config, sspi_token: Option<Vec<u8>>) -> Login7 {
836 let version = if config.strict_mode {
838 tds_protocol::version::TdsVersion::V8_0
839 } else {
840 config.tds_version
841 };
842
843 let mut login = Login7::new()
844 .with_tds_version(version)
845 .with_packet_size(config.packet_size as u32)
846 .with_app_name(&config.application_name)
847 .with_server_name(&config.host)
848 .with_hostname(Self::resolve_workstation_id(config));
849
850 if let Some(ref database) = config.database {
851 login = login.with_database(database);
852 }
853
854 if config.application_intent == crate::config::ApplicationIntent::ReadOnly {
856 login = login.with_read_only_intent(true);
857 }
858
859 if let Some(ref lang) = config.language {
861 login = login.with_language(lang);
862 }
863
864 if let Some(token) = sspi_token {
866 login = login.with_integrated_auth(token);
868 } else if let mssql_auth::Credentials::SqlServer { username, password } =
869 &config.credentials
870 {
871 login = login.with_sql_auth(username.as_ref(), password.as_ref());
872 }
873
874 #[cfg(feature = "always-encrypted")]
877 if config.column_encryption.is_some() {
878 login = login.with_feature(tds_protocol::login7::FeatureExtension {
879 feature_id: tds_protocol::login7::FeatureId::ColumnEncryption,
880 data: bytes::Bytes::from_static(&[0x01]), });
882 tracing::debug!("Login7: adding ColumnEncryption feature extension (version 1)");
883 }
884
885 login
886 }
887
888 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
898 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
899 #[allow(clippy::match_like_matches_macro)]
900 let is_integrated = match &config.credentials {
901 mssql_auth::Credentials::Integrated => true,
902 _ => false,
903 };
904
905 if !is_integrated {
906 return Ok(None);
907 }
908
909 #[cfg(all(windows, feature = "sspi-auth"))]
914 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
915 Box::new(mssql_auth::NativeSspiAuth::new(&config.host, config.port)?);
916
917 #[cfg(all(not(windows), feature = "sspi-auth"))]
919 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
920 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
921
922 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
923 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
924 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
925
926 Ok(Some(negotiator))
927 }
928
929 #[cfg(feature = "tls")]
931 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
932 where
933 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
934 {
935 let payload = prelogin.encode();
936 let max_packet = MAX_PACKET_SIZE;
937
938 connection
939 .send_message(PacketType::PreLogin, payload, max_packet)
940 .await?;
941 Ok(())
942 }
943
944 #[cfg(feature = "tls")]
946 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
947 where
948 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
949 {
950 let message = connection
951 .read_message()
952 .await?
953 .ok_or(Error::ConnectionClosed)?;
954
955 Ok(PreLogin::decode(&message.payload[..])?)
956 }
957
958 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
960 where
961 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
962 {
963 let payload = login.encode();
964 let max_packet = MAX_PACKET_SIZE;
965
966 connection
967 .send_message(PacketType::Tds7Login, payload, max_packet)
968 .await?;
969 Ok(())
970 }
971
972 #[allow(clippy::never_loop)] async fn process_login_response<T>(
983 connection: &mut Connection<T>,
984 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
985 &dyn mssql_auth::SspiNegotiator,
986 >,
987 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
988 where
989 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
990 {
991 let mut server_version = None;
992 let mut database = None;
993 let mut routing = None;
994
995 'outer: loop {
996 let message = connection
997 .read_message()
998 .await?
999 .ok_or(Error::ConnectionClosed)?;
1000
1001 let response_bytes = message.payload;
1002 let mut parser = TokenParser::new(response_bytes);
1003
1004 while let Some(token) = parser.next_token()? {
1005 match token {
1006 Token::LoginAck(ack) => {
1007 tracing::info!(
1008 version = ack.tds_version,
1009 interface = ack.interface,
1010 prog_name = %ack.prog_name,
1011 "login acknowledged"
1012 );
1013 server_version = Some(ack.tds_version);
1014 }
1015 Token::EnvChange(env) => {
1016 Self::process_env_change(&env, &mut database, &mut routing);
1017 }
1018 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
1019 Token::Sspi(sspi_token) => {
1020 let neg = negotiator.ok_or_else(|| {
1021 Error::Protocol(
1022 "server sent SSPI challenge but no negotiator is configured"
1023 .to_string(),
1024 )
1025 })?;
1026
1027 tracing::debug!(
1028 challenge_len = sspi_token.data.len(),
1029 "received SSPI challenge from server"
1030 );
1031
1032 if let Some(response) = neg.step(&sspi_token.data)? {
1033 tracing::debug!(response_len = response.len(), "sending SSPI response");
1034 connection
1035 .send_message(
1036 PacketType::Sspi,
1037 bytes::Bytes::from(response),
1038 tds_protocol::packet::MAX_PACKET_SIZE,
1039 )
1040 .await?;
1041 }
1042
1043 continue 'outer;
1045 }
1046 Token::Error(err) => {
1047 return Err(Error::Server {
1048 number: err.number,
1049 state: err.state,
1050 class: err.class,
1051 message: err.message.clone(),
1052 server: if err.server.is_empty() {
1053 None
1054 } else {
1055 Some(err.server.clone())
1056 },
1057 procedure: if err.procedure.is_empty() {
1058 None
1059 } else {
1060 Some(err.procedure.clone())
1061 },
1062 line: err.line as u32,
1063 });
1064 }
1065 Token::Info(info) => {
1066 tracing::info!(
1067 number = info.number,
1068 message = %info.message,
1069 "server info message"
1070 );
1071 }
1072 Token::Done(done) => {
1073 if done.status.error {
1074 return Err(Error::Protocol("login failed".to_string()));
1075 }
1076 break 'outer;
1077 }
1078 _ => {}
1079 }
1080 }
1081
1082 break;
1084 }
1085
1086 Ok((server_version, database, routing))
1087 }
1088
1089 fn process_env_change(
1091 env: &EnvChange,
1092 database: &mut Option<String>,
1093 routing: &mut Option<(String, u16)>,
1094 ) {
1095 use tds_protocol::token::EnvChangeValue;
1096
1097 match env.env_type {
1098 EnvChangeType::Database => {
1099 if let EnvChangeValue::String(ref new_value) = env.new_value {
1100 tracing::debug!(database = %new_value, "database changed");
1101 *database = Some(new_value.clone());
1102 }
1103 }
1104 EnvChangeType::Routing => {
1105 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1106 tracing::info!(host = %host, port = port, "routing redirect received");
1107 *routing = Some((host.clone(), port));
1108 }
1109 }
1110 _ => {
1111 if let EnvChangeValue::String(ref new_value) = env.new_value {
1112 tracing::debug!(
1113 env_type = ?env.env_type,
1114 new_value = %new_value,
1115 "environment change"
1116 );
1117 }
1118 }
1119 }
1120 }
1121}