1use std::marker::PhantomData;
7
8use bytes::BytesMut;
9use mssql_codec::connection::Connection;
10#[cfg(feature = "tls")]
11use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode};
12use tds_protocol::login7::Login7;
13use tds_protocol::packet::MAX_PACKET_SIZE;
14use tds_protocol::packet::PacketType;
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::token::{EnvChange, EnvChangeType, Token, TokenParser};
17use tokio::net::TcpStream;
18use tokio::time::timeout;
19
20use crate::config::Config;
21use crate::error::{Error, Result};
22#[cfg(feature = "otel")]
23use crate::instrumentation::InstrumentationContext;
24use crate::state::{Disconnected, Ready};
25use crate::statement_cache::StatementCache;
26
27use super::{Client, ConnectionHandle};
28
29impl Client<Disconnected> {
30 pub async fn connect(config: Config) -> Result<Client<Ready>> {
41 let max_redirects = config.redirect.max_redirects;
42 let follow_redirects = config.redirect.follow_redirects;
43 let per_attempt = config.timeouts.connect_timeout
46 + config.timeouts.tls_timeout
47 + config.timeouts.login_timeout;
48 let overall = per_attempt * (max_redirects as u32 + 1);
49 let overall = overall.min(std::time::Duration::from_secs(300));
50 let mut attempts = 0;
51 let initial_host = config.host.clone();
52 let initial_port = config.port;
53 let mut current_config = config;
54
55 let result = timeout(overall, async {
56 loop {
57 attempts += 1;
58 if attempts > max_redirects + 1 {
59 return Err(Error::TooManyRedirects { max: max_redirects });
60 }
61
62 match Self::try_connect(¤t_config).await {
63 Ok(client) => return Ok(client),
64 Err(Error::Routing { host, port }) => {
65 if !follow_redirects {
66 return Err(Error::Routing { host, port });
67 }
68 tracing::info!(
69 host = %host,
70 port = port,
71 attempt = attempts,
72 max_redirects = max_redirects,
73 "following Azure SQL routing redirect"
74 );
75 current_config = current_config.with_host(&host).with_port(port);
76 continue;
77 }
78 Err(e) => return Err(e),
79 }
80 }
81 })
82 .await;
83
84 match result {
85 Ok(inner) => inner,
86 Err(_elapsed) => Err(Error::ConnectTimeout {
87 host: initial_host,
88 port: initial_port,
89 }),
90 }
91 }
92
93 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
94 tracing::info!(
95 host = %config.host,
96 port = config.port,
97 database = ?config.database,
98 "connecting to SQL Server"
99 );
100
101 let addr = format!("{}:{}", config.host, config.port);
102
103 tracing::debug!("establishing TCP connection to {}", addr);
105 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
106 .await
107 .map_err(|_| Error::ConnectTimeout {
108 host: config.host.clone(),
109 port: config.port,
110 })?
111 .map_err(Error::from)?;
112
113 tcp_stream.set_nodelay(true).map_err(Error::from)?;
115
116 #[cfg(feature = "tls")]
117 {
118 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
120
121 if tls_mode.is_tls_first() {
123 return Self::connect_tds_8(config, tcp_stream).await;
124 }
125
126 Self::connect_tds_7x(config, tcp_stream).await
128 }
129
130 #[cfg(not(feature = "tls"))]
131 {
132 if config.strict_mode {
134 return Err(Error::Config(
135 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
136 ));
137 }
138
139 if !config.no_tls {
140 return Err(Error::Config(
141 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
142 or use Encrypt=no_tls in your connection string for unencrypted connections."
143 .into(),
144 ));
145 }
146
147 Self::connect_no_tls(config, tcp_stream).await
149 }
150 }
151
152 #[cfg(feature = "tls")]
156 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
157 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
158
159 let tls_config = TlsConfig::new()
161 .strict_mode(true)
162 .trust_server_certificate(config.trust_server_certificate)
163 .with_alpn_protocols(vec![b"tds/8.0".to_vec()]);
164
165 let tls_connector = TlsConnector::new(tls_config)?;
166
167 let tls_stream = timeout(
169 config.timeouts.tls_timeout,
170 tls_connector.connect(tcp_stream, &config.host),
171 )
172 .await
173 .map_err(|_| Error::TlsTimeout {
174 host: config.host.clone(),
175 port: config.port,
176 })??;
177
178 tracing::debug!("TLS handshake completed (strict mode)");
179
180 let mut connection = Connection::new(tls_stream);
182
183 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
185 Self::send_prelogin(&mut connection, &prelogin).await?;
186 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
187
188 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
190 let negotiator = Self::create_negotiator(config)?;
191 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
192 let sspi_token = match negotiator {
193 Some(ref neg) => Some(neg.initialize()?),
194 None => None,
195 };
196 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
197 let sspi_token: Option<Vec<u8>> = None;
198
199 let login = Self::build_login7(config, sspi_token);
201 Self::send_login7(&mut connection, &login).await?;
202
203 let (server_version, current_database, routing) = timeout(
205 config.timeouts.login_timeout,
206 Self::process_login_response(
207 &mut connection,
208 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
209 negotiator.as_deref(),
210 ),
211 )
212 .await
213 .map_err(|_| Error::LoginTimeout {
214 host: config.host.clone(),
215 port: config.port,
216 })??;
217
218 if let Some((host, port)) = routing {
220 return Err(Error::Routing { host, port });
221 }
222
223 Ok(Client {
224 config: config.clone(),
225 _state: PhantomData,
226 connection: Some(ConnectionHandle::Tls(connection)),
227 server_version,
228 current_database: current_database.clone(),
229 statement_cache: StatementCache::with_default_size(),
230 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
233 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
234 .with_database(current_database.unwrap_or_default()),
235 })
236 }
237
238 #[cfg(feature = "tls")]
246 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
247 use bytes::BufMut;
248 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
249 use tokio::io::{AsyncReadExt, AsyncWriteExt};
250
251 tracing::debug!("using TDS 7.x flow (PreLogin first)");
252
253 let client_encryption = if config.no_tls {
256 tracing::warn!(
258 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
259 Credentials and data will be transmitted in plaintext. \
260 This should only be used for development/testing with legacy SQL Server."
261 );
262 EncryptionLevel::NotSupported
263 } else if config.encrypt {
264 EncryptionLevel::On
265 } else {
266 EncryptionLevel::Off
267 };
268 let prelogin = Self::build_prelogin(config, client_encryption);
269 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
270 let prelogin_bytes = prelogin.encode();
271
272 let header = PacketHeader::new(
274 PacketType::PreLogin,
275 PacketStatus::END_OF_MESSAGE,
276 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
277 );
278
279 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
280 header.encode(&mut packet_buf);
281 packet_buf.put_slice(&prelogin_bytes);
282
283 tcp_stream
284 .write_all(&packet_buf)
285 .await
286 .map_err(Error::from)?;
287
288 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
290 tcp_stream
291 .read_exact(&mut header_buf)
292 .await
293 .map_err(Error::from)?;
294
295 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
296 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
297
298 let mut response_buf = vec![0u8; payload_length];
299 tcp_stream
300 .read_exact(&mut response_buf)
301 .await
302 .map_err(Error::from)?;
303
304 let prelogin_response = PreLogin::decode(&response_buf[..])?;
305
306 let client_tds_version = config.tds_version;
311 if let Some(ref server_version) = prelogin_response.server_version {
312 tracing::debug!(
313 requested_tds_version = %client_tds_version,
314 server_product_version = %server_version,
315 server_product = server_version.product_name(),
316 max_tds_version = %server_version.max_tds_version(),
317 "PreLogin response received"
318 );
319
320 let server_max_tds = server_version.max_tds_version();
322 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
323 tracing::warn!(
324 requested_tds_version = %client_tds_version,
325 server_max_tds_version = %server_max_tds,
326 server_product = server_version.product_name(),
327 "Server supports lower TDS version than requested. \
328 Connection will use server's maximum: {}",
329 server_max_tds
330 );
331 }
332
333 if server_max_tds.is_legacy() {
335 tracing::warn!(
336 server_product = server_version.product_name(),
337 server_max_tds_version = %server_max_tds,
338 "Server uses legacy TDS version. Some features may not be available."
339 );
340 }
341 } else {
342 tracing::debug!(
343 requested_tds_version = %client_tds_version,
344 "PreLogin response received (no version info)"
345 );
346 }
347
348 let server_encryption = prelogin_response.encryption;
350 tracing::debug!(encryption = ?server_encryption, "server encryption level");
351
352 let negotiated_encryption = match (client_encryption, server_encryption) {
358 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
359 EncryptionLevel::NotSupported
360 }
361 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
362 (EncryptionLevel::On, EncryptionLevel::Off)
363 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
364 return Err(Error::Protocol(
365 "Server does not support requested encryption level".to_string(),
366 ));
367 }
368 _ => EncryptionLevel::On,
369 };
370
371 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
374
375 if use_tls {
376 let tls_config =
379 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
380
381 let tls_connector = TlsConnector::new(tls_config)?;
382
383 let mut tls_stream = timeout(
385 config.timeouts.tls_timeout,
386 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
387 )
388 .await
389 .map_err(|_| Error::TlsTimeout {
390 host: config.host.clone(),
391 port: config.port,
392 })??;
393
394 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
395
396 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
398
399 if login_only_encryption {
400 use tokio::io::AsyncWriteExt;
408
409 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
414 let negotiator = Self::create_negotiator(config)?;
415 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
416 let sspi_token = match negotiator {
417 Some(ref neg) => Some(neg.initialize()?),
418 None => None,
419 };
420 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
421 let sspi_token: Option<Vec<u8>> = None;
422
423 let login = Self::build_login7(config, sspi_token);
425 let login_payload = login.encode();
426
427 let max_packet = MAX_PACKET_SIZE;
429 let max_payload = max_packet - PACKET_HEADER_SIZE;
430 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
431 let total_chunks = chunks.len();
432
433 for (i, chunk) in chunks.into_iter().enumerate() {
434 let is_last = i == total_chunks - 1;
435 let status = if is_last {
436 PacketStatus::END_OF_MESSAGE
437 } else {
438 PacketStatus::NORMAL
439 };
440
441 let header = PacketHeader::new(
442 PacketType::Tds7Login,
443 status,
444 (PACKET_HEADER_SIZE + chunk.len()) as u16,
445 );
446
447 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
448 header.encode(&mut packet_buf);
449 packet_buf.put_slice(chunk);
450
451 tls_stream
452 .write_all(&packet_buf)
453 .await
454 .map_err(Error::from)?;
455 }
456
457 tls_stream.flush().await.map_err(Error::from)?;
459
460 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
461
462 let (wrapper, _client_conn) = tls_stream.into_inner();
466 let tcp_stream = wrapper.into_inner();
467
468 let mut connection = Connection::new(tcp_stream);
470
471 let (server_version, current_database, routing) = timeout(
473 config.timeouts.login_timeout,
474 Self::process_login_response(
475 &mut connection,
476 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
477 negotiator.as_deref(),
478 ),
479 )
480 .await
481 .map_err(|_| Error::LoginTimeout {
482 host: config.host.clone(),
483 port: config.port,
484 })??;
485
486 if let Some((host, port)) = routing {
488 return Err(Error::Routing { host, port });
489 }
490
491 Ok(Client {
493 config: config.clone(),
494 _state: PhantomData,
495 connection: Some(ConnectionHandle::Plain(connection)),
496 server_version,
497 current_database: current_database.clone(),
498 statement_cache: StatementCache::with_default_size(),
499 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
502 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
503 .with_database(current_database.unwrap_or_default()),
504 })
505 } else {
506 let mut connection = Connection::new(tls_stream);
509
510 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
512 let negotiator = Self::create_negotiator(config)?;
513 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
514 let sspi_token = match negotiator {
515 Some(ref neg) => Some(neg.initialize()?),
516 None => None,
517 };
518 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
519 let sspi_token: Option<Vec<u8>> = None;
520
521 let login = Self::build_login7(config, sspi_token);
523 Self::send_login7(&mut connection, &login).await?;
524
525 let (server_version, current_database, routing) = timeout(
527 config.timeouts.login_timeout,
528 Self::process_login_response(
529 &mut connection,
530 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
531 negotiator.as_deref(),
532 ),
533 )
534 .await
535 .map_err(|_| Error::LoginTimeout {
536 host: config.host.clone(),
537 port: config.port,
538 })??;
539
540 if let Some((host, port)) = routing {
542 return Err(Error::Routing { host, port });
543 }
544
545 Ok(Client {
546 config: config.clone(),
547 _state: PhantomData,
548 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
549 server_version,
550 current_database: current_database.clone(),
551 statement_cache: StatementCache::with_default_size(),
552 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
555 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
556 .with_database(current_database.unwrap_or_default()),
557 })
558 }
559 } else {
560 tracing::warn!(
562 "Connecting without TLS encryption. This is insecure and should only be \
563 used for development/testing on trusted networks."
564 );
565
566 let mut connection = Connection::new(tcp_stream);
567
568 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
570 let negotiator = Self::create_negotiator(config)?;
571 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
572 let sspi_token = match negotiator {
573 Some(ref neg) => Some(neg.initialize()?),
574 None => None,
575 };
576 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
577 let sspi_token: Option<Vec<u8>> = None;
578
579 let login = Self::build_login7(config, sspi_token);
581 Self::send_login7(&mut connection, &login).await?;
582
583 let (server_version, current_database, routing) = timeout(
585 config.timeouts.login_timeout,
586 Self::process_login_response(
587 &mut connection,
588 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
589 negotiator.as_deref(),
590 ),
591 )
592 .await
593 .map_err(|_| Error::LoginTimeout {
594 host: config.host.clone(),
595 port: config.port,
596 })??;
597
598 if let Some((host, port)) = routing {
600 return Err(Error::Routing { host, port });
601 }
602
603 Ok(Client {
604 config: config.clone(),
605 _state: PhantomData,
606 connection: Some(ConnectionHandle::Plain(connection)),
607 server_version,
608 current_database: current_database.clone(),
609 statement_cache: StatementCache::with_default_size(),
610 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
613 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
614 .with_database(current_database.unwrap_or_default()),
615 })
616 }
617 }
618
619 #[cfg(not(feature = "tls"))]
630 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
631 use bytes::BufMut;
632 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
633 use tokio::io::{AsyncReadExt, AsyncWriteExt};
634
635 tracing::warn!(
636 "⚠️ Connecting without TLS (tls feature disabled). \
637 Credentials and data will be transmitted in plaintext."
638 );
639
640 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
642 let prelogin_bytes = prelogin.encode();
643
644 let header = PacketHeader::new(
646 PacketType::PreLogin,
647 PacketStatus::END_OF_MESSAGE,
648 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
649 );
650
651 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
652 header.encode(&mut packet_buf);
653 packet_buf.put_slice(&prelogin_bytes);
654
655 tcp_stream
656 .write_all(&packet_buf)
657 .await
658 .map_err(Error::from)?;
659
660 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
662 tcp_stream
663 .read_exact(&mut header_buf)
664 .await
665 .map_err(Error::from)?;
666
667 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
668 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
669
670 let mut response_buf = vec![0u8; payload_length];
671 tcp_stream
672 .read_exact(&mut response_buf)
673 .await
674 .map_err(Error::from)?;
675
676 let prelogin_response = PreLogin::decode(&response_buf[..])?;
677
678 let server_encryption = prelogin_response.encryption;
680 if server_encryption != EncryptionLevel::NotSupported {
681 return Err(Error::Config(format!(
682 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
683 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
684 server_encryption
685 )));
686 }
687
688 tracing::debug!("Server accepted unencrypted connection");
689
690 let mut connection = Connection::new(tcp_stream);
691
692 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
694 let negotiator = Self::create_negotiator(config)?;
695 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
696 let sspi_token = match negotiator {
697 Some(ref neg) => Some(neg.initialize()?),
698 None => None,
699 };
700 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
701 let sspi_token: Option<Vec<u8>> = None;
702
703 let login = Self::build_login7(config, sspi_token);
705 Self::send_login7(&mut connection, &login).await?;
706
707 let (server_version, current_database, routing) = timeout(
709 config.timeouts.login_timeout,
710 Self::process_login_response(
711 &mut connection,
712 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
713 negotiator.as_deref(),
714 ),
715 )
716 .await
717 .map_err(|_| Error::LoginTimeout {
718 host: config.host.clone(),
719 port: config.port,
720 })??;
721
722 if let Some((host, port)) = routing {
724 return Err(Error::Routing { host, port });
725 }
726
727 Ok(Client {
728 config: config.clone(),
729 _state: PhantomData,
730 connection: Some(ConnectionHandle::Plain(connection)),
731 server_version,
732 current_database: current_database.clone(),
733 statement_cache: StatementCache::with_default_size(),
734 transaction_descriptor: 0,
735 needs_reset: false,
736 #[cfg(feature = "otel")]
737 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
738 .with_database(current_database.unwrap_or_default()),
739 })
740 }
741
742 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
744 let version = if config.strict_mode {
746 tds_protocol::version::TdsVersion::V8_0
747 } else {
748 config.tds_version
749 };
750
751 let mut prelogin = PreLogin::new()
752 .with_version(version)
753 .with_encryption(encryption);
754
755 if config.mars {
756 prelogin = prelogin.with_mars(true);
757 }
758
759 if let Some(ref instance) = config.instance {
760 prelogin = prelogin.with_instance(instance);
761 }
762
763 prelogin
764 }
765
766 fn build_login7(config: &Config, sspi_token: Option<Vec<u8>>) -> Login7 {
771 let version = if config.strict_mode {
773 tds_protocol::version::TdsVersion::V8_0
774 } else {
775 config.tds_version
776 };
777
778 let mut login = Login7::new()
779 .with_tds_version(version)
780 .with_packet_size(config.packet_size as u32)
781 .with_app_name(&config.application_name)
782 .with_server_name(&config.host)
783 .with_hostname(&config.host);
784
785 if let Some(ref database) = config.database {
786 login = login.with_database(database);
787 }
788
789 if let Some(token) = sspi_token {
791 login = login.with_integrated_auth(token);
793 } else if let mssql_auth::Credentials::SqlServer { username, password } =
794 &config.credentials
795 {
796 login = login.with_sql_auth(username.as_ref(), password.as_ref());
797 }
798
799 login
800 }
801
802 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
807 fn create_negotiator(config: &Config) -> Result<Option<Box<dyn mssql_auth::SspiNegotiator>>> {
808 #[allow(clippy::match_like_matches_macro)]
809 let is_integrated = match &config.credentials {
810 mssql_auth::Credentials::Integrated => true,
811 _ => false,
812 };
813
814 if !is_integrated {
815 return Ok(None);
816 }
817
818 #[cfg(feature = "sspi-auth")]
820 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
821 Box::new(mssql_auth::SspiAuth::new(&config.host, config.port)?);
822
823 #[cfg(all(feature = "integrated-auth", not(feature = "sspi-auth")))]
824 let negotiator: Box<dyn mssql_auth::SspiNegotiator> =
825 Box::new(mssql_auth::IntegratedAuth::new(&config.host, config.port));
826
827 Ok(Some(negotiator))
828 }
829
830 #[cfg(feature = "tls")]
832 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
833 where
834 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
835 {
836 let payload = prelogin.encode();
837 let max_packet = MAX_PACKET_SIZE;
838
839 connection
840 .send_message(PacketType::PreLogin, payload, max_packet)
841 .await?;
842 Ok(())
843 }
844
845 #[cfg(feature = "tls")]
847 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
848 where
849 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
850 {
851 let message = connection
852 .read_message()
853 .await?
854 .ok_or(Error::ConnectionClosed)?;
855
856 Ok(PreLogin::decode(&message.payload[..])?)
857 }
858
859 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
861 where
862 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
863 {
864 let payload = login.encode();
865 let max_packet = MAX_PACKET_SIZE;
866
867 connection
868 .send_message(PacketType::Tds7Login, payload, max_packet)
869 .await?;
870 Ok(())
871 }
872
873 #[allow(clippy::never_loop)] async fn process_login_response<T>(
884 connection: &mut Connection<T>,
885 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))] negotiator: Option<
886 &dyn mssql_auth::SspiNegotiator,
887 >,
888 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
889 where
890 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
891 {
892 let mut server_version = None;
893 let mut database = None;
894 let mut routing = None;
895
896 'outer: loop {
897 let message = connection
898 .read_message()
899 .await?
900 .ok_or(Error::ConnectionClosed)?;
901
902 let response_bytes = message.payload;
903 let mut parser = TokenParser::new(response_bytes);
904
905 while let Some(token) = parser.next_token()? {
906 match token {
907 Token::LoginAck(ack) => {
908 tracing::info!(
909 version = ack.tds_version,
910 interface = ack.interface,
911 prog_name = %ack.prog_name,
912 "login acknowledged"
913 );
914 server_version = Some(ack.tds_version);
915 }
916 Token::EnvChange(env) => {
917 Self::process_env_change(&env, &mut database, &mut routing);
918 }
919 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
920 Token::Sspi(sspi_token) => {
921 let neg = negotiator.ok_or_else(|| {
922 Error::Protocol(
923 "server sent SSPI challenge but no negotiator is configured"
924 .to_string(),
925 )
926 })?;
927
928 tracing::debug!(
929 challenge_len = sspi_token.data.len(),
930 "received SSPI challenge from server"
931 );
932
933 if let Some(response) = neg.step(&sspi_token.data)? {
934 tracing::debug!(response_len = response.len(), "sending SSPI response");
935 connection
936 .send_message(
937 PacketType::Sspi,
938 bytes::Bytes::from(response),
939 tds_protocol::packet::MAX_PACKET_SIZE,
940 )
941 .await?;
942 }
943
944 continue 'outer;
946 }
947 Token::Error(err) => {
948 return Err(Error::Server {
949 number: err.number,
950 state: err.state,
951 class: err.class,
952 message: err.message.clone(),
953 server: if err.server.is_empty() {
954 None
955 } else {
956 Some(err.server.clone())
957 },
958 procedure: if err.procedure.is_empty() {
959 None
960 } else {
961 Some(err.procedure.clone())
962 },
963 line: err.line as u32,
964 });
965 }
966 Token::Info(info) => {
967 tracing::info!(
968 number = info.number,
969 message = %info.message,
970 "server info message"
971 );
972 }
973 Token::Done(done) => {
974 if done.status.error {
975 return Err(Error::Protocol("login failed".to_string()));
976 }
977 break 'outer;
978 }
979 _ => {}
980 }
981 }
982
983 break;
985 }
986
987 Ok((server_version, database, routing))
988 }
989
990 fn process_env_change(
992 env: &EnvChange,
993 database: &mut Option<String>,
994 routing: &mut Option<(String, u16)>,
995 ) {
996 use tds_protocol::token::EnvChangeValue;
997
998 match env.env_type {
999 EnvChangeType::Database => {
1000 if let EnvChangeValue::String(ref new_value) = env.new_value {
1001 tracing::debug!(database = %new_value, "database changed");
1002 *database = Some(new_value.clone());
1003 }
1004 }
1005 EnvChangeType::Routing => {
1006 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1007 tracing::info!(host = %host, port = port, "routing redirect received");
1008 *routing = Some((host.clone(), port));
1009 }
1010 }
1011 _ => {
1012 if let EnvChangeValue::String(ref new_value) = env.new_value {
1013 tracing::debug!(
1014 env_type = ?env.env_type,
1015 new_value = %new_value,
1016 "environment change"
1017 );
1018 }
1019 }
1020 }
1021 }
1022}