1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12#[cfg(feature = "tls")]
13use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
14use tds_protocol::login7::Login7;
15#[cfg(feature = "tls")]
16use tds_protocol::packet::MAX_PACKET_SIZE;
17use tds_protocol::packet::PacketType;
18use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
19use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
20use tds_protocol::token::{
21 ColMetaData, Collation, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token,
22 TokenParser,
23};
24#[cfg(feature = "decimal")]
25use tds_protocol::tvp::encode_tvp_decimal;
26use tds_protocol::tvp::{
27 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
28 encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar, encode_tvp_varbinary,
29};
30use tokio::net::TcpStream;
31use tokio::time::timeout;
32
33use crate::config::Config;
34use crate::error::{Error, Result};
35#[cfg(feature = "otel")]
36use crate::instrumentation::InstrumentationContext;
37use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
38use crate::statement_cache::StatementCache;
39use crate::stream::{MultiResultStream, QueryStream};
40use crate::transaction::SavePoint;
41
42pub struct Client<S: ConnectionState> {
48 config: Config,
49 _state: PhantomData<S>,
50 connection: Option<ConnectionHandle>,
52 server_version: Option<u32>,
54 current_database: Option<String>,
56 statement_cache: StatementCache,
58 transaction_descriptor: u64,
62 needs_reset: bool,
66 #[cfg(feature = "otel")]
68 instrumentation: InstrumentationContext,
69}
70
71#[allow(dead_code)] enum ConnectionHandle {
79 #[cfg(feature = "tls")]
81 Tls(Connection<TlsStream<TcpStream>>),
82 #[cfg(feature = "tls")]
84 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
85 Plain(Connection<TcpStream>),
87}
88
89impl Client<Disconnected> {
90 pub async fn connect(config: Config) -> Result<Client<Ready>> {
101 let max_redirects = config.redirect.max_redirects;
102 let follow_redirects = config.redirect.follow_redirects;
103 let mut attempts = 0;
104 let mut current_config = config;
105
106 loop {
107 attempts += 1;
108 if attempts > max_redirects + 1 {
109 return Err(Error::TooManyRedirects { max: max_redirects });
110 }
111
112 match Self::try_connect(¤t_config).await {
113 Ok(client) => return Ok(client),
114 Err(Error::Routing { host, port }) => {
115 if !follow_redirects {
116 return Err(Error::Routing { host, port });
117 }
118 tracing::info!(
119 host = %host,
120 port = port,
121 attempt = attempts,
122 max_redirects = max_redirects,
123 "following Azure SQL routing redirect"
124 );
125 current_config = current_config.with_host(&host).with_port(port);
126 continue;
127 }
128 Err(e) => return Err(e),
129 }
130 }
131 }
132
133 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
134 tracing::info!(
135 host = %config.host,
136 port = config.port,
137 database = ?config.database,
138 "connecting to SQL Server"
139 );
140
141 let addr = format!("{}:{}", config.host, config.port);
142
143 tracing::debug!("establishing TCP connection to {}", addr);
145 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
146 .await
147 .map_err(|_| Error::ConnectTimeout)?
148 .map_err(|e| Error::Io(Arc::new(e)))?;
149
150 tcp_stream
152 .set_nodelay(true)
153 .map_err(|e| Error::Io(Arc::new(e)))?;
154
155 #[cfg(feature = "tls")]
156 {
157 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
159
160 if tls_mode.is_tls_first() {
162 return Self::connect_tds_8(config, tcp_stream).await;
163 }
164
165 Self::connect_tds_7x(config, tcp_stream).await
167 }
168
169 #[cfg(not(feature = "tls"))]
170 {
171 if config.strict_mode {
173 return Err(Error::Config(
174 "TDS 8.0 strict mode requires TLS. Enable the 'tls' feature or use Encrypt=no_tls".into()
175 ));
176 }
177
178 if !config.no_tls {
179 return Err(Error::Config(
180 "TLS encryption requires the 'tls' feature. Either enable the 'tls' feature \
181 or use Encrypt=no_tls in your connection string for unencrypted connections."
182 .into(),
183 ));
184 }
185
186 Self::connect_no_tls(config, tcp_stream).await
188 }
189 }
190
191 #[cfg(feature = "tls")]
195 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
196 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
197
198 let tls_config = TlsConfig::new()
200 .strict_mode(true)
201 .trust_server_certificate(config.trust_server_certificate);
202
203 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
204
205 let tls_stream = timeout(
207 config.timeouts.tls_timeout,
208 tls_connector.connect(tcp_stream, &config.host),
209 )
210 .await
211 .map_err(|_| Error::TlsTimeout)?
212 .map_err(|e| Error::Tls(e.to_string()))?;
213
214 tracing::debug!("TLS handshake completed (strict mode)");
215
216 let mut connection = Connection::new(tls_stream);
218
219 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
221 Self::send_prelogin(&mut connection, &prelogin).await?;
222 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
223
224 let login = Self::build_login7(config);
226 Self::send_login7(&mut connection, &login).await?;
227
228 let (server_version, current_database, routing) =
230 Self::process_login_response(&mut connection).await?;
231
232 if let Some((host, port)) = routing {
234 return Err(Error::Routing { host, port });
235 }
236
237 Ok(Client {
238 config: config.clone(),
239 _state: PhantomData,
240 connection: Some(ConnectionHandle::Tls(connection)),
241 server_version,
242 current_database: current_database.clone(),
243 statement_cache: StatementCache::with_default_size(),
244 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
247 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
248 .with_database(current_database.unwrap_or_default()),
249 })
250 }
251
252 #[cfg(feature = "tls")]
260 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
261 use bytes::BufMut;
262 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
263 use tokio::io::{AsyncReadExt, AsyncWriteExt};
264
265 tracing::debug!("using TDS 7.x flow (PreLogin first)");
266
267 let client_encryption = if config.no_tls {
270 tracing::warn!(
272 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
273 Credentials and data will be transmitted in plaintext. \
274 This should only be used for development/testing with legacy SQL Server."
275 );
276 EncryptionLevel::NotSupported
277 } else if config.encrypt {
278 EncryptionLevel::On
279 } else {
280 EncryptionLevel::Off
281 };
282 let prelogin = Self::build_prelogin(config, client_encryption);
283 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
284 let prelogin_bytes = prelogin.encode();
285
286 let header = PacketHeader::new(
288 PacketType::PreLogin,
289 PacketStatus::END_OF_MESSAGE,
290 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
291 );
292
293 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
294 header.encode(&mut packet_buf);
295 packet_buf.put_slice(&prelogin_bytes);
296
297 tcp_stream
298 .write_all(&packet_buf)
299 .await
300 .map_err(|e| Error::Io(Arc::new(e)))?;
301
302 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
304 tcp_stream
305 .read_exact(&mut header_buf)
306 .await
307 .map_err(|e| Error::Io(Arc::new(e)))?;
308
309 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
310 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
311
312 let mut response_buf = vec![0u8; payload_length];
313 tcp_stream
314 .read_exact(&mut response_buf)
315 .await
316 .map_err(|e| Error::Io(Arc::new(e)))?;
317
318 let prelogin_response =
319 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
320
321 let client_tds_version = config.tds_version;
326 if let Some(ref server_version) = prelogin_response.server_version {
327 tracing::debug!(
328 requested_tds_version = %client_tds_version,
329 server_product_version = %server_version,
330 server_product = server_version.product_name(),
331 max_tds_version = %server_version.max_tds_version(),
332 "PreLogin response received"
333 );
334
335 let server_max_tds = server_version.max_tds_version();
337 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
338 tracing::warn!(
339 requested_tds_version = %client_tds_version,
340 server_max_tds_version = %server_max_tds,
341 server_product = server_version.product_name(),
342 "Server supports lower TDS version than requested. \
343 Connection will use server's maximum: {}",
344 server_max_tds
345 );
346 }
347
348 if server_max_tds.is_legacy() {
350 tracing::warn!(
351 server_product = server_version.product_name(),
352 server_max_tds_version = %server_max_tds,
353 "Server uses legacy TDS version. Some features may not be available."
354 );
355 }
356 } else {
357 tracing::debug!(
358 requested_tds_version = %client_tds_version,
359 "PreLogin response received (no version info)"
360 );
361 }
362
363 let server_encryption = prelogin_response.encryption;
365 tracing::debug!(encryption = ?server_encryption, "server encryption level");
366
367 let negotiated_encryption = match (client_encryption, server_encryption) {
373 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
374 EncryptionLevel::NotSupported
375 }
376 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
377 (EncryptionLevel::On, EncryptionLevel::Off)
378 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
379 return Err(Error::Protocol(
380 "Server does not support requested encryption level".to_string(),
381 ));
382 }
383 _ => EncryptionLevel::On,
384 };
385
386 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
389
390 if use_tls {
391 let tls_config =
394 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
395
396 let tls_connector =
397 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
398
399 let mut tls_stream = timeout(
401 config.timeouts.tls_timeout,
402 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
403 )
404 .await
405 .map_err(|_| Error::TlsTimeout)?
406 .map_err(|e| Error::Tls(e.to_string()))?;
407
408 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
409
410 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
412
413 if login_only_encryption {
414 use tokio::io::AsyncWriteExt;
422
423 let login = Self::build_login7(config);
425 let login_payload = login.encode();
426
427 let max_packet = MAX_PACKET_SIZE;
429 let max_payload = max_packet - PACKET_HEADER_SIZE;
430 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
431 let total_chunks = chunks.len();
432
433 for (i, chunk) in chunks.into_iter().enumerate() {
434 let is_last = i == total_chunks - 1;
435 let status = if is_last {
436 PacketStatus::END_OF_MESSAGE
437 } else {
438 PacketStatus::NORMAL
439 };
440
441 let header = PacketHeader::new(
442 PacketType::Tds7Login,
443 status,
444 (PACKET_HEADER_SIZE + chunk.len()) as u16,
445 );
446
447 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
448 header.encode(&mut packet_buf);
449 packet_buf.put_slice(chunk);
450
451 tls_stream
452 .write_all(&packet_buf)
453 .await
454 .map_err(|e| Error::Io(Arc::new(e)))?;
455 }
456
457 tls_stream
459 .flush()
460 .await
461 .map_err(|e| Error::Io(Arc::new(e)))?;
462
463 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
464
465 let (wrapper, _client_conn) = tls_stream.into_inner();
469 let tcp_stream = wrapper.into_inner();
470
471 let mut connection = Connection::new(tcp_stream);
473
474 let (server_version, current_database, routing) =
476 Self::process_login_response(&mut connection).await?;
477
478 if let Some((host, port)) = routing {
480 return Err(Error::Routing { host, port });
481 }
482
483 Ok(Client {
485 config: config.clone(),
486 _state: PhantomData,
487 connection: Some(ConnectionHandle::Plain(connection)),
488 server_version,
489 current_database: current_database.clone(),
490 statement_cache: StatementCache::with_default_size(),
491 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
494 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
495 .with_database(current_database.unwrap_or_default()),
496 })
497 } else {
498 let mut connection = Connection::new(tls_stream);
501
502 let login = Self::build_login7(config);
504 Self::send_login7(&mut connection, &login).await?;
505
506 let (server_version, current_database, routing) =
508 Self::process_login_response(&mut connection).await?;
509
510 if let Some((host, port)) = routing {
512 return Err(Error::Routing { host, port });
513 }
514
515 Ok(Client {
516 config: config.clone(),
517 _state: PhantomData,
518 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
519 server_version,
520 current_database: current_database.clone(),
521 statement_cache: StatementCache::with_default_size(),
522 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
525 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
526 .with_database(current_database.unwrap_or_default()),
527 })
528 }
529 } else {
530 tracing::warn!(
532 "Connecting without TLS encryption. This is insecure and should only be \
533 used for development/testing on trusted networks."
534 );
535
536 let login = Self::build_login7(config);
538 let login_bytes = login.encode();
539 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
540 tracing::debug!(
542 "Login7 fixed header (94 bytes): {:02X?}",
543 &login_bytes[..login_bytes.len().min(94)]
544 );
545 if login_bytes.len() > 94 {
547 tracing::debug!(
548 "Login7 variable data ({} bytes): {:02X?}",
549 login_bytes.len() - 94,
550 &login_bytes[94..]
551 );
552 }
553
554 let login_header = PacketHeader::new(
556 PacketType::Tds7Login,
557 PacketStatus::END_OF_MESSAGE,
558 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
559 )
560 .with_packet_id(1);
561 let mut login_packet_buf =
562 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
563 login_header.encode(&mut login_packet_buf);
564 login_packet_buf.put_slice(&login_bytes);
565
566 tracing::debug!(
567 "Sending Login7 packet: {} bytes total, header: {:02X?}",
568 login_packet_buf.len(),
569 &login_packet_buf[..PACKET_HEADER_SIZE]
570 );
571 tcp_stream
572 .write_all(&login_packet_buf)
573 .await
574 .map_err(|e| Error::Io(Arc::new(e)))?;
575 tcp_stream
576 .flush()
577 .await
578 .map_err(|e| Error::Io(Arc::new(e)))?;
579 tracing::debug!("Login7 sent and flushed over raw TCP");
580
581 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
583 tcp_stream
584 .read_exact(&mut response_header_buf)
585 .await
586 .map_err(|e| Error::Io(Arc::new(e)))?;
587
588 let response_type = response_header_buf[0];
589 let response_length =
590 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
591 tracing::debug!(
592 "Response header: type={:#04X}, length={}",
593 response_type,
594 response_length
595 );
596
597 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
599 let mut response_payload = vec![0u8; payload_length];
600 tcp_stream
601 .read_exact(&mut response_payload)
602 .await
603 .map_err(|e| Error::Io(Arc::new(e)))?;
604 tracing::debug!(
605 "Response payload: {} bytes, first 32: {:02X?}",
606 response_payload.len(),
607 &response_payload[..response_payload.len().min(32)]
608 );
609
610 let connection = Connection::new(tcp_stream);
612
613 let response_bytes = bytes::Bytes::from(response_payload);
615 let mut parser = TokenParser::new(response_bytes);
616 let mut server_version = None;
617 let mut current_database = None;
618 let routing = None;
619
620 while let Some(token) = parser
621 .next_token()
622 .map_err(|e| Error::Protocol(e.to_string()))?
623 {
624 match token {
625 Token::LoginAck(ack) => {
626 tracing::info!(
627 version = ack.tds_version,
628 interface = ack.interface,
629 prog_name = %ack.prog_name,
630 "login acknowledged"
631 );
632 server_version = Some(ack.tds_version);
633 }
634 Token::EnvChange(env) => {
635 Self::process_env_change(&env, &mut current_database, &mut None);
636 }
637 Token::Error(err) => {
638 return Err(Error::Server {
639 number: err.number,
640 state: err.state,
641 class: err.class,
642 message: err.message.clone(),
643 server: if err.server.is_empty() {
644 None
645 } else {
646 Some(err.server.clone())
647 },
648 procedure: if err.procedure.is_empty() {
649 None
650 } else {
651 Some(err.procedure.clone())
652 },
653 line: err.line as u32,
654 });
655 }
656 Token::Info(info) => {
657 tracing::info!(
658 number = info.number,
659 message = %info.message,
660 "server info message"
661 );
662 }
663 Token::Done(done) => {
664 if done.status.error {
665 return Err(Error::Protocol("login failed".to_string()));
666 }
667 break;
668 }
669 _ => {}
670 }
671 }
672
673 if let Some((host, port)) = routing {
675 return Err(Error::Routing { host, port });
676 }
677
678 Ok(Client {
679 config: config.clone(),
680 _state: PhantomData,
681 connection: Some(ConnectionHandle::Plain(connection)),
682 server_version,
683 current_database: current_database.clone(),
684 statement_cache: StatementCache::with_default_size(),
685 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
688 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
689 .with_database(current_database.unwrap_or_default()),
690 })
691 }
692 }
693
694 #[cfg(not(feature = "tls"))]
705 async fn connect_no_tls(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
706 use bytes::BufMut;
707 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
708 use tokio::io::{AsyncReadExt, AsyncWriteExt};
709
710 tracing::warn!(
711 "⚠️ Connecting without TLS (tls feature disabled). \
712 Credentials and data will be transmitted in plaintext."
713 );
714
715 let prelogin = Self::build_prelogin(config, EncryptionLevel::NotSupported);
717 let prelogin_bytes = prelogin.encode();
718
719 let header = PacketHeader::new(
721 PacketType::PreLogin,
722 PacketStatus::END_OF_MESSAGE,
723 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
724 );
725
726 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
727 header.encode(&mut packet_buf);
728 packet_buf.put_slice(&prelogin_bytes);
729
730 tcp_stream
731 .write_all(&packet_buf)
732 .await
733 .map_err(|e| Error::Io(Arc::new(e)))?;
734
735 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
737 tcp_stream
738 .read_exact(&mut header_buf)
739 .await
740 .map_err(|e| Error::Io(Arc::new(e)))?;
741
742 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
743 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
744
745 let mut response_buf = vec![0u8; payload_length];
746 tcp_stream
747 .read_exact(&mut response_buf)
748 .await
749 .map_err(|e| Error::Io(Arc::new(e)))?;
750
751 let prelogin_response =
752 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
753
754 let server_encryption = prelogin_response.encryption;
756 if server_encryption != EncryptionLevel::NotSupported {
757 return Err(Error::Config(format!(
758 "Server requires encryption (level: {:?}) but TLS feature is disabled. \
759 Either enable the 'tls' feature or configure the server to allow unencrypted connections.",
760 server_encryption
761 )));
762 }
763
764 tracing::debug!("Server accepted unencrypted connection");
765
766 let login = Self::build_login7(config);
768 let login_bytes = login.encode();
769
770 let login_header = PacketHeader::new(
772 PacketType::Tds7Login,
773 PacketStatus::END_OF_MESSAGE,
774 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
775 )
776 .with_packet_id(1);
777 let mut login_packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
778 login_header.encode(&mut login_packet_buf);
779 login_packet_buf.put_slice(&login_bytes);
780
781 tcp_stream
782 .write_all(&login_packet_buf)
783 .await
784 .map_err(|e| Error::Io(Arc::new(e)))?;
785 tcp_stream
786 .flush()
787 .await
788 .map_err(|e| Error::Io(Arc::new(e)))?;
789
790 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
792 tcp_stream
793 .read_exact(&mut response_header_buf)
794 .await
795 .map_err(|e| Error::Io(Arc::new(e)))?;
796
797 let response_length =
798 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
799
800 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
802 let mut response_payload = vec![0u8; payload_length];
803 tcp_stream
804 .read_exact(&mut response_payload)
805 .await
806 .map_err(|e| Error::Io(Arc::new(e)))?;
807
808 let connection = Connection::new(tcp_stream);
810
811 let response_bytes = bytes::Bytes::from(response_payload);
813 let mut parser = TokenParser::new(response_bytes);
814 let mut server_version = None;
815 let mut current_database = None;
816
817 while let Some(token) = parser
818 .next_token()
819 .map_err(|e| Error::Protocol(e.to_string()))?
820 {
821 match token {
822 Token::LoginAck(ack) => {
823 tracing::info!(
824 version = ack.tds_version,
825 interface = ack.interface,
826 prog_name = %ack.prog_name,
827 "login acknowledged"
828 );
829 server_version = Some(ack.tds_version);
830 }
831 Token::EnvChange(env) => {
832 Self::process_env_change(&env, &mut current_database, &mut None);
833 }
834 Token::Error(err) => {
835 return Err(Error::Server {
836 number: err.number,
837 state: err.state,
838 class: err.class,
839 message: err.message.clone(),
840 server: if err.server.is_empty() {
841 None
842 } else {
843 Some(err.server.clone())
844 },
845 procedure: if err.procedure.is_empty() {
846 None
847 } else {
848 Some(err.procedure.clone())
849 },
850 line: err.line as u32,
851 });
852 }
853 Token::Info(info) => {
854 tracing::info!(
855 number = info.number,
856 message = %info.message,
857 "server info message"
858 );
859 }
860 Token::Done(done) => {
861 if done.status.error {
862 return Err(Error::Protocol("login failed".to_string()));
863 }
864 break;
865 }
866 _ => {}
867 }
868 }
869
870 Ok(Client {
871 config: config.clone(),
872 _state: PhantomData,
873 connection: Some(ConnectionHandle::Plain(connection)),
874 server_version,
875 current_database: current_database.clone(),
876 statement_cache: StatementCache::with_default_size(),
877 transaction_descriptor: 0,
878 needs_reset: false,
879 #[cfg(feature = "otel")]
880 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
881 .with_database(current_database.unwrap_or_default()),
882 })
883 }
884
885 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
887 let version = if config.strict_mode {
889 tds_protocol::version::TdsVersion::V8_0
890 } else {
891 config.tds_version
892 };
893
894 let mut prelogin = PreLogin::new()
895 .with_version(version)
896 .with_encryption(encryption);
897
898 if config.mars {
899 prelogin = prelogin.with_mars(true);
900 }
901
902 if let Some(ref instance) = config.instance {
903 prelogin = prelogin.with_instance(instance);
904 }
905
906 prelogin
907 }
908
909 fn build_login7(config: &Config) -> Login7 {
911 let version = if config.strict_mode {
913 tds_protocol::version::TdsVersion::V8_0
914 } else {
915 config.tds_version
916 };
917
918 let mut login = Login7::new()
919 .with_tds_version(version)
920 .with_packet_size(config.packet_size as u32)
921 .with_app_name(&config.application_name)
922 .with_server_name(&config.host)
923 .with_hostname(&config.host);
924
925 if let Some(ref database) = config.database {
926 login = login.with_database(database);
927 }
928
929 match &config.credentials {
931 mssql_auth::Credentials::SqlServer { username, password } => {
932 login = login.with_sql_auth(username.as_ref(), password.as_ref());
933 }
934 _ => {}
936 }
937
938 login
939 }
940
941 #[cfg(feature = "tls")]
943 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
944 where
945 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
946 {
947 let payload = prelogin.encode();
948 let max_packet = MAX_PACKET_SIZE;
949
950 connection
951 .send_message(PacketType::PreLogin, payload, max_packet)
952 .await
953 .map_err(|e| Error::Protocol(e.to_string()))
954 }
955
956 #[cfg(feature = "tls")]
958 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
959 where
960 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
961 {
962 let message = connection
963 .read_message()
964 .await
965 .map_err(|e| Error::Protocol(e.to_string()))?
966 .ok_or(Error::ConnectionClosed)?;
967
968 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
969 }
970
971 #[cfg(feature = "tls")]
973 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
974 where
975 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
976 {
977 let payload = login.encode();
978 let max_packet = MAX_PACKET_SIZE;
979
980 connection
981 .send_message(PacketType::Tds7Login, payload, max_packet)
982 .await
983 .map_err(|e| Error::Protocol(e.to_string()))
984 }
985
986 #[cfg(feature = "tls")]
990 async fn process_login_response<T>(
991 connection: &mut Connection<T>,
992 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
993 where
994 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
995 {
996 let message = connection
997 .read_message()
998 .await
999 .map_err(|e| Error::Protocol(e.to_string()))?
1000 .ok_or(Error::ConnectionClosed)?;
1001
1002 let response_bytes = message.payload;
1003
1004 let mut parser = TokenParser::new(response_bytes);
1005 let mut server_version = None;
1006 let mut database = None;
1007 let mut routing = None;
1008
1009 while let Some(token) = parser
1010 .next_token()
1011 .map_err(|e| Error::Protocol(e.to_string()))?
1012 {
1013 match token {
1014 Token::LoginAck(ack) => {
1015 tracing::info!(
1016 version = ack.tds_version,
1017 interface = ack.interface,
1018 prog_name = %ack.prog_name,
1019 "login acknowledged"
1020 );
1021 server_version = Some(ack.tds_version);
1022 }
1023 Token::EnvChange(env) => {
1024 Self::process_env_change(&env, &mut database, &mut routing);
1025 }
1026 Token::Error(err) => {
1027 return Err(Error::Server {
1028 number: err.number,
1029 state: err.state,
1030 class: err.class,
1031 message: err.message.clone(),
1032 server: if err.server.is_empty() {
1033 None
1034 } else {
1035 Some(err.server.clone())
1036 },
1037 procedure: if err.procedure.is_empty() {
1038 None
1039 } else {
1040 Some(err.procedure.clone())
1041 },
1042 line: err.line as u32,
1043 });
1044 }
1045 Token::Info(info) => {
1046 tracing::info!(
1047 number = info.number,
1048 message = %info.message,
1049 "server info message"
1050 );
1051 }
1052 Token::Done(done) => {
1053 if done.status.error {
1054 return Err(Error::Protocol("login failed".to_string()));
1055 }
1056 break;
1057 }
1058 _ => {}
1059 }
1060 }
1061
1062 Ok((server_version, database, routing))
1063 }
1064
1065 fn process_env_change(
1067 env: &EnvChange,
1068 database: &mut Option<String>,
1069 routing: &mut Option<(String, u16)>,
1070 ) {
1071 use tds_protocol::token::EnvChangeValue;
1072
1073 match env.env_type {
1074 EnvChangeType::Database => {
1075 if let EnvChangeValue::String(ref new_value) = env.new_value {
1076 tracing::debug!(database = %new_value, "database changed");
1077 *database = Some(new_value.clone());
1078 }
1079 }
1080 EnvChangeType::Routing => {
1081 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
1082 tracing::info!(host = %host, port = port, "routing redirect received");
1083 *routing = Some((host.clone(), port));
1084 }
1085 }
1086 _ => {
1087 if let EnvChangeValue::String(ref new_value) = env.new_value {
1088 tracing::debug!(
1089 env_type = ?env.env_type,
1090 new_value = %new_value,
1091 "environment change"
1092 );
1093 }
1094 }
1095 }
1096 }
1097}
1098
1099impl<S: ConnectionState> Client<S> {
1101 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
1109 use tds_protocol::token::EnvChangeValue;
1110
1111 match env.env_type {
1112 EnvChangeType::BeginTransaction => {
1113 if let EnvChangeValue::Binary(ref data) = env.new_value {
1114 if data.len() >= 8 {
1115 let descriptor = u64::from_le_bytes([
1116 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
1117 ]);
1118 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
1119 *transaction_descriptor = descriptor;
1120 }
1121 }
1122 }
1123 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
1124 tracing::debug!(
1125 env_type = ?env.env_type,
1126 "transaction ended via raw SQL"
1127 );
1128 *transaction_descriptor = 0;
1129 }
1130 _ => {}
1131 }
1132 }
1133
1134 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
1143 let payload =
1144 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
1145 let max_packet = self.config.packet_size as usize;
1146
1147 let reset = self.needs_reset;
1149 if reset {
1150 self.needs_reset = false; tracing::debug!("sending SQL batch with RESETCONNECTION flag");
1152 }
1153
1154 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1155
1156 match connection {
1157 #[cfg(feature = "tls")]
1158 ConnectionHandle::Tls(conn) => {
1159 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
1160 .await
1161 .map_err(|e| Error::Protocol(e.to_string()))?;
1162 }
1163 #[cfg(feature = "tls")]
1164 ConnectionHandle::TlsPrelogin(conn) => {
1165 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
1166 .await
1167 .map_err(|e| Error::Protocol(e.to_string()))?;
1168 }
1169 ConnectionHandle::Plain(conn) => {
1170 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
1171 .await
1172 .map_err(|e| Error::Protocol(e.to_string()))?;
1173 }
1174 }
1175
1176 Ok(())
1177 }
1178
1179 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
1186 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
1187 let max_packet = self.config.packet_size as usize;
1188
1189 let reset = self.needs_reset;
1191 if reset {
1192 self.needs_reset = false; tracing::debug!("sending RPC with RESETCONNECTION flag");
1194 }
1195
1196 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1197
1198 match connection {
1199 #[cfg(feature = "tls")]
1200 ConnectionHandle::Tls(conn) => {
1201 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
1202 .await
1203 .map_err(|e| Error::Protocol(e.to_string()))?;
1204 }
1205 #[cfg(feature = "tls")]
1206 ConnectionHandle::TlsPrelogin(conn) => {
1207 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
1208 .await
1209 .map_err(|e| Error::Protocol(e.to_string()))?;
1210 }
1211 ConnectionHandle::Plain(conn) => {
1212 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
1213 .await
1214 .map_err(|e| Error::Protocol(e.to_string()))?;
1215 }
1216 }
1217
1218 Ok(())
1219 }
1220
1221 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
1223 use bytes::{BufMut, BytesMut};
1224 use mssql_types::SqlValue;
1225
1226 params
1227 .iter()
1228 .enumerate()
1229 .map(|(i, p)| {
1230 let sql_value = p.to_sql()?;
1231 let name = format!("@p{}", i + 1);
1232
1233 Ok(match sql_value {
1234 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
1235 SqlValue::Bool(v) => {
1236 let mut buf = BytesMut::with_capacity(1);
1237 buf.put_u8(if v { 1 } else { 0 });
1238 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
1239 }
1240 SqlValue::TinyInt(v) => {
1241 let mut buf = BytesMut::with_capacity(1);
1242 buf.put_u8(v);
1243 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
1244 }
1245 SqlValue::SmallInt(v) => {
1246 let mut buf = BytesMut::with_capacity(2);
1247 buf.put_i16_le(v);
1248 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
1249 }
1250 SqlValue::Int(v) => RpcParam::int(&name, v),
1251 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
1252 SqlValue::Float(v) => {
1253 let mut buf = BytesMut::with_capacity(4);
1254 buf.put_f32_le(v);
1255 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
1256 }
1257 SqlValue::Double(v) => {
1258 let mut buf = BytesMut::with_capacity(8);
1259 buf.put_f64_le(v);
1260 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
1261 }
1262 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
1263 SqlValue::Binary(ref b) => {
1264 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
1265 }
1266 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
1267 #[cfg(feature = "uuid")]
1268 SqlValue::Uuid(u) => {
1269 let bytes = u.as_bytes();
1271 let mut buf = BytesMut::with_capacity(16);
1272 buf.put_u32_le(u32::from_be_bytes([
1274 bytes[0], bytes[1], bytes[2], bytes[3],
1275 ]));
1276 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
1277 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
1278 buf.put_slice(&bytes[8..16]);
1279 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
1280 }
1281 #[cfg(feature = "decimal")]
1282 SqlValue::Decimal(d) => {
1283 RpcParam::nvarchar(&name, &d.to_string())
1285 }
1286 #[cfg(feature = "chrono")]
1287 SqlValue::Date(_)
1288 | SqlValue::Time(_)
1289 | SqlValue::DateTime(_)
1290 | SqlValue::DateTimeOffset(_) => {
1291 let s = match &sql_value {
1294 SqlValue::Date(d) => d.to_string(),
1295 SqlValue::Time(t) => t.to_string(),
1296 SqlValue::DateTime(dt) => dt.to_string(),
1297 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
1298 _ => unreachable!(),
1299 };
1300 RpcParam::nvarchar(&name, &s)
1301 }
1302 #[cfg(feature = "json")]
1303 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
1304 SqlValue::Tvp(ref tvp_data) => {
1305 Self::encode_tvp_param(&name, tvp_data)?
1307 }
1308 _ => {
1310 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
1311 from: sql_value.type_name().to_string(),
1312 to: "RPC parameter",
1313 }));
1314 }
1315 })
1316 })
1317 .collect()
1318 }
1319
1320 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
1325 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1327 .columns
1328 .iter()
1329 .map(|col| {
1330 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1331 TvpWireColumnDef {
1332 wire_type,
1333 flags: TvpColumnFlags {
1334 nullable: col.nullable,
1335 },
1336 }
1337 })
1338 .collect();
1339
1340 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1342
1343 let mut buf = BytesMut::with_capacity(256);
1345
1346 encoder.encode_metadata(&mut buf);
1348
1349 for row in &tvp_data.rows {
1351 encoder.encode_row(&mut buf, |row_buf| {
1352 for (col_idx, value) in row.iter().enumerate() {
1353 let wire_type = &wire_columns[col_idx].wire_type;
1354 Self::encode_tvp_value(value, wire_type, row_buf);
1355 }
1356 });
1357 }
1358
1359 encoder.encode_end(&mut buf);
1361
1362 let full_type_name = if tvp_data.schema.is_empty() {
1364 tvp_data.type_name.clone()
1365 } else {
1366 format!("{}.{}", tvp_data.schema, tvp_data.type_name)
1367 };
1368
1369 let type_info = RpcTypeInfo::tvp(&full_type_name);
1372
1373 Ok(RpcParam {
1374 name: name.to_string(),
1375 flags: tds_protocol::rpc::ParamFlags::default(),
1376 type_info,
1377 value: Some(buf.freeze()),
1378 })
1379 }
1380
1381 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1383 match col_type {
1384 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1385 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1386 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1387 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1388 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1389 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1390 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1391 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1392 precision: *precision,
1393 scale: *scale,
1394 },
1395 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1396 max_length: *max_length,
1397 },
1398 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1399 max_length: *max_length,
1400 },
1401 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1402 max_length: *max_length,
1403 },
1404 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1405 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1406 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1407 mssql_types::TvpColumnType::DateTime2 { scale } => {
1408 TvpWireType::DateTime2 { scale: *scale }
1409 }
1410 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1411 TvpWireType::DateTimeOffset { scale: *scale }
1412 }
1413 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1414 }
1415 }
1416
1417 fn encode_tvp_value(
1419 value: &mssql_types::SqlValue,
1420 wire_type: &TvpWireType,
1421 buf: &mut BytesMut,
1422 ) {
1423 use mssql_types::SqlValue;
1424
1425 match value {
1426 SqlValue::Null => {
1427 encode_tvp_null(wire_type, buf);
1428 }
1429 SqlValue::Bool(v) => {
1430 encode_tvp_bit(*v, buf);
1431 }
1432 SqlValue::TinyInt(v) => {
1433 encode_tvp_int(*v as i64, 1, buf);
1434 }
1435 SqlValue::SmallInt(v) => {
1436 encode_tvp_int(*v as i64, 2, buf);
1437 }
1438 SqlValue::Int(v) => {
1439 encode_tvp_int(*v as i64, 4, buf);
1440 }
1441 SqlValue::BigInt(v) => {
1442 encode_tvp_int(*v, 8, buf);
1443 }
1444 SqlValue::Float(v) => {
1445 encode_tvp_float(*v as f64, 4, buf);
1446 }
1447 SqlValue::Double(v) => {
1448 encode_tvp_float(*v, 8, buf);
1449 }
1450 SqlValue::String(s) => {
1451 let max_len = match wire_type {
1452 TvpWireType::NVarChar { max_length } => *max_length,
1453 _ => 4000,
1454 };
1455 encode_tvp_nvarchar(s, max_len, buf);
1456 }
1457 SqlValue::Binary(b) => {
1458 let max_len = match wire_type {
1459 TvpWireType::VarBinary { max_length } => *max_length,
1460 _ => 8000,
1461 };
1462 encode_tvp_varbinary(b, max_len, buf);
1463 }
1464 #[cfg(feature = "decimal")]
1465 SqlValue::Decimal(d) => {
1466 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1467 let mantissa = d.mantissa().unsigned_abs();
1468 encode_tvp_decimal(sign, mantissa, buf);
1469 }
1470 #[cfg(feature = "uuid")]
1471 SqlValue::Uuid(u) => {
1472 let bytes = u.as_bytes();
1473 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1474 }
1475 #[cfg(feature = "chrono")]
1476 SqlValue::Date(d) => {
1477 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1479 let days = d.signed_duration_since(base).num_days() as u32;
1480 tds_protocol::tvp::encode_tvp_date(days, buf);
1481 }
1482 #[cfg(feature = "chrono")]
1483 SqlValue::Time(t) => {
1484 use chrono::Timelike;
1485 let nanos =
1486 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1487 let intervals = nanos / 100;
1488 let scale = match wire_type {
1489 TvpWireType::Time { scale } => *scale,
1490 _ => 7,
1491 };
1492 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1493 }
1494 #[cfg(feature = "chrono")]
1495 SqlValue::DateTime(dt) => {
1496 use chrono::Timelike;
1497 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1499 + dt.time().nanosecond() as u64;
1500 let intervals = nanos / 100;
1501 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1503 let days = dt.date().signed_duration_since(base).num_days() as u32;
1504 let scale = match wire_type {
1505 TvpWireType::DateTime2 { scale } => *scale,
1506 _ => 7,
1507 };
1508 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1509 }
1510 #[cfg(feature = "chrono")]
1511 SqlValue::DateTimeOffset(dto) => {
1512 use chrono::{Offset, Timelike};
1513 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1515 + dto.time().nanosecond() as u64;
1516 let intervals = nanos / 100;
1517 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1519 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1520 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1522 let scale = match wire_type {
1523 TvpWireType::DateTimeOffset { scale } => *scale,
1524 _ => 7,
1525 };
1526 tds_protocol::tvp::encode_tvp_datetimeoffset(
1527 intervals,
1528 days,
1529 offset_minutes,
1530 scale,
1531 buf,
1532 );
1533 }
1534 #[cfg(feature = "json")]
1535 SqlValue::Json(j) => {
1536 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1538 }
1539 SqlValue::Xml(s) => {
1540 encode_tvp_nvarchar(s, 0xFFFF, buf);
1542 }
1543 SqlValue::Tvp(_) => {
1544 encode_tvp_null(wire_type, buf);
1546 }
1547 _ => {
1549 encode_tvp_null(wire_type, buf);
1550 }
1551 }
1552 }
1553
1554 async fn read_query_response(
1556 &mut self,
1557 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1558 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1559
1560 let message = match connection {
1561 #[cfg(feature = "tls")]
1562 ConnectionHandle::Tls(conn) => conn
1563 .read_message()
1564 .await
1565 .map_err(|e| Error::Protocol(e.to_string()))?,
1566 #[cfg(feature = "tls")]
1567 ConnectionHandle::TlsPrelogin(conn) => conn
1568 .read_message()
1569 .await
1570 .map_err(|e| Error::Protocol(e.to_string()))?,
1571 ConnectionHandle::Plain(conn) => conn
1572 .read_message()
1573 .await
1574 .map_err(|e| Error::Protocol(e.to_string()))?,
1575 }
1576 .ok_or(Error::ConnectionClosed)?;
1577
1578 let mut parser = TokenParser::new(message.payload);
1579 let mut columns: Vec<crate::row::Column> = Vec::new();
1580 let mut rows: Vec<crate::row::Row> = Vec::new();
1581 let mut protocol_metadata: Option<ColMetaData> = None;
1582
1583 loop {
1584 let token = parser
1586 .next_token_with_metadata(protocol_metadata.as_ref())
1587 .map_err(|e| Error::Protocol(e.to_string()))?;
1588
1589 let Some(token) = token else {
1590 break;
1591 };
1592
1593 match token {
1594 Token::ColMetaData(meta) => {
1595 rows.clear();
1598
1599 columns = meta
1600 .columns
1601 .iter()
1602 .enumerate()
1603 .map(|(i, col)| {
1604 let type_name = format!("{:?}", col.type_id);
1605 let mut column = crate::row::Column::new(&col.name, i, type_name)
1606 .with_nullable(col.flags & 0x01 != 0);
1607
1608 if let Some(max_len) = col.type_info.max_length {
1609 column = column.with_max_length(max_len);
1610 }
1611 if let (Some(prec), Some(scale)) =
1612 (col.type_info.precision, col.type_info.scale)
1613 {
1614 column = column.with_precision_scale(prec, scale);
1615 }
1616 if let Some(collation) = col.type_info.collation {
1619 column = column.with_collation(collation);
1620 }
1621 column
1622 })
1623 .collect();
1624
1625 tracing::debug!(columns = columns.len(), "received column metadata");
1626 protocol_metadata = Some(meta);
1627 }
1628 Token::Row(raw_row) => {
1629 if let Some(ref meta) = protocol_metadata {
1630 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1631 rows.push(row);
1632 }
1633 }
1634 Token::NbcRow(nbc_row) => {
1635 if let Some(ref meta) = protocol_metadata {
1636 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1637 rows.push(row);
1638 }
1639 }
1640 Token::Error(err) => {
1641 return Err(Error::Server {
1642 number: err.number,
1643 state: err.state,
1644 class: err.class,
1645 message: err.message.clone(),
1646 server: if err.server.is_empty() {
1647 None
1648 } else {
1649 Some(err.server.clone())
1650 },
1651 procedure: if err.procedure.is_empty() {
1652 None
1653 } else {
1654 Some(err.procedure.clone())
1655 },
1656 line: err.line as u32,
1657 });
1658 }
1659 Token::Done(done) => {
1660 if done.status.error {
1661 return Err(Error::Query("query failed".to_string()));
1662 }
1663 tracing::debug!(
1664 row_count = done.row_count,
1665 has_more = done.status.more,
1666 "query complete"
1667 );
1668 if !done.status.more {
1671 break;
1672 }
1673 }
1674 Token::DoneProc(done) => {
1675 if done.status.error {
1676 return Err(Error::Query("query failed".to_string()));
1677 }
1678 }
1679 Token::DoneInProc(done) => {
1680 if done.status.error {
1681 return Err(Error::Query("query failed".to_string()));
1682 }
1683 }
1684 Token::Info(info) => {
1685 tracing::debug!(
1686 number = info.number,
1687 message = %info.message,
1688 "server info message"
1689 );
1690 }
1691 Token::EnvChange(env) => {
1692 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1696 }
1697 _ => {}
1698 }
1699 }
1700
1701 tracing::debug!(
1702 columns = columns.len(),
1703 rows = rows.len(),
1704 "query response parsed"
1705 );
1706 Ok((columns, rows))
1707 }
1708
1709 fn convert_raw_row(
1713 raw: &RawRow,
1714 meta: &ColMetaData,
1715 columns: &[crate::row::Column],
1716 ) -> Result<crate::row::Row> {
1717 let mut values = Vec::with_capacity(meta.columns.len());
1718 let mut buf = raw.data.as_ref();
1719
1720 for col in &meta.columns {
1721 let value = Self::parse_column_value(&mut buf, col)?;
1722 values.push(value);
1723 }
1724
1725 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1726 }
1727
1728 fn convert_nbc_row(
1732 nbc: &NbcRow,
1733 meta: &ColMetaData,
1734 columns: &[crate::row::Column],
1735 ) -> Result<crate::row::Row> {
1736 let mut values = Vec::with_capacity(meta.columns.len());
1737 let mut buf = nbc.data.as_ref();
1738
1739 for (i, col) in meta.columns.iter().enumerate() {
1740 if nbc.is_null(i) {
1741 values.push(mssql_types::SqlValue::Null);
1742 } else {
1743 let value = Self::parse_column_value(&mut buf, col)?;
1744 values.push(value);
1745 }
1746 }
1747
1748 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1749 }
1750
1751 fn parse_money_value(buf: &mut &[u8], bytes: usize) -> Result<mssql_types::SqlValue> {
1757 use bytes::Buf;
1758 use mssql_types::SqlValue;
1759
1760 if bytes == 0 {
1761 return Ok(SqlValue::Null);
1762 }
1763
1764 let cents = match bytes {
1765 4 => buf.get_i32_le() as i64,
1766 8 => {
1767 let high = buf.get_i32_le();
1768 let low = buf.get_u32_le();
1769 ((high as i64) << 32) | (low as i64)
1770 }
1771 _ => return Err(Error::Protocol(format!("invalid money length: {bytes}"))),
1772 };
1773
1774 #[cfg(feature = "decimal")]
1775 {
1776 use rust_decimal::Decimal;
1777 Ok(SqlValue::Decimal(Decimal::from_i128_with_scale(
1778 cents as i128,
1779 4,
1780 )))
1781 }
1782
1783 #[cfg(not(feature = "decimal"))]
1784 {
1785 Ok(SqlValue::Double((cents as f64) / 10000.0))
1786 }
1787 }
1788
1789 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1791 use bytes::Buf;
1792 use mssql_types::SqlValue;
1793 use tds_protocol::types::TypeId;
1794
1795 let value = match col.type_id {
1796 TypeId::Null => SqlValue::Null,
1798
1799 TypeId::Int1 => {
1801 if buf.remaining() < 1 {
1802 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1803 }
1804 SqlValue::TinyInt(buf.get_u8())
1805 }
1806 TypeId::Bit => {
1807 if buf.remaining() < 1 {
1808 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1809 }
1810 SqlValue::Bool(buf.get_u8() != 0)
1811 }
1812
1813 TypeId::Int2 => {
1815 if buf.remaining() < 2 {
1816 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1817 }
1818 SqlValue::SmallInt(buf.get_i16_le())
1819 }
1820
1821 TypeId::Int4 => {
1823 if buf.remaining() < 4 {
1824 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1825 }
1826 SqlValue::Int(buf.get_i32_le())
1827 }
1828 TypeId::Float4 => {
1829 if buf.remaining() < 4 {
1830 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1831 }
1832 SqlValue::Float(buf.get_f32_le())
1833 }
1834
1835 TypeId::Int8 => {
1837 if buf.remaining() < 8 {
1838 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1839 }
1840 SqlValue::BigInt(buf.get_i64_le())
1841 }
1842 TypeId::Float8 => {
1843 if buf.remaining() < 8 {
1844 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1845 }
1846 SqlValue::Double(buf.get_f64_le())
1847 }
1848
1849 TypeId::Money | TypeId::Money4 | TypeId::MoneyN => {
1851 let bytes = match col.type_id {
1852 TypeId::Money => 8,
1853 TypeId::Money4 => 4,
1854 TypeId::MoneyN => {
1855 if buf.remaining() < 1 {
1856 return Err(Error::Protocol(
1857 "unexpected EOF reading MoneyN length".into(),
1858 ));
1859 }
1860 buf.get_u8() as usize
1861 }
1862 _ => unreachable!(),
1863 };
1864
1865 if buf.remaining() < bytes {
1866 return Err(Error::Protocol(format!(
1867 "unexpected EOF reading money data ({bytes} bytes)"
1868 )));
1869 }
1870
1871 Self::parse_money_value(buf, bytes)?
1872 }
1873
1874 TypeId::IntN => {
1876 if buf.remaining() < 1 {
1877 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1878 }
1879 let len = buf.get_u8();
1880 match len {
1881 0 => SqlValue::Null,
1882 1 => SqlValue::TinyInt(buf.get_u8()),
1883 2 => SqlValue::SmallInt(buf.get_i16_le()),
1884 4 => SqlValue::Int(buf.get_i32_le()),
1885 8 => SqlValue::BigInt(buf.get_i64_le()),
1886 _ => {
1887 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1888 }
1889 }
1890 }
1891 TypeId::FloatN => {
1892 if buf.remaining() < 1 {
1893 return Err(Error::Protocol(
1894 "unexpected EOF reading FloatN length".into(),
1895 ));
1896 }
1897 let len = buf.get_u8();
1898 match len {
1899 0 => SqlValue::Null,
1900 4 => SqlValue::Float(buf.get_f32_le()),
1901 8 => SqlValue::Double(buf.get_f64_le()),
1902 _ => {
1903 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1904 }
1905 }
1906 }
1907 TypeId::BitN => {
1908 if buf.remaining() < 1 {
1909 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1910 }
1911 let len = buf.get_u8();
1912 match len {
1913 0 => SqlValue::Null,
1914 1 => SqlValue::Bool(buf.get_u8() != 0),
1915 _ => {
1916 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1917 }
1918 }
1919 }
1920
1921 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1923 if buf.remaining() < 1 {
1924 return Err(Error::Protocol(
1925 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1926 ));
1927 }
1928 let len = buf.get_u8() as usize;
1929 if len == 0 {
1930 SqlValue::Null
1931 } else {
1932 if buf.remaining() < len {
1933 return Err(Error::Protocol(
1934 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1935 ));
1936 }
1937
1938 let sign = buf.get_u8();
1940 let mantissa_len = len - 1;
1941
1942 let mut mantissa_bytes = [0u8; 16];
1944 for i in 0..mantissa_len.min(16) {
1945 mantissa_bytes[i] = buf.get_u8();
1946 }
1947 for _ in 16..mantissa_len {
1949 buf.get_u8();
1950 }
1951
1952 let mantissa = u128::from_le_bytes(mantissa_bytes);
1953 let scale = col.type_info.scale.unwrap_or(0) as u32;
1954
1955 #[cfg(feature = "decimal")]
1956 {
1957 use rust_decimal::Decimal;
1958 if scale > 28 {
1961 let divisor = 10f64.powi(scale as i32);
1963 let value = (mantissa as f64) / divisor;
1964 let value = if sign == 0 { -value } else { value };
1965 SqlValue::Double(value)
1966 } else {
1967 let mut decimal =
1968 Decimal::from_i128_with_scale(mantissa as i128, scale);
1969 if sign == 0 {
1970 decimal.set_sign_negative(true);
1971 }
1972 SqlValue::Decimal(decimal)
1973 }
1974 }
1975
1976 #[cfg(not(feature = "decimal"))]
1977 {
1978 let divisor = 10f64.powi(scale as i32);
1980 let value = (mantissa as f64) / divisor;
1981 let value = if sign == 0 { -value } else { value };
1982 SqlValue::Double(value)
1983 }
1984 }
1985 }
1986
1987 TypeId::DateTimeN => {
1989 if buf.remaining() < 1 {
1990 return Err(Error::Protocol(
1991 "unexpected EOF reading DateTimeN length".into(),
1992 ));
1993 }
1994 let len = buf.get_u8() as usize;
1995 if len == 0 {
1996 SqlValue::Null
1997 } else if buf.remaining() < len {
1998 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1999 } else {
2000 match len {
2001 4 => {
2002 let days = buf.get_u16_le() as i64;
2004 let minutes = buf.get_u16_le() as u32;
2005 #[cfg(feature = "chrono")]
2006 {
2007 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
2008 let date = base + chrono::Duration::days(days);
2009 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
2010 minutes * 60,
2011 0,
2012 )
2013 .unwrap();
2014 SqlValue::DateTime(date.and_time(time))
2015 }
2016 #[cfg(not(feature = "chrono"))]
2017 {
2018 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
2019 }
2020 }
2021 8 => {
2022 let days = buf.get_i32_le() as i64;
2024 let time_300ths = buf.get_u32_le() as u64;
2025 #[cfg(feature = "chrono")]
2026 {
2027 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
2028 let date = base + chrono::Duration::days(days);
2029 let total_ms = (time_300ths * 1000) / 300;
2031 let secs = (total_ms / 1000) as u32;
2032 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
2033 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
2034 secs, nanos,
2035 )
2036 .unwrap();
2037 SqlValue::DateTime(date.and_time(time))
2038 }
2039 #[cfg(not(feature = "chrono"))]
2040 {
2041 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
2042 }
2043 }
2044 _ => {
2045 return Err(Error::Protocol(format!(
2046 "invalid DateTimeN length: {len}"
2047 )));
2048 }
2049 }
2050 }
2051 }
2052
2053 TypeId::DateTime => {
2055 if buf.remaining() < 8 {
2056 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
2057 }
2058 let days = buf.get_i32_le() as i64;
2059 let time_300ths = buf.get_u32_le() as u64;
2060 #[cfg(feature = "chrono")]
2061 {
2062 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
2063 let date = base + chrono::Duration::days(days);
2064 let total_ms = (time_300ths * 1000) / 300;
2065 let secs = (total_ms / 1000) as u32;
2066 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
2067 let time =
2068 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
2069 SqlValue::DateTime(date.and_time(time))
2070 }
2071 #[cfg(not(feature = "chrono"))]
2072 {
2073 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
2074 }
2075 }
2076
2077 TypeId::DateTime4 => {
2079 if buf.remaining() < 4 {
2080 return Err(Error::Protocol(
2081 "unexpected EOF reading SMALLDATETIME".into(),
2082 ));
2083 }
2084 let days = buf.get_u16_le() as i64;
2085 let minutes = buf.get_u16_le() as u32;
2086 #[cfg(feature = "chrono")]
2087 {
2088 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
2089 let date = base + chrono::Duration::days(days);
2090 let time =
2091 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
2092 .unwrap();
2093 SqlValue::DateTime(date.and_time(time))
2094 }
2095 #[cfg(not(feature = "chrono"))]
2096 {
2097 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
2098 }
2099 }
2100
2101 TypeId::Date => {
2103 if buf.remaining() < 1 {
2104 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
2105 }
2106 let len = buf.get_u8() as usize;
2107 if len == 0 {
2108 SqlValue::Null
2109 } else if len != 3 {
2110 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
2111 } else if buf.remaining() < 3 {
2112 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
2113 } else {
2114 let days = buf.get_u8() as u32
2116 | ((buf.get_u8() as u32) << 8)
2117 | ((buf.get_u8() as u32) << 16);
2118 #[cfg(feature = "chrono")]
2119 {
2120 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2121 let date = base + chrono::Duration::days(days as i64);
2122 SqlValue::Date(date)
2123 }
2124 #[cfg(not(feature = "chrono"))]
2125 {
2126 SqlValue::String(format!("DATE({days})"))
2127 }
2128 }
2129 }
2130
2131 TypeId::Time => {
2133 if buf.remaining() < 1 {
2134 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
2135 }
2136 let len = buf.get_u8() as usize;
2137 if len == 0 {
2138 SqlValue::Null
2139 } else if buf.remaining() < len {
2140 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
2141 } else {
2142 let mut time_bytes = [0u8; 8];
2143 for byte in time_bytes.iter_mut().take(len) {
2144 *byte = buf.get_u8();
2145 }
2146 let intervals = u64::from_le_bytes(time_bytes);
2147 #[cfg(feature = "chrono")]
2148 {
2149 let scale = col.type_info.scale.unwrap_or(7);
2150 let time = Self::intervals_to_time(intervals, scale);
2151 SqlValue::Time(time)
2152 }
2153 #[cfg(not(feature = "chrono"))]
2154 {
2155 SqlValue::String(format!("TIME({intervals})"))
2156 }
2157 }
2158 }
2159
2160 TypeId::DateTime2 => {
2162 if buf.remaining() < 1 {
2163 return Err(Error::Protocol(
2164 "unexpected EOF reading DATETIME2 length".into(),
2165 ));
2166 }
2167 let len = buf.get_u8() as usize;
2168 if len == 0 {
2169 SqlValue::Null
2170 } else if buf.remaining() < len {
2171 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
2172 } else {
2173 let scale = col.type_info.scale.unwrap_or(7);
2174 let time_len = Self::time_bytes_for_scale(scale);
2175
2176 let mut time_bytes = [0u8; 8];
2178 for byte in time_bytes.iter_mut().take(time_len) {
2179 *byte = buf.get_u8();
2180 }
2181 let intervals = u64::from_le_bytes(time_bytes);
2182
2183 let days = buf.get_u8() as u32
2185 | ((buf.get_u8() as u32) << 8)
2186 | ((buf.get_u8() as u32) << 16);
2187
2188 #[cfg(feature = "chrono")]
2189 {
2190 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2191 let date = base + chrono::Duration::days(days as i64);
2192 let time = Self::intervals_to_time(intervals, scale);
2193 SqlValue::DateTime(date.and_time(time))
2194 }
2195 #[cfg(not(feature = "chrono"))]
2196 {
2197 SqlValue::String(format!("DATETIME2({days},{intervals})"))
2198 }
2199 }
2200 }
2201
2202 TypeId::DateTimeOffset => {
2204 if buf.remaining() < 1 {
2205 return Err(Error::Protocol(
2206 "unexpected EOF reading DATETIMEOFFSET length".into(),
2207 ));
2208 }
2209 let len = buf.get_u8() as usize;
2210 if len == 0 {
2211 SqlValue::Null
2212 } else if buf.remaining() < len {
2213 return Err(Error::Protocol(
2214 "unexpected EOF reading DATETIMEOFFSET".into(),
2215 ));
2216 } else {
2217 let scale = col.type_info.scale.unwrap_or(7);
2218 let time_len = Self::time_bytes_for_scale(scale);
2219
2220 let mut time_bytes = [0u8; 8];
2222 for byte in time_bytes.iter_mut().take(time_len) {
2223 *byte = buf.get_u8();
2224 }
2225 let intervals = u64::from_le_bytes(time_bytes);
2226
2227 let days = buf.get_u8() as u32
2229 | ((buf.get_u8() as u32) << 8)
2230 | ((buf.get_u8() as u32) << 16);
2231
2232 let offset_minutes = buf.get_i16_le();
2234
2235 #[cfg(feature = "chrono")]
2236 {
2237 use chrono::TimeZone;
2238 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2239 let date = base + chrono::Duration::days(days as i64);
2240 let time = Self::intervals_to_time(intervals, scale);
2241 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
2242 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
2243 let datetime = offset
2244 .from_local_datetime(&date.and_time(time))
2245 .single()
2246 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
2247 SqlValue::DateTimeOffset(datetime)
2248 }
2249 #[cfg(not(feature = "chrono"))]
2250 {
2251 SqlValue::String(format!(
2252 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
2253 ))
2254 }
2255 }
2256 }
2257
2258 TypeId::Text => Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
2260
2261 TypeId::Char | TypeId::VarChar => {
2263 if buf.remaining() < 1 {
2264 return Err(Error::Protocol(
2265 "unexpected EOF reading legacy varchar length".into(),
2266 ));
2267 }
2268 let len = buf.get_u8();
2269 if len == 0xFF {
2270 SqlValue::Null
2271 } else if len == 0 {
2272 SqlValue::String(String::new())
2273 } else if buf.remaining() < len as usize {
2274 return Err(Error::Protocol(
2275 "unexpected EOF reading legacy varchar data".into(),
2276 ));
2277 } else {
2278 let data = &buf[..len as usize];
2279 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2281 buf.advance(len as usize);
2282 SqlValue::String(s)
2283 }
2284 }
2285
2286 TypeId::BigVarChar | TypeId::BigChar => {
2288 if col.type_info.max_length == Some(0xFFFF) {
2290 Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?
2292 } else {
2293 if buf.remaining() < 2 {
2295 return Err(Error::Protocol(
2296 "unexpected EOF reading varchar length".into(),
2297 ));
2298 }
2299 let len = buf.get_u16_le();
2300 if len == 0xFFFF {
2301 SqlValue::Null
2302 } else if buf.remaining() < len as usize {
2303 return Err(Error::Protocol(
2304 "unexpected EOF reading varchar data".into(),
2305 ));
2306 } else {
2307 let data = &buf[..len as usize];
2308 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2310 buf.advance(len as usize);
2311 SqlValue::String(s)
2312 }
2313 }
2314 }
2315
2316 TypeId::NText => Self::parse_plp_nvarchar(buf)?,
2318
2319 TypeId::NVarChar | TypeId::NChar => {
2321 if col.type_info.max_length == Some(0xFFFF) {
2323 Self::parse_plp_nvarchar(buf)?
2325 } else {
2326 if buf.remaining() < 2 {
2328 return Err(Error::Protocol(
2329 "unexpected EOF reading nvarchar length".into(),
2330 ));
2331 }
2332 let len = buf.get_u16_le();
2333 if len == 0xFFFF {
2334 SqlValue::Null
2335 } else if buf.remaining() < len as usize {
2336 return Err(Error::Protocol(
2337 "unexpected EOF reading nvarchar data".into(),
2338 ));
2339 } else {
2340 let data = &buf[..len as usize];
2341 let utf16: Vec<u16> = data
2343 .chunks_exact(2)
2344 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2345 .collect();
2346 let s = String::from_utf16(&utf16)
2347 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
2348 buf.advance(len as usize);
2349 SqlValue::String(s)
2350 }
2351 }
2352 }
2353
2354 TypeId::Image => Self::parse_plp_varbinary(buf)?,
2356
2357 TypeId::Binary | TypeId::VarBinary => {
2359 if buf.remaining() < 1 {
2360 return Err(Error::Protocol(
2361 "unexpected EOF reading legacy varbinary length".into(),
2362 ));
2363 }
2364 let len = buf.get_u8();
2365 if len == 0xFF {
2366 SqlValue::Null
2367 } else if len == 0 {
2368 SqlValue::Binary(bytes::Bytes::new())
2369 } else if buf.remaining() < len as usize {
2370 return Err(Error::Protocol(
2371 "unexpected EOF reading legacy varbinary data".into(),
2372 ));
2373 } else {
2374 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2375 buf.advance(len as usize);
2376 SqlValue::Binary(data)
2377 }
2378 }
2379
2380 TypeId::BigVarBinary | TypeId::BigBinary => {
2382 if col.type_info.max_length == Some(0xFFFF) {
2384 Self::parse_plp_varbinary(buf)?
2386 } else {
2387 if buf.remaining() < 2 {
2388 return Err(Error::Protocol(
2389 "unexpected EOF reading varbinary length".into(),
2390 ));
2391 }
2392 let len = buf.get_u16_le();
2393 if len == 0xFFFF {
2394 SqlValue::Null
2395 } else if buf.remaining() < len as usize {
2396 return Err(Error::Protocol(
2397 "unexpected EOF reading varbinary data".into(),
2398 ));
2399 } else {
2400 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2401 buf.advance(len as usize);
2402 SqlValue::Binary(data)
2403 }
2404 }
2405 }
2406
2407 TypeId::Xml => {
2409 match Self::parse_plp_nvarchar(buf)? {
2411 SqlValue::Null => SqlValue::Null,
2412 SqlValue::String(s) => SqlValue::Xml(s),
2413 _ => {
2414 return Err(Error::Protocol(
2415 "unexpected value type when parsing XML".into(),
2416 ));
2417 }
2418 }
2419 }
2420
2421 TypeId::Guid => {
2423 if buf.remaining() < 1 {
2424 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
2425 }
2426 let len = buf.get_u8();
2427 if len == 0 {
2428 SqlValue::Null
2429 } else if len != 16 {
2430 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
2431 } else if buf.remaining() < 16 {
2432 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
2433 } else {
2434 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
2436 buf.advance(16);
2437 SqlValue::Binary(data)
2438 }
2439 }
2440
2441 TypeId::Variant => Self::parse_sql_variant(buf)?,
2443
2444 TypeId::Udt => Self::parse_plp_varbinary(buf)?,
2446
2447 _ => {
2449 if buf.remaining() < 2 {
2451 return Err(Error::Protocol(format!(
2452 "unexpected EOF reading {:?}",
2453 col.type_id
2454 )));
2455 }
2456 let len = buf.get_u16_le();
2457 if len == 0xFFFF {
2458 SqlValue::Null
2459 } else if buf.remaining() < len as usize {
2460 return Err(Error::Protocol(format!(
2461 "unexpected EOF reading {:?} data",
2462 col.type_id
2463 )));
2464 } else {
2465 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2466 buf.advance(len as usize);
2467 SqlValue::Binary(data)
2468 }
2469 }
2470 };
2471
2472 Ok(value)
2473 }
2474
2475 fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2481 use bytes::Buf;
2482 use mssql_types::SqlValue;
2483
2484 if buf.remaining() < 8 {
2485 return Err(Error::Protocol(
2486 "unexpected EOF reading PLP total length".into(),
2487 ));
2488 }
2489
2490 let total_len = buf.get_u64_le();
2491 if total_len == 0xFFFFFFFFFFFFFFFF {
2492 return Ok(SqlValue::Null);
2493 }
2494
2495 let mut all_data = Vec::new();
2497 loop {
2498 if buf.remaining() < 4 {
2499 return Err(Error::Protocol(
2500 "unexpected EOF reading PLP chunk length".into(),
2501 ));
2502 }
2503 let chunk_len = buf.get_u32_le() as usize;
2504 if chunk_len == 0 {
2505 break; }
2507 if buf.remaining() < chunk_len {
2508 return Err(Error::Protocol(
2509 "unexpected EOF reading PLP chunk data".into(),
2510 ));
2511 }
2512 all_data.extend_from_slice(&buf[..chunk_len]);
2513 buf.advance(chunk_len);
2514 }
2515
2516 let utf16: Vec<u16> = all_data
2518 .chunks_exact(2)
2519 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2520 .collect();
2521 let s = String::from_utf16(&utf16)
2522 .map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
2523 Ok(SqlValue::String(s))
2524 }
2525
2526 #[allow(unused_variables)]
2532 fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
2533 #[cfg(feature = "encoding")]
2535 if let Some(coll) = collation {
2536 if let Some(encoding) = coll.encoding() {
2537 let (decoded, _, had_errors) = encoding.decode(data);
2538 if !had_errors {
2539 return decoded.into_owned();
2540 }
2541 }
2542 }
2543
2544 String::from_utf8_lossy(data).into_owned()
2546 }
2547
2548 fn parse_plp_varchar(
2550 buf: &mut &[u8],
2551 collation: Option<&Collation>,
2552 ) -> Result<mssql_types::SqlValue> {
2553 use bytes::Buf;
2554 use mssql_types::SqlValue;
2555
2556 if buf.remaining() < 8 {
2557 return Err(Error::Protocol(
2558 "unexpected EOF reading PLP total length".into(),
2559 ));
2560 }
2561
2562 let total_len = buf.get_u64_le();
2563 if total_len == 0xFFFFFFFFFFFFFFFF {
2564 return Ok(SqlValue::Null);
2565 }
2566
2567 let mut all_data = Vec::new();
2569 loop {
2570 if buf.remaining() < 4 {
2571 return Err(Error::Protocol(
2572 "unexpected EOF reading PLP chunk length".into(),
2573 ));
2574 }
2575 let chunk_len = buf.get_u32_le() as usize;
2576 if chunk_len == 0 {
2577 break; }
2579 if buf.remaining() < chunk_len {
2580 return Err(Error::Protocol(
2581 "unexpected EOF reading PLP chunk data".into(),
2582 ));
2583 }
2584 all_data.extend_from_slice(&buf[..chunk_len]);
2585 buf.advance(chunk_len);
2586 }
2587
2588 let s = Self::decode_varchar_string(&all_data, collation);
2590 Ok(SqlValue::String(s))
2591 }
2592
2593 fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2595 use bytes::Buf;
2596 use mssql_types::SqlValue;
2597
2598 if buf.remaining() < 8 {
2599 return Err(Error::Protocol(
2600 "unexpected EOF reading PLP total length".into(),
2601 ));
2602 }
2603
2604 let total_len = buf.get_u64_le();
2605 if total_len == 0xFFFFFFFFFFFFFFFF {
2606 return Ok(SqlValue::Null);
2607 }
2608
2609 let mut all_data = Vec::new();
2611 loop {
2612 if buf.remaining() < 4 {
2613 return Err(Error::Protocol(
2614 "unexpected EOF reading PLP chunk length".into(),
2615 ));
2616 }
2617 let chunk_len = buf.get_u32_le() as usize;
2618 if chunk_len == 0 {
2619 break; }
2621 if buf.remaining() < chunk_len {
2622 return Err(Error::Protocol(
2623 "unexpected EOF reading PLP chunk data".into(),
2624 ));
2625 }
2626 all_data.extend_from_slice(&buf[..chunk_len]);
2627 buf.advance(chunk_len);
2628 }
2629
2630 Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
2631 }
2632
2633 fn parse_sql_variant(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2642 use bytes::Buf;
2643 use mssql_types::SqlValue;
2644
2645 if buf.remaining() < 4 {
2647 return Err(Error::Protocol(
2648 "unexpected EOF reading SQL_VARIANT length".into(),
2649 ));
2650 }
2651 let total_len = buf.get_u32_le() as usize;
2652
2653 if total_len == 0 {
2654 return Ok(SqlValue::Null);
2655 }
2656
2657 if buf.remaining() < total_len {
2658 return Err(Error::Protocol(
2659 "unexpected EOF reading SQL_VARIANT data".into(),
2660 ));
2661 }
2662
2663 if total_len < 2 {
2665 return Err(Error::Protocol(
2666 "SQL_VARIANT too short for type info".into(),
2667 ));
2668 }
2669
2670 let base_type = buf.get_u8();
2671 let prop_count = buf.get_u8() as usize;
2672
2673 if buf.remaining() < prop_count {
2674 return Err(Error::Protocol(
2675 "unexpected EOF reading SQL_VARIANT properties".into(),
2676 ));
2677 }
2678
2679 let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
2681
2682 match base_type {
2685 0x30 => {
2687 buf.advance(prop_count);
2689 if data_len < 1 {
2690 return Ok(SqlValue::Null);
2691 }
2692 let v = buf.get_u8();
2693 Ok(SqlValue::TinyInt(v))
2694 }
2695 0x32 => {
2696 buf.advance(prop_count);
2698 if data_len < 1 {
2699 return Ok(SqlValue::Null);
2700 }
2701 let v = buf.get_u8();
2702 Ok(SqlValue::Bool(v != 0))
2703 }
2704 0x34 => {
2705 buf.advance(prop_count);
2707 if data_len < 2 {
2708 return Ok(SqlValue::Null);
2709 }
2710 let v = buf.get_i16_le();
2711 Ok(SqlValue::SmallInt(v))
2712 }
2713 0x38 => {
2714 buf.advance(prop_count);
2716 if data_len < 4 {
2717 return Ok(SqlValue::Null);
2718 }
2719 let v = buf.get_i32_le();
2720 Ok(SqlValue::Int(v))
2721 }
2722 0x7F => {
2723 buf.advance(prop_count);
2725 if data_len < 8 {
2726 return Ok(SqlValue::Null);
2727 }
2728 let v = buf.get_i64_le();
2729 Ok(SqlValue::BigInt(v))
2730 }
2731 0x6D => {
2732 let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2734 buf.advance(prop_count.saturating_sub(1));
2735
2736 if float_len == 4 && data_len >= 4 {
2737 let v = buf.get_f32_le();
2738 Ok(SqlValue::Float(v))
2739 } else if data_len >= 8 {
2740 let v = buf.get_f64_le();
2741 Ok(SqlValue::Double(v))
2742 } else {
2743 Ok(SqlValue::Null)
2744 }
2745 }
2746 0x6E => {
2747 let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2749 buf.advance(prop_count.saturating_sub(1));
2750
2751 if money_len == 0 || data_len == 0 {
2752 Ok(SqlValue::Null)
2753 } else if (money_len == 4 && data_len >= 4) || (money_len == 8 && data_len >= 8) {
2754 Self::parse_money_value(buf, money_len as usize)
2755 } else {
2756 buf.advance(data_len);
2757 Ok(SqlValue::Null)
2758 }
2759 }
2760 0x6F => {
2761 #[cfg(feature = "chrono")]
2763 let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2764 #[cfg(not(feature = "chrono"))]
2765 if prop_count >= 1 {
2766 buf.get_u8();
2767 }
2768 buf.advance(prop_count.saturating_sub(1));
2769
2770 #[cfg(feature = "chrono")]
2771 {
2772 use chrono::NaiveDate;
2773 if dt_len == 4 && data_len >= 4 {
2774 let days = buf.get_u16_le() as i64;
2776 let mins = buf.get_u16_le() as u32;
2777 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2778 .unwrap()
2779 .and_hms_opt(0, 0, 0)
2780 .unwrap();
2781 let dt = base
2782 + chrono::Duration::days(days)
2783 + chrono::Duration::minutes(mins as i64);
2784 Ok(SqlValue::DateTime(dt))
2785 } else if data_len >= 8 {
2786 let days = buf.get_i32_le() as i64;
2788 let ticks = buf.get_u32_le() as i64;
2789 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2790 .unwrap()
2791 .and_hms_opt(0, 0, 0)
2792 .unwrap();
2793 let millis = (ticks * 10) / 3;
2794 let dt = base
2795 + chrono::Duration::days(days)
2796 + chrono::Duration::milliseconds(millis);
2797 Ok(SqlValue::DateTime(dt))
2798 } else {
2799 Ok(SqlValue::Null)
2800 }
2801 }
2802 #[cfg(not(feature = "chrono"))]
2803 {
2804 buf.advance(data_len);
2805 Ok(SqlValue::Null)
2806 }
2807 }
2808 0x6A | 0x6C => {
2809 let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
2811 let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
2812 buf.advance(prop_count.saturating_sub(2));
2813
2814 if data_len < 1 {
2815 return Ok(SqlValue::Null);
2816 }
2817
2818 let sign = buf.get_u8();
2819 let mantissa_len = data_len - 1;
2820
2821 if mantissa_len > 16 {
2822 buf.advance(mantissa_len);
2824 return Ok(SqlValue::Null);
2825 }
2826
2827 let mut mantissa_bytes = [0u8; 16];
2828 for i in 0..mantissa_len.min(16) {
2829 mantissa_bytes[i] = buf.get_u8();
2830 }
2831 let mantissa = u128::from_le_bytes(mantissa_bytes);
2832
2833 #[cfg(feature = "decimal")]
2834 {
2835 use rust_decimal::Decimal;
2836 if scale > 28 {
2837 let divisor = 10f64.powi(scale as i32);
2839 let value = (mantissa as f64) / divisor;
2840 let value = if sign == 0 { -value } else { value };
2841 Ok(SqlValue::Double(value))
2842 } else {
2843 let mut decimal =
2844 Decimal::from_i128_with_scale(mantissa as i128, scale as u32);
2845 if sign == 0 {
2846 decimal.set_sign_negative(true);
2847 }
2848 Ok(SqlValue::Decimal(decimal))
2849 }
2850 }
2851 #[cfg(not(feature = "decimal"))]
2852 {
2853 let divisor = 10f64.powi(scale as i32);
2854 let value = (mantissa as f64) / divisor;
2855 let value = if sign == 0 { -value } else { value };
2856 Ok(SqlValue::Double(value))
2857 }
2858 }
2859 0x24 => {
2860 buf.advance(prop_count);
2862 if data_len < 16 {
2863 return Ok(SqlValue::Null);
2864 }
2865 let mut guid_bytes = [0u8; 16];
2866 for byte in &mut guid_bytes {
2867 *byte = buf.get_u8();
2868 }
2869 Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(&guid_bytes)))
2870 }
2871 0x28 => {
2872 buf.advance(prop_count);
2874 #[cfg(feature = "chrono")]
2875 {
2876 if data_len < 3 {
2877 return Ok(SqlValue::Null);
2878 }
2879 let mut date_bytes = [0u8; 4];
2880 date_bytes[0] = buf.get_u8();
2881 date_bytes[1] = buf.get_u8();
2882 date_bytes[2] = buf.get_u8();
2883 let days = u32::from_le_bytes(date_bytes);
2884 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2885 let date = base + chrono::Duration::days(days as i64);
2886 Ok(SqlValue::Date(date))
2887 }
2888 #[cfg(not(feature = "chrono"))]
2889 {
2890 buf.advance(data_len);
2891 Ok(SqlValue::Null)
2892 }
2893 }
2894 0xA7 | 0x2F | 0x27 => {
2895 let collation = if prop_count >= 5 && buf.remaining() >= 5 {
2898 let lcid = buf.get_u32_le();
2899 let sort_id = buf.get_u8();
2900 buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
2902 } else {
2903 buf.advance(prop_count);
2904 None
2905 };
2906 if data_len == 0 {
2907 return Ok(SqlValue::String(String::new()));
2908 }
2909 let data = &buf[..data_len];
2910 let s = Self::decode_varchar_string(data, collation.as_ref());
2912 buf.advance(data_len);
2913 Ok(SqlValue::String(s))
2914 }
2915 0xE7 | 0xEF => {
2916 buf.advance(prop_count);
2918 if data_len == 0 {
2919 return Ok(SqlValue::String(String::new()));
2920 }
2921 let utf16: Vec<u16> = buf[..data_len]
2923 .chunks_exact(2)
2924 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2925 .collect();
2926 buf.advance(data_len);
2927 let s = String::from_utf16(&utf16).map_err(|_| {
2928 Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into())
2929 })?;
2930 Ok(SqlValue::String(s))
2931 }
2932 0xA5 | 0x2D | 0x25 => {
2933 buf.advance(prop_count);
2935 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2936 buf.advance(data_len);
2937 Ok(SqlValue::Binary(data))
2938 }
2939 _ => {
2940 buf.advance(prop_count);
2942 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2943 buf.advance(data_len);
2944 Ok(SqlValue::Binary(data))
2945 }
2946 }
2947 }
2948
2949 fn time_bytes_for_scale(scale: u8) -> usize {
2951 match scale {
2952 0..=2 => 3,
2953 3..=4 => 4,
2954 5..=7 => 5,
2955 _ => 5, }
2957 }
2958
2959 #[cfg(feature = "chrono")]
2961 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2962 let nanos = match scale {
2972 0 => intervals * 1_000_000_000,
2973 1 => intervals * 100_000_000,
2974 2 => intervals * 10_000_000,
2975 3 => intervals * 1_000_000,
2976 4 => intervals * 100_000,
2977 5 => intervals * 10_000,
2978 6 => intervals * 1_000,
2979 7 => intervals * 100,
2980 _ => intervals * 100,
2981 };
2982
2983 let secs = (nanos / 1_000_000_000) as u32;
2984 let nano_part = (nanos % 1_000_000_000) as u32;
2985
2986 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2987 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2988 }
2989
2990 async fn read_execute_result(&mut self) -> Result<u64> {
2992 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2993
2994 let message = match connection {
2995 #[cfg(feature = "tls")]
2996 ConnectionHandle::Tls(conn) => conn
2997 .read_message()
2998 .await
2999 .map_err(|e| Error::Protocol(e.to_string()))?,
3000 #[cfg(feature = "tls")]
3001 ConnectionHandle::TlsPrelogin(conn) => conn
3002 .read_message()
3003 .await
3004 .map_err(|e| Error::Protocol(e.to_string()))?,
3005 ConnectionHandle::Plain(conn) => conn
3006 .read_message()
3007 .await
3008 .map_err(|e| Error::Protocol(e.to_string()))?,
3009 }
3010 .ok_or(Error::ConnectionClosed)?;
3011
3012 let mut parser = TokenParser::new(message.payload);
3013 let mut rows_affected = 0u64;
3014 let mut current_metadata: Option<ColMetaData> = None;
3015
3016 loop {
3017 let token = parser
3019 .next_token_with_metadata(current_metadata.as_ref())
3020 .map_err(|e| Error::Protocol(e.to_string()))?;
3021
3022 let Some(token) = token else {
3023 break;
3024 };
3025
3026 match token {
3027 Token::ColMetaData(meta) => {
3028 current_metadata = Some(meta);
3030 }
3031 Token::Row(_) | Token::NbcRow(_) => {
3032 }
3035 Token::Done(done) => {
3036 if done.status.error {
3037 return Err(Error::Query("execution failed".to_string()));
3038 }
3039 if done.status.count {
3040 rows_affected += done.row_count;
3042 }
3043 if !done.status.more {
3046 break;
3047 }
3048 }
3049 Token::DoneProc(done) => {
3050 if done.status.count {
3051 rows_affected += done.row_count;
3052 }
3053 }
3054 Token::DoneInProc(done) => {
3055 if done.status.count {
3056 rows_affected += done.row_count;
3057 }
3058 }
3059 Token::Error(err) => {
3060 return Err(Error::Server {
3061 number: err.number,
3062 state: err.state,
3063 class: err.class,
3064 message: err.message.clone(),
3065 server: if err.server.is_empty() {
3066 None
3067 } else {
3068 Some(err.server.clone())
3069 },
3070 procedure: if err.procedure.is_empty() {
3071 None
3072 } else {
3073 Some(err.procedure.clone())
3074 },
3075 line: err.line as u32,
3076 });
3077 }
3078 Token::Info(info) => {
3079 tracing::info!(
3080 number = info.number,
3081 message = %info.message,
3082 "server info message"
3083 );
3084 }
3085 Token::EnvChange(env) => {
3086 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
3090 }
3091 _ => {}
3092 }
3093 }
3094
3095 Ok(rows_affected)
3096 }
3097
3098 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
3104 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3105
3106 let message = match connection {
3107 #[cfg(feature = "tls")]
3108 ConnectionHandle::Tls(conn) => conn
3109 .read_message()
3110 .await
3111 .map_err(|e| Error::Protocol(e.to_string()))?,
3112 #[cfg(feature = "tls")]
3113 ConnectionHandle::TlsPrelogin(conn) => conn
3114 .read_message()
3115 .await
3116 .map_err(|e| Error::Protocol(e.to_string()))?,
3117 ConnectionHandle::Plain(conn) => conn
3118 .read_message()
3119 .await
3120 .map_err(|e| Error::Protocol(e.to_string()))?,
3121 }
3122 .ok_or(Error::ConnectionClosed)?;
3123
3124 let mut parser = TokenParser::new(message.payload);
3125 let mut transaction_descriptor: u64 = 0;
3126
3127 loop {
3128 let token = parser
3129 .next_token()
3130 .map_err(|e| Error::Protocol(e.to_string()))?;
3131
3132 let Some(token) = token else {
3133 break;
3134 };
3135
3136 match token {
3137 Token::EnvChange(env) => {
3138 if env.env_type == EnvChangeType::BeginTransaction {
3139 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
3142 {
3143 if data.len() >= 8 {
3144 transaction_descriptor = u64::from_le_bytes([
3145 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
3146 data[7],
3147 ]);
3148 tracing::debug!(
3149 transaction_descriptor =
3150 format!("0x{:016X}", transaction_descriptor),
3151 "transaction begun"
3152 );
3153 }
3154 }
3155 }
3156 }
3157 Token::Done(done) => {
3158 if done.status.error {
3159 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
3160 }
3161 break;
3162 }
3163 Token::Error(err) => {
3164 return Err(Error::Server {
3165 number: err.number,
3166 state: err.state,
3167 class: err.class,
3168 message: err.message.clone(),
3169 server: if err.server.is_empty() {
3170 None
3171 } else {
3172 Some(err.server.clone())
3173 },
3174 procedure: if err.procedure.is_empty() {
3175 None
3176 } else {
3177 Some(err.procedure.clone())
3178 },
3179 line: err.line as u32,
3180 });
3181 }
3182 Token::Info(info) => {
3183 tracing::info!(
3184 number = info.number,
3185 message = %info.message,
3186 "server info message"
3187 );
3188 }
3189 _ => {}
3190 }
3191 }
3192
3193 Ok(transaction_descriptor)
3194 }
3195}
3196
3197impl Client<Ready> {
3198 pub fn mark_needs_reset(&mut self) {
3209 self.needs_reset = true;
3210 }
3211
3212 #[must_use]
3217 pub fn needs_reset(&self) -> bool {
3218 self.needs_reset
3219 }
3220
3221 pub async fn query<'a>(
3246 &'a mut self,
3247 sql: &str,
3248 params: &[&(dyn crate::ToSql + Sync)],
3249 ) -> Result<QueryStream<'a>> {
3250 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
3251
3252 #[cfg(feature = "otel")]
3253 let instrumentation = self.instrumentation.clone();
3254 #[cfg(feature = "otel")]
3255 let mut span = instrumentation.query_span(sql);
3256
3257 let result = async {
3258 if params.is_empty() {
3259 self.send_sql_batch(sql).await?;
3261 } else {
3262 let rpc_params = Self::convert_params(params)?;
3264 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3265 self.send_rpc(&rpc).await?;
3266 }
3267
3268 self.read_query_response().await
3270 }
3271 .await;
3272
3273 #[cfg(feature = "otel")]
3274 match &result {
3275 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3276 Err(e) => InstrumentationContext::record_error(&mut span, e),
3277 }
3278
3279 #[cfg(feature = "otel")]
3281 drop(span);
3282
3283 let (columns, rows) = result?;
3284 Ok(QueryStream::new(columns, rows))
3285 }
3286
3287 pub async fn query_with_timeout<'a>(
3314 &'a mut self,
3315 sql: &str,
3316 params: &[&(dyn crate::ToSql + Sync)],
3317 timeout_duration: std::time::Duration,
3318 ) -> Result<QueryStream<'a>> {
3319 timeout(timeout_duration, self.query(sql, params))
3320 .await
3321 .map_err(|_| Error::CommandTimeout)?
3322 }
3323
3324 pub async fn query_multiple<'a>(
3351 &'a mut self,
3352 sql: &str,
3353 params: &[&(dyn crate::ToSql + Sync)],
3354 ) -> Result<MultiResultStream<'a>> {
3355 tracing::debug!(
3356 sql = sql,
3357 params_count = params.len(),
3358 "executing multi-result query"
3359 );
3360
3361 if params.is_empty() {
3362 self.send_sql_batch(sql).await?;
3364 } else {
3365 let rpc_params = Self::convert_params(params)?;
3367 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3368 self.send_rpc(&rpc).await?;
3369 }
3370
3371 let result_sets = self.read_multi_result_response().await?;
3373 Ok(MultiResultStream::new(result_sets))
3374 }
3375
3376 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
3378 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3379
3380 let message = match connection {
3381 #[cfg(feature = "tls")]
3382 ConnectionHandle::Tls(conn) => conn
3383 .read_message()
3384 .await
3385 .map_err(|e| Error::Protocol(e.to_string()))?,
3386 #[cfg(feature = "tls")]
3387 ConnectionHandle::TlsPrelogin(conn) => conn
3388 .read_message()
3389 .await
3390 .map_err(|e| Error::Protocol(e.to_string()))?,
3391 ConnectionHandle::Plain(conn) => conn
3392 .read_message()
3393 .await
3394 .map_err(|e| Error::Protocol(e.to_string()))?,
3395 }
3396 .ok_or(Error::ConnectionClosed)?;
3397
3398 let mut parser = TokenParser::new(message.payload);
3399 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
3400 let mut current_columns: Vec<crate::row::Column> = Vec::new();
3401 let mut current_rows: Vec<crate::row::Row> = Vec::new();
3402 let mut protocol_metadata: Option<ColMetaData> = None;
3403
3404 loop {
3405 let token = parser
3406 .next_token_with_metadata(protocol_metadata.as_ref())
3407 .map_err(|e| Error::Protocol(e.to_string()))?;
3408
3409 let Some(token) = token else {
3410 break;
3411 };
3412
3413 match token {
3414 Token::ColMetaData(meta) => {
3415 if !current_columns.is_empty() {
3417 result_sets.push(crate::stream::ResultSet::new(
3418 std::mem::take(&mut current_columns),
3419 std::mem::take(&mut current_rows),
3420 ));
3421 }
3422
3423 current_columns = meta
3425 .columns
3426 .iter()
3427 .enumerate()
3428 .map(|(i, col)| {
3429 let type_name = format!("{:?}", col.type_id);
3430 let mut column = crate::row::Column::new(&col.name, i, type_name)
3431 .with_nullable(col.flags & 0x01 != 0);
3432
3433 if let Some(max_len) = col.type_info.max_length {
3434 column = column.with_max_length(max_len);
3435 }
3436 if let (Some(prec), Some(scale)) =
3437 (col.type_info.precision, col.type_info.scale)
3438 {
3439 column = column.with_precision_scale(prec, scale);
3440 }
3441 if let Some(collation) = col.type_info.collation {
3444 column = column.with_collation(collation);
3445 }
3446 column
3447 })
3448 .collect();
3449
3450 tracing::debug!(
3451 columns = current_columns.len(),
3452 result_set = result_sets.len(),
3453 "received column metadata for result set"
3454 );
3455 protocol_metadata = Some(meta);
3456 }
3457 Token::Row(raw_row) => {
3458 if let Some(ref meta) = protocol_metadata {
3459 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
3460 current_rows.push(row);
3461 }
3462 }
3463 Token::NbcRow(nbc_row) => {
3464 if let Some(ref meta) = protocol_metadata {
3465 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
3466 current_rows.push(row);
3467 }
3468 }
3469 Token::Error(err) => {
3470 return Err(Error::Server {
3471 number: err.number,
3472 state: err.state,
3473 class: err.class,
3474 message: err.message.clone(),
3475 server: if err.server.is_empty() {
3476 None
3477 } else {
3478 Some(err.server.clone())
3479 },
3480 procedure: if err.procedure.is_empty() {
3481 None
3482 } else {
3483 Some(err.procedure.clone())
3484 },
3485 line: err.line as u32,
3486 });
3487 }
3488 Token::Done(done) => {
3489 if done.status.error {
3490 return Err(Error::Query("query failed".to_string()));
3491 }
3492
3493 if !current_columns.is_empty() {
3495 result_sets.push(crate::stream::ResultSet::new(
3496 std::mem::take(&mut current_columns),
3497 std::mem::take(&mut current_rows),
3498 ));
3499 protocol_metadata = None;
3500 }
3501
3502 if !done.status.more {
3504 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
3505 break;
3506 }
3507 }
3508 Token::DoneInProc(done) => {
3509 if done.status.error {
3510 return Err(Error::Query("query failed".to_string()));
3511 }
3512
3513 if !current_columns.is_empty() {
3515 result_sets.push(crate::stream::ResultSet::new(
3516 std::mem::take(&mut current_columns),
3517 std::mem::take(&mut current_rows),
3518 ));
3519 protocol_metadata = None;
3520 }
3521
3522 if !done.status.more {
3524 }
3526 }
3527 Token::DoneProc(done) => {
3528 if done.status.error {
3529 return Err(Error::Query("query failed".to_string()));
3530 }
3531 }
3533 Token::Info(info) => {
3534 tracing::debug!(
3535 number = info.number,
3536 message = %info.message,
3537 "server info message"
3538 );
3539 }
3540 _ => {}
3541 }
3542 }
3543
3544 if !current_columns.is_empty() {
3546 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
3547 }
3548
3549 Ok(result_sets)
3550 }
3551
3552 pub async fn execute(
3556 &mut self,
3557 sql: &str,
3558 params: &[&(dyn crate::ToSql + Sync)],
3559 ) -> Result<u64> {
3560 tracing::debug!(
3561 sql = sql,
3562 params_count = params.len(),
3563 "executing statement"
3564 );
3565
3566 #[cfg(feature = "otel")]
3567 let instrumentation = self.instrumentation.clone();
3568 #[cfg(feature = "otel")]
3569 let mut span = instrumentation.query_span(sql);
3570
3571 let result = async {
3572 if params.is_empty() {
3573 self.send_sql_batch(sql).await?;
3575 } else {
3576 let rpc_params = Self::convert_params(params)?;
3578 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3579 self.send_rpc(&rpc).await?;
3580 }
3581
3582 self.read_execute_result().await
3584 }
3585 .await;
3586
3587 #[cfg(feature = "otel")]
3588 match &result {
3589 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3590 Err(e) => InstrumentationContext::record_error(&mut span, e),
3591 }
3592
3593 #[cfg(feature = "otel")]
3595 drop(span);
3596
3597 result
3598 }
3599
3600 pub async fn execute_with_timeout(
3627 &mut self,
3628 sql: &str,
3629 params: &[&(dyn crate::ToSql + Sync)],
3630 timeout_duration: std::time::Duration,
3631 ) -> Result<u64> {
3632 timeout(timeout_duration, self.execute(sql, params))
3633 .await
3634 .map_err(|_| Error::CommandTimeout)?
3635 }
3636
3637 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
3644 tracing::debug!("beginning transaction");
3645
3646 #[cfg(feature = "otel")]
3647 let instrumentation = self.instrumentation.clone();
3648 #[cfg(feature = "otel")]
3649 let mut span = instrumentation.transaction_span("BEGIN");
3650
3651 let result = async {
3653 self.send_sql_batch("BEGIN TRANSACTION").await?;
3654 self.read_transaction_begin_result().await
3655 }
3656 .await;
3657
3658 #[cfg(feature = "otel")]
3659 match &result {
3660 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3661 Err(e) => InstrumentationContext::record_error(&mut span, e),
3662 }
3663
3664 #[cfg(feature = "otel")]
3666 drop(span);
3667
3668 let transaction_descriptor = result?;
3669
3670 Ok(Client {
3671 config: self.config,
3672 _state: PhantomData,
3673 connection: self.connection,
3674 server_version: self.server_version,
3675 current_database: self.current_database,
3676 statement_cache: self.statement_cache,
3677 transaction_descriptor, needs_reset: self.needs_reset,
3679 #[cfg(feature = "otel")]
3680 instrumentation: self.instrumentation,
3681 })
3682 }
3683
3684 pub async fn begin_transaction_with_isolation(
3699 mut self,
3700 isolation_level: crate::transaction::IsolationLevel,
3701 ) -> Result<Client<InTransaction>> {
3702 tracing::debug!(
3703 isolation_level = %isolation_level.name(),
3704 "beginning transaction with isolation level"
3705 );
3706
3707 #[cfg(feature = "otel")]
3708 let instrumentation = self.instrumentation.clone();
3709 #[cfg(feature = "otel")]
3710 let mut span = instrumentation.transaction_span("BEGIN");
3711
3712 let result = async {
3714 self.send_sql_batch(isolation_level.as_sql()).await?;
3715 self.read_execute_result().await?;
3716
3717 self.send_sql_batch("BEGIN TRANSACTION").await?;
3719 self.read_transaction_begin_result().await
3720 }
3721 .await;
3722
3723 #[cfg(feature = "otel")]
3724 match &result {
3725 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3726 Err(e) => InstrumentationContext::record_error(&mut span, e),
3727 }
3728
3729 #[cfg(feature = "otel")]
3730 drop(span);
3731
3732 let transaction_descriptor = result?;
3733
3734 Ok(Client {
3735 config: self.config,
3736 _state: PhantomData,
3737 connection: self.connection,
3738 server_version: self.server_version,
3739 current_database: self.current_database,
3740 statement_cache: self.statement_cache,
3741 transaction_descriptor,
3742 needs_reset: self.needs_reset,
3743 #[cfg(feature = "otel")]
3744 instrumentation: self.instrumentation,
3745 })
3746 }
3747
3748 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
3753 tracing::debug!(sql = sql, "executing simple query");
3754
3755 self.send_sql_batch(sql).await?;
3757
3758 let _ = self.read_execute_result().await?;
3760
3761 Ok(())
3762 }
3763
3764 pub async fn close(self) -> Result<()> {
3766 tracing::debug!("closing connection");
3767 Ok(())
3768 }
3769
3770 #[must_use]
3772 pub fn database(&self) -> Option<&str> {
3773 self.config.database.as_deref()
3774 }
3775
3776 #[must_use]
3778 pub fn host(&self) -> &str {
3779 &self.config.host
3780 }
3781
3782 #[must_use]
3784 pub fn port(&self) -> u16 {
3785 self.config.port
3786 }
3787
3788 #[must_use]
3807 pub fn is_in_transaction(&self) -> bool {
3808 self.transaction_descriptor != 0
3809 }
3810
3811 #[must_use]
3833 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3834 let connection = self
3835 .connection
3836 .as_ref()
3837 .expect("connection should be present");
3838 match connection {
3839 #[cfg(feature = "tls")]
3840 ConnectionHandle::Tls(conn) => {
3841 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3842 }
3843 #[cfg(feature = "tls")]
3844 ConnectionHandle::TlsPrelogin(conn) => {
3845 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3846 }
3847 ConnectionHandle::Plain(conn) => {
3848 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3849 }
3850 }
3851 }
3852}
3853
3854impl Client<InTransaction> {
3855 pub async fn query<'a>(
3859 &'a mut self,
3860 sql: &str,
3861 params: &[&(dyn crate::ToSql + Sync)],
3862 ) -> Result<QueryStream<'a>> {
3863 tracing::debug!(
3864 sql = sql,
3865 params_count = params.len(),
3866 "executing query in transaction"
3867 );
3868
3869 #[cfg(feature = "otel")]
3870 let instrumentation = self.instrumentation.clone();
3871 #[cfg(feature = "otel")]
3872 let mut span = instrumentation.query_span(sql);
3873
3874 let result = async {
3875 if params.is_empty() {
3876 self.send_sql_batch(sql).await?;
3878 } else {
3879 let rpc_params = Self::convert_params(params)?;
3881 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3882 self.send_rpc(&rpc).await?;
3883 }
3884
3885 self.read_query_response().await
3887 }
3888 .await;
3889
3890 #[cfg(feature = "otel")]
3891 match &result {
3892 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3893 Err(e) => InstrumentationContext::record_error(&mut span, e),
3894 }
3895
3896 #[cfg(feature = "otel")]
3898 drop(span);
3899
3900 let (columns, rows) = result?;
3901 Ok(QueryStream::new(columns, rows))
3902 }
3903
3904 pub async fn execute(
3908 &mut self,
3909 sql: &str,
3910 params: &[&(dyn crate::ToSql + Sync)],
3911 ) -> Result<u64> {
3912 tracing::debug!(
3913 sql = sql,
3914 params_count = params.len(),
3915 "executing statement in transaction"
3916 );
3917
3918 #[cfg(feature = "otel")]
3919 let instrumentation = self.instrumentation.clone();
3920 #[cfg(feature = "otel")]
3921 let mut span = instrumentation.query_span(sql);
3922
3923 let result = async {
3924 if params.is_empty() {
3925 self.send_sql_batch(sql).await?;
3927 } else {
3928 let rpc_params = Self::convert_params(params)?;
3930 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3931 self.send_rpc(&rpc).await?;
3932 }
3933
3934 self.read_execute_result().await
3936 }
3937 .await;
3938
3939 #[cfg(feature = "otel")]
3940 match &result {
3941 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3942 Err(e) => InstrumentationContext::record_error(&mut span, e),
3943 }
3944
3945 #[cfg(feature = "otel")]
3947 drop(span);
3948
3949 result
3950 }
3951
3952 pub async fn query_with_timeout<'a>(
3956 &'a mut self,
3957 sql: &str,
3958 params: &[&(dyn crate::ToSql + Sync)],
3959 timeout_duration: std::time::Duration,
3960 ) -> Result<QueryStream<'a>> {
3961 timeout(timeout_duration, self.query(sql, params))
3962 .await
3963 .map_err(|_| Error::CommandTimeout)?
3964 }
3965
3966 pub async fn execute_with_timeout(
3970 &mut self,
3971 sql: &str,
3972 params: &[&(dyn crate::ToSql + Sync)],
3973 timeout_duration: std::time::Duration,
3974 ) -> Result<u64> {
3975 timeout(timeout_duration, self.execute(sql, params))
3976 .await
3977 .map_err(|_| Error::CommandTimeout)?
3978 }
3979
3980 pub async fn commit(mut self) -> Result<Client<Ready>> {
3984 tracing::debug!("committing transaction");
3985
3986 #[cfg(feature = "otel")]
3987 let instrumentation = self.instrumentation.clone();
3988 #[cfg(feature = "otel")]
3989 let mut span = instrumentation.transaction_span("COMMIT");
3990
3991 let result = async {
3993 self.send_sql_batch("COMMIT TRANSACTION").await?;
3994 self.read_execute_result().await
3995 }
3996 .await;
3997
3998 #[cfg(feature = "otel")]
3999 match &result {
4000 Ok(_) => InstrumentationContext::record_success(&mut span, None),
4001 Err(e) => InstrumentationContext::record_error(&mut span, e),
4002 }
4003
4004 #[cfg(feature = "otel")]
4006 drop(span);
4007
4008 result?;
4009
4010 Ok(Client {
4011 config: self.config,
4012 _state: PhantomData,
4013 connection: self.connection,
4014 server_version: self.server_version,
4015 current_database: self.current_database,
4016 statement_cache: self.statement_cache,
4017 transaction_descriptor: 0, needs_reset: self.needs_reset,
4019 #[cfg(feature = "otel")]
4020 instrumentation: self.instrumentation,
4021 })
4022 }
4023
4024 pub async fn rollback(mut self) -> Result<Client<Ready>> {
4028 tracing::debug!("rolling back transaction");
4029
4030 #[cfg(feature = "otel")]
4031 let instrumentation = self.instrumentation.clone();
4032 #[cfg(feature = "otel")]
4033 let mut span = instrumentation.transaction_span("ROLLBACK");
4034
4035 let result = async {
4037 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
4038 self.read_execute_result().await
4039 }
4040 .await;
4041
4042 #[cfg(feature = "otel")]
4043 match &result {
4044 Ok(_) => InstrumentationContext::record_success(&mut span, None),
4045 Err(e) => InstrumentationContext::record_error(&mut span, e),
4046 }
4047
4048 #[cfg(feature = "otel")]
4050 drop(span);
4051
4052 result?;
4053
4054 Ok(Client {
4055 config: self.config,
4056 _state: PhantomData,
4057 connection: self.connection,
4058 server_version: self.server_version,
4059 current_database: self.current_database,
4060 statement_cache: self.statement_cache,
4061 transaction_descriptor: 0, needs_reset: self.needs_reset,
4063 #[cfg(feature = "otel")]
4064 instrumentation: self.instrumentation,
4065 })
4066 }
4067
4068 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
4085 validate_identifier(name)?;
4086 tracing::debug!(name = name, "creating savepoint");
4087
4088 let sql = format!("SAVE TRANSACTION {}", name);
4091 self.send_sql_batch(&sql).await?;
4092 self.read_execute_result().await?;
4093
4094 Ok(SavePoint::new(name.to_string()))
4095 }
4096
4097 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
4112 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
4113
4114 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
4117 self.send_sql_batch(&sql).await?;
4118 self.read_execute_result().await?;
4119
4120 Ok(())
4121 }
4122
4123 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
4129 tracing::debug!(name = savepoint.name(), "releasing savepoint");
4130
4131 drop(savepoint);
4135 Ok(())
4136 }
4137
4138 #[must_use]
4142 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
4143 let connection = self
4144 .connection
4145 .as_ref()
4146 .expect("connection should be present");
4147 match connection {
4148 #[cfg(feature = "tls")]
4149 ConnectionHandle::Tls(conn) => {
4150 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
4151 }
4152 #[cfg(feature = "tls")]
4153 ConnectionHandle::TlsPrelogin(conn) => {
4154 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
4155 }
4156 ConnectionHandle::Plain(conn) => {
4157 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
4158 }
4159 }
4160 }
4161}
4162
4163fn validate_identifier(name: &str) -> Result<()> {
4165 use once_cell::sync::Lazy;
4166 use regex::Regex;
4167
4168 static IDENTIFIER_RE: Lazy<Regex> =
4169 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
4170
4171 if name.is_empty() {
4172 return Err(Error::InvalidIdentifier(
4173 "identifier cannot be empty".into(),
4174 ));
4175 }
4176
4177 if !IDENTIFIER_RE.is_match(name) {
4178 return Err(Error::InvalidIdentifier(format!(
4179 "invalid identifier '{}': must start with letter/underscore, \
4180 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
4181 name
4182 )));
4183 }
4184
4185 Ok(())
4186}
4187
4188impl<S: ConnectionState> std::fmt::Debug for Client<S> {
4189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4190 f.debug_struct("Client")
4191 .field("host", &self.config.host)
4192 .field("port", &self.config.port)
4193 .field("database", &self.config.database)
4194 .finish()
4195 }
4196}
4197
4198#[cfg(test)]
4199#[allow(clippy::unwrap_used, clippy::panic)]
4200mod tests {
4201 use super::*;
4202
4203 #[test]
4204 fn test_validate_identifier_valid() {
4205 assert!(validate_identifier("my_table").is_ok());
4206 assert!(validate_identifier("Table123").is_ok());
4207 assert!(validate_identifier("_private").is_ok());
4208 assert!(validate_identifier("sp_test").is_ok());
4209 }
4210
4211 #[test]
4212 fn test_validate_identifier_invalid() {
4213 assert!(validate_identifier("").is_err());
4214 assert!(validate_identifier("123abc").is_err());
4215 assert!(validate_identifier("table-name").is_err());
4216 assert!(validate_identifier("table name").is_err());
4217 assert!(validate_identifier("table;DROP TABLE users").is_err());
4218 }
4219
4220 fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
4229 let mut data = Vec::new();
4230 data.extend_from_slice(&total_len.to_le_bytes());
4232 for chunk in chunks {
4234 let len = chunk.len() as u32;
4235 data.extend_from_slice(&len.to_le_bytes());
4236 data.extend_from_slice(chunk);
4237 }
4238 data.extend_from_slice(&0u32.to_le_bytes());
4240 data
4241 }
4242
4243 #[test]
4244 fn test_parse_plp_nvarchar_simple() {
4245 let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
4247 let plp = make_plp_data(10, &[&utf16_data]);
4248 let mut buf: &[u8] = &plp;
4249
4250 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4251 match result {
4252 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
4253 _ => panic!("expected String, got {:?}", result),
4254 }
4255 }
4256
4257 #[test]
4258 fn test_parse_plp_nvarchar_null() {
4259 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4261 let mut buf: &[u8] = &plp;
4262
4263 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4264 assert!(matches!(result, mssql_types::SqlValue::Null));
4265 }
4266
4267 #[test]
4268 fn test_parse_plp_nvarchar_empty() {
4269 let plp = make_plp_data(0, &[]);
4271 let mut buf: &[u8] = &plp;
4272
4273 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4274 match result {
4275 mssql_types::SqlValue::String(s) => assert_eq!(s, ""),
4276 _ => panic!("expected empty String"),
4277 }
4278 }
4279
4280 #[test]
4281 fn test_parse_plp_nvarchar_multi_chunk() {
4282 let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
4286 let mut buf: &[u8] = &plp;
4287
4288 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4289 match result {
4290 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
4291 _ => panic!("expected String"),
4292 }
4293 }
4294
4295 #[test]
4296 fn test_parse_plp_varchar_simple() {
4297 let data = b"Hello World";
4298 let plp = make_plp_data(11, &[data]);
4299 let mut buf: &[u8] = &plp;
4300
4301 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
4302 match result {
4303 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello World"),
4304 _ => panic!("expected String"),
4305 }
4306 }
4307
4308 #[test]
4309 fn test_parse_plp_varchar_null() {
4310 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4311 let mut buf: &[u8] = &plp;
4312
4313 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
4314 assert!(matches!(result, mssql_types::SqlValue::Null));
4315 }
4316
4317 #[test]
4318 fn test_parse_plp_varbinary_simple() {
4319 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
4320 let plp = make_plp_data(5, &[&data]);
4321 let mut buf: &[u8] = &plp;
4322
4323 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4324 match result {
4325 mssql_types::SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
4326 _ => panic!("expected Binary"),
4327 }
4328 }
4329
4330 #[test]
4331 fn test_parse_plp_varbinary_null() {
4332 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4333 let mut buf: &[u8] = &plp;
4334
4335 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4336 assert!(matches!(result, mssql_types::SqlValue::Null));
4337 }
4338
4339 #[test]
4340 fn test_parse_plp_varbinary_large() {
4341 let chunk1: Vec<u8> = (0..100u8).collect();
4343 let chunk2: Vec<u8> = (100..200u8).collect();
4344 let chunk3: Vec<u8> = (200..255u8).collect();
4345 let total_len = chunk1.len() + chunk2.len() + chunk3.len();
4346 let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
4347 let mut buf: &[u8] = &plp;
4348
4349 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4350 match result {
4351 mssql_types::SqlValue::Binary(b) => {
4352 assert_eq!(b.len(), 255);
4353 for (i, &byte) in b.iter().enumerate() {
4355 assert_eq!(byte, i as u8);
4356 }
4357 }
4358 _ => panic!("expected Binary"),
4359 }
4360 }
4361
4362 use tds_protocol::token::{ColumnData, TypeInfo};
4370 use tds_protocol::types::TypeId;
4371
4372 fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
4375 let mut data = Vec::new();
4376
4377 let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
4379 let byte_len = (utf16.len() * 2) as u16;
4380 data.extend_from_slice(&byte_len.to_le_bytes());
4381 for code_unit in utf16 {
4382 data.extend_from_slice(&code_unit.to_le_bytes());
4383 }
4384
4385 data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
4388
4389 data
4390 }
4391
4392 #[test]
4393 fn test_parse_row_nvarchar_then_int() {
4394 let raw_data = make_nvarchar_int_row("World", 42);
4396
4397 let col0 = ColumnData {
4399 name: "greeting".to_string(),
4400 type_id: TypeId::NVarChar,
4401 col_type: 0xE7,
4402 flags: 0x01,
4403 user_type: 0,
4404 type_info: TypeInfo {
4405 max_length: Some(10), precision: None,
4407 scale: None,
4408 collation: None,
4409 },
4410 };
4411
4412 let col1 = ColumnData {
4413 name: "number".to_string(),
4414 type_id: TypeId::IntN,
4415 col_type: 0x26,
4416 flags: 0x01,
4417 user_type: 0,
4418 type_info: TypeInfo {
4419 max_length: Some(4),
4420 precision: None,
4421 scale: None,
4422 collation: None,
4423 },
4424 };
4425
4426 let mut buf: &[u8] = &raw_data;
4427
4428 let value0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4430 match value0 {
4431 mssql_types::SqlValue::String(s) => assert_eq!(s, "World"),
4432 _ => panic!("expected String, got {:?}", value0),
4433 }
4434
4435 let value1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4437 match value1 {
4438 mssql_types::SqlValue::Int(i) => assert_eq!(i, 42),
4439 _ => panic!("expected Int, got {:?}", value1),
4440 }
4441
4442 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4444 }
4445
4446 #[test]
4447 fn test_parse_row_multiple_types() {
4448 let mut data = Vec::new();
4450
4451 data.extend_from_slice(&0xFFFFu16.to_le_bytes());
4453
4454 data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
4457
4458 let utf16: Vec<u16> = "Test".encode_utf16().collect();
4460 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4461 for code_unit in utf16 {
4462 data.extend_from_slice(&code_unit.to_le_bytes());
4463 }
4464
4465 data.push(0);
4467
4468 let col0 = ColumnData {
4470 name: "col0".to_string(),
4471 type_id: TypeId::NVarChar,
4472 col_type: 0xE7,
4473 flags: 0x01,
4474 user_type: 0,
4475 type_info: TypeInfo {
4476 max_length: Some(100),
4477 precision: None,
4478 scale: None,
4479 collation: None,
4480 },
4481 };
4482 let col1 = ColumnData {
4483 name: "col1".to_string(),
4484 type_id: TypeId::IntN,
4485 col_type: 0x26,
4486 flags: 0x01,
4487 user_type: 0,
4488 type_info: TypeInfo {
4489 max_length: Some(4),
4490 precision: None,
4491 scale: None,
4492 collation: None,
4493 },
4494 };
4495 let col2 = col0.clone();
4496 let col3 = col1.clone();
4497
4498 let mut buf: &[u8] = &data;
4499
4500 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4502 assert!(
4503 matches!(v0, mssql_types::SqlValue::Null),
4504 "col0 should be Null"
4505 );
4506
4507 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4508 assert!(
4509 matches!(v1, mssql_types::SqlValue::Int(123)),
4510 "col1 should be 123"
4511 );
4512
4513 let v2 = Client::<Ready>::parse_column_value(&mut buf, &col2).unwrap();
4514 match v2 {
4515 mssql_types::SqlValue::String(s) => assert_eq!(s, "Test"),
4516 _ => panic!("col2 should be 'Test'"),
4517 }
4518
4519 let v3 = Client::<Ready>::parse_column_value(&mut buf, &col3).unwrap();
4520 assert!(
4521 matches!(v3, mssql_types::SqlValue::Null),
4522 "col3 should be Null"
4523 );
4524
4525 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4527 }
4528
4529 #[test]
4530 fn test_parse_row_with_unicode() {
4531 let test_str = "Héllo Wörld 日本語";
4533 let mut data = Vec::new();
4534
4535 let utf16: Vec<u16> = test_str.encode_utf16().collect();
4537 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4538 for code_unit in utf16 {
4539 data.extend_from_slice(&code_unit.to_le_bytes());
4540 }
4541
4542 data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
4545
4546 let col0 = ColumnData {
4547 name: "text".to_string(),
4548 type_id: TypeId::NVarChar,
4549 col_type: 0xE7,
4550 flags: 0x01,
4551 user_type: 0,
4552 type_info: TypeInfo {
4553 max_length: Some(100),
4554 precision: None,
4555 scale: None,
4556 collation: None,
4557 },
4558 };
4559 let col1 = ColumnData {
4560 name: "num".to_string(),
4561 type_id: TypeId::IntN,
4562 col_type: 0x26,
4563 flags: 0x01,
4564 user_type: 0,
4565 type_info: TypeInfo {
4566 max_length: Some(8),
4567 precision: None,
4568 scale: None,
4569 collation: None,
4570 },
4571 };
4572
4573 let mut buf: &[u8] = &data;
4574
4575 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4576 match v0 {
4577 mssql_types::SqlValue::String(s) => assert_eq!(s, test_str),
4578 _ => panic!("expected String"),
4579 }
4580
4581 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4582 match v1 {
4583 mssql_types::SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
4584 _ => panic!("expected BigInt"),
4585 }
4586
4587 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4588 }
4589}