1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, Collation, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token,
19 TokenParser,
20};
21#[cfg(feature = "decimal")]
22use tds_protocol::tvp::encode_tvp_decimal;
23use tds_protocol::tvp::{
24 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
25 encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar, encode_tvp_varbinary,
26};
27use tokio::net::TcpStream;
28use tokio::time::timeout;
29
30use crate::config::Config;
31use crate::error::{Error, Result};
32#[cfg(feature = "otel")]
33use crate::instrumentation::InstrumentationContext;
34use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
35use crate::statement_cache::StatementCache;
36use crate::stream::{MultiResultStream, QueryStream};
37use crate::transaction::SavePoint;
38
39pub struct Client<S: ConnectionState> {
45 config: Config,
46 _state: PhantomData<S>,
47 connection: Option<ConnectionHandle>,
49 server_version: Option<u32>,
51 current_database: Option<String>,
53 statement_cache: StatementCache,
55 transaction_descriptor: u64,
59 needs_reset: bool,
63 #[cfg(feature = "otel")]
65 instrumentation: InstrumentationContext,
66}
67
68#[allow(dead_code)] enum ConnectionHandle {
76 Tls(Connection<TlsStream<TcpStream>>),
78 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
80 Plain(Connection<TcpStream>),
82}
83
84impl Client<Disconnected> {
85 pub async fn connect(config: Config) -> Result<Client<Ready>> {
96 let max_redirects = config.redirect.max_redirects;
97 let follow_redirects = config.redirect.follow_redirects;
98 let mut attempts = 0;
99 let mut current_config = config;
100
101 loop {
102 attempts += 1;
103 if attempts > max_redirects + 1 {
104 return Err(Error::TooManyRedirects { max: max_redirects });
105 }
106
107 match Self::try_connect(¤t_config).await {
108 Ok(client) => return Ok(client),
109 Err(Error::Routing { host, port }) => {
110 if !follow_redirects {
111 return Err(Error::Routing { host, port });
112 }
113 tracing::info!(
114 host = %host,
115 port = port,
116 attempt = attempts,
117 max_redirects = max_redirects,
118 "following Azure SQL routing redirect"
119 );
120 current_config = current_config.with_host(&host).with_port(port);
121 continue;
122 }
123 Err(e) => return Err(e),
124 }
125 }
126 }
127
128 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
129 tracing::info!(
130 host = %config.host,
131 port = config.port,
132 database = ?config.database,
133 "connecting to SQL Server"
134 );
135
136 let addr = format!("{}:{}", config.host, config.port);
137
138 tracing::debug!("establishing TCP connection to {}", addr);
140 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
141 .await
142 .map_err(|_| Error::ConnectTimeout)?
143 .map_err(|e| Error::Io(Arc::new(e)))?;
144
145 tcp_stream
147 .set_nodelay(true)
148 .map_err(|e| Error::Io(Arc::new(e)))?;
149
150 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
152
153 if tls_mode.is_tls_first() {
155 return Self::connect_tds_8(config, tcp_stream).await;
156 }
157
158 Self::connect_tds_7x(config, tcp_stream).await
160 }
161
162 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
166 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
167
168 let tls_config = TlsConfig::new()
170 .strict_mode(true)
171 .trust_server_certificate(config.trust_server_certificate);
172
173 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
174
175 let tls_stream = timeout(
177 config.timeouts.tls_timeout,
178 tls_connector.connect(tcp_stream, &config.host),
179 )
180 .await
181 .map_err(|_| Error::TlsTimeout)?
182 .map_err(|e| Error::Tls(e.to_string()))?;
183
184 tracing::debug!("TLS handshake completed (strict mode)");
185
186 let mut connection = Connection::new(tls_stream);
188
189 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
191 Self::send_prelogin(&mut connection, &prelogin).await?;
192 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
193
194 let login = Self::build_login7(config);
196 Self::send_login7(&mut connection, &login).await?;
197
198 let (server_version, current_database, routing) =
200 Self::process_login_response(&mut connection).await?;
201
202 if let Some((host, port)) = routing {
204 return Err(Error::Routing { host, port });
205 }
206
207 Ok(Client {
208 config: config.clone(),
209 _state: PhantomData,
210 connection: Some(ConnectionHandle::Tls(connection)),
211 server_version,
212 current_database: current_database.clone(),
213 statement_cache: StatementCache::with_default_size(),
214 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
217 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
218 .with_database(current_database.unwrap_or_default()),
219 })
220 }
221
222 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
230 use bytes::BufMut;
231 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
232 use tokio::io::{AsyncReadExt, AsyncWriteExt};
233
234 tracing::debug!("using TDS 7.x flow (PreLogin first)");
235
236 let client_encryption = if config.no_tls {
239 tracing::warn!(
241 "⚠️ no_tls mode enabled. Connection will be UNENCRYPTED. \
242 Credentials and data will be transmitted in plaintext. \
243 This should only be used for development/testing with legacy SQL Server."
244 );
245 EncryptionLevel::NotSupported
246 } else if config.encrypt {
247 EncryptionLevel::On
248 } else {
249 EncryptionLevel::Off
250 };
251 let prelogin = Self::build_prelogin(config, client_encryption);
252 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
253 let prelogin_bytes = prelogin.encode();
254
255 let header = PacketHeader::new(
257 PacketType::PreLogin,
258 PacketStatus::END_OF_MESSAGE,
259 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
260 );
261
262 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
263 header.encode(&mut packet_buf);
264 packet_buf.put_slice(&prelogin_bytes);
265
266 tcp_stream
267 .write_all(&packet_buf)
268 .await
269 .map_err(|e| Error::Io(Arc::new(e)))?;
270
271 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
273 tcp_stream
274 .read_exact(&mut header_buf)
275 .await
276 .map_err(|e| Error::Io(Arc::new(e)))?;
277
278 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
279 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
280
281 let mut response_buf = vec![0u8; payload_length];
282 tcp_stream
283 .read_exact(&mut response_buf)
284 .await
285 .map_err(|e| Error::Io(Arc::new(e)))?;
286
287 let prelogin_response =
288 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
289
290 let client_tds_version = config.tds_version;
295 if let Some(ref server_version) = prelogin_response.server_version {
296 tracing::debug!(
297 requested_tds_version = %client_tds_version,
298 server_product_version = %server_version,
299 server_product = server_version.product_name(),
300 max_tds_version = %server_version.max_tds_version(),
301 "PreLogin response received"
302 );
303
304 let server_max_tds = server_version.max_tds_version();
306 if server_max_tds < client_tds_version && !client_tds_version.is_tds_8() {
307 tracing::warn!(
308 requested_tds_version = %client_tds_version,
309 server_max_tds_version = %server_max_tds,
310 server_product = server_version.product_name(),
311 "Server supports lower TDS version than requested. \
312 Connection will use server's maximum: {}",
313 server_max_tds
314 );
315 }
316
317 if server_max_tds.is_legacy() {
319 tracing::warn!(
320 server_product = server_version.product_name(),
321 server_max_tds_version = %server_max_tds,
322 "Server uses legacy TDS version. Some features may not be available."
323 );
324 }
325 } else {
326 tracing::debug!(
327 requested_tds_version = %client_tds_version,
328 "PreLogin response received (no version info)"
329 );
330 }
331
332 let server_encryption = prelogin_response.encryption;
334 tracing::debug!(encryption = ?server_encryption, "server encryption level");
335
336 let negotiated_encryption = match (client_encryption, server_encryption) {
342 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
343 EncryptionLevel::NotSupported
344 }
345 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
346 (EncryptionLevel::On, EncryptionLevel::Off)
347 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
348 return Err(Error::Protocol(
349 "Server does not support requested encryption level".to_string(),
350 ));
351 }
352 _ => EncryptionLevel::On,
353 };
354
355 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
358
359 if use_tls {
360 let tls_config =
363 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
364
365 let tls_connector =
366 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
367
368 let mut tls_stream = timeout(
370 config.timeouts.tls_timeout,
371 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
372 )
373 .await
374 .map_err(|_| Error::TlsTimeout)?
375 .map_err(|e| Error::Tls(e.to_string()))?;
376
377 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
378
379 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
381
382 if login_only_encryption {
383 use tokio::io::AsyncWriteExt;
391
392 let login = Self::build_login7(config);
394 let login_payload = login.encode();
395
396 let max_packet = MAX_PACKET_SIZE;
398 let max_payload = max_packet - PACKET_HEADER_SIZE;
399 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
400 let total_chunks = chunks.len();
401
402 for (i, chunk) in chunks.into_iter().enumerate() {
403 let is_last = i == total_chunks - 1;
404 let status = if is_last {
405 PacketStatus::END_OF_MESSAGE
406 } else {
407 PacketStatus::NORMAL
408 };
409
410 let header = PacketHeader::new(
411 PacketType::Tds7Login,
412 status,
413 (PACKET_HEADER_SIZE + chunk.len()) as u16,
414 );
415
416 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
417 header.encode(&mut packet_buf);
418 packet_buf.put_slice(chunk);
419
420 tls_stream
421 .write_all(&packet_buf)
422 .await
423 .map_err(|e| Error::Io(Arc::new(e)))?;
424 }
425
426 tls_stream
428 .flush()
429 .await
430 .map_err(|e| Error::Io(Arc::new(e)))?;
431
432 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
433
434 let (wrapper, _client_conn) = tls_stream.into_inner();
438 let tcp_stream = wrapper.into_inner();
439
440 let mut connection = Connection::new(tcp_stream);
442
443 let (server_version, current_database, routing) =
445 Self::process_login_response(&mut connection).await?;
446
447 if let Some((host, port)) = routing {
449 return Err(Error::Routing { host, port });
450 }
451
452 Ok(Client {
454 config: config.clone(),
455 _state: PhantomData,
456 connection: Some(ConnectionHandle::Plain(connection)),
457 server_version,
458 current_database: current_database.clone(),
459 statement_cache: StatementCache::with_default_size(),
460 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
463 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
464 .with_database(current_database.unwrap_or_default()),
465 })
466 } else {
467 let mut connection = Connection::new(tls_stream);
470
471 let login = Self::build_login7(config);
473 Self::send_login7(&mut connection, &login).await?;
474
475 let (server_version, current_database, routing) =
477 Self::process_login_response(&mut connection).await?;
478
479 if let Some((host, port)) = routing {
481 return Err(Error::Routing { host, port });
482 }
483
484 Ok(Client {
485 config: config.clone(),
486 _state: PhantomData,
487 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
488 server_version,
489 current_database: current_database.clone(),
490 statement_cache: StatementCache::with_default_size(),
491 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
494 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
495 .with_database(current_database.unwrap_or_default()),
496 })
497 }
498 } else {
499 tracing::warn!(
501 "Connecting without TLS encryption. This is insecure and should only be \
502 used for development/testing on trusted networks."
503 );
504
505 let login = Self::build_login7(config);
507 let login_bytes = login.encode();
508 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
509 tracing::debug!(
511 "Login7 fixed header (94 bytes): {:02X?}",
512 &login_bytes[..login_bytes.len().min(94)]
513 );
514 if login_bytes.len() > 94 {
516 tracing::debug!(
517 "Login7 variable data ({} bytes): {:02X?}",
518 login_bytes.len() - 94,
519 &login_bytes[94..]
520 );
521 }
522
523 let login_header = PacketHeader::new(
525 PacketType::Tds7Login,
526 PacketStatus::END_OF_MESSAGE,
527 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
528 )
529 .with_packet_id(1);
530 let mut login_packet_buf =
531 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
532 login_header.encode(&mut login_packet_buf);
533 login_packet_buf.put_slice(&login_bytes);
534
535 tracing::debug!(
536 "Sending Login7 packet: {} bytes total, header: {:02X?}",
537 login_packet_buf.len(),
538 &login_packet_buf[..PACKET_HEADER_SIZE]
539 );
540 tcp_stream
541 .write_all(&login_packet_buf)
542 .await
543 .map_err(|e| Error::Io(Arc::new(e)))?;
544 tcp_stream
545 .flush()
546 .await
547 .map_err(|e| Error::Io(Arc::new(e)))?;
548 tracing::debug!("Login7 sent and flushed over raw TCP");
549
550 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
552 tcp_stream
553 .read_exact(&mut response_header_buf)
554 .await
555 .map_err(|e| Error::Io(Arc::new(e)))?;
556
557 let response_type = response_header_buf[0];
558 let response_length =
559 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
560 tracing::debug!(
561 "Response header: type={:#04X}, length={}",
562 response_type,
563 response_length
564 );
565
566 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
568 let mut response_payload = vec![0u8; payload_length];
569 tcp_stream
570 .read_exact(&mut response_payload)
571 .await
572 .map_err(|e| Error::Io(Arc::new(e)))?;
573 tracing::debug!(
574 "Response payload: {} bytes, first 32: {:02X?}",
575 response_payload.len(),
576 &response_payload[..response_payload.len().min(32)]
577 );
578
579 let connection = Connection::new(tcp_stream);
581
582 let response_bytes = bytes::Bytes::from(response_payload);
584 let mut parser = TokenParser::new(response_bytes);
585 let mut server_version = None;
586 let mut current_database = None;
587 let routing = None;
588
589 while let Some(token) = parser
590 .next_token()
591 .map_err(|e| Error::Protocol(e.to_string()))?
592 {
593 match token {
594 Token::LoginAck(ack) => {
595 tracing::info!(
596 version = ack.tds_version,
597 interface = ack.interface,
598 prog_name = %ack.prog_name,
599 "login acknowledged"
600 );
601 server_version = Some(ack.tds_version);
602 }
603 Token::EnvChange(env) => {
604 Self::process_env_change(&env, &mut current_database, &mut None);
605 }
606 Token::Error(err) => {
607 return Err(Error::Server {
608 number: err.number,
609 state: err.state,
610 class: err.class,
611 message: err.message.clone(),
612 server: if err.server.is_empty() {
613 None
614 } else {
615 Some(err.server.clone())
616 },
617 procedure: if err.procedure.is_empty() {
618 None
619 } else {
620 Some(err.procedure.clone())
621 },
622 line: err.line as u32,
623 });
624 }
625 Token::Info(info) => {
626 tracing::info!(
627 number = info.number,
628 message = %info.message,
629 "server info message"
630 );
631 }
632 Token::Done(done) => {
633 if done.status.error {
634 return Err(Error::Protocol("login failed".to_string()));
635 }
636 break;
637 }
638 _ => {}
639 }
640 }
641
642 if let Some((host, port)) = routing {
644 return Err(Error::Routing { host, port });
645 }
646
647 Ok(Client {
648 config: config.clone(),
649 _state: PhantomData,
650 connection: Some(ConnectionHandle::Plain(connection)),
651 server_version,
652 current_database: current_database.clone(),
653 statement_cache: StatementCache::with_default_size(),
654 transaction_descriptor: 0, needs_reset: false, #[cfg(feature = "otel")]
657 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
658 .with_database(current_database.unwrap_or_default()),
659 })
660 }
661 }
662
663 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
665 let version = if config.strict_mode {
667 tds_protocol::version::TdsVersion::V8_0
668 } else {
669 config.tds_version
670 };
671
672 let mut prelogin = PreLogin::new()
673 .with_version(version)
674 .with_encryption(encryption);
675
676 if config.mars {
677 prelogin = prelogin.with_mars(true);
678 }
679
680 if let Some(ref instance) = config.instance {
681 prelogin = prelogin.with_instance(instance);
682 }
683
684 prelogin
685 }
686
687 fn build_login7(config: &Config) -> Login7 {
689 let version = if config.strict_mode {
691 tds_protocol::version::TdsVersion::V8_0
692 } else {
693 config.tds_version
694 };
695
696 let mut login = Login7::new()
697 .with_tds_version(version)
698 .with_packet_size(config.packet_size as u32)
699 .with_app_name(&config.application_name)
700 .with_server_name(&config.host)
701 .with_hostname(&config.host);
702
703 if let Some(ref database) = config.database {
704 login = login.with_database(database);
705 }
706
707 match &config.credentials {
709 mssql_auth::Credentials::SqlServer { username, password } => {
710 login = login.with_sql_auth(username.as_ref(), password.as_ref());
711 }
712 _ => {}
714 }
715
716 login
717 }
718
719 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
721 where
722 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
723 {
724 let payload = prelogin.encode();
725 let max_packet = MAX_PACKET_SIZE;
726
727 connection
728 .send_message(PacketType::PreLogin, payload, max_packet)
729 .await
730 .map_err(|e| Error::Protocol(e.to_string()))
731 }
732
733 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
735 where
736 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
737 {
738 let message = connection
739 .read_message()
740 .await
741 .map_err(|e| Error::Protocol(e.to_string()))?
742 .ok_or(Error::ConnectionClosed)?;
743
744 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
745 }
746
747 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
749 where
750 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
751 {
752 let payload = login.encode();
753 let max_packet = MAX_PACKET_SIZE;
754
755 connection
756 .send_message(PacketType::Tds7Login, payload, max_packet)
757 .await
758 .map_err(|e| Error::Protocol(e.to_string()))
759 }
760
761 async fn process_login_response<T>(
765 connection: &mut Connection<T>,
766 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
767 where
768 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
769 {
770 let message = connection
771 .read_message()
772 .await
773 .map_err(|e| Error::Protocol(e.to_string()))?
774 .ok_or(Error::ConnectionClosed)?;
775
776 let response_bytes = message.payload;
777
778 let mut parser = TokenParser::new(response_bytes);
779 let mut server_version = None;
780 let mut database = None;
781 let mut routing = None;
782
783 while let Some(token) = parser
784 .next_token()
785 .map_err(|e| Error::Protocol(e.to_string()))?
786 {
787 match token {
788 Token::LoginAck(ack) => {
789 tracing::info!(
790 version = ack.tds_version,
791 interface = ack.interface,
792 prog_name = %ack.prog_name,
793 "login acknowledged"
794 );
795 server_version = Some(ack.tds_version);
796 }
797 Token::EnvChange(env) => {
798 Self::process_env_change(&env, &mut database, &mut routing);
799 }
800 Token::Error(err) => {
801 return Err(Error::Server {
802 number: err.number,
803 state: err.state,
804 class: err.class,
805 message: err.message.clone(),
806 server: if err.server.is_empty() {
807 None
808 } else {
809 Some(err.server.clone())
810 },
811 procedure: if err.procedure.is_empty() {
812 None
813 } else {
814 Some(err.procedure.clone())
815 },
816 line: err.line as u32,
817 });
818 }
819 Token::Info(info) => {
820 tracing::info!(
821 number = info.number,
822 message = %info.message,
823 "server info message"
824 );
825 }
826 Token::Done(done) => {
827 if done.status.error {
828 return Err(Error::Protocol("login failed".to_string()));
829 }
830 break;
831 }
832 _ => {}
833 }
834 }
835
836 Ok((server_version, database, routing))
837 }
838
839 fn process_env_change(
841 env: &EnvChange,
842 database: &mut Option<String>,
843 routing: &mut Option<(String, u16)>,
844 ) {
845 use tds_protocol::token::EnvChangeValue;
846
847 match env.env_type {
848 EnvChangeType::Database => {
849 if let EnvChangeValue::String(ref new_value) = env.new_value {
850 tracing::debug!(database = %new_value, "database changed");
851 *database = Some(new_value.clone());
852 }
853 }
854 EnvChangeType::Routing => {
855 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
856 tracing::info!(host = %host, port = port, "routing redirect received");
857 *routing = Some((host.clone(), port));
858 }
859 }
860 _ => {
861 if let EnvChangeValue::String(ref new_value) = env.new_value {
862 tracing::debug!(
863 env_type = ?env.env_type,
864 new_value = %new_value,
865 "environment change"
866 );
867 }
868 }
869 }
870 }
871}
872
873impl<S: ConnectionState> Client<S> {
875 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
883 use tds_protocol::token::EnvChangeValue;
884
885 match env.env_type {
886 EnvChangeType::BeginTransaction => {
887 if let EnvChangeValue::Binary(ref data) = env.new_value {
888 if data.len() >= 8 {
889 let descriptor = u64::from_le_bytes([
890 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
891 ]);
892 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
893 *transaction_descriptor = descriptor;
894 }
895 }
896 }
897 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
898 tracing::debug!(
899 env_type = ?env.env_type,
900 "transaction ended via raw SQL"
901 );
902 *transaction_descriptor = 0;
903 }
904 _ => {}
905 }
906 }
907
908 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
917 let payload =
918 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
919 let max_packet = self.config.packet_size as usize;
920
921 let reset = self.needs_reset;
923 if reset {
924 self.needs_reset = false; tracing::debug!("sending SQL batch with RESETCONNECTION flag");
926 }
927
928 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
929
930 match connection {
931 ConnectionHandle::Tls(conn) => {
932 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
933 .await
934 .map_err(|e| Error::Protocol(e.to_string()))?;
935 }
936 ConnectionHandle::TlsPrelogin(conn) => {
937 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
938 .await
939 .map_err(|e| Error::Protocol(e.to_string()))?;
940 }
941 ConnectionHandle::Plain(conn) => {
942 conn.send_message_with_reset(PacketType::SqlBatch, payload, max_packet, reset)
943 .await
944 .map_err(|e| Error::Protocol(e.to_string()))?;
945 }
946 }
947
948 Ok(())
949 }
950
951 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
958 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
959 let max_packet = self.config.packet_size as usize;
960
961 let reset = self.needs_reset;
963 if reset {
964 self.needs_reset = false; tracing::debug!("sending RPC with RESETCONNECTION flag");
966 }
967
968 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
969
970 match connection {
971 ConnectionHandle::Tls(conn) => {
972 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
973 .await
974 .map_err(|e| Error::Protocol(e.to_string()))?;
975 }
976 ConnectionHandle::TlsPrelogin(conn) => {
977 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
978 .await
979 .map_err(|e| Error::Protocol(e.to_string()))?;
980 }
981 ConnectionHandle::Plain(conn) => {
982 conn.send_message_with_reset(PacketType::Rpc, payload, max_packet, reset)
983 .await
984 .map_err(|e| Error::Protocol(e.to_string()))?;
985 }
986 }
987
988 Ok(())
989 }
990
991 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
993 use bytes::{BufMut, BytesMut};
994 use mssql_types::SqlValue;
995
996 params
997 .iter()
998 .enumerate()
999 .map(|(i, p)| {
1000 let sql_value = p.to_sql()?;
1001 let name = format!("@p{}", i + 1);
1002
1003 Ok(match sql_value {
1004 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
1005 SqlValue::Bool(v) => {
1006 let mut buf = BytesMut::with_capacity(1);
1007 buf.put_u8(if v { 1 } else { 0 });
1008 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
1009 }
1010 SqlValue::TinyInt(v) => {
1011 let mut buf = BytesMut::with_capacity(1);
1012 buf.put_u8(v);
1013 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
1014 }
1015 SqlValue::SmallInt(v) => {
1016 let mut buf = BytesMut::with_capacity(2);
1017 buf.put_i16_le(v);
1018 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
1019 }
1020 SqlValue::Int(v) => RpcParam::int(&name, v),
1021 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
1022 SqlValue::Float(v) => {
1023 let mut buf = BytesMut::with_capacity(4);
1024 buf.put_f32_le(v);
1025 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
1026 }
1027 SqlValue::Double(v) => {
1028 let mut buf = BytesMut::with_capacity(8);
1029 buf.put_f64_le(v);
1030 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
1031 }
1032 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
1033 SqlValue::Binary(ref b) => {
1034 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
1035 }
1036 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
1037 #[cfg(feature = "uuid")]
1038 SqlValue::Uuid(u) => {
1039 let bytes = u.as_bytes();
1041 let mut buf = BytesMut::with_capacity(16);
1042 buf.put_u32_le(u32::from_be_bytes([
1044 bytes[0], bytes[1], bytes[2], bytes[3],
1045 ]));
1046 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
1047 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
1048 buf.put_slice(&bytes[8..16]);
1049 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
1050 }
1051 #[cfg(feature = "decimal")]
1052 SqlValue::Decimal(d) => {
1053 RpcParam::nvarchar(&name, &d.to_string())
1055 }
1056 #[cfg(feature = "chrono")]
1057 SqlValue::Date(_)
1058 | SqlValue::Time(_)
1059 | SqlValue::DateTime(_)
1060 | SqlValue::DateTimeOffset(_) => {
1061 let s = match &sql_value {
1064 SqlValue::Date(d) => d.to_string(),
1065 SqlValue::Time(t) => t.to_string(),
1066 SqlValue::DateTime(dt) => dt.to_string(),
1067 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
1068 _ => unreachable!(),
1069 };
1070 RpcParam::nvarchar(&name, &s)
1071 }
1072 #[cfg(feature = "json")]
1073 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
1074 SqlValue::Tvp(ref tvp_data) => {
1075 Self::encode_tvp_param(&name, tvp_data)?
1077 }
1078 _ => {
1080 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
1081 from: sql_value.type_name().to_string(),
1082 to: "RPC parameter",
1083 }));
1084 }
1085 })
1086 })
1087 .collect()
1088 }
1089
1090 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
1095 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1097 .columns
1098 .iter()
1099 .map(|col| {
1100 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1101 TvpWireColumnDef {
1102 wire_type,
1103 flags: TvpColumnFlags {
1104 nullable: col.nullable,
1105 },
1106 }
1107 })
1108 .collect();
1109
1110 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1112
1113 let mut buf = BytesMut::with_capacity(256);
1115
1116 encoder.encode_metadata(&mut buf);
1118
1119 for row in &tvp_data.rows {
1121 encoder.encode_row(&mut buf, |row_buf| {
1122 for (col_idx, value) in row.iter().enumerate() {
1123 let wire_type = &wire_columns[col_idx].wire_type;
1124 Self::encode_tvp_value(value, wire_type, row_buf);
1125 }
1126 });
1127 }
1128
1129 encoder.encode_end(&mut buf);
1131
1132 let full_type_name = if tvp_data.schema.is_empty() {
1134 tvp_data.type_name.clone()
1135 } else {
1136 format!("{}.{}", tvp_data.schema, tvp_data.type_name)
1137 };
1138
1139 let type_info = RpcTypeInfo::tvp(&full_type_name);
1142
1143 Ok(RpcParam {
1144 name: name.to_string(),
1145 flags: tds_protocol::rpc::ParamFlags::default(),
1146 type_info,
1147 value: Some(buf.freeze()),
1148 })
1149 }
1150
1151 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1153 match col_type {
1154 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1155 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1156 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1157 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1158 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1159 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1160 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1161 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1162 precision: *precision,
1163 scale: *scale,
1164 },
1165 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1166 max_length: *max_length,
1167 },
1168 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1169 max_length: *max_length,
1170 },
1171 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1172 max_length: *max_length,
1173 },
1174 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1175 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1176 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1177 mssql_types::TvpColumnType::DateTime2 { scale } => {
1178 TvpWireType::DateTime2 { scale: *scale }
1179 }
1180 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1181 TvpWireType::DateTimeOffset { scale: *scale }
1182 }
1183 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1184 }
1185 }
1186
1187 fn encode_tvp_value(
1189 value: &mssql_types::SqlValue,
1190 wire_type: &TvpWireType,
1191 buf: &mut BytesMut,
1192 ) {
1193 use mssql_types::SqlValue;
1194
1195 match value {
1196 SqlValue::Null => {
1197 encode_tvp_null(wire_type, buf);
1198 }
1199 SqlValue::Bool(v) => {
1200 encode_tvp_bit(*v, buf);
1201 }
1202 SqlValue::TinyInt(v) => {
1203 encode_tvp_int(*v as i64, 1, buf);
1204 }
1205 SqlValue::SmallInt(v) => {
1206 encode_tvp_int(*v as i64, 2, buf);
1207 }
1208 SqlValue::Int(v) => {
1209 encode_tvp_int(*v as i64, 4, buf);
1210 }
1211 SqlValue::BigInt(v) => {
1212 encode_tvp_int(*v, 8, buf);
1213 }
1214 SqlValue::Float(v) => {
1215 encode_tvp_float(*v as f64, 4, buf);
1216 }
1217 SqlValue::Double(v) => {
1218 encode_tvp_float(*v, 8, buf);
1219 }
1220 SqlValue::String(s) => {
1221 let max_len = match wire_type {
1222 TvpWireType::NVarChar { max_length } => *max_length,
1223 _ => 4000,
1224 };
1225 encode_tvp_nvarchar(s, max_len, buf);
1226 }
1227 SqlValue::Binary(b) => {
1228 let max_len = match wire_type {
1229 TvpWireType::VarBinary { max_length } => *max_length,
1230 _ => 8000,
1231 };
1232 encode_tvp_varbinary(b, max_len, buf);
1233 }
1234 #[cfg(feature = "decimal")]
1235 SqlValue::Decimal(d) => {
1236 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1237 let mantissa = d.mantissa().unsigned_abs();
1238 encode_tvp_decimal(sign, mantissa, buf);
1239 }
1240 #[cfg(feature = "uuid")]
1241 SqlValue::Uuid(u) => {
1242 let bytes = u.as_bytes();
1243 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1244 }
1245 #[cfg(feature = "chrono")]
1246 SqlValue::Date(d) => {
1247 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1249 let days = d.signed_duration_since(base).num_days() as u32;
1250 tds_protocol::tvp::encode_tvp_date(days, buf);
1251 }
1252 #[cfg(feature = "chrono")]
1253 SqlValue::Time(t) => {
1254 use chrono::Timelike;
1255 let nanos =
1256 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1257 let intervals = nanos / 100;
1258 let scale = match wire_type {
1259 TvpWireType::Time { scale } => *scale,
1260 _ => 7,
1261 };
1262 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1263 }
1264 #[cfg(feature = "chrono")]
1265 SqlValue::DateTime(dt) => {
1266 use chrono::Timelike;
1267 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1269 + dt.time().nanosecond() as u64;
1270 let intervals = nanos / 100;
1271 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1273 let days = dt.date().signed_duration_since(base).num_days() as u32;
1274 let scale = match wire_type {
1275 TvpWireType::DateTime2 { scale } => *scale,
1276 _ => 7,
1277 };
1278 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1279 }
1280 #[cfg(feature = "chrono")]
1281 SqlValue::DateTimeOffset(dto) => {
1282 use chrono::{Offset, Timelike};
1283 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1285 + dto.time().nanosecond() as u64;
1286 let intervals = nanos / 100;
1287 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1289 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1290 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1292 let scale = match wire_type {
1293 TvpWireType::DateTimeOffset { scale } => *scale,
1294 _ => 7,
1295 };
1296 tds_protocol::tvp::encode_tvp_datetimeoffset(
1297 intervals,
1298 days,
1299 offset_minutes,
1300 scale,
1301 buf,
1302 );
1303 }
1304 #[cfg(feature = "json")]
1305 SqlValue::Json(j) => {
1306 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1308 }
1309 SqlValue::Xml(s) => {
1310 encode_tvp_nvarchar(s, 0xFFFF, buf);
1312 }
1313 SqlValue::Tvp(_) => {
1314 encode_tvp_null(wire_type, buf);
1316 }
1317 _ => {
1319 encode_tvp_null(wire_type, buf);
1320 }
1321 }
1322 }
1323
1324 async fn read_query_response(
1326 &mut self,
1327 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1328 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1329
1330 let message = match connection {
1331 ConnectionHandle::Tls(conn) => conn
1332 .read_message()
1333 .await
1334 .map_err(|e| Error::Protocol(e.to_string()))?,
1335 ConnectionHandle::TlsPrelogin(conn) => conn
1336 .read_message()
1337 .await
1338 .map_err(|e| Error::Protocol(e.to_string()))?,
1339 ConnectionHandle::Plain(conn) => conn
1340 .read_message()
1341 .await
1342 .map_err(|e| Error::Protocol(e.to_string()))?,
1343 }
1344 .ok_or(Error::ConnectionClosed)?;
1345
1346 let mut parser = TokenParser::new(message.payload);
1347 let mut columns: Vec<crate::row::Column> = Vec::new();
1348 let mut rows: Vec<crate::row::Row> = Vec::new();
1349 let mut protocol_metadata: Option<ColMetaData> = None;
1350
1351 loop {
1352 let token = parser
1354 .next_token_with_metadata(protocol_metadata.as_ref())
1355 .map_err(|e| Error::Protocol(e.to_string()))?;
1356
1357 let Some(token) = token else {
1358 break;
1359 };
1360
1361 match token {
1362 Token::ColMetaData(meta) => {
1363 rows.clear();
1366
1367 columns = meta
1368 .columns
1369 .iter()
1370 .enumerate()
1371 .map(|(i, col)| {
1372 let type_name = format!("{:?}", col.type_id);
1373 let mut column = crate::row::Column::new(&col.name, i, type_name)
1374 .with_nullable(col.flags & 0x01 != 0);
1375
1376 if let Some(max_len) = col.type_info.max_length {
1377 column = column.with_max_length(max_len);
1378 }
1379 if let (Some(prec), Some(scale)) =
1380 (col.type_info.precision, col.type_info.scale)
1381 {
1382 column = column.with_precision_scale(prec, scale);
1383 }
1384 if let Some(collation) = col.type_info.collation {
1387 column = column.with_collation(collation);
1388 }
1389 column
1390 })
1391 .collect();
1392
1393 tracing::debug!(columns = columns.len(), "received column metadata");
1394 protocol_metadata = Some(meta);
1395 }
1396 Token::Row(raw_row) => {
1397 if let Some(ref meta) = protocol_metadata {
1398 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1399 rows.push(row);
1400 }
1401 }
1402 Token::NbcRow(nbc_row) => {
1403 if let Some(ref meta) = protocol_metadata {
1404 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1405 rows.push(row);
1406 }
1407 }
1408 Token::Error(err) => {
1409 return Err(Error::Server {
1410 number: err.number,
1411 state: err.state,
1412 class: err.class,
1413 message: err.message.clone(),
1414 server: if err.server.is_empty() {
1415 None
1416 } else {
1417 Some(err.server.clone())
1418 },
1419 procedure: if err.procedure.is_empty() {
1420 None
1421 } else {
1422 Some(err.procedure.clone())
1423 },
1424 line: err.line as u32,
1425 });
1426 }
1427 Token::Done(done) => {
1428 if done.status.error {
1429 return Err(Error::Query("query failed".to_string()));
1430 }
1431 tracing::debug!(
1432 row_count = done.row_count,
1433 has_more = done.status.more,
1434 "query complete"
1435 );
1436 if !done.status.more {
1439 break;
1440 }
1441 }
1442 Token::DoneProc(done) => {
1443 if done.status.error {
1444 return Err(Error::Query("query failed".to_string()));
1445 }
1446 }
1447 Token::DoneInProc(done) => {
1448 if done.status.error {
1449 return Err(Error::Query("query failed".to_string()));
1450 }
1451 }
1452 Token::Info(info) => {
1453 tracing::debug!(
1454 number = info.number,
1455 message = %info.message,
1456 "server info message"
1457 );
1458 }
1459 Token::EnvChange(env) => {
1460 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1464 }
1465 _ => {}
1466 }
1467 }
1468
1469 tracing::debug!(
1470 columns = columns.len(),
1471 rows = rows.len(),
1472 "query response parsed"
1473 );
1474 Ok((columns, rows))
1475 }
1476
1477 fn convert_raw_row(
1481 raw: &RawRow,
1482 meta: &ColMetaData,
1483 columns: &[crate::row::Column],
1484 ) -> Result<crate::row::Row> {
1485 let mut values = Vec::with_capacity(meta.columns.len());
1486 let mut buf = raw.data.as_ref();
1487
1488 for col in &meta.columns {
1489 let value = Self::parse_column_value(&mut buf, col)?;
1490 values.push(value);
1491 }
1492
1493 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1494 }
1495
1496 fn convert_nbc_row(
1500 nbc: &NbcRow,
1501 meta: &ColMetaData,
1502 columns: &[crate::row::Column],
1503 ) -> Result<crate::row::Row> {
1504 let mut values = Vec::with_capacity(meta.columns.len());
1505 let mut buf = nbc.data.as_ref();
1506
1507 for (i, col) in meta.columns.iter().enumerate() {
1508 if nbc.is_null(i) {
1509 values.push(mssql_types::SqlValue::Null);
1510 } else {
1511 let value = Self::parse_column_value(&mut buf, col)?;
1512 values.push(value);
1513 }
1514 }
1515
1516 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1517 }
1518
1519 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1521 use bytes::Buf;
1522 use mssql_types::SqlValue;
1523 use tds_protocol::types::TypeId;
1524
1525 let value = match col.type_id {
1526 TypeId::Null => SqlValue::Null,
1528
1529 TypeId::Int1 => {
1531 if buf.remaining() < 1 {
1532 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1533 }
1534 SqlValue::TinyInt(buf.get_u8())
1535 }
1536 TypeId::Bit => {
1537 if buf.remaining() < 1 {
1538 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1539 }
1540 SqlValue::Bool(buf.get_u8() != 0)
1541 }
1542
1543 TypeId::Int2 => {
1545 if buf.remaining() < 2 {
1546 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1547 }
1548 SqlValue::SmallInt(buf.get_i16_le())
1549 }
1550
1551 TypeId::Int4 => {
1553 if buf.remaining() < 4 {
1554 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1555 }
1556 SqlValue::Int(buf.get_i32_le())
1557 }
1558 TypeId::Float4 => {
1559 if buf.remaining() < 4 {
1560 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1561 }
1562 SqlValue::Float(buf.get_f32_le())
1563 }
1564
1565 TypeId::Int8 => {
1567 if buf.remaining() < 8 {
1568 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1569 }
1570 SqlValue::BigInt(buf.get_i64_le())
1571 }
1572 TypeId::Float8 => {
1573 if buf.remaining() < 8 {
1574 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1575 }
1576 SqlValue::Double(buf.get_f64_le())
1577 }
1578 TypeId::Money => {
1579 if buf.remaining() < 8 {
1580 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1581 }
1582 let high = buf.get_i32_le();
1584 let low = buf.get_u32_le();
1585 let cents = ((high as i64) << 32) | (low as i64);
1586 let value = (cents as f64) / 10000.0;
1587 SqlValue::Double(value)
1588 }
1589 TypeId::Money4 => {
1590 if buf.remaining() < 4 {
1591 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1592 }
1593 let cents = buf.get_i32_le();
1594 let value = (cents as f64) / 10000.0;
1595 SqlValue::Double(value)
1596 }
1597
1598 TypeId::IntN => {
1600 if buf.remaining() < 1 {
1601 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1602 }
1603 let len = buf.get_u8();
1604 match len {
1605 0 => SqlValue::Null,
1606 1 => SqlValue::TinyInt(buf.get_u8()),
1607 2 => SqlValue::SmallInt(buf.get_i16_le()),
1608 4 => SqlValue::Int(buf.get_i32_le()),
1609 8 => SqlValue::BigInt(buf.get_i64_le()),
1610 _ => {
1611 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1612 }
1613 }
1614 }
1615 TypeId::FloatN => {
1616 if buf.remaining() < 1 {
1617 return Err(Error::Protocol(
1618 "unexpected EOF reading FloatN length".into(),
1619 ));
1620 }
1621 let len = buf.get_u8();
1622 match len {
1623 0 => SqlValue::Null,
1624 4 => SqlValue::Float(buf.get_f32_le()),
1625 8 => SqlValue::Double(buf.get_f64_le()),
1626 _ => {
1627 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1628 }
1629 }
1630 }
1631 TypeId::BitN => {
1632 if buf.remaining() < 1 {
1633 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1634 }
1635 let len = buf.get_u8();
1636 match len {
1637 0 => SqlValue::Null,
1638 1 => SqlValue::Bool(buf.get_u8() != 0),
1639 _ => {
1640 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1641 }
1642 }
1643 }
1644 TypeId::MoneyN => {
1645 if buf.remaining() < 1 {
1646 return Err(Error::Protocol(
1647 "unexpected EOF reading MoneyN length".into(),
1648 ));
1649 }
1650 let len = buf.get_u8();
1651 match len {
1652 0 => SqlValue::Null,
1653 4 => {
1654 let cents = buf.get_i32_le();
1655 SqlValue::Double((cents as f64) / 10000.0)
1656 }
1657 8 => {
1658 let high = buf.get_i32_le();
1659 let low = buf.get_u32_le();
1660 let cents = ((high as i64) << 32) | (low as i64);
1661 SqlValue::Double((cents as f64) / 10000.0)
1662 }
1663 _ => {
1664 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1665 }
1666 }
1667 }
1668 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1670 if buf.remaining() < 1 {
1671 return Err(Error::Protocol(
1672 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1673 ));
1674 }
1675 let len = buf.get_u8() as usize;
1676 if len == 0 {
1677 SqlValue::Null
1678 } else {
1679 if buf.remaining() < len {
1680 return Err(Error::Protocol(
1681 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1682 ));
1683 }
1684
1685 let sign = buf.get_u8();
1687 let mantissa_len = len - 1;
1688
1689 let mut mantissa_bytes = [0u8; 16];
1691 for i in 0..mantissa_len.min(16) {
1692 mantissa_bytes[i] = buf.get_u8();
1693 }
1694 for _ in 16..mantissa_len {
1696 buf.get_u8();
1697 }
1698
1699 let mantissa = u128::from_le_bytes(mantissa_bytes);
1700 let scale = col.type_info.scale.unwrap_or(0) as u32;
1701
1702 #[cfg(feature = "decimal")]
1703 {
1704 use rust_decimal::Decimal;
1705 if scale > 28 {
1708 let divisor = 10f64.powi(scale as i32);
1710 let value = (mantissa as f64) / divisor;
1711 let value = if sign == 0 { -value } else { value };
1712 SqlValue::Double(value)
1713 } else {
1714 let mut decimal =
1715 Decimal::from_i128_with_scale(mantissa as i128, scale);
1716 if sign == 0 {
1717 decimal.set_sign_negative(true);
1718 }
1719 SqlValue::Decimal(decimal)
1720 }
1721 }
1722
1723 #[cfg(not(feature = "decimal"))]
1724 {
1725 let divisor = 10f64.powi(scale as i32);
1727 let value = (mantissa as f64) / divisor;
1728 let value = if sign == 0 { -value } else { value };
1729 SqlValue::Double(value)
1730 }
1731 }
1732 }
1733
1734 TypeId::DateTimeN => {
1736 if buf.remaining() < 1 {
1737 return Err(Error::Protocol(
1738 "unexpected EOF reading DateTimeN length".into(),
1739 ));
1740 }
1741 let len = buf.get_u8() as usize;
1742 if len == 0 {
1743 SqlValue::Null
1744 } else if buf.remaining() < len {
1745 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1746 } else {
1747 match len {
1748 4 => {
1749 let days = buf.get_u16_le() as i64;
1751 let minutes = buf.get_u16_le() as u32;
1752 #[cfg(feature = "chrono")]
1753 {
1754 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1755 let date = base + chrono::Duration::days(days);
1756 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1757 minutes * 60,
1758 0,
1759 )
1760 .unwrap();
1761 SqlValue::DateTime(date.and_time(time))
1762 }
1763 #[cfg(not(feature = "chrono"))]
1764 {
1765 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1766 }
1767 }
1768 8 => {
1769 let days = buf.get_i32_le() as i64;
1771 let time_300ths = buf.get_u32_le() as u64;
1772 #[cfg(feature = "chrono")]
1773 {
1774 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1775 let date = base + chrono::Duration::days(days);
1776 let total_ms = (time_300ths * 1000) / 300;
1778 let secs = (total_ms / 1000) as u32;
1779 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1780 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1781 secs, nanos,
1782 )
1783 .unwrap();
1784 SqlValue::DateTime(date.and_time(time))
1785 }
1786 #[cfg(not(feature = "chrono"))]
1787 {
1788 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1789 }
1790 }
1791 _ => {
1792 return Err(Error::Protocol(format!(
1793 "invalid DateTimeN length: {len}"
1794 )));
1795 }
1796 }
1797 }
1798 }
1799
1800 TypeId::DateTime => {
1802 if buf.remaining() < 8 {
1803 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1804 }
1805 let days = buf.get_i32_le() as i64;
1806 let time_300ths = buf.get_u32_le() as u64;
1807 #[cfg(feature = "chrono")]
1808 {
1809 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1810 let date = base + chrono::Duration::days(days);
1811 let total_ms = (time_300ths * 1000) / 300;
1812 let secs = (total_ms / 1000) as u32;
1813 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1814 let time =
1815 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1816 SqlValue::DateTime(date.and_time(time))
1817 }
1818 #[cfg(not(feature = "chrono"))]
1819 {
1820 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1821 }
1822 }
1823
1824 TypeId::DateTime4 => {
1826 if buf.remaining() < 4 {
1827 return Err(Error::Protocol(
1828 "unexpected EOF reading SMALLDATETIME".into(),
1829 ));
1830 }
1831 let days = buf.get_u16_le() as i64;
1832 let minutes = buf.get_u16_le() as u32;
1833 #[cfg(feature = "chrono")]
1834 {
1835 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1836 let date = base + chrono::Duration::days(days);
1837 let time =
1838 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1839 .unwrap();
1840 SqlValue::DateTime(date.and_time(time))
1841 }
1842 #[cfg(not(feature = "chrono"))]
1843 {
1844 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1845 }
1846 }
1847
1848 TypeId::Date => {
1850 if buf.remaining() < 1 {
1851 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1852 }
1853 let len = buf.get_u8() as usize;
1854 if len == 0 {
1855 SqlValue::Null
1856 } else if len != 3 {
1857 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1858 } else if buf.remaining() < 3 {
1859 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1860 } else {
1861 let days = buf.get_u8() as u32
1863 | ((buf.get_u8() as u32) << 8)
1864 | ((buf.get_u8() as u32) << 16);
1865 #[cfg(feature = "chrono")]
1866 {
1867 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1868 let date = base + chrono::Duration::days(days as i64);
1869 SqlValue::Date(date)
1870 }
1871 #[cfg(not(feature = "chrono"))]
1872 {
1873 SqlValue::String(format!("DATE({days})"))
1874 }
1875 }
1876 }
1877
1878 TypeId::Time => {
1880 if buf.remaining() < 1 {
1881 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1882 }
1883 let len = buf.get_u8() as usize;
1884 if len == 0 {
1885 SqlValue::Null
1886 } else if buf.remaining() < len {
1887 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1888 } else {
1889 let mut time_bytes = [0u8; 8];
1890 for byte in time_bytes.iter_mut().take(len) {
1891 *byte = buf.get_u8();
1892 }
1893 let intervals = u64::from_le_bytes(time_bytes);
1894 #[cfg(feature = "chrono")]
1895 {
1896 let scale = col.type_info.scale.unwrap_or(7);
1897 let time = Self::intervals_to_time(intervals, scale);
1898 SqlValue::Time(time)
1899 }
1900 #[cfg(not(feature = "chrono"))]
1901 {
1902 SqlValue::String(format!("TIME({intervals})"))
1903 }
1904 }
1905 }
1906
1907 TypeId::DateTime2 => {
1909 if buf.remaining() < 1 {
1910 return Err(Error::Protocol(
1911 "unexpected EOF reading DATETIME2 length".into(),
1912 ));
1913 }
1914 let len = buf.get_u8() as usize;
1915 if len == 0 {
1916 SqlValue::Null
1917 } else if buf.remaining() < len {
1918 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1919 } else {
1920 let scale = col.type_info.scale.unwrap_or(7);
1921 let time_len = Self::time_bytes_for_scale(scale);
1922
1923 let mut time_bytes = [0u8; 8];
1925 for byte in time_bytes.iter_mut().take(time_len) {
1926 *byte = buf.get_u8();
1927 }
1928 let intervals = u64::from_le_bytes(time_bytes);
1929
1930 let days = buf.get_u8() as u32
1932 | ((buf.get_u8() as u32) << 8)
1933 | ((buf.get_u8() as u32) << 16);
1934
1935 #[cfg(feature = "chrono")]
1936 {
1937 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1938 let date = base + chrono::Duration::days(days as i64);
1939 let time = Self::intervals_to_time(intervals, scale);
1940 SqlValue::DateTime(date.and_time(time))
1941 }
1942 #[cfg(not(feature = "chrono"))]
1943 {
1944 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1945 }
1946 }
1947 }
1948
1949 TypeId::DateTimeOffset => {
1951 if buf.remaining() < 1 {
1952 return Err(Error::Protocol(
1953 "unexpected EOF reading DATETIMEOFFSET length".into(),
1954 ));
1955 }
1956 let len = buf.get_u8() as usize;
1957 if len == 0 {
1958 SqlValue::Null
1959 } else if buf.remaining() < len {
1960 return Err(Error::Protocol(
1961 "unexpected EOF reading DATETIMEOFFSET".into(),
1962 ));
1963 } else {
1964 let scale = col.type_info.scale.unwrap_or(7);
1965 let time_len = Self::time_bytes_for_scale(scale);
1966
1967 let mut time_bytes = [0u8; 8];
1969 for byte in time_bytes.iter_mut().take(time_len) {
1970 *byte = buf.get_u8();
1971 }
1972 let intervals = u64::from_le_bytes(time_bytes);
1973
1974 let days = buf.get_u8() as u32
1976 | ((buf.get_u8() as u32) << 8)
1977 | ((buf.get_u8() as u32) << 16);
1978
1979 let offset_minutes = buf.get_i16_le();
1981
1982 #[cfg(feature = "chrono")]
1983 {
1984 use chrono::TimeZone;
1985 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1986 let date = base + chrono::Duration::days(days as i64);
1987 let time = Self::intervals_to_time(intervals, scale);
1988 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1989 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1990 let datetime = offset
1991 .from_local_datetime(&date.and_time(time))
1992 .single()
1993 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1994 SqlValue::DateTimeOffset(datetime)
1995 }
1996 #[cfg(not(feature = "chrono"))]
1997 {
1998 SqlValue::String(format!(
1999 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
2000 ))
2001 }
2002 }
2003 }
2004
2005 TypeId::Text => Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
2007
2008 TypeId::Char | TypeId::VarChar => {
2010 if buf.remaining() < 1 {
2011 return Err(Error::Protocol(
2012 "unexpected EOF reading legacy varchar length".into(),
2013 ));
2014 }
2015 let len = buf.get_u8();
2016 if len == 0xFF {
2017 SqlValue::Null
2018 } else if len == 0 {
2019 SqlValue::String(String::new())
2020 } else if buf.remaining() < len as usize {
2021 return Err(Error::Protocol(
2022 "unexpected EOF reading legacy varchar data".into(),
2023 ));
2024 } else {
2025 let data = &buf[..len as usize];
2026 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2028 buf.advance(len as usize);
2029 SqlValue::String(s)
2030 }
2031 }
2032
2033 TypeId::BigVarChar | TypeId::BigChar => {
2035 if col.type_info.max_length == Some(0xFFFF) {
2037 Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?
2039 } else {
2040 if buf.remaining() < 2 {
2042 return Err(Error::Protocol(
2043 "unexpected EOF reading varchar length".into(),
2044 ));
2045 }
2046 let len = buf.get_u16_le();
2047 if len == 0xFFFF {
2048 SqlValue::Null
2049 } else if buf.remaining() < len as usize {
2050 return Err(Error::Protocol(
2051 "unexpected EOF reading varchar data".into(),
2052 ));
2053 } else {
2054 let data = &buf[..len as usize];
2055 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2057 buf.advance(len as usize);
2058 SqlValue::String(s)
2059 }
2060 }
2061 }
2062
2063 TypeId::NText => Self::parse_plp_nvarchar(buf)?,
2065
2066 TypeId::NVarChar | TypeId::NChar => {
2068 if col.type_info.max_length == Some(0xFFFF) {
2070 Self::parse_plp_nvarchar(buf)?
2072 } else {
2073 if buf.remaining() < 2 {
2075 return Err(Error::Protocol(
2076 "unexpected EOF reading nvarchar length".into(),
2077 ));
2078 }
2079 let len = buf.get_u16_le();
2080 if len == 0xFFFF {
2081 SqlValue::Null
2082 } else if buf.remaining() < len as usize {
2083 return Err(Error::Protocol(
2084 "unexpected EOF reading nvarchar data".into(),
2085 ));
2086 } else {
2087 let data = &buf[..len as usize];
2088 let utf16: Vec<u16> = data
2090 .chunks_exact(2)
2091 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2092 .collect();
2093 let s = String::from_utf16(&utf16)
2094 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
2095 buf.advance(len as usize);
2096 SqlValue::String(s)
2097 }
2098 }
2099 }
2100
2101 TypeId::Image => Self::parse_plp_varbinary(buf)?,
2103
2104 TypeId::Binary | TypeId::VarBinary => {
2106 if buf.remaining() < 1 {
2107 return Err(Error::Protocol(
2108 "unexpected EOF reading legacy varbinary length".into(),
2109 ));
2110 }
2111 let len = buf.get_u8();
2112 if len == 0xFF {
2113 SqlValue::Null
2114 } else if len == 0 {
2115 SqlValue::Binary(bytes::Bytes::new())
2116 } else if buf.remaining() < len as usize {
2117 return Err(Error::Protocol(
2118 "unexpected EOF reading legacy varbinary data".into(),
2119 ));
2120 } else {
2121 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2122 buf.advance(len as usize);
2123 SqlValue::Binary(data)
2124 }
2125 }
2126
2127 TypeId::BigVarBinary | TypeId::BigBinary => {
2129 if col.type_info.max_length == Some(0xFFFF) {
2131 Self::parse_plp_varbinary(buf)?
2133 } else {
2134 if buf.remaining() < 2 {
2135 return Err(Error::Protocol(
2136 "unexpected EOF reading varbinary length".into(),
2137 ));
2138 }
2139 let len = buf.get_u16_le();
2140 if len == 0xFFFF {
2141 SqlValue::Null
2142 } else if buf.remaining() < len as usize {
2143 return Err(Error::Protocol(
2144 "unexpected EOF reading varbinary data".into(),
2145 ));
2146 } else {
2147 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2148 buf.advance(len as usize);
2149 SqlValue::Binary(data)
2150 }
2151 }
2152 }
2153
2154 TypeId::Xml => {
2156 match Self::parse_plp_nvarchar(buf)? {
2158 SqlValue::Null => SqlValue::Null,
2159 SqlValue::String(s) => SqlValue::Xml(s),
2160 _ => {
2161 return Err(Error::Protocol(
2162 "unexpected value type when parsing XML".into(),
2163 ));
2164 }
2165 }
2166 }
2167
2168 TypeId::Guid => {
2170 if buf.remaining() < 1 {
2171 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
2172 }
2173 let len = buf.get_u8();
2174 if len == 0 {
2175 SqlValue::Null
2176 } else if len != 16 {
2177 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
2178 } else if buf.remaining() < 16 {
2179 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
2180 } else {
2181 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
2183 buf.advance(16);
2184 SqlValue::Binary(data)
2185 }
2186 }
2187
2188 TypeId::Variant => Self::parse_sql_variant(buf)?,
2190
2191 TypeId::Udt => Self::parse_plp_varbinary(buf)?,
2193
2194 _ => {
2196 if buf.remaining() < 2 {
2198 return Err(Error::Protocol(format!(
2199 "unexpected EOF reading {:?}",
2200 col.type_id
2201 )));
2202 }
2203 let len = buf.get_u16_le();
2204 if len == 0xFFFF {
2205 SqlValue::Null
2206 } else if buf.remaining() < len as usize {
2207 return Err(Error::Protocol(format!(
2208 "unexpected EOF reading {:?} data",
2209 col.type_id
2210 )));
2211 } else {
2212 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2213 buf.advance(len as usize);
2214 SqlValue::Binary(data)
2215 }
2216 }
2217 };
2218
2219 Ok(value)
2220 }
2221
2222 fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2228 use bytes::Buf;
2229 use mssql_types::SqlValue;
2230
2231 if buf.remaining() < 8 {
2232 return Err(Error::Protocol(
2233 "unexpected EOF reading PLP total length".into(),
2234 ));
2235 }
2236
2237 let total_len = buf.get_u64_le();
2238 if total_len == 0xFFFFFFFFFFFFFFFF {
2239 return Ok(SqlValue::Null);
2240 }
2241
2242 let mut all_data = Vec::new();
2244 loop {
2245 if buf.remaining() < 4 {
2246 return Err(Error::Protocol(
2247 "unexpected EOF reading PLP chunk length".into(),
2248 ));
2249 }
2250 let chunk_len = buf.get_u32_le() as usize;
2251 if chunk_len == 0 {
2252 break; }
2254 if buf.remaining() < chunk_len {
2255 return Err(Error::Protocol(
2256 "unexpected EOF reading PLP chunk data".into(),
2257 ));
2258 }
2259 all_data.extend_from_slice(&buf[..chunk_len]);
2260 buf.advance(chunk_len);
2261 }
2262
2263 let utf16: Vec<u16> = all_data
2265 .chunks_exact(2)
2266 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2267 .collect();
2268 let s = String::from_utf16(&utf16)
2269 .map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
2270 Ok(SqlValue::String(s))
2271 }
2272
2273 #[allow(unused_variables)]
2279 fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
2280 if let Ok(s) = std::str::from_utf8(data) {
2282 return s.to_owned();
2283 }
2284
2285 #[cfg(feature = "encoding")]
2287 if let Some(coll) = collation {
2288 if let Some(encoding) = coll.encoding() {
2289 let (decoded, _, had_errors) = encoding.decode(data);
2290 if !had_errors {
2291 return decoded.into_owned();
2292 }
2293 }
2294 }
2295
2296 String::from_utf8_lossy(data).into_owned()
2298 }
2299
2300 fn parse_plp_varchar(
2302 buf: &mut &[u8],
2303 collation: Option<&Collation>,
2304 ) -> Result<mssql_types::SqlValue> {
2305 use bytes::Buf;
2306 use mssql_types::SqlValue;
2307
2308 if buf.remaining() < 8 {
2309 return Err(Error::Protocol(
2310 "unexpected EOF reading PLP total length".into(),
2311 ));
2312 }
2313
2314 let total_len = buf.get_u64_le();
2315 if total_len == 0xFFFFFFFFFFFFFFFF {
2316 return Ok(SqlValue::Null);
2317 }
2318
2319 let mut all_data = Vec::new();
2321 loop {
2322 if buf.remaining() < 4 {
2323 return Err(Error::Protocol(
2324 "unexpected EOF reading PLP chunk length".into(),
2325 ));
2326 }
2327 let chunk_len = buf.get_u32_le() as usize;
2328 if chunk_len == 0 {
2329 break; }
2331 if buf.remaining() < chunk_len {
2332 return Err(Error::Protocol(
2333 "unexpected EOF reading PLP chunk data".into(),
2334 ));
2335 }
2336 all_data.extend_from_slice(&buf[..chunk_len]);
2337 buf.advance(chunk_len);
2338 }
2339
2340 let s = Self::decode_varchar_string(&all_data, collation);
2342 Ok(SqlValue::String(s))
2343 }
2344
2345 fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2347 use bytes::Buf;
2348 use mssql_types::SqlValue;
2349
2350 if buf.remaining() < 8 {
2351 return Err(Error::Protocol(
2352 "unexpected EOF reading PLP total length".into(),
2353 ));
2354 }
2355
2356 let total_len = buf.get_u64_le();
2357 if total_len == 0xFFFFFFFFFFFFFFFF {
2358 return Ok(SqlValue::Null);
2359 }
2360
2361 let mut all_data = Vec::new();
2363 loop {
2364 if buf.remaining() < 4 {
2365 return Err(Error::Protocol(
2366 "unexpected EOF reading PLP chunk length".into(),
2367 ));
2368 }
2369 let chunk_len = buf.get_u32_le() as usize;
2370 if chunk_len == 0 {
2371 break; }
2373 if buf.remaining() < chunk_len {
2374 return Err(Error::Protocol(
2375 "unexpected EOF reading PLP chunk data".into(),
2376 ));
2377 }
2378 all_data.extend_from_slice(&buf[..chunk_len]);
2379 buf.advance(chunk_len);
2380 }
2381
2382 Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
2383 }
2384
2385 fn parse_sql_variant(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2394 use bytes::Buf;
2395 use mssql_types::SqlValue;
2396
2397 if buf.remaining() < 4 {
2399 return Err(Error::Protocol(
2400 "unexpected EOF reading SQL_VARIANT length".into(),
2401 ));
2402 }
2403 let total_len = buf.get_u32_le() as usize;
2404
2405 if total_len == 0 {
2406 return Ok(SqlValue::Null);
2407 }
2408
2409 if buf.remaining() < total_len {
2410 return Err(Error::Protocol(
2411 "unexpected EOF reading SQL_VARIANT data".into(),
2412 ));
2413 }
2414
2415 if total_len < 2 {
2417 return Err(Error::Protocol(
2418 "SQL_VARIANT too short for type info".into(),
2419 ));
2420 }
2421
2422 let base_type = buf.get_u8();
2423 let prop_count = buf.get_u8() as usize;
2424
2425 if buf.remaining() < prop_count {
2426 return Err(Error::Protocol(
2427 "unexpected EOF reading SQL_VARIANT properties".into(),
2428 ));
2429 }
2430
2431 let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
2433
2434 match base_type {
2437 0x30 => {
2439 buf.advance(prop_count);
2441 if data_len < 1 {
2442 return Ok(SqlValue::Null);
2443 }
2444 let v = buf.get_u8();
2445 Ok(SqlValue::TinyInt(v))
2446 }
2447 0x32 => {
2448 buf.advance(prop_count);
2450 if data_len < 1 {
2451 return Ok(SqlValue::Null);
2452 }
2453 let v = buf.get_u8();
2454 Ok(SqlValue::Bool(v != 0))
2455 }
2456 0x34 => {
2457 buf.advance(prop_count);
2459 if data_len < 2 {
2460 return Ok(SqlValue::Null);
2461 }
2462 let v = buf.get_i16_le();
2463 Ok(SqlValue::SmallInt(v))
2464 }
2465 0x38 => {
2466 buf.advance(prop_count);
2468 if data_len < 4 {
2469 return Ok(SqlValue::Null);
2470 }
2471 let v = buf.get_i32_le();
2472 Ok(SqlValue::Int(v))
2473 }
2474 0x7F => {
2475 buf.advance(prop_count);
2477 if data_len < 8 {
2478 return Ok(SqlValue::Null);
2479 }
2480 let v = buf.get_i64_le();
2481 Ok(SqlValue::BigInt(v))
2482 }
2483 0x6D => {
2484 let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2486 buf.advance(prop_count.saturating_sub(1));
2487
2488 if float_len == 4 && data_len >= 4 {
2489 let v = buf.get_f32_le();
2490 Ok(SqlValue::Float(v))
2491 } else if data_len >= 8 {
2492 let v = buf.get_f64_le();
2493 Ok(SqlValue::Double(v))
2494 } else {
2495 Ok(SqlValue::Null)
2496 }
2497 }
2498 0x6E => {
2499 let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2501 buf.advance(prop_count.saturating_sub(1));
2502
2503 if money_len == 4 && data_len >= 4 {
2504 let raw = buf.get_i32_le();
2505 let value = raw as f64 / 10000.0;
2506 Ok(SqlValue::Double(value))
2507 } else if data_len >= 8 {
2508 let high = buf.get_i32_le() as i64;
2509 let low = buf.get_u32_le() as i64;
2510 let raw = (high << 32) | low;
2511 let value = raw as f64 / 10000.0;
2512 Ok(SqlValue::Double(value))
2513 } else {
2514 Ok(SqlValue::Null)
2515 }
2516 }
2517 0x6F => {
2518 #[cfg(feature = "chrono")]
2520 let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2521 #[cfg(not(feature = "chrono"))]
2522 if prop_count >= 1 {
2523 buf.get_u8();
2524 }
2525 buf.advance(prop_count.saturating_sub(1));
2526
2527 #[cfg(feature = "chrono")]
2528 {
2529 use chrono::NaiveDate;
2530 if dt_len == 4 && data_len >= 4 {
2531 let days = buf.get_u16_le() as i64;
2533 let mins = buf.get_u16_le() as u32;
2534 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2535 .unwrap()
2536 .and_hms_opt(0, 0, 0)
2537 .unwrap();
2538 let dt = base
2539 + chrono::Duration::days(days)
2540 + chrono::Duration::minutes(mins as i64);
2541 Ok(SqlValue::DateTime(dt))
2542 } else if data_len >= 8 {
2543 let days = buf.get_i32_le() as i64;
2545 let ticks = buf.get_u32_le() as i64;
2546 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2547 .unwrap()
2548 .and_hms_opt(0, 0, 0)
2549 .unwrap();
2550 let millis = (ticks * 10) / 3;
2551 let dt = base
2552 + chrono::Duration::days(days)
2553 + chrono::Duration::milliseconds(millis);
2554 Ok(SqlValue::DateTime(dt))
2555 } else {
2556 Ok(SqlValue::Null)
2557 }
2558 }
2559 #[cfg(not(feature = "chrono"))]
2560 {
2561 buf.advance(data_len);
2562 Ok(SqlValue::Null)
2563 }
2564 }
2565 0x6A | 0x6C => {
2566 let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
2568 let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
2569 buf.advance(prop_count.saturating_sub(2));
2570
2571 if data_len < 1 {
2572 return Ok(SqlValue::Null);
2573 }
2574
2575 let sign = buf.get_u8();
2576 let mantissa_len = data_len - 1;
2577
2578 if mantissa_len > 16 {
2579 buf.advance(mantissa_len);
2581 return Ok(SqlValue::Null);
2582 }
2583
2584 let mut mantissa_bytes = [0u8; 16];
2585 for i in 0..mantissa_len.min(16) {
2586 mantissa_bytes[i] = buf.get_u8();
2587 }
2588 let mantissa = u128::from_le_bytes(mantissa_bytes);
2589
2590 #[cfg(feature = "decimal")]
2591 {
2592 use rust_decimal::Decimal;
2593 if scale > 28 {
2594 let divisor = 10f64.powi(scale as i32);
2596 let value = (mantissa as f64) / divisor;
2597 let value = if sign == 0 { -value } else { value };
2598 Ok(SqlValue::Double(value))
2599 } else {
2600 let mut decimal =
2601 Decimal::from_i128_with_scale(mantissa as i128, scale as u32);
2602 if sign == 0 {
2603 decimal.set_sign_negative(true);
2604 }
2605 Ok(SqlValue::Decimal(decimal))
2606 }
2607 }
2608 #[cfg(not(feature = "decimal"))]
2609 {
2610 let divisor = 10f64.powi(scale as i32);
2611 let value = (mantissa as f64) / divisor;
2612 let value = if sign == 0 { -value } else { value };
2613 Ok(SqlValue::Double(value))
2614 }
2615 }
2616 0x24 => {
2617 buf.advance(prop_count);
2619 if data_len < 16 {
2620 return Ok(SqlValue::Null);
2621 }
2622 let mut guid_bytes = [0u8; 16];
2623 for byte in &mut guid_bytes {
2624 *byte = buf.get_u8();
2625 }
2626 Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(&guid_bytes)))
2627 }
2628 0x28 => {
2629 buf.advance(prop_count);
2631 #[cfg(feature = "chrono")]
2632 {
2633 if data_len < 3 {
2634 return Ok(SqlValue::Null);
2635 }
2636 let mut date_bytes = [0u8; 4];
2637 date_bytes[0] = buf.get_u8();
2638 date_bytes[1] = buf.get_u8();
2639 date_bytes[2] = buf.get_u8();
2640 let days = u32::from_le_bytes(date_bytes);
2641 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2642 let date = base + chrono::Duration::days(days as i64);
2643 Ok(SqlValue::Date(date))
2644 }
2645 #[cfg(not(feature = "chrono"))]
2646 {
2647 buf.advance(data_len);
2648 Ok(SqlValue::Null)
2649 }
2650 }
2651 0xA7 | 0x2F | 0x27 => {
2652 let collation = if prop_count >= 5 && buf.remaining() >= 5 {
2655 let lcid = buf.get_u32_le();
2656 let sort_id = buf.get_u8();
2657 buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
2659 } else {
2660 buf.advance(prop_count);
2661 None
2662 };
2663 if data_len == 0 {
2664 return Ok(SqlValue::String(String::new()));
2665 }
2666 let data = &buf[..data_len];
2667 let s = Self::decode_varchar_string(data, collation.as_ref());
2669 buf.advance(data_len);
2670 Ok(SqlValue::String(s))
2671 }
2672 0xE7 | 0xEF => {
2673 buf.advance(prop_count);
2675 if data_len == 0 {
2676 return Ok(SqlValue::String(String::new()));
2677 }
2678 let utf16: Vec<u16> = buf[..data_len]
2680 .chunks_exact(2)
2681 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2682 .collect();
2683 buf.advance(data_len);
2684 let s = String::from_utf16(&utf16).map_err(|_| {
2685 Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into())
2686 })?;
2687 Ok(SqlValue::String(s))
2688 }
2689 0xA5 | 0x2D | 0x25 => {
2690 buf.advance(prop_count);
2692 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2693 buf.advance(data_len);
2694 Ok(SqlValue::Binary(data))
2695 }
2696 _ => {
2697 buf.advance(prop_count);
2699 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2700 buf.advance(data_len);
2701 Ok(SqlValue::Binary(data))
2702 }
2703 }
2704 }
2705
2706 fn time_bytes_for_scale(scale: u8) -> usize {
2708 match scale {
2709 0..=2 => 3,
2710 3..=4 => 4,
2711 5..=7 => 5,
2712 _ => 5, }
2714 }
2715
2716 #[cfg(feature = "chrono")]
2718 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2719 let nanos = match scale {
2729 0 => intervals * 1_000_000_000,
2730 1 => intervals * 100_000_000,
2731 2 => intervals * 10_000_000,
2732 3 => intervals * 1_000_000,
2733 4 => intervals * 100_000,
2734 5 => intervals * 10_000,
2735 6 => intervals * 1_000,
2736 7 => intervals * 100,
2737 _ => intervals * 100,
2738 };
2739
2740 let secs = (nanos / 1_000_000_000) as u32;
2741 let nano_part = (nanos % 1_000_000_000) as u32;
2742
2743 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2744 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2745 }
2746
2747 async fn read_execute_result(&mut self) -> Result<u64> {
2749 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2750
2751 let message = match connection {
2752 ConnectionHandle::Tls(conn) => conn
2753 .read_message()
2754 .await
2755 .map_err(|e| Error::Protocol(e.to_string()))?,
2756 ConnectionHandle::TlsPrelogin(conn) => conn
2757 .read_message()
2758 .await
2759 .map_err(|e| Error::Protocol(e.to_string()))?,
2760 ConnectionHandle::Plain(conn) => conn
2761 .read_message()
2762 .await
2763 .map_err(|e| Error::Protocol(e.to_string()))?,
2764 }
2765 .ok_or(Error::ConnectionClosed)?;
2766
2767 let mut parser = TokenParser::new(message.payload);
2768 let mut rows_affected = 0u64;
2769 let mut current_metadata: Option<ColMetaData> = None;
2770
2771 loop {
2772 let token = parser
2774 .next_token_with_metadata(current_metadata.as_ref())
2775 .map_err(|e| Error::Protocol(e.to_string()))?;
2776
2777 let Some(token) = token else {
2778 break;
2779 };
2780
2781 match token {
2782 Token::ColMetaData(meta) => {
2783 current_metadata = Some(meta);
2785 }
2786 Token::Row(_) | Token::NbcRow(_) => {
2787 }
2790 Token::Done(done) => {
2791 if done.status.error {
2792 return Err(Error::Query("execution failed".to_string()));
2793 }
2794 if done.status.count {
2795 rows_affected += done.row_count;
2797 }
2798 if !done.status.more {
2801 break;
2802 }
2803 }
2804 Token::DoneProc(done) => {
2805 if done.status.count {
2806 rows_affected += done.row_count;
2807 }
2808 }
2809 Token::DoneInProc(done) => {
2810 if done.status.count {
2811 rows_affected += done.row_count;
2812 }
2813 }
2814 Token::Error(err) => {
2815 return Err(Error::Server {
2816 number: err.number,
2817 state: err.state,
2818 class: err.class,
2819 message: err.message.clone(),
2820 server: if err.server.is_empty() {
2821 None
2822 } else {
2823 Some(err.server.clone())
2824 },
2825 procedure: if err.procedure.is_empty() {
2826 None
2827 } else {
2828 Some(err.procedure.clone())
2829 },
2830 line: err.line as u32,
2831 });
2832 }
2833 Token::Info(info) => {
2834 tracing::info!(
2835 number = info.number,
2836 message = %info.message,
2837 "server info message"
2838 );
2839 }
2840 Token::EnvChange(env) => {
2841 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
2845 }
2846 _ => {}
2847 }
2848 }
2849
2850 Ok(rows_affected)
2851 }
2852
2853 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2859 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2860
2861 let message = match connection {
2862 ConnectionHandle::Tls(conn) => conn
2863 .read_message()
2864 .await
2865 .map_err(|e| Error::Protocol(e.to_string()))?,
2866 ConnectionHandle::TlsPrelogin(conn) => conn
2867 .read_message()
2868 .await
2869 .map_err(|e| Error::Protocol(e.to_string()))?,
2870 ConnectionHandle::Plain(conn) => conn
2871 .read_message()
2872 .await
2873 .map_err(|e| Error::Protocol(e.to_string()))?,
2874 }
2875 .ok_or(Error::ConnectionClosed)?;
2876
2877 let mut parser = TokenParser::new(message.payload);
2878 let mut transaction_descriptor: u64 = 0;
2879
2880 loop {
2881 let token = parser
2882 .next_token()
2883 .map_err(|e| Error::Protocol(e.to_string()))?;
2884
2885 let Some(token) = token else {
2886 break;
2887 };
2888
2889 match token {
2890 Token::EnvChange(env) => {
2891 if env.env_type == EnvChangeType::BeginTransaction {
2892 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2895 {
2896 if data.len() >= 8 {
2897 transaction_descriptor = u64::from_le_bytes([
2898 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2899 data[7],
2900 ]);
2901 tracing::debug!(
2902 transaction_descriptor =
2903 format!("0x{:016X}", transaction_descriptor),
2904 "transaction begun"
2905 );
2906 }
2907 }
2908 }
2909 }
2910 Token::Done(done) => {
2911 if done.status.error {
2912 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2913 }
2914 break;
2915 }
2916 Token::Error(err) => {
2917 return Err(Error::Server {
2918 number: err.number,
2919 state: err.state,
2920 class: err.class,
2921 message: err.message.clone(),
2922 server: if err.server.is_empty() {
2923 None
2924 } else {
2925 Some(err.server.clone())
2926 },
2927 procedure: if err.procedure.is_empty() {
2928 None
2929 } else {
2930 Some(err.procedure.clone())
2931 },
2932 line: err.line as u32,
2933 });
2934 }
2935 Token::Info(info) => {
2936 tracing::info!(
2937 number = info.number,
2938 message = %info.message,
2939 "server info message"
2940 );
2941 }
2942 _ => {}
2943 }
2944 }
2945
2946 Ok(transaction_descriptor)
2947 }
2948}
2949
2950impl Client<Ready> {
2951 pub fn mark_needs_reset(&mut self) {
2962 self.needs_reset = true;
2963 }
2964
2965 #[must_use]
2970 pub fn needs_reset(&self) -> bool {
2971 self.needs_reset
2972 }
2973
2974 pub async fn query<'a>(
2999 &'a mut self,
3000 sql: &str,
3001 params: &[&(dyn crate::ToSql + Sync)],
3002 ) -> Result<QueryStream<'a>> {
3003 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
3004
3005 #[cfg(feature = "otel")]
3006 let instrumentation = self.instrumentation.clone();
3007 #[cfg(feature = "otel")]
3008 let mut span = instrumentation.query_span(sql);
3009
3010 let result = async {
3011 if params.is_empty() {
3012 self.send_sql_batch(sql).await?;
3014 } else {
3015 let rpc_params = Self::convert_params(params)?;
3017 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3018 self.send_rpc(&rpc).await?;
3019 }
3020
3021 self.read_query_response().await
3023 }
3024 .await;
3025
3026 #[cfg(feature = "otel")]
3027 match &result {
3028 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3029 Err(e) => InstrumentationContext::record_error(&mut span, e),
3030 }
3031
3032 #[cfg(feature = "otel")]
3034 drop(span);
3035
3036 let (columns, rows) = result?;
3037 Ok(QueryStream::new(columns, rows))
3038 }
3039
3040 pub async fn query_with_timeout<'a>(
3067 &'a mut self,
3068 sql: &str,
3069 params: &[&(dyn crate::ToSql + Sync)],
3070 timeout_duration: std::time::Duration,
3071 ) -> Result<QueryStream<'a>> {
3072 timeout(timeout_duration, self.query(sql, params))
3073 .await
3074 .map_err(|_| Error::CommandTimeout)?
3075 }
3076
3077 pub async fn query_multiple<'a>(
3104 &'a mut self,
3105 sql: &str,
3106 params: &[&(dyn crate::ToSql + Sync)],
3107 ) -> Result<MultiResultStream<'a>> {
3108 tracing::debug!(
3109 sql = sql,
3110 params_count = params.len(),
3111 "executing multi-result query"
3112 );
3113
3114 if params.is_empty() {
3115 self.send_sql_batch(sql).await?;
3117 } else {
3118 let rpc_params = Self::convert_params(params)?;
3120 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3121 self.send_rpc(&rpc).await?;
3122 }
3123
3124 let result_sets = self.read_multi_result_response().await?;
3126 Ok(MultiResultStream::new(result_sets))
3127 }
3128
3129 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
3131 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3132
3133 let message = match connection {
3134 ConnectionHandle::Tls(conn) => conn
3135 .read_message()
3136 .await
3137 .map_err(|e| Error::Protocol(e.to_string()))?,
3138 ConnectionHandle::TlsPrelogin(conn) => conn
3139 .read_message()
3140 .await
3141 .map_err(|e| Error::Protocol(e.to_string()))?,
3142 ConnectionHandle::Plain(conn) => conn
3143 .read_message()
3144 .await
3145 .map_err(|e| Error::Protocol(e.to_string()))?,
3146 }
3147 .ok_or(Error::ConnectionClosed)?;
3148
3149 let mut parser = TokenParser::new(message.payload);
3150 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
3151 let mut current_columns: Vec<crate::row::Column> = Vec::new();
3152 let mut current_rows: Vec<crate::row::Row> = Vec::new();
3153 let mut protocol_metadata: Option<ColMetaData> = None;
3154
3155 loop {
3156 let token = parser
3157 .next_token_with_metadata(protocol_metadata.as_ref())
3158 .map_err(|e| Error::Protocol(e.to_string()))?;
3159
3160 let Some(token) = token else {
3161 break;
3162 };
3163
3164 match token {
3165 Token::ColMetaData(meta) => {
3166 if !current_columns.is_empty() {
3168 result_sets.push(crate::stream::ResultSet::new(
3169 std::mem::take(&mut current_columns),
3170 std::mem::take(&mut current_rows),
3171 ));
3172 }
3173
3174 current_columns = meta
3176 .columns
3177 .iter()
3178 .enumerate()
3179 .map(|(i, col)| {
3180 let type_name = format!("{:?}", col.type_id);
3181 let mut column = crate::row::Column::new(&col.name, i, type_name)
3182 .with_nullable(col.flags & 0x01 != 0);
3183
3184 if let Some(max_len) = col.type_info.max_length {
3185 column = column.with_max_length(max_len);
3186 }
3187 if let (Some(prec), Some(scale)) =
3188 (col.type_info.precision, col.type_info.scale)
3189 {
3190 column = column.with_precision_scale(prec, scale);
3191 }
3192 if let Some(collation) = col.type_info.collation {
3195 column = column.with_collation(collation);
3196 }
3197 column
3198 })
3199 .collect();
3200
3201 tracing::debug!(
3202 columns = current_columns.len(),
3203 result_set = result_sets.len(),
3204 "received column metadata for result set"
3205 );
3206 protocol_metadata = Some(meta);
3207 }
3208 Token::Row(raw_row) => {
3209 if let Some(ref meta) = protocol_metadata {
3210 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
3211 current_rows.push(row);
3212 }
3213 }
3214 Token::NbcRow(nbc_row) => {
3215 if let Some(ref meta) = protocol_metadata {
3216 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
3217 current_rows.push(row);
3218 }
3219 }
3220 Token::Error(err) => {
3221 return Err(Error::Server {
3222 number: err.number,
3223 state: err.state,
3224 class: err.class,
3225 message: err.message.clone(),
3226 server: if err.server.is_empty() {
3227 None
3228 } else {
3229 Some(err.server.clone())
3230 },
3231 procedure: if err.procedure.is_empty() {
3232 None
3233 } else {
3234 Some(err.procedure.clone())
3235 },
3236 line: err.line as u32,
3237 });
3238 }
3239 Token::Done(done) => {
3240 if done.status.error {
3241 return Err(Error::Query("query failed".to_string()));
3242 }
3243
3244 if !current_columns.is_empty() {
3246 result_sets.push(crate::stream::ResultSet::new(
3247 std::mem::take(&mut current_columns),
3248 std::mem::take(&mut current_rows),
3249 ));
3250 protocol_metadata = None;
3251 }
3252
3253 if !done.status.more {
3255 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
3256 break;
3257 }
3258 }
3259 Token::DoneInProc(done) => {
3260 if done.status.error {
3261 return Err(Error::Query("query failed".to_string()));
3262 }
3263
3264 if !current_columns.is_empty() {
3266 result_sets.push(crate::stream::ResultSet::new(
3267 std::mem::take(&mut current_columns),
3268 std::mem::take(&mut current_rows),
3269 ));
3270 protocol_metadata = None;
3271 }
3272
3273 if !done.status.more {
3275 }
3277 }
3278 Token::DoneProc(done) => {
3279 if done.status.error {
3280 return Err(Error::Query("query failed".to_string()));
3281 }
3282 }
3284 Token::Info(info) => {
3285 tracing::debug!(
3286 number = info.number,
3287 message = %info.message,
3288 "server info message"
3289 );
3290 }
3291 _ => {}
3292 }
3293 }
3294
3295 if !current_columns.is_empty() {
3297 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
3298 }
3299
3300 Ok(result_sets)
3301 }
3302
3303 pub async fn execute(
3307 &mut self,
3308 sql: &str,
3309 params: &[&(dyn crate::ToSql + Sync)],
3310 ) -> Result<u64> {
3311 tracing::debug!(
3312 sql = sql,
3313 params_count = params.len(),
3314 "executing statement"
3315 );
3316
3317 #[cfg(feature = "otel")]
3318 let instrumentation = self.instrumentation.clone();
3319 #[cfg(feature = "otel")]
3320 let mut span = instrumentation.query_span(sql);
3321
3322 let result = async {
3323 if params.is_empty() {
3324 self.send_sql_batch(sql).await?;
3326 } else {
3327 let rpc_params = Self::convert_params(params)?;
3329 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3330 self.send_rpc(&rpc).await?;
3331 }
3332
3333 self.read_execute_result().await
3335 }
3336 .await;
3337
3338 #[cfg(feature = "otel")]
3339 match &result {
3340 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3341 Err(e) => InstrumentationContext::record_error(&mut span, e),
3342 }
3343
3344 #[cfg(feature = "otel")]
3346 drop(span);
3347
3348 result
3349 }
3350
3351 pub async fn execute_with_timeout(
3378 &mut self,
3379 sql: &str,
3380 params: &[&(dyn crate::ToSql + Sync)],
3381 timeout_duration: std::time::Duration,
3382 ) -> Result<u64> {
3383 timeout(timeout_duration, self.execute(sql, params))
3384 .await
3385 .map_err(|_| Error::CommandTimeout)?
3386 }
3387
3388 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
3395 tracing::debug!("beginning transaction");
3396
3397 #[cfg(feature = "otel")]
3398 let instrumentation = self.instrumentation.clone();
3399 #[cfg(feature = "otel")]
3400 let mut span = instrumentation.transaction_span("BEGIN");
3401
3402 let result = async {
3404 self.send_sql_batch("BEGIN TRANSACTION").await?;
3405 self.read_transaction_begin_result().await
3406 }
3407 .await;
3408
3409 #[cfg(feature = "otel")]
3410 match &result {
3411 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3412 Err(e) => InstrumentationContext::record_error(&mut span, e),
3413 }
3414
3415 #[cfg(feature = "otel")]
3417 drop(span);
3418
3419 let transaction_descriptor = result?;
3420
3421 Ok(Client {
3422 config: self.config,
3423 _state: PhantomData,
3424 connection: self.connection,
3425 server_version: self.server_version,
3426 current_database: self.current_database,
3427 statement_cache: self.statement_cache,
3428 transaction_descriptor, needs_reset: self.needs_reset,
3430 #[cfg(feature = "otel")]
3431 instrumentation: self.instrumentation,
3432 })
3433 }
3434
3435 pub async fn begin_transaction_with_isolation(
3450 mut self,
3451 isolation_level: crate::transaction::IsolationLevel,
3452 ) -> Result<Client<InTransaction>> {
3453 tracing::debug!(
3454 isolation_level = %isolation_level.name(),
3455 "beginning transaction with isolation level"
3456 );
3457
3458 #[cfg(feature = "otel")]
3459 let instrumentation = self.instrumentation.clone();
3460 #[cfg(feature = "otel")]
3461 let mut span = instrumentation.transaction_span("BEGIN");
3462
3463 let result = async {
3465 self.send_sql_batch(isolation_level.as_sql()).await?;
3466 self.read_execute_result().await?;
3467
3468 self.send_sql_batch("BEGIN TRANSACTION").await?;
3470 self.read_transaction_begin_result().await
3471 }
3472 .await;
3473
3474 #[cfg(feature = "otel")]
3475 match &result {
3476 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3477 Err(e) => InstrumentationContext::record_error(&mut span, e),
3478 }
3479
3480 #[cfg(feature = "otel")]
3481 drop(span);
3482
3483 let transaction_descriptor = result?;
3484
3485 Ok(Client {
3486 config: self.config,
3487 _state: PhantomData,
3488 connection: self.connection,
3489 server_version: self.server_version,
3490 current_database: self.current_database,
3491 statement_cache: self.statement_cache,
3492 transaction_descriptor,
3493 needs_reset: self.needs_reset,
3494 #[cfg(feature = "otel")]
3495 instrumentation: self.instrumentation,
3496 })
3497 }
3498
3499 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
3504 tracing::debug!(sql = sql, "executing simple query");
3505
3506 self.send_sql_batch(sql).await?;
3508
3509 let _ = self.read_execute_result().await?;
3511
3512 Ok(())
3513 }
3514
3515 pub async fn close(self) -> Result<()> {
3517 tracing::debug!("closing connection");
3518 Ok(())
3519 }
3520
3521 #[must_use]
3523 pub fn database(&self) -> Option<&str> {
3524 self.config.database.as_deref()
3525 }
3526
3527 #[must_use]
3529 pub fn host(&self) -> &str {
3530 &self.config.host
3531 }
3532
3533 #[must_use]
3535 pub fn port(&self) -> u16 {
3536 self.config.port
3537 }
3538
3539 #[must_use]
3558 pub fn is_in_transaction(&self) -> bool {
3559 self.transaction_descriptor != 0
3560 }
3561
3562 #[must_use]
3584 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3585 let connection = self
3586 .connection
3587 .as_ref()
3588 .expect("connection should be present");
3589 match connection {
3590 ConnectionHandle::Tls(conn) => {
3591 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3592 }
3593 ConnectionHandle::TlsPrelogin(conn) => {
3594 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3595 }
3596 ConnectionHandle::Plain(conn) => {
3597 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3598 }
3599 }
3600 }
3601}
3602
3603impl Client<InTransaction> {
3604 pub async fn query<'a>(
3608 &'a mut self,
3609 sql: &str,
3610 params: &[&(dyn crate::ToSql + Sync)],
3611 ) -> Result<QueryStream<'a>> {
3612 tracing::debug!(
3613 sql = sql,
3614 params_count = params.len(),
3615 "executing query in transaction"
3616 );
3617
3618 #[cfg(feature = "otel")]
3619 let instrumentation = self.instrumentation.clone();
3620 #[cfg(feature = "otel")]
3621 let mut span = instrumentation.query_span(sql);
3622
3623 let result = async {
3624 if params.is_empty() {
3625 self.send_sql_batch(sql).await?;
3627 } else {
3628 let rpc_params = Self::convert_params(params)?;
3630 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3631 self.send_rpc(&rpc).await?;
3632 }
3633
3634 self.read_query_response().await
3636 }
3637 .await;
3638
3639 #[cfg(feature = "otel")]
3640 match &result {
3641 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3642 Err(e) => InstrumentationContext::record_error(&mut span, e),
3643 }
3644
3645 #[cfg(feature = "otel")]
3647 drop(span);
3648
3649 let (columns, rows) = result?;
3650 Ok(QueryStream::new(columns, rows))
3651 }
3652
3653 pub async fn execute(
3657 &mut self,
3658 sql: &str,
3659 params: &[&(dyn crate::ToSql + Sync)],
3660 ) -> Result<u64> {
3661 tracing::debug!(
3662 sql = sql,
3663 params_count = params.len(),
3664 "executing statement in transaction"
3665 );
3666
3667 #[cfg(feature = "otel")]
3668 let instrumentation = self.instrumentation.clone();
3669 #[cfg(feature = "otel")]
3670 let mut span = instrumentation.query_span(sql);
3671
3672 let result = async {
3673 if params.is_empty() {
3674 self.send_sql_batch(sql).await?;
3676 } else {
3677 let rpc_params = Self::convert_params(params)?;
3679 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3680 self.send_rpc(&rpc).await?;
3681 }
3682
3683 self.read_execute_result().await
3685 }
3686 .await;
3687
3688 #[cfg(feature = "otel")]
3689 match &result {
3690 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3691 Err(e) => InstrumentationContext::record_error(&mut span, e),
3692 }
3693
3694 #[cfg(feature = "otel")]
3696 drop(span);
3697
3698 result
3699 }
3700
3701 pub async fn query_with_timeout<'a>(
3705 &'a mut self,
3706 sql: &str,
3707 params: &[&(dyn crate::ToSql + Sync)],
3708 timeout_duration: std::time::Duration,
3709 ) -> Result<QueryStream<'a>> {
3710 timeout(timeout_duration, self.query(sql, params))
3711 .await
3712 .map_err(|_| Error::CommandTimeout)?
3713 }
3714
3715 pub async fn execute_with_timeout(
3719 &mut self,
3720 sql: &str,
3721 params: &[&(dyn crate::ToSql + Sync)],
3722 timeout_duration: std::time::Duration,
3723 ) -> Result<u64> {
3724 timeout(timeout_duration, self.execute(sql, params))
3725 .await
3726 .map_err(|_| Error::CommandTimeout)?
3727 }
3728
3729 pub async fn commit(mut self) -> Result<Client<Ready>> {
3733 tracing::debug!("committing transaction");
3734
3735 #[cfg(feature = "otel")]
3736 let instrumentation = self.instrumentation.clone();
3737 #[cfg(feature = "otel")]
3738 let mut span = instrumentation.transaction_span("COMMIT");
3739
3740 let result = async {
3742 self.send_sql_batch("COMMIT TRANSACTION").await?;
3743 self.read_execute_result().await
3744 }
3745 .await;
3746
3747 #[cfg(feature = "otel")]
3748 match &result {
3749 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3750 Err(e) => InstrumentationContext::record_error(&mut span, e),
3751 }
3752
3753 #[cfg(feature = "otel")]
3755 drop(span);
3756
3757 result?;
3758
3759 Ok(Client {
3760 config: self.config,
3761 _state: PhantomData,
3762 connection: self.connection,
3763 server_version: self.server_version,
3764 current_database: self.current_database,
3765 statement_cache: self.statement_cache,
3766 transaction_descriptor: 0, needs_reset: self.needs_reset,
3768 #[cfg(feature = "otel")]
3769 instrumentation: self.instrumentation,
3770 })
3771 }
3772
3773 pub async fn rollback(mut self) -> Result<Client<Ready>> {
3777 tracing::debug!("rolling back transaction");
3778
3779 #[cfg(feature = "otel")]
3780 let instrumentation = self.instrumentation.clone();
3781 #[cfg(feature = "otel")]
3782 let mut span = instrumentation.transaction_span("ROLLBACK");
3783
3784 let result = async {
3786 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
3787 self.read_execute_result().await
3788 }
3789 .await;
3790
3791 #[cfg(feature = "otel")]
3792 match &result {
3793 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3794 Err(e) => InstrumentationContext::record_error(&mut span, e),
3795 }
3796
3797 #[cfg(feature = "otel")]
3799 drop(span);
3800
3801 result?;
3802
3803 Ok(Client {
3804 config: self.config,
3805 _state: PhantomData,
3806 connection: self.connection,
3807 server_version: self.server_version,
3808 current_database: self.current_database,
3809 statement_cache: self.statement_cache,
3810 transaction_descriptor: 0, needs_reset: self.needs_reset,
3812 #[cfg(feature = "otel")]
3813 instrumentation: self.instrumentation,
3814 })
3815 }
3816
3817 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3834 validate_identifier(name)?;
3835 tracing::debug!(name = name, "creating savepoint");
3836
3837 let sql = format!("SAVE TRANSACTION {}", name);
3840 self.send_sql_batch(&sql).await?;
3841 self.read_execute_result().await?;
3842
3843 Ok(SavePoint::new(name.to_string()))
3844 }
3845
3846 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3861 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3862
3863 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3866 self.send_sql_batch(&sql).await?;
3867 self.read_execute_result().await?;
3868
3869 Ok(())
3870 }
3871
3872 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3878 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3879
3880 drop(savepoint);
3884 Ok(())
3885 }
3886
3887 #[must_use]
3891 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3892 let connection = self
3893 .connection
3894 .as_ref()
3895 .expect("connection should be present");
3896 match connection {
3897 ConnectionHandle::Tls(conn) => {
3898 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3899 }
3900 ConnectionHandle::TlsPrelogin(conn) => {
3901 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3902 }
3903 ConnectionHandle::Plain(conn) => {
3904 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3905 }
3906 }
3907 }
3908}
3909
3910fn validate_identifier(name: &str) -> Result<()> {
3912 use once_cell::sync::Lazy;
3913 use regex::Regex;
3914
3915 static IDENTIFIER_RE: Lazy<Regex> =
3916 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3917
3918 if name.is_empty() {
3919 return Err(Error::InvalidIdentifier(
3920 "identifier cannot be empty".into(),
3921 ));
3922 }
3923
3924 if !IDENTIFIER_RE.is_match(name) {
3925 return Err(Error::InvalidIdentifier(format!(
3926 "invalid identifier '{}': must start with letter/underscore, \
3927 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3928 name
3929 )));
3930 }
3931
3932 Ok(())
3933}
3934
3935impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3936 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3937 f.debug_struct("Client")
3938 .field("host", &self.config.host)
3939 .field("port", &self.config.port)
3940 .field("database", &self.config.database)
3941 .finish()
3942 }
3943}
3944
3945#[cfg(test)]
3946#[allow(clippy::unwrap_used, clippy::panic)]
3947mod tests {
3948 use super::*;
3949
3950 #[test]
3951 fn test_validate_identifier_valid() {
3952 assert!(validate_identifier("my_table").is_ok());
3953 assert!(validate_identifier("Table123").is_ok());
3954 assert!(validate_identifier("_private").is_ok());
3955 assert!(validate_identifier("sp_test").is_ok());
3956 }
3957
3958 #[test]
3959 fn test_validate_identifier_invalid() {
3960 assert!(validate_identifier("").is_err());
3961 assert!(validate_identifier("123abc").is_err());
3962 assert!(validate_identifier("table-name").is_err());
3963 assert!(validate_identifier("table name").is_err());
3964 assert!(validate_identifier("table;DROP TABLE users").is_err());
3965 }
3966
3967 fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
3976 let mut data = Vec::new();
3977 data.extend_from_slice(&total_len.to_le_bytes());
3979 for chunk in chunks {
3981 let len = chunk.len() as u32;
3982 data.extend_from_slice(&len.to_le_bytes());
3983 data.extend_from_slice(chunk);
3984 }
3985 data.extend_from_slice(&0u32.to_le_bytes());
3987 data
3988 }
3989
3990 #[test]
3991 fn test_parse_plp_nvarchar_simple() {
3992 let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
3994 let plp = make_plp_data(10, &[&utf16_data]);
3995 let mut buf: &[u8] = &plp;
3996
3997 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3998 match result {
3999 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
4000 _ => panic!("expected String, got {:?}", result),
4001 }
4002 }
4003
4004 #[test]
4005 fn test_parse_plp_nvarchar_null() {
4006 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4008 let mut buf: &[u8] = &plp;
4009
4010 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4011 assert!(matches!(result, mssql_types::SqlValue::Null));
4012 }
4013
4014 #[test]
4015 fn test_parse_plp_nvarchar_empty() {
4016 let plp = make_plp_data(0, &[]);
4018 let mut buf: &[u8] = &plp;
4019
4020 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4021 match result {
4022 mssql_types::SqlValue::String(s) => assert_eq!(s, ""),
4023 _ => panic!("expected empty String"),
4024 }
4025 }
4026
4027 #[test]
4028 fn test_parse_plp_nvarchar_multi_chunk() {
4029 let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
4033 let mut buf: &[u8] = &plp;
4034
4035 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
4036 match result {
4037 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
4038 _ => panic!("expected String"),
4039 }
4040 }
4041
4042 #[test]
4043 fn test_parse_plp_varchar_simple() {
4044 let data = b"Hello World";
4045 let plp = make_plp_data(11, &[data]);
4046 let mut buf: &[u8] = &plp;
4047
4048 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
4049 match result {
4050 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello World"),
4051 _ => panic!("expected String"),
4052 }
4053 }
4054
4055 #[test]
4056 fn test_parse_plp_varchar_null() {
4057 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4058 let mut buf: &[u8] = &plp;
4059
4060 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
4061 assert!(matches!(result, mssql_types::SqlValue::Null));
4062 }
4063
4064 #[test]
4065 fn test_parse_plp_varbinary_simple() {
4066 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
4067 let plp = make_plp_data(5, &[&data]);
4068 let mut buf: &[u8] = &plp;
4069
4070 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4071 match result {
4072 mssql_types::SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
4073 _ => panic!("expected Binary"),
4074 }
4075 }
4076
4077 #[test]
4078 fn test_parse_plp_varbinary_null() {
4079 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4080 let mut buf: &[u8] = &plp;
4081
4082 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4083 assert!(matches!(result, mssql_types::SqlValue::Null));
4084 }
4085
4086 #[test]
4087 fn test_parse_plp_varbinary_large() {
4088 let chunk1: Vec<u8> = (0..100u8).collect();
4090 let chunk2: Vec<u8> = (100..200u8).collect();
4091 let chunk3: Vec<u8> = (200..255u8).collect();
4092 let total_len = chunk1.len() + chunk2.len() + chunk3.len();
4093 let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
4094 let mut buf: &[u8] = &plp;
4095
4096 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4097 match result {
4098 mssql_types::SqlValue::Binary(b) => {
4099 assert_eq!(b.len(), 255);
4100 for (i, &byte) in b.iter().enumerate() {
4102 assert_eq!(byte, i as u8);
4103 }
4104 }
4105 _ => panic!("expected Binary"),
4106 }
4107 }
4108
4109 use tds_protocol::token::{ColumnData, TypeInfo};
4117 use tds_protocol::types::TypeId;
4118
4119 fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
4122 let mut data = Vec::new();
4123
4124 let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
4126 let byte_len = (utf16.len() * 2) as u16;
4127 data.extend_from_slice(&byte_len.to_le_bytes());
4128 for code_unit in utf16 {
4129 data.extend_from_slice(&code_unit.to_le_bytes());
4130 }
4131
4132 data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
4135
4136 data
4137 }
4138
4139 #[test]
4140 fn test_parse_row_nvarchar_then_int() {
4141 let raw_data = make_nvarchar_int_row("World", 42);
4143
4144 let col0 = ColumnData {
4146 name: "greeting".to_string(),
4147 type_id: TypeId::NVarChar,
4148 col_type: 0xE7,
4149 flags: 0x01,
4150 user_type: 0,
4151 type_info: TypeInfo {
4152 max_length: Some(10), precision: None,
4154 scale: None,
4155 collation: None,
4156 },
4157 };
4158
4159 let col1 = ColumnData {
4160 name: "number".to_string(),
4161 type_id: TypeId::IntN,
4162 col_type: 0x26,
4163 flags: 0x01,
4164 user_type: 0,
4165 type_info: TypeInfo {
4166 max_length: Some(4),
4167 precision: None,
4168 scale: None,
4169 collation: None,
4170 },
4171 };
4172
4173 let mut buf: &[u8] = &raw_data;
4174
4175 let value0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4177 match value0 {
4178 mssql_types::SqlValue::String(s) => assert_eq!(s, "World"),
4179 _ => panic!("expected String, got {:?}", value0),
4180 }
4181
4182 let value1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4184 match value1 {
4185 mssql_types::SqlValue::Int(i) => assert_eq!(i, 42),
4186 _ => panic!("expected Int, got {:?}", value1),
4187 }
4188
4189 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4191 }
4192
4193 #[test]
4194 fn test_parse_row_multiple_types() {
4195 let mut data = Vec::new();
4197
4198 data.extend_from_slice(&0xFFFFu16.to_le_bytes());
4200
4201 data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
4204
4205 let utf16: Vec<u16> = "Test".encode_utf16().collect();
4207 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4208 for code_unit in utf16 {
4209 data.extend_from_slice(&code_unit.to_le_bytes());
4210 }
4211
4212 data.push(0);
4214
4215 let col0 = ColumnData {
4217 name: "col0".to_string(),
4218 type_id: TypeId::NVarChar,
4219 col_type: 0xE7,
4220 flags: 0x01,
4221 user_type: 0,
4222 type_info: TypeInfo {
4223 max_length: Some(100),
4224 precision: None,
4225 scale: None,
4226 collation: None,
4227 },
4228 };
4229 let col1 = ColumnData {
4230 name: "col1".to_string(),
4231 type_id: TypeId::IntN,
4232 col_type: 0x26,
4233 flags: 0x01,
4234 user_type: 0,
4235 type_info: TypeInfo {
4236 max_length: Some(4),
4237 precision: None,
4238 scale: None,
4239 collation: None,
4240 },
4241 };
4242 let col2 = col0.clone();
4243 let col3 = col1.clone();
4244
4245 let mut buf: &[u8] = &data;
4246
4247 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4249 assert!(
4250 matches!(v0, mssql_types::SqlValue::Null),
4251 "col0 should be Null"
4252 );
4253
4254 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4255 assert!(
4256 matches!(v1, mssql_types::SqlValue::Int(123)),
4257 "col1 should be 123"
4258 );
4259
4260 let v2 = Client::<Ready>::parse_column_value(&mut buf, &col2).unwrap();
4261 match v2 {
4262 mssql_types::SqlValue::String(s) => assert_eq!(s, "Test"),
4263 _ => panic!("col2 should be 'Test'"),
4264 }
4265
4266 let v3 = Client::<Ready>::parse_column_value(&mut buf, &col3).unwrap();
4267 assert!(
4268 matches!(v3, mssql_types::SqlValue::Null),
4269 "col3 should be Null"
4270 );
4271
4272 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4274 }
4275
4276 #[test]
4277 fn test_parse_row_with_unicode() {
4278 let test_str = "Héllo Wörld 日本語";
4280 let mut data = Vec::new();
4281
4282 let utf16: Vec<u16> = test_str.encode_utf16().collect();
4284 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4285 for code_unit in utf16 {
4286 data.extend_from_slice(&code_unit.to_le_bytes());
4287 }
4288
4289 data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
4292
4293 let col0 = ColumnData {
4294 name: "text".to_string(),
4295 type_id: TypeId::NVarChar,
4296 col_type: 0xE7,
4297 flags: 0x01,
4298 user_type: 0,
4299 type_info: TypeInfo {
4300 max_length: Some(100),
4301 precision: None,
4302 scale: None,
4303 collation: None,
4304 },
4305 };
4306 let col1 = ColumnData {
4307 name: "num".to_string(),
4308 type_id: TypeId::IntN,
4309 col_type: 0x26,
4310 flags: 0x01,
4311 user_type: 0,
4312 type_info: TypeInfo {
4313 max_length: Some(8),
4314 precision: None,
4315 scale: None,
4316 collation: None,
4317 },
4318 };
4319
4320 let mut buf: &[u8] = &data;
4321
4322 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4323 match v0 {
4324 mssql_types::SqlValue::String(s) => assert_eq!(s, test_str),
4325 _ => panic!("expected String"),
4326 }
4327
4328 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4329 match v1 {
4330 mssql_types::SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
4331 _ => panic!("expected BigInt"),
4332 }
4333
4334 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4335 }
4336}