1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token, TokenParser,
19};
20use tds_protocol::tvp::{
21 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
22 encode_tvp_decimal, encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar,
23 encode_tvp_varbinary,
24};
25use tokio::net::TcpStream;
26use tokio::time::timeout;
27
28use crate::config::Config;
29use crate::error::{Error, Result};
30#[cfg(feature = "otel")]
31use crate::instrumentation::InstrumentationContext;
32use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
33use crate::statement_cache::StatementCache;
34use crate::stream::{MultiResultStream, QueryStream};
35use crate::transaction::SavePoint;
36
37pub struct Client<S: ConnectionState> {
43 config: Config,
44 _state: PhantomData<S>,
45 connection: Option<ConnectionHandle>,
47 server_version: Option<u32>,
49 current_database: Option<String>,
51 statement_cache: StatementCache,
53 transaction_descriptor: u64,
57 #[cfg(feature = "otel")]
59 instrumentation: InstrumentationContext,
60}
61
62#[allow(dead_code)] enum ConnectionHandle {
70 Tls(Connection<TlsStream<TcpStream>>),
72 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
74 Plain(Connection<TcpStream>),
76}
77
78impl Client<Disconnected> {
79 pub async fn connect(config: Config) -> Result<Client<Ready>> {
90 let max_redirects = config.redirect.max_redirects;
91 let follow_redirects = config.redirect.follow_redirects;
92 let mut attempts = 0;
93 let mut current_config = config;
94
95 loop {
96 attempts += 1;
97 if attempts > max_redirects + 1 {
98 return Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config).await {
102 Ok(client) => return Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 return Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 attempt = attempts,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => return Err(e),
118 }
119 }
120 }
121
122 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
123 tracing::info!(
124 host = %config.host,
125 port = config.port,
126 database = ?config.database,
127 "connecting to SQL Server"
128 );
129
130 let addr = format!("{}:{}", config.host, config.port);
131
132 tracing::debug!("establishing TCP connection to {}", addr);
134 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
135 .await
136 .map_err(|_| Error::ConnectTimeout)?
137 .map_err(|e| Error::Io(Arc::new(e)))?;
138
139 tcp_stream
141 .set_nodelay(true)
142 .map_err(|e| Error::Io(Arc::new(e)))?;
143
144 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
146
147 if tls_mode.is_tls_first() {
149 return Self::connect_tds_8(config, tcp_stream).await;
150 }
151
152 Self::connect_tds_7x(config, tcp_stream).await
154 }
155
156 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
160 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
161
162 let tls_config = TlsConfig::new()
164 .strict_mode(true)
165 .trust_server_certificate(config.trust_server_certificate);
166
167 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
168
169 let tls_stream = timeout(
171 config.timeouts.tls_timeout,
172 tls_connector.connect(tcp_stream, &config.host),
173 )
174 .await
175 .map_err(|_| Error::TlsTimeout)?
176 .map_err(|e| Error::Tls(e.to_string()))?;
177
178 tracing::debug!("TLS handshake completed (strict mode)");
179
180 let mut connection = Connection::new(tls_stream);
182
183 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
185 Self::send_prelogin(&mut connection, &prelogin).await?;
186 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
187
188 let login = Self::build_login7(config);
190 Self::send_login7(&mut connection, &login).await?;
191
192 let (server_version, current_database, routing) =
194 Self::process_login_response(&mut connection).await?;
195
196 if let Some((host, port)) = routing {
198 return Err(Error::Routing { host, port });
199 }
200
201 Ok(Client {
202 config: config.clone(),
203 _state: PhantomData,
204 connection: Some(ConnectionHandle::Tls(connection)),
205 server_version,
206 current_database: current_database.clone(),
207 statement_cache: StatementCache::with_default_size(),
208 transaction_descriptor: 0, #[cfg(feature = "otel")]
210 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
211 .with_database(current_database.unwrap_or_default()),
212 })
213 }
214
215 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
223 use bytes::BufMut;
224 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
225 use tokio::io::{AsyncReadExt, AsyncWriteExt};
226
227 tracing::debug!("using TDS 7.x flow (PreLogin first)");
228
229 let client_encryption = if config.encrypt {
232 EncryptionLevel::On
233 } else {
234 EncryptionLevel::Off
235 };
236 let prelogin = Self::build_prelogin(config, client_encryption);
237 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
238 let prelogin_bytes = prelogin.encode();
239
240 let header = PacketHeader::new(
242 PacketType::PreLogin,
243 PacketStatus::END_OF_MESSAGE,
244 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
245 );
246
247 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
248 header.encode(&mut packet_buf);
249 packet_buf.put_slice(&prelogin_bytes);
250
251 tcp_stream
252 .write_all(&packet_buf)
253 .await
254 .map_err(|e| Error::Io(Arc::new(e)))?;
255
256 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
258 tcp_stream
259 .read_exact(&mut header_buf)
260 .await
261 .map_err(|e| Error::Io(Arc::new(e)))?;
262
263 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
264 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
265
266 let mut response_buf = vec![0u8; payload_length];
267 tcp_stream
268 .read_exact(&mut response_buf)
269 .await
270 .map_err(|e| Error::Io(Arc::new(e)))?;
271
272 let prelogin_response =
273 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
274
275 let server_version = prelogin_response.version;
277 let client_version = config.tds_version;
278 tracing::debug!(
279 client_version = %client_version,
280 server_version = %server_version,
281 server_sql_version = server_version.sql_server_version_name(),
282 "TDS version negotiation"
283 );
284
285 if server_version < client_version && !client_version.is_tds_8() {
287 tracing::warn!(
288 client_version = %client_version,
289 server_version = %server_version,
290 "Server supports lower TDS version than requested. \
291 Connection will use server's version: {}",
292 server_version.sql_server_version_name()
293 );
294 }
295
296 if server_version.is_legacy() {
298 tracing::warn!(
299 server_version = %server_version,
300 "Server uses legacy TDS version ({}). \
301 Some features may not be available.",
302 server_version.sql_server_version_name()
303 );
304 }
305
306 let server_encryption = prelogin_response.encryption;
308 tracing::debug!(encryption = ?server_encryption, "server encryption level");
309
310 let negotiated_encryption = match (client_encryption, server_encryption) {
316 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
317 EncryptionLevel::NotSupported
318 }
319 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
320 (EncryptionLevel::On, EncryptionLevel::Off)
321 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
322 return Err(Error::Protocol(
323 "Server does not support requested encryption level".to_string(),
324 ));
325 }
326 _ => EncryptionLevel::On,
327 };
328
329 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
332
333 if use_tls {
334 let tls_config =
337 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
338
339 let tls_connector =
340 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
341
342 let mut tls_stream = timeout(
344 config.timeouts.tls_timeout,
345 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
346 )
347 .await
348 .map_err(|_| Error::TlsTimeout)?
349 .map_err(|e| Error::Tls(e.to_string()))?;
350
351 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
352
353 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
355
356 if login_only_encryption {
357 use tokio::io::AsyncWriteExt;
365
366 let login = Self::build_login7(config);
368 let login_payload = login.encode();
369
370 let max_packet = MAX_PACKET_SIZE;
372 let max_payload = max_packet - PACKET_HEADER_SIZE;
373 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
374 let total_chunks = chunks.len();
375
376 for (i, chunk) in chunks.into_iter().enumerate() {
377 let is_last = i == total_chunks - 1;
378 let status = if is_last {
379 PacketStatus::END_OF_MESSAGE
380 } else {
381 PacketStatus::NORMAL
382 };
383
384 let header = PacketHeader::new(
385 PacketType::Tds7Login,
386 status,
387 (PACKET_HEADER_SIZE + chunk.len()) as u16,
388 );
389
390 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
391 header.encode(&mut packet_buf);
392 packet_buf.put_slice(chunk);
393
394 tls_stream
395 .write_all(&packet_buf)
396 .await
397 .map_err(|e| Error::Io(Arc::new(e)))?;
398 }
399
400 tls_stream
402 .flush()
403 .await
404 .map_err(|e| Error::Io(Arc::new(e)))?;
405
406 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
407
408 let (wrapper, _client_conn) = tls_stream.into_inner();
412 let tcp_stream = wrapper.into_inner();
413
414 let mut connection = Connection::new(tcp_stream);
416
417 let (server_version, current_database, routing) =
419 Self::process_login_response(&mut connection).await?;
420
421 if let Some((host, port)) = routing {
423 return Err(Error::Routing { host, port });
424 }
425
426 Ok(Client {
428 config: config.clone(),
429 _state: PhantomData,
430 connection: Some(ConnectionHandle::Plain(connection)),
431 server_version,
432 current_database: current_database.clone(),
433 statement_cache: StatementCache::with_default_size(),
434 transaction_descriptor: 0, #[cfg(feature = "otel")]
436 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
437 .with_database(current_database.unwrap_or_default()),
438 })
439 } else {
440 let mut connection = Connection::new(tls_stream);
443
444 let login = Self::build_login7(config);
446 Self::send_login7(&mut connection, &login).await?;
447
448 let (server_version, current_database, routing) =
450 Self::process_login_response(&mut connection).await?;
451
452 if let Some((host, port)) = routing {
454 return Err(Error::Routing { host, port });
455 }
456
457 Ok(Client {
458 config: config.clone(),
459 _state: PhantomData,
460 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
461 server_version,
462 current_database: current_database.clone(),
463 statement_cache: StatementCache::with_default_size(),
464 transaction_descriptor: 0, #[cfg(feature = "otel")]
466 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
467 .with_database(current_database.unwrap_or_default()),
468 })
469 }
470 } else {
471 tracing::warn!(
473 "Connecting without TLS encryption. This is insecure and should only be \
474 used for development/testing on trusted networks."
475 );
476
477 let login = Self::build_login7(config);
479 let login_bytes = login.encode();
480 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
481 tracing::debug!(
483 "Login7 fixed header (94 bytes): {:02X?}",
484 &login_bytes[..login_bytes.len().min(94)]
485 );
486 if login_bytes.len() > 94 {
488 tracing::debug!(
489 "Login7 variable data ({} bytes): {:02X?}",
490 login_bytes.len() - 94,
491 &login_bytes[94..]
492 );
493 }
494
495 let login_header = PacketHeader::new(
497 PacketType::Tds7Login,
498 PacketStatus::END_OF_MESSAGE,
499 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
500 )
501 .with_packet_id(1);
502 let mut login_packet_buf =
503 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
504 login_header.encode(&mut login_packet_buf);
505 login_packet_buf.put_slice(&login_bytes);
506
507 tracing::debug!(
508 "Sending Login7 packet: {} bytes total, header: {:02X?}",
509 login_packet_buf.len(),
510 &login_packet_buf[..PACKET_HEADER_SIZE]
511 );
512 tcp_stream
513 .write_all(&login_packet_buf)
514 .await
515 .map_err(|e| Error::Io(Arc::new(e)))?;
516 tcp_stream
517 .flush()
518 .await
519 .map_err(|e| Error::Io(Arc::new(e)))?;
520 tracing::debug!("Login7 sent and flushed over raw TCP");
521
522 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
524 tcp_stream
525 .read_exact(&mut response_header_buf)
526 .await
527 .map_err(|e| Error::Io(Arc::new(e)))?;
528
529 let response_type = response_header_buf[0];
530 let response_length =
531 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
532 tracing::debug!(
533 "Response header: type={:#04X}, length={}",
534 response_type,
535 response_length
536 );
537
538 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
540 let mut response_payload = vec![0u8; payload_length];
541 tcp_stream
542 .read_exact(&mut response_payload)
543 .await
544 .map_err(|e| Error::Io(Arc::new(e)))?;
545 tracing::debug!(
546 "Response payload: {} bytes, first 32: {:02X?}",
547 response_payload.len(),
548 &response_payload[..response_payload.len().min(32)]
549 );
550
551 let connection = Connection::new(tcp_stream);
553
554 let response_bytes = bytes::Bytes::from(response_payload);
556 let mut parser = TokenParser::new(response_bytes);
557 let mut server_version = None;
558 let mut current_database = None;
559 let routing = None;
560
561 while let Some(token) = parser
562 .next_token()
563 .map_err(|e| Error::Protocol(e.to_string()))?
564 {
565 match token {
566 Token::LoginAck(ack) => {
567 tracing::info!(
568 version = ack.tds_version,
569 interface = ack.interface,
570 prog_name = %ack.prog_name,
571 "login acknowledged"
572 );
573 server_version = Some(ack.tds_version);
574 }
575 Token::EnvChange(env) => {
576 Self::process_env_change(&env, &mut current_database, &mut None);
577 }
578 Token::Error(err) => {
579 return Err(Error::Server {
580 number: err.number,
581 state: err.state,
582 class: err.class,
583 message: err.message.clone(),
584 server: if err.server.is_empty() {
585 None
586 } else {
587 Some(err.server.clone())
588 },
589 procedure: if err.procedure.is_empty() {
590 None
591 } else {
592 Some(err.procedure.clone())
593 },
594 line: err.line as u32,
595 });
596 }
597 Token::Info(info) => {
598 tracing::info!(
599 number = info.number,
600 message = %info.message,
601 "server info message"
602 );
603 }
604 Token::Done(done) => {
605 if done.status.error {
606 return Err(Error::Protocol("login failed".to_string()));
607 }
608 break;
609 }
610 _ => {}
611 }
612 }
613
614 if let Some((host, port)) = routing {
616 return Err(Error::Routing { host, port });
617 }
618
619 Ok(Client {
620 config: config.clone(),
621 _state: PhantomData,
622 connection: Some(ConnectionHandle::Plain(connection)),
623 server_version,
624 current_database: current_database.clone(),
625 statement_cache: StatementCache::with_default_size(),
626 transaction_descriptor: 0, #[cfg(feature = "otel")]
628 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
629 .with_database(current_database.unwrap_or_default()),
630 })
631 }
632 }
633
634 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
636 let version = if config.strict_mode {
638 tds_protocol::version::TdsVersion::V8_0
639 } else {
640 config.tds_version
641 };
642
643 let mut prelogin = PreLogin::new()
644 .with_version(version)
645 .with_encryption(encryption);
646
647 if config.mars {
648 prelogin = prelogin.with_mars(true);
649 }
650
651 if let Some(ref instance) = config.instance {
652 prelogin = prelogin.with_instance(instance);
653 }
654
655 prelogin
656 }
657
658 fn build_login7(config: &Config) -> Login7 {
660 let version = if config.strict_mode {
662 tds_protocol::version::TdsVersion::V8_0
663 } else {
664 config.tds_version
665 };
666
667 let mut login = Login7::new()
668 .with_tds_version(version)
669 .with_packet_size(config.packet_size as u32)
670 .with_app_name(&config.application_name)
671 .with_server_name(&config.host)
672 .with_hostname(&config.host);
673
674 if let Some(ref database) = config.database {
675 login = login.with_database(database);
676 }
677
678 match &config.credentials {
680 mssql_auth::Credentials::SqlServer { username, password } => {
681 login = login.with_sql_auth(username.as_ref(), password.as_ref());
682 }
683 _ => {}
685 }
686
687 login
688 }
689
690 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
692 where
693 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
694 {
695 let payload = prelogin.encode();
696 let max_packet = MAX_PACKET_SIZE;
697
698 connection
699 .send_message(PacketType::PreLogin, payload, max_packet)
700 .await
701 .map_err(|e| Error::Protocol(e.to_string()))
702 }
703
704 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
706 where
707 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
708 {
709 let message = connection
710 .read_message()
711 .await
712 .map_err(|e| Error::Protocol(e.to_string()))?
713 .ok_or(Error::ConnectionClosed)?;
714
715 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
716 }
717
718 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
720 where
721 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
722 {
723 let payload = login.encode();
724 let max_packet = MAX_PACKET_SIZE;
725
726 connection
727 .send_message(PacketType::Tds7Login, payload, max_packet)
728 .await
729 .map_err(|e| Error::Protocol(e.to_string()))
730 }
731
732 async fn process_login_response<T>(
736 connection: &mut Connection<T>,
737 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
738 where
739 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
740 {
741 let message = connection
742 .read_message()
743 .await
744 .map_err(|e| Error::Protocol(e.to_string()))?
745 .ok_or(Error::ConnectionClosed)?;
746
747 let response_bytes = message.payload;
748
749 let mut parser = TokenParser::new(response_bytes);
750 let mut server_version = None;
751 let mut database = None;
752 let mut routing = None;
753
754 while let Some(token) = parser
755 .next_token()
756 .map_err(|e| Error::Protocol(e.to_string()))?
757 {
758 match token {
759 Token::LoginAck(ack) => {
760 tracing::info!(
761 version = ack.tds_version,
762 interface = ack.interface,
763 prog_name = %ack.prog_name,
764 "login acknowledged"
765 );
766 server_version = Some(ack.tds_version);
767 }
768 Token::EnvChange(env) => {
769 Self::process_env_change(&env, &mut database, &mut routing);
770 }
771 Token::Error(err) => {
772 return Err(Error::Server {
773 number: err.number,
774 state: err.state,
775 class: err.class,
776 message: err.message.clone(),
777 server: if err.server.is_empty() {
778 None
779 } else {
780 Some(err.server.clone())
781 },
782 procedure: if err.procedure.is_empty() {
783 None
784 } else {
785 Some(err.procedure.clone())
786 },
787 line: err.line as u32,
788 });
789 }
790 Token::Info(info) => {
791 tracing::info!(
792 number = info.number,
793 message = %info.message,
794 "server info message"
795 );
796 }
797 Token::Done(done) => {
798 if done.status.error {
799 return Err(Error::Protocol("login failed".to_string()));
800 }
801 break;
802 }
803 _ => {}
804 }
805 }
806
807 Ok((server_version, database, routing))
808 }
809
810 fn process_env_change(
812 env: &EnvChange,
813 database: &mut Option<String>,
814 routing: &mut Option<(String, u16)>,
815 ) {
816 use tds_protocol::token::EnvChangeValue;
817
818 match env.env_type {
819 EnvChangeType::Database => {
820 if let EnvChangeValue::String(ref new_value) = env.new_value {
821 tracing::debug!(database = %new_value, "database changed");
822 *database = Some(new_value.clone());
823 }
824 }
825 EnvChangeType::Routing => {
826 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
827 tracing::info!(host = %host, port = port, "routing redirect received");
828 *routing = Some((host.clone(), port));
829 }
830 }
831 _ => {
832 if let EnvChangeValue::String(ref new_value) = env.new_value {
833 tracing::debug!(
834 env_type = ?env.env_type,
835 new_value = %new_value,
836 "environment change"
837 );
838 }
839 }
840 }
841 }
842}
843
844impl<S: ConnectionState> Client<S> {
846 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
854 use tds_protocol::token::EnvChangeValue;
855
856 match env.env_type {
857 EnvChangeType::BeginTransaction => {
858 if let EnvChangeValue::Binary(ref data) = env.new_value {
859 if data.len() >= 8 {
860 let descriptor = u64::from_le_bytes([
861 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
862 ]);
863 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
864 *transaction_descriptor = descriptor;
865 }
866 }
867 }
868 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
869 tracing::debug!(
870 env_type = ?env.env_type,
871 "transaction ended via raw SQL"
872 );
873 *transaction_descriptor = 0;
874 }
875 _ => {}
876 }
877 }
878
879 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
885 let payload =
886 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
887 let max_packet = self.config.packet_size as usize;
888
889 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
890
891 match connection {
892 ConnectionHandle::Tls(conn) => {
893 conn.send_message(PacketType::SqlBatch, payload, max_packet)
894 .await
895 .map_err(|e| Error::Protocol(e.to_string()))?;
896 }
897 ConnectionHandle::TlsPrelogin(conn) => {
898 conn.send_message(PacketType::SqlBatch, payload, max_packet)
899 .await
900 .map_err(|e| Error::Protocol(e.to_string()))?;
901 }
902 ConnectionHandle::Plain(conn) => {
903 conn.send_message(PacketType::SqlBatch, payload, max_packet)
904 .await
905 .map_err(|e| Error::Protocol(e.to_string()))?;
906 }
907 }
908
909 Ok(())
910 }
911
912 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
916 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
917 let max_packet = self.config.packet_size as usize;
918
919 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
920
921 match connection {
922 ConnectionHandle::Tls(conn) => {
923 conn.send_message(PacketType::Rpc, payload, max_packet)
924 .await
925 .map_err(|e| Error::Protocol(e.to_string()))?;
926 }
927 ConnectionHandle::TlsPrelogin(conn) => {
928 conn.send_message(PacketType::Rpc, payload, max_packet)
929 .await
930 .map_err(|e| Error::Protocol(e.to_string()))?;
931 }
932 ConnectionHandle::Plain(conn) => {
933 conn.send_message(PacketType::Rpc, payload, max_packet)
934 .await
935 .map_err(|e| Error::Protocol(e.to_string()))?;
936 }
937 }
938
939 Ok(())
940 }
941
942 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
944 use bytes::{BufMut, BytesMut};
945 use mssql_types::SqlValue;
946
947 params
948 .iter()
949 .enumerate()
950 .map(|(i, p)| {
951 let sql_value = p.to_sql()?;
952 let name = format!("@p{}", i + 1);
953
954 Ok(match sql_value {
955 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
956 SqlValue::Bool(v) => {
957 let mut buf = BytesMut::with_capacity(1);
958 buf.put_u8(if v { 1 } else { 0 });
959 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
960 }
961 SqlValue::TinyInt(v) => {
962 let mut buf = BytesMut::with_capacity(1);
963 buf.put_u8(v);
964 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
965 }
966 SqlValue::SmallInt(v) => {
967 let mut buf = BytesMut::with_capacity(2);
968 buf.put_i16_le(v);
969 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
970 }
971 SqlValue::Int(v) => RpcParam::int(&name, v),
972 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
973 SqlValue::Float(v) => {
974 let mut buf = BytesMut::with_capacity(4);
975 buf.put_f32_le(v);
976 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
977 }
978 SqlValue::Double(v) => {
979 let mut buf = BytesMut::with_capacity(8);
980 buf.put_f64_le(v);
981 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
982 }
983 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
984 SqlValue::Binary(ref b) => {
985 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
986 }
987 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
988 #[cfg(feature = "uuid")]
989 SqlValue::Uuid(u) => {
990 let bytes = u.as_bytes();
992 let mut buf = BytesMut::with_capacity(16);
993 buf.put_u32_le(u32::from_be_bytes([
995 bytes[0], bytes[1], bytes[2], bytes[3],
996 ]));
997 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
998 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
999 buf.put_slice(&bytes[8..16]);
1000 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
1001 }
1002 #[cfg(feature = "decimal")]
1003 SqlValue::Decimal(d) => {
1004 RpcParam::nvarchar(&name, &d.to_string())
1006 }
1007 #[cfg(feature = "chrono")]
1008 SqlValue::Date(_)
1009 | SqlValue::Time(_)
1010 | SqlValue::DateTime(_)
1011 | SqlValue::DateTimeOffset(_) => {
1012 let s = match &sql_value {
1015 SqlValue::Date(d) => d.to_string(),
1016 SqlValue::Time(t) => t.to_string(),
1017 SqlValue::DateTime(dt) => dt.to_string(),
1018 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
1019 _ => unreachable!(),
1020 };
1021 RpcParam::nvarchar(&name, &s)
1022 }
1023 #[cfg(feature = "json")]
1024 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
1025 SqlValue::Tvp(ref tvp_data) => {
1026 Self::encode_tvp_param(&name, tvp_data)?
1028 }
1029 _ => {
1031 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
1032 from: sql_value.type_name().to_string(),
1033 to: "RPC parameter",
1034 }));
1035 }
1036 })
1037 })
1038 .collect()
1039 }
1040
1041 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
1046 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1048 .columns
1049 .iter()
1050 .map(|col| {
1051 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1052 TvpWireColumnDef {
1053 wire_type,
1054 flags: TvpColumnFlags {
1055 nullable: col.nullable,
1056 },
1057 }
1058 })
1059 .collect();
1060
1061 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1063
1064 let mut buf = BytesMut::with_capacity(256);
1066
1067 encoder.encode_metadata(&mut buf);
1069
1070 for row in &tvp_data.rows {
1072 encoder.encode_row(&mut buf, |row_buf| {
1073 for (col_idx, value) in row.iter().enumerate() {
1074 let wire_type = &wire_columns[col_idx].wire_type;
1075 Self::encode_tvp_value(value, wire_type, row_buf);
1076 }
1077 });
1078 }
1079
1080 encoder.encode_end(&mut buf);
1082
1083 let type_info = RpcTypeInfo {
1087 type_id: 0xF3, max_length: None,
1089 precision: None,
1090 scale: None,
1091 collation: None,
1092 };
1093
1094 Ok(RpcParam {
1095 name: name.to_string(),
1096 flags: tds_protocol::rpc::ParamFlags::default(),
1097 type_info,
1098 value: Some(buf.freeze()),
1099 })
1100 }
1101
1102 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1104 match col_type {
1105 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1106 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1107 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1108 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1109 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1110 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1111 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1112 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1113 precision: *precision,
1114 scale: *scale,
1115 },
1116 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1117 max_length: *max_length,
1118 },
1119 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1120 max_length: *max_length,
1121 },
1122 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1123 max_length: *max_length,
1124 },
1125 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1126 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1127 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1128 mssql_types::TvpColumnType::DateTime2 { scale } => {
1129 TvpWireType::DateTime2 { scale: *scale }
1130 }
1131 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1132 TvpWireType::DateTimeOffset { scale: *scale }
1133 }
1134 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1135 }
1136 }
1137
1138 fn encode_tvp_value(
1140 value: &mssql_types::SqlValue,
1141 wire_type: &TvpWireType,
1142 buf: &mut BytesMut,
1143 ) {
1144 use mssql_types::SqlValue;
1145
1146 match value {
1147 SqlValue::Null => {
1148 encode_tvp_null(wire_type, buf);
1149 }
1150 SqlValue::Bool(v) => {
1151 encode_tvp_bit(*v, buf);
1152 }
1153 SqlValue::TinyInt(v) => {
1154 encode_tvp_int(*v as i64, 1, buf);
1155 }
1156 SqlValue::SmallInt(v) => {
1157 encode_tvp_int(*v as i64, 2, buf);
1158 }
1159 SqlValue::Int(v) => {
1160 encode_tvp_int(*v as i64, 4, buf);
1161 }
1162 SqlValue::BigInt(v) => {
1163 encode_tvp_int(*v, 8, buf);
1164 }
1165 SqlValue::Float(v) => {
1166 encode_tvp_float(*v as f64, 4, buf);
1167 }
1168 SqlValue::Double(v) => {
1169 encode_tvp_float(*v, 8, buf);
1170 }
1171 SqlValue::String(s) => {
1172 let max_len = match wire_type {
1173 TvpWireType::NVarChar { max_length } => *max_length,
1174 _ => 4000,
1175 };
1176 encode_tvp_nvarchar(s, max_len, buf);
1177 }
1178 SqlValue::Binary(b) => {
1179 let max_len = match wire_type {
1180 TvpWireType::VarBinary { max_length } => *max_length,
1181 _ => 8000,
1182 };
1183 encode_tvp_varbinary(b, max_len, buf);
1184 }
1185 #[cfg(feature = "decimal")]
1186 SqlValue::Decimal(d) => {
1187 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1188 let mantissa = d.mantissa().unsigned_abs();
1189 encode_tvp_decimal(sign, mantissa, buf);
1190 }
1191 #[cfg(feature = "uuid")]
1192 SqlValue::Uuid(u) => {
1193 let bytes = u.as_bytes();
1194 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1195 }
1196 #[cfg(feature = "chrono")]
1197 SqlValue::Date(d) => {
1198 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1200 let days = d.signed_duration_since(base).num_days() as u32;
1201 tds_protocol::tvp::encode_tvp_date(days, buf);
1202 }
1203 #[cfg(feature = "chrono")]
1204 SqlValue::Time(t) => {
1205 use chrono::Timelike;
1206 let nanos =
1207 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1208 let intervals = nanos / 100;
1209 let scale = match wire_type {
1210 TvpWireType::Time { scale } => *scale,
1211 _ => 7,
1212 };
1213 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1214 }
1215 #[cfg(feature = "chrono")]
1216 SqlValue::DateTime(dt) => {
1217 use chrono::Timelike;
1218 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1220 + dt.time().nanosecond() as u64;
1221 let intervals = nanos / 100;
1222 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1224 let days = dt.date().signed_duration_since(base).num_days() as u32;
1225 let scale = match wire_type {
1226 TvpWireType::DateTime2 { scale } => *scale,
1227 _ => 7,
1228 };
1229 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1230 }
1231 #[cfg(feature = "chrono")]
1232 SqlValue::DateTimeOffset(dto) => {
1233 use chrono::{Offset, Timelike};
1234 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1236 + dto.time().nanosecond() as u64;
1237 let intervals = nanos / 100;
1238 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1240 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1241 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1243 let scale = match wire_type {
1244 TvpWireType::DateTimeOffset { scale } => *scale,
1245 _ => 7,
1246 };
1247 tds_protocol::tvp::encode_tvp_datetimeoffset(
1248 intervals,
1249 days,
1250 offset_minutes,
1251 scale,
1252 buf,
1253 );
1254 }
1255 #[cfg(feature = "json")]
1256 SqlValue::Json(j) => {
1257 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1259 }
1260 SqlValue::Xml(s) => {
1261 encode_tvp_nvarchar(s, 0xFFFF, buf);
1263 }
1264 SqlValue::Tvp(_) => {
1265 encode_tvp_null(wire_type, buf);
1267 }
1268 _ => {
1270 encode_tvp_null(wire_type, buf);
1271 }
1272 }
1273 }
1274
1275 async fn read_query_response(
1277 &mut self,
1278 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1279 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1280
1281 let message = match connection {
1282 ConnectionHandle::Tls(conn) => conn
1283 .read_message()
1284 .await
1285 .map_err(|e| Error::Protocol(e.to_string()))?,
1286 ConnectionHandle::TlsPrelogin(conn) => conn
1287 .read_message()
1288 .await
1289 .map_err(|e| Error::Protocol(e.to_string()))?,
1290 ConnectionHandle::Plain(conn) => conn
1291 .read_message()
1292 .await
1293 .map_err(|e| Error::Protocol(e.to_string()))?,
1294 }
1295 .ok_or(Error::ConnectionClosed)?;
1296
1297 let mut parser = TokenParser::new(message.payload);
1298 let mut columns: Vec<crate::row::Column> = Vec::new();
1299 let mut rows: Vec<crate::row::Row> = Vec::new();
1300 let mut protocol_metadata: Option<ColMetaData> = None;
1301
1302 loop {
1303 let token = parser
1305 .next_token_with_metadata(protocol_metadata.as_ref())
1306 .map_err(|e| Error::Protocol(e.to_string()))?;
1307
1308 let Some(token) = token else {
1309 break;
1310 };
1311
1312 match token {
1313 Token::ColMetaData(meta) => {
1314 rows.clear();
1317
1318 columns = meta
1319 .columns
1320 .iter()
1321 .enumerate()
1322 .map(|(i, col)| {
1323 let type_name = format!("{:?}", col.type_id);
1324 let mut column = crate::row::Column::new(&col.name, i, type_name)
1325 .with_nullable(col.flags & 0x01 != 0);
1326
1327 if let Some(max_len) = col.type_info.max_length {
1328 column = column.with_max_length(max_len);
1329 }
1330 if let (Some(prec), Some(scale)) =
1331 (col.type_info.precision, col.type_info.scale)
1332 {
1333 column = column.with_precision_scale(prec, scale);
1334 }
1335 column
1336 })
1337 .collect();
1338
1339 tracing::debug!(columns = columns.len(), "received column metadata");
1340 protocol_metadata = Some(meta);
1341 }
1342 Token::Row(raw_row) => {
1343 if let Some(ref meta) = protocol_metadata {
1344 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1345 rows.push(row);
1346 }
1347 }
1348 Token::NbcRow(nbc_row) => {
1349 if let Some(ref meta) = protocol_metadata {
1350 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1351 rows.push(row);
1352 }
1353 }
1354 Token::Error(err) => {
1355 return Err(Error::Server {
1356 number: err.number,
1357 state: err.state,
1358 class: err.class,
1359 message: err.message.clone(),
1360 server: if err.server.is_empty() {
1361 None
1362 } else {
1363 Some(err.server.clone())
1364 },
1365 procedure: if err.procedure.is_empty() {
1366 None
1367 } else {
1368 Some(err.procedure.clone())
1369 },
1370 line: err.line as u32,
1371 });
1372 }
1373 Token::Done(done) => {
1374 if done.status.error {
1375 return Err(Error::Query("query failed".to_string()));
1376 }
1377 tracing::debug!(
1378 row_count = done.row_count,
1379 has_more = done.status.more,
1380 "query complete"
1381 );
1382 if !done.status.more {
1385 break;
1386 }
1387 }
1388 Token::DoneProc(done) => {
1389 if done.status.error {
1390 return Err(Error::Query("query failed".to_string()));
1391 }
1392 }
1393 Token::DoneInProc(done) => {
1394 if done.status.error {
1395 return Err(Error::Query("query failed".to_string()));
1396 }
1397 }
1398 Token::Info(info) => {
1399 tracing::debug!(
1400 number = info.number,
1401 message = %info.message,
1402 "server info message"
1403 );
1404 }
1405 Token::EnvChange(env) => {
1406 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1410 }
1411 _ => {}
1412 }
1413 }
1414
1415 tracing::debug!(
1416 columns = columns.len(),
1417 rows = rows.len(),
1418 "query response parsed"
1419 );
1420 Ok((columns, rows))
1421 }
1422
1423 fn convert_raw_row(
1427 raw: &RawRow,
1428 meta: &ColMetaData,
1429 columns: &[crate::row::Column],
1430 ) -> Result<crate::row::Row> {
1431 let mut values = Vec::with_capacity(meta.columns.len());
1432 let mut buf = raw.data.as_ref();
1433
1434 for col in &meta.columns {
1435 let value = Self::parse_column_value(&mut buf, col)?;
1436 values.push(value);
1437 }
1438
1439 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1440 }
1441
1442 fn convert_nbc_row(
1446 nbc: &NbcRow,
1447 meta: &ColMetaData,
1448 columns: &[crate::row::Column],
1449 ) -> Result<crate::row::Row> {
1450 let mut values = Vec::with_capacity(meta.columns.len());
1451 let mut buf = nbc.data.as_ref();
1452
1453 for (i, col) in meta.columns.iter().enumerate() {
1454 if nbc.is_null(i) {
1455 values.push(mssql_types::SqlValue::Null);
1456 } else {
1457 let value = Self::parse_column_value(&mut buf, col)?;
1458 values.push(value);
1459 }
1460 }
1461
1462 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1463 }
1464
1465 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1467 use bytes::Buf;
1468 use mssql_types::SqlValue;
1469 use tds_protocol::types::TypeId;
1470
1471 let value = match col.type_id {
1472 TypeId::Null => SqlValue::Null,
1474
1475 TypeId::Int1 => {
1477 if buf.remaining() < 1 {
1478 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1479 }
1480 SqlValue::TinyInt(buf.get_u8())
1481 }
1482 TypeId::Bit => {
1483 if buf.remaining() < 1 {
1484 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1485 }
1486 SqlValue::Bool(buf.get_u8() != 0)
1487 }
1488
1489 TypeId::Int2 => {
1491 if buf.remaining() < 2 {
1492 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1493 }
1494 SqlValue::SmallInt(buf.get_i16_le())
1495 }
1496
1497 TypeId::Int4 => {
1499 if buf.remaining() < 4 {
1500 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1501 }
1502 SqlValue::Int(buf.get_i32_le())
1503 }
1504 TypeId::Float4 => {
1505 if buf.remaining() < 4 {
1506 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1507 }
1508 SqlValue::Float(buf.get_f32_le())
1509 }
1510
1511 TypeId::Int8 => {
1513 if buf.remaining() < 8 {
1514 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1515 }
1516 SqlValue::BigInt(buf.get_i64_le())
1517 }
1518 TypeId::Float8 => {
1519 if buf.remaining() < 8 {
1520 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1521 }
1522 SqlValue::Double(buf.get_f64_le())
1523 }
1524 TypeId::Money => {
1525 if buf.remaining() < 8 {
1526 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1527 }
1528 let high = buf.get_i32_le();
1530 let low = buf.get_u32_le();
1531 let cents = ((high as i64) << 32) | (low as i64);
1532 let value = (cents as f64) / 10000.0;
1533 SqlValue::Double(value)
1534 }
1535 TypeId::Money4 => {
1536 if buf.remaining() < 4 {
1537 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1538 }
1539 let cents = buf.get_i32_le();
1540 let value = (cents as f64) / 10000.0;
1541 SqlValue::Double(value)
1542 }
1543
1544 TypeId::IntN => {
1546 if buf.remaining() < 1 {
1547 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1548 }
1549 let len = buf.get_u8();
1550 match len {
1551 0 => SqlValue::Null,
1552 1 => SqlValue::TinyInt(buf.get_u8()),
1553 2 => SqlValue::SmallInt(buf.get_i16_le()),
1554 4 => SqlValue::Int(buf.get_i32_le()),
1555 8 => SqlValue::BigInt(buf.get_i64_le()),
1556 _ => {
1557 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1558 }
1559 }
1560 }
1561 TypeId::FloatN => {
1562 if buf.remaining() < 1 {
1563 return Err(Error::Protocol(
1564 "unexpected EOF reading FloatN length".into(),
1565 ));
1566 }
1567 let len = buf.get_u8();
1568 match len {
1569 0 => SqlValue::Null,
1570 4 => SqlValue::Float(buf.get_f32_le()),
1571 8 => SqlValue::Double(buf.get_f64_le()),
1572 _ => {
1573 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1574 }
1575 }
1576 }
1577 TypeId::BitN => {
1578 if buf.remaining() < 1 {
1579 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1580 }
1581 let len = buf.get_u8();
1582 match len {
1583 0 => SqlValue::Null,
1584 1 => SqlValue::Bool(buf.get_u8() != 0),
1585 _ => {
1586 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1587 }
1588 }
1589 }
1590 TypeId::MoneyN => {
1591 if buf.remaining() < 1 {
1592 return Err(Error::Protocol(
1593 "unexpected EOF reading MoneyN length".into(),
1594 ));
1595 }
1596 let len = buf.get_u8();
1597 match len {
1598 0 => SqlValue::Null,
1599 4 => {
1600 let cents = buf.get_i32_le();
1601 SqlValue::Double((cents as f64) / 10000.0)
1602 }
1603 8 => {
1604 let high = buf.get_i32_le();
1605 let low = buf.get_u32_le();
1606 let cents = ((high as i64) << 32) | (low as i64);
1607 SqlValue::Double((cents as f64) / 10000.0)
1608 }
1609 _ => {
1610 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1611 }
1612 }
1613 }
1614 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1616 if buf.remaining() < 1 {
1617 return Err(Error::Protocol(
1618 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1619 ));
1620 }
1621 let len = buf.get_u8() as usize;
1622 if len == 0 {
1623 SqlValue::Null
1624 } else {
1625 if buf.remaining() < len {
1626 return Err(Error::Protocol(
1627 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1628 ));
1629 }
1630
1631 let sign = buf.get_u8();
1633 let mantissa_len = len - 1;
1634
1635 let mut mantissa_bytes = [0u8; 16];
1637 for i in 0..mantissa_len.min(16) {
1638 mantissa_bytes[i] = buf.get_u8();
1639 }
1640 for _ in 16..mantissa_len {
1642 buf.get_u8();
1643 }
1644
1645 let mantissa = u128::from_le_bytes(mantissa_bytes);
1646 let scale = col.type_info.scale.unwrap_or(0) as u32;
1647
1648 #[cfg(feature = "decimal")]
1649 {
1650 use rust_decimal::Decimal;
1651 if scale > 28 {
1654 let divisor = 10f64.powi(scale as i32);
1656 let value = (mantissa as f64) / divisor;
1657 let value = if sign == 0 { -value } else { value };
1658 SqlValue::Double(value)
1659 } else {
1660 let mut decimal =
1661 Decimal::from_i128_with_scale(mantissa as i128, scale);
1662 if sign == 0 {
1663 decimal.set_sign_negative(true);
1664 }
1665 SqlValue::Decimal(decimal)
1666 }
1667 }
1668
1669 #[cfg(not(feature = "decimal"))]
1670 {
1671 let divisor = 10f64.powi(scale as i32);
1673 let value = (mantissa as f64) / divisor;
1674 let value = if sign == 0 { -value } else { value };
1675 SqlValue::Double(value)
1676 }
1677 }
1678 }
1679
1680 TypeId::DateTimeN => {
1682 if buf.remaining() < 1 {
1683 return Err(Error::Protocol(
1684 "unexpected EOF reading DateTimeN length".into(),
1685 ));
1686 }
1687 let len = buf.get_u8() as usize;
1688 if len == 0 {
1689 SqlValue::Null
1690 } else if buf.remaining() < len {
1691 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1692 } else {
1693 match len {
1694 4 => {
1695 let days = buf.get_u16_le() as i64;
1697 let minutes = buf.get_u16_le() as u32;
1698 #[cfg(feature = "chrono")]
1699 {
1700 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1701 let date = base + chrono::Duration::days(days);
1702 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1703 minutes * 60,
1704 0,
1705 )
1706 .unwrap();
1707 SqlValue::DateTime(date.and_time(time))
1708 }
1709 #[cfg(not(feature = "chrono"))]
1710 {
1711 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1712 }
1713 }
1714 8 => {
1715 let days = buf.get_i32_le() as i64;
1717 let time_300ths = buf.get_u32_le() as u64;
1718 #[cfg(feature = "chrono")]
1719 {
1720 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1721 let date = base + chrono::Duration::days(days);
1722 let total_ms = (time_300ths * 1000) / 300;
1724 let secs = (total_ms / 1000) as u32;
1725 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1726 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1727 secs, nanos,
1728 )
1729 .unwrap();
1730 SqlValue::DateTime(date.and_time(time))
1731 }
1732 #[cfg(not(feature = "chrono"))]
1733 {
1734 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1735 }
1736 }
1737 _ => {
1738 return Err(Error::Protocol(format!(
1739 "invalid DateTimeN length: {len}"
1740 )));
1741 }
1742 }
1743 }
1744 }
1745
1746 TypeId::DateTime => {
1748 if buf.remaining() < 8 {
1749 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1750 }
1751 let days = buf.get_i32_le() as i64;
1752 let time_300ths = buf.get_u32_le() as u64;
1753 #[cfg(feature = "chrono")]
1754 {
1755 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1756 let date = base + chrono::Duration::days(days);
1757 let total_ms = (time_300ths * 1000) / 300;
1758 let secs = (total_ms / 1000) as u32;
1759 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1760 let time =
1761 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1762 SqlValue::DateTime(date.and_time(time))
1763 }
1764 #[cfg(not(feature = "chrono"))]
1765 {
1766 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1767 }
1768 }
1769
1770 TypeId::DateTime4 => {
1772 if buf.remaining() < 4 {
1773 return Err(Error::Protocol(
1774 "unexpected EOF reading SMALLDATETIME".into(),
1775 ));
1776 }
1777 let days = buf.get_u16_le() as i64;
1778 let minutes = buf.get_u16_le() as u32;
1779 #[cfg(feature = "chrono")]
1780 {
1781 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1782 let date = base + chrono::Duration::days(days);
1783 let time =
1784 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1785 .unwrap();
1786 SqlValue::DateTime(date.and_time(time))
1787 }
1788 #[cfg(not(feature = "chrono"))]
1789 {
1790 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1791 }
1792 }
1793
1794 TypeId::Date => {
1796 if buf.remaining() < 1 {
1797 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1798 }
1799 let len = buf.get_u8() as usize;
1800 if len == 0 {
1801 SqlValue::Null
1802 } else if len != 3 {
1803 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1804 } else if buf.remaining() < 3 {
1805 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1806 } else {
1807 let days = buf.get_u8() as u32
1809 | ((buf.get_u8() as u32) << 8)
1810 | ((buf.get_u8() as u32) << 16);
1811 #[cfg(feature = "chrono")]
1812 {
1813 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1814 let date = base + chrono::Duration::days(days as i64);
1815 SqlValue::Date(date)
1816 }
1817 #[cfg(not(feature = "chrono"))]
1818 {
1819 SqlValue::String(format!("DATE({days})"))
1820 }
1821 }
1822 }
1823
1824 TypeId::Time => {
1826 if buf.remaining() < 1 {
1827 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1828 }
1829 let len = buf.get_u8() as usize;
1830 if len == 0 {
1831 SqlValue::Null
1832 } else if buf.remaining() < len {
1833 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1834 } else {
1835 let scale = col.type_info.scale.unwrap_or(7);
1836 let mut time_bytes = [0u8; 8];
1837 for byte in time_bytes.iter_mut().take(len) {
1838 *byte = buf.get_u8();
1839 }
1840 let intervals = u64::from_le_bytes(time_bytes);
1841 #[cfg(feature = "chrono")]
1842 {
1843 let time = Self::intervals_to_time(intervals, scale);
1844 SqlValue::Time(time)
1845 }
1846 #[cfg(not(feature = "chrono"))]
1847 {
1848 SqlValue::String(format!("TIME({intervals})"))
1849 }
1850 }
1851 }
1852
1853 TypeId::DateTime2 => {
1855 if buf.remaining() < 1 {
1856 return Err(Error::Protocol(
1857 "unexpected EOF reading DATETIME2 length".into(),
1858 ));
1859 }
1860 let len = buf.get_u8() as usize;
1861 if len == 0 {
1862 SqlValue::Null
1863 } else if buf.remaining() < len {
1864 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1865 } else {
1866 let scale = col.type_info.scale.unwrap_or(7);
1867 let time_len = Self::time_bytes_for_scale(scale);
1868
1869 let mut time_bytes = [0u8; 8];
1871 for byte in time_bytes.iter_mut().take(time_len) {
1872 *byte = buf.get_u8();
1873 }
1874 let intervals = u64::from_le_bytes(time_bytes);
1875
1876 let days = buf.get_u8() as u32
1878 | ((buf.get_u8() as u32) << 8)
1879 | ((buf.get_u8() as u32) << 16);
1880
1881 #[cfg(feature = "chrono")]
1882 {
1883 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1884 let date = base + chrono::Duration::days(days as i64);
1885 let time = Self::intervals_to_time(intervals, scale);
1886 SqlValue::DateTime(date.and_time(time))
1887 }
1888 #[cfg(not(feature = "chrono"))]
1889 {
1890 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1891 }
1892 }
1893 }
1894
1895 TypeId::DateTimeOffset => {
1897 if buf.remaining() < 1 {
1898 return Err(Error::Protocol(
1899 "unexpected EOF reading DATETIMEOFFSET length".into(),
1900 ));
1901 }
1902 let len = buf.get_u8() as usize;
1903 if len == 0 {
1904 SqlValue::Null
1905 } else if buf.remaining() < len {
1906 return Err(Error::Protocol(
1907 "unexpected EOF reading DATETIMEOFFSET".into(),
1908 ));
1909 } else {
1910 let scale = col.type_info.scale.unwrap_or(7);
1911 let time_len = Self::time_bytes_for_scale(scale);
1912
1913 let mut time_bytes = [0u8; 8];
1915 for byte in time_bytes.iter_mut().take(time_len) {
1916 *byte = buf.get_u8();
1917 }
1918 let intervals = u64::from_le_bytes(time_bytes);
1919
1920 let days = buf.get_u8() as u32
1922 | ((buf.get_u8() as u32) << 8)
1923 | ((buf.get_u8() as u32) << 16);
1924
1925 let offset_minutes = buf.get_i16_le();
1927
1928 #[cfg(feature = "chrono")]
1929 {
1930 use chrono::TimeZone;
1931 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1932 let date = base + chrono::Duration::days(days as i64);
1933 let time = Self::intervals_to_time(intervals, scale);
1934 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1935 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1936 let datetime = offset
1937 .from_local_datetime(&date.and_time(time))
1938 .single()
1939 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1940 SqlValue::DateTimeOffset(datetime)
1941 }
1942 #[cfg(not(feature = "chrono"))]
1943 {
1944 SqlValue::String(format!(
1945 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1946 ))
1947 }
1948 }
1949 }
1950
1951 TypeId::Text => Self::parse_plp_varchar(buf)?,
1953
1954 TypeId::Char | TypeId::VarChar => {
1956 if buf.remaining() < 1 {
1957 return Err(Error::Protocol(
1958 "unexpected EOF reading legacy varchar length".into(),
1959 ));
1960 }
1961 let len = buf.get_u8();
1962 if len == 0xFF {
1963 SqlValue::Null
1964 } else if len == 0 {
1965 SqlValue::String(String::new())
1966 } else if buf.remaining() < len as usize {
1967 return Err(Error::Protocol(
1968 "unexpected EOF reading legacy varchar data".into(),
1969 ));
1970 } else {
1971 let data = &buf[..len as usize];
1972 let s = String::from_utf8_lossy(data).into_owned();
1973 buf.advance(len as usize);
1974 SqlValue::String(s)
1975 }
1976 }
1977
1978 TypeId::BigVarChar | TypeId::BigChar => {
1980 if col.type_info.max_length == Some(0xFFFF) {
1982 Self::parse_plp_varchar(buf)?
1984 } else {
1985 if buf.remaining() < 2 {
1987 return Err(Error::Protocol(
1988 "unexpected EOF reading varchar length".into(),
1989 ));
1990 }
1991 let len = buf.get_u16_le();
1992 if len == 0xFFFF {
1993 SqlValue::Null
1994 } else if buf.remaining() < len as usize {
1995 return Err(Error::Protocol(
1996 "unexpected EOF reading varchar data".into(),
1997 ));
1998 } else {
1999 let data = &buf[..len as usize];
2000 let s = String::from_utf8_lossy(data).into_owned();
2001 buf.advance(len as usize);
2002 SqlValue::String(s)
2003 }
2004 }
2005 }
2006
2007 TypeId::NText => Self::parse_plp_nvarchar(buf)?,
2009
2010 TypeId::NVarChar | TypeId::NChar => {
2012 if col.type_info.max_length == Some(0xFFFF) {
2014 Self::parse_plp_nvarchar(buf)?
2016 } else {
2017 if buf.remaining() < 2 {
2019 return Err(Error::Protocol(
2020 "unexpected EOF reading nvarchar length".into(),
2021 ));
2022 }
2023 let len = buf.get_u16_le();
2024 if len == 0xFFFF {
2025 SqlValue::Null
2026 } else if buf.remaining() < len as usize {
2027 return Err(Error::Protocol(
2028 "unexpected EOF reading nvarchar data".into(),
2029 ));
2030 } else {
2031 let data = &buf[..len as usize];
2032 let utf16: Vec<u16> = data
2034 .chunks_exact(2)
2035 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2036 .collect();
2037 let s = String::from_utf16(&utf16)
2038 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
2039 buf.advance(len as usize);
2040 SqlValue::String(s)
2041 }
2042 }
2043 }
2044
2045 TypeId::Image => Self::parse_plp_varbinary(buf)?,
2047
2048 TypeId::Binary | TypeId::VarBinary => {
2050 if buf.remaining() < 1 {
2051 return Err(Error::Protocol(
2052 "unexpected EOF reading legacy varbinary length".into(),
2053 ));
2054 }
2055 let len = buf.get_u8();
2056 if len == 0xFF {
2057 SqlValue::Null
2058 } else if len == 0 {
2059 SqlValue::Binary(bytes::Bytes::new())
2060 } else if buf.remaining() < len as usize {
2061 return Err(Error::Protocol(
2062 "unexpected EOF reading legacy varbinary data".into(),
2063 ));
2064 } else {
2065 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2066 buf.advance(len as usize);
2067 SqlValue::Binary(data)
2068 }
2069 }
2070
2071 TypeId::BigVarBinary | TypeId::BigBinary => {
2073 if col.type_info.max_length == Some(0xFFFF) {
2075 Self::parse_plp_varbinary(buf)?
2077 } else {
2078 if buf.remaining() < 2 {
2079 return Err(Error::Protocol(
2080 "unexpected EOF reading varbinary length".into(),
2081 ));
2082 }
2083 let len = buf.get_u16_le();
2084 if len == 0xFFFF {
2085 SqlValue::Null
2086 } else if buf.remaining() < len as usize {
2087 return Err(Error::Protocol(
2088 "unexpected EOF reading varbinary data".into(),
2089 ));
2090 } else {
2091 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2092 buf.advance(len as usize);
2093 SqlValue::Binary(data)
2094 }
2095 }
2096 }
2097
2098 TypeId::Xml => {
2100 match Self::parse_plp_nvarchar(buf)? {
2102 SqlValue::Null => SqlValue::Null,
2103 SqlValue::String(s) => SqlValue::Xml(s),
2104 _ => {
2105 return Err(Error::Protocol(
2106 "unexpected value type when parsing XML".into(),
2107 ));
2108 }
2109 }
2110 }
2111
2112 TypeId::Guid => {
2114 if buf.remaining() < 1 {
2115 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
2116 }
2117 let len = buf.get_u8();
2118 if len == 0 {
2119 SqlValue::Null
2120 } else if len != 16 {
2121 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
2122 } else if buf.remaining() < 16 {
2123 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
2124 } else {
2125 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
2127 buf.advance(16);
2128 SqlValue::Binary(data)
2129 }
2130 }
2131
2132 TypeId::Variant => Self::parse_sql_variant(buf)?,
2134
2135 TypeId::Udt => Self::parse_plp_varbinary(buf)?,
2137
2138 _ => {
2140 if buf.remaining() < 2 {
2142 return Err(Error::Protocol(format!(
2143 "unexpected EOF reading {:?}",
2144 col.type_id
2145 )));
2146 }
2147 let len = buf.get_u16_le();
2148 if len == 0xFFFF {
2149 SqlValue::Null
2150 } else if buf.remaining() < len as usize {
2151 return Err(Error::Protocol(format!(
2152 "unexpected EOF reading {:?} data",
2153 col.type_id
2154 )));
2155 } else {
2156 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2157 buf.advance(len as usize);
2158 SqlValue::Binary(data)
2159 }
2160 }
2161 };
2162
2163 Ok(value)
2164 }
2165
2166 fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2172 use bytes::Buf;
2173 use mssql_types::SqlValue;
2174
2175 if buf.remaining() < 8 {
2176 return Err(Error::Protocol(
2177 "unexpected EOF reading PLP total length".into(),
2178 ));
2179 }
2180
2181 let total_len = buf.get_u64_le();
2182 if total_len == 0xFFFFFFFFFFFFFFFF {
2183 return Ok(SqlValue::Null);
2184 }
2185
2186 let mut all_data = Vec::new();
2188 loop {
2189 if buf.remaining() < 4 {
2190 return Err(Error::Protocol(
2191 "unexpected EOF reading PLP chunk length".into(),
2192 ));
2193 }
2194 let chunk_len = buf.get_u32_le() as usize;
2195 if chunk_len == 0 {
2196 break; }
2198 if buf.remaining() < chunk_len {
2199 return Err(Error::Protocol(
2200 "unexpected EOF reading PLP chunk data".into(),
2201 ));
2202 }
2203 all_data.extend_from_slice(&buf[..chunk_len]);
2204 buf.advance(chunk_len);
2205 }
2206
2207 let utf16: Vec<u16> = all_data
2209 .chunks_exact(2)
2210 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2211 .collect();
2212 let s = String::from_utf16(&utf16)
2213 .map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
2214 Ok(SqlValue::String(s))
2215 }
2216
2217 fn parse_plp_varchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2219 use bytes::Buf;
2220 use mssql_types::SqlValue;
2221
2222 if buf.remaining() < 8 {
2223 return Err(Error::Protocol(
2224 "unexpected EOF reading PLP total length".into(),
2225 ));
2226 }
2227
2228 let total_len = buf.get_u64_le();
2229 if total_len == 0xFFFFFFFFFFFFFFFF {
2230 return Ok(SqlValue::Null);
2231 }
2232
2233 let mut all_data = Vec::new();
2235 loop {
2236 if buf.remaining() < 4 {
2237 return Err(Error::Protocol(
2238 "unexpected EOF reading PLP chunk length".into(),
2239 ));
2240 }
2241 let chunk_len = buf.get_u32_le() as usize;
2242 if chunk_len == 0 {
2243 break; }
2245 if buf.remaining() < chunk_len {
2246 return Err(Error::Protocol(
2247 "unexpected EOF reading PLP chunk data".into(),
2248 ));
2249 }
2250 all_data.extend_from_slice(&buf[..chunk_len]);
2251 buf.advance(chunk_len);
2252 }
2253
2254 let s = String::from_utf8_lossy(&all_data).into_owned();
2256 Ok(SqlValue::String(s))
2257 }
2258
2259 fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2261 use bytes::Buf;
2262 use mssql_types::SqlValue;
2263
2264 if buf.remaining() < 8 {
2265 return Err(Error::Protocol(
2266 "unexpected EOF reading PLP total length".into(),
2267 ));
2268 }
2269
2270 let total_len = buf.get_u64_le();
2271 if total_len == 0xFFFFFFFFFFFFFFFF {
2272 return Ok(SqlValue::Null);
2273 }
2274
2275 let mut all_data = Vec::new();
2277 loop {
2278 if buf.remaining() < 4 {
2279 return Err(Error::Protocol(
2280 "unexpected EOF reading PLP chunk length".into(),
2281 ));
2282 }
2283 let chunk_len = buf.get_u32_le() as usize;
2284 if chunk_len == 0 {
2285 break; }
2287 if buf.remaining() < chunk_len {
2288 return Err(Error::Protocol(
2289 "unexpected EOF reading PLP chunk data".into(),
2290 ));
2291 }
2292 all_data.extend_from_slice(&buf[..chunk_len]);
2293 buf.advance(chunk_len);
2294 }
2295
2296 Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
2297 }
2298
2299 fn parse_sql_variant(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2308 use bytes::Buf;
2309 use mssql_types::SqlValue;
2310
2311 if buf.remaining() < 4 {
2313 return Err(Error::Protocol(
2314 "unexpected EOF reading SQL_VARIANT length".into(),
2315 ));
2316 }
2317 let total_len = buf.get_u32_le() as usize;
2318
2319 if total_len == 0 {
2320 return Ok(SqlValue::Null);
2321 }
2322
2323 if buf.remaining() < total_len {
2324 return Err(Error::Protocol(
2325 "unexpected EOF reading SQL_VARIANT data".into(),
2326 ));
2327 }
2328
2329 if total_len < 2 {
2331 return Err(Error::Protocol(
2332 "SQL_VARIANT too short for type info".into(),
2333 ));
2334 }
2335
2336 let base_type = buf.get_u8();
2337 let prop_count = buf.get_u8() as usize;
2338
2339 if buf.remaining() < prop_count {
2340 return Err(Error::Protocol(
2341 "unexpected EOF reading SQL_VARIANT properties".into(),
2342 ));
2343 }
2344
2345 let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
2347
2348 match base_type {
2351 0x30 => {
2353 buf.advance(prop_count);
2355 if data_len < 1 {
2356 return Ok(SqlValue::Null);
2357 }
2358 let v = buf.get_u8();
2359 Ok(SqlValue::TinyInt(v))
2360 }
2361 0x32 => {
2362 buf.advance(prop_count);
2364 if data_len < 1 {
2365 return Ok(SqlValue::Null);
2366 }
2367 let v = buf.get_u8();
2368 Ok(SqlValue::Bool(v != 0))
2369 }
2370 0x34 => {
2371 buf.advance(prop_count);
2373 if data_len < 2 {
2374 return Ok(SqlValue::Null);
2375 }
2376 let v = buf.get_i16_le();
2377 Ok(SqlValue::SmallInt(v))
2378 }
2379 0x38 => {
2380 buf.advance(prop_count);
2382 if data_len < 4 {
2383 return Ok(SqlValue::Null);
2384 }
2385 let v = buf.get_i32_le();
2386 Ok(SqlValue::Int(v))
2387 }
2388 0x7F => {
2389 buf.advance(prop_count);
2391 if data_len < 8 {
2392 return Ok(SqlValue::Null);
2393 }
2394 let v = buf.get_i64_le();
2395 Ok(SqlValue::BigInt(v))
2396 }
2397 0x6D => {
2398 let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2400 buf.advance(prop_count.saturating_sub(1));
2401
2402 if float_len == 4 && data_len >= 4 {
2403 let v = buf.get_f32_le();
2404 Ok(SqlValue::Float(v))
2405 } else if data_len >= 8 {
2406 let v = buf.get_f64_le();
2407 Ok(SqlValue::Double(v))
2408 } else {
2409 Ok(SqlValue::Null)
2410 }
2411 }
2412 0x6E => {
2413 let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2415 buf.advance(prop_count.saturating_sub(1));
2416
2417 if money_len == 4 && data_len >= 4 {
2418 let raw = buf.get_i32_le();
2419 let value = raw as f64 / 10000.0;
2420 Ok(SqlValue::Double(value))
2421 } else if data_len >= 8 {
2422 let high = buf.get_i32_le() as i64;
2423 let low = buf.get_u32_le() as i64;
2424 let raw = (high << 32) | low;
2425 let value = raw as f64 / 10000.0;
2426 Ok(SqlValue::Double(value))
2427 } else {
2428 Ok(SqlValue::Null)
2429 }
2430 }
2431 0x6F => {
2432 let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2434 buf.advance(prop_count.saturating_sub(1));
2435
2436 #[cfg(feature = "chrono")]
2437 {
2438 use chrono::NaiveDate;
2439 if dt_len == 4 && data_len >= 4 {
2440 let days = buf.get_u16_le() as i64;
2442 let mins = buf.get_u16_le() as u32;
2443 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2444 .unwrap()
2445 .and_hms_opt(0, 0, 0)
2446 .unwrap();
2447 let dt = base
2448 + chrono::Duration::days(days)
2449 + chrono::Duration::minutes(mins as i64);
2450 Ok(SqlValue::DateTime(dt))
2451 } else if data_len >= 8 {
2452 let days = buf.get_i32_le() as i64;
2454 let ticks = buf.get_u32_le() as i64;
2455 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2456 .unwrap()
2457 .and_hms_opt(0, 0, 0)
2458 .unwrap();
2459 let millis = (ticks * 10) / 3;
2460 let dt = base
2461 + chrono::Duration::days(days)
2462 + chrono::Duration::milliseconds(millis);
2463 Ok(SqlValue::DateTime(dt))
2464 } else {
2465 Ok(SqlValue::Null)
2466 }
2467 }
2468 #[cfg(not(feature = "chrono"))]
2469 {
2470 buf.advance(data_len);
2471 Ok(SqlValue::Null)
2472 }
2473 }
2474 0x6A | 0x6C => {
2475 let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
2477 let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
2478 buf.advance(prop_count.saturating_sub(2));
2479
2480 if data_len < 1 {
2481 return Ok(SqlValue::Null);
2482 }
2483
2484 let sign = buf.get_u8();
2485 let mantissa_len = data_len - 1;
2486
2487 if mantissa_len > 16 {
2488 buf.advance(mantissa_len);
2490 return Ok(SqlValue::Null);
2491 }
2492
2493 let mut mantissa_bytes = [0u8; 16];
2494 for i in 0..mantissa_len.min(16) {
2495 mantissa_bytes[i] = buf.get_u8();
2496 }
2497 let mantissa = u128::from_le_bytes(mantissa_bytes);
2498
2499 #[cfg(feature = "decimal")]
2500 {
2501 use rust_decimal::Decimal;
2502 if scale > 28 {
2503 let divisor = 10f64.powi(scale as i32);
2505 let value = (mantissa as f64) / divisor;
2506 let value = if sign == 0 { -value } else { value };
2507 Ok(SqlValue::Double(value))
2508 } else {
2509 let mut decimal =
2510 Decimal::from_i128_with_scale(mantissa as i128, scale as u32);
2511 if sign == 0 {
2512 decimal.set_sign_negative(true);
2513 }
2514 Ok(SqlValue::Decimal(decimal))
2515 }
2516 }
2517 #[cfg(not(feature = "decimal"))]
2518 {
2519 let divisor = 10f64.powi(scale as i32);
2520 let value = (mantissa as f64) / divisor;
2521 let value = if sign == 0 { -value } else { value };
2522 Ok(SqlValue::Double(value))
2523 }
2524 }
2525 0x24 => {
2526 buf.advance(prop_count);
2528 if data_len < 16 {
2529 return Ok(SqlValue::Null);
2530 }
2531 let mut guid_bytes = [0u8; 16];
2532 for byte in &mut guid_bytes {
2533 *byte = buf.get_u8();
2534 }
2535 Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(&guid_bytes)))
2536 }
2537 0x28 => {
2538 buf.advance(prop_count);
2540 #[cfg(feature = "chrono")]
2541 {
2542 if data_len < 3 {
2543 return Ok(SqlValue::Null);
2544 }
2545 let mut date_bytes = [0u8; 4];
2546 date_bytes[0] = buf.get_u8();
2547 date_bytes[1] = buf.get_u8();
2548 date_bytes[2] = buf.get_u8();
2549 let days = u32::from_le_bytes(date_bytes);
2550 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2551 let date = base + chrono::Duration::days(days as i64);
2552 Ok(SqlValue::Date(date))
2553 }
2554 #[cfg(not(feature = "chrono"))]
2555 {
2556 buf.advance(data_len);
2557 Ok(SqlValue::Null)
2558 }
2559 }
2560 0xA7 | 0x2F | 0x27 => {
2561 buf.advance(prop_count);
2563 if data_len == 0 {
2564 return Ok(SqlValue::String(String::new()));
2565 }
2566 let data = &buf[..data_len];
2567 let s = String::from_utf8_lossy(data).into_owned();
2568 buf.advance(data_len);
2569 Ok(SqlValue::String(s))
2570 }
2571 0xE7 | 0xEF => {
2572 buf.advance(prop_count);
2574 if data_len == 0 {
2575 return Ok(SqlValue::String(String::new()));
2576 }
2577 let utf16: Vec<u16> = buf[..data_len]
2579 .chunks_exact(2)
2580 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2581 .collect();
2582 buf.advance(data_len);
2583 let s = String::from_utf16(&utf16).map_err(|_| {
2584 Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into())
2585 })?;
2586 Ok(SqlValue::String(s))
2587 }
2588 0xA5 | 0x2D | 0x25 => {
2589 buf.advance(prop_count);
2591 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2592 buf.advance(data_len);
2593 Ok(SqlValue::Binary(data))
2594 }
2595 _ => {
2596 buf.advance(prop_count);
2598 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2599 buf.advance(data_len);
2600 Ok(SqlValue::Binary(data))
2601 }
2602 }
2603 }
2604
2605 fn time_bytes_for_scale(scale: u8) -> usize {
2607 match scale {
2608 0..=2 => 3,
2609 3..=4 => 4,
2610 5..=7 => 5,
2611 _ => 5, }
2613 }
2614
2615 #[cfg(feature = "chrono")]
2617 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2618 let nanos = match scale {
2628 0 => intervals * 1_000_000_000,
2629 1 => intervals * 100_000_000,
2630 2 => intervals * 10_000_000,
2631 3 => intervals * 1_000_000,
2632 4 => intervals * 100_000,
2633 5 => intervals * 10_000,
2634 6 => intervals * 1_000,
2635 7 => intervals * 100,
2636 _ => intervals * 100,
2637 };
2638
2639 let secs = (nanos / 1_000_000_000) as u32;
2640 let nano_part = (nanos % 1_000_000_000) as u32;
2641
2642 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2643 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2644 }
2645
2646 async fn read_execute_result(&mut self) -> Result<u64> {
2648 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2649
2650 let message = match connection {
2651 ConnectionHandle::Tls(conn) => conn
2652 .read_message()
2653 .await
2654 .map_err(|e| Error::Protocol(e.to_string()))?,
2655 ConnectionHandle::TlsPrelogin(conn) => conn
2656 .read_message()
2657 .await
2658 .map_err(|e| Error::Protocol(e.to_string()))?,
2659 ConnectionHandle::Plain(conn) => conn
2660 .read_message()
2661 .await
2662 .map_err(|e| Error::Protocol(e.to_string()))?,
2663 }
2664 .ok_or(Error::ConnectionClosed)?;
2665
2666 let mut parser = TokenParser::new(message.payload);
2667 let mut rows_affected = 0u64;
2668 let mut current_metadata: Option<ColMetaData> = None;
2669
2670 loop {
2671 let token = parser
2673 .next_token_with_metadata(current_metadata.as_ref())
2674 .map_err(|e| Error::Protocol(e.to_string()))?;
2675
2676 let Some(token) = token else {
2677 break;
2678 };
2679
2680 match token {
2681 Token::ColMetaData(meta) => {
2682 current_metadata = Some(meta);
2684 }
2685 Token::Row(_) | Token::NbcRow(_) => {
2686 }
2689 Token::Done(done) => {
2690 if done.status.error {
2691 return Err(Error::Query("execution failed".to_string()));
2692 }
2693 if done.status.count {
2694 rows_affected += done.row_count;
2696 }
2697 if !done.status.more {
2700 break;
2701 }
2702 }
2703 Token::DoneProc(done) => {
2704 if done.status.count {
2705 rows_affected += done.row_count;
2706 }
2707 }
2708 Token::DoneInProc(done) => {
2709 if done.status.count {
2710 rows_affected += done.row_count;
2711 }
2712 }
2713 Token::Error(err) => {
2714 return Err(Error::Server {
2715 number: err.number,
2716 state: err.state,
2717 class: err.class,
2718 message: err.message.clone(),
2719 server: if err.server.is_empty() {
2720 None
2721 } else {
2722 Some(err.server.clone())
2723 },
2724 procedure: if err.procedure.is_empty() {
2725 None
2726 } else {
2727 Some(err.procedure.clone())
2728 },
2729 line: err.line as u32,
2730 });
2731 }
2732 Token::Info(info) => {
2733 tracing::info!(
2734 number = info.number,
2735 message = %info.message,
2736 "server info message"
2737 );
2738 }
2739 Token::EnvChange(env) => {
2740 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
2744 }
2745 _ => {}
2746 }
2747 }
2748
2749 Ok(rows_affected)
2750 }
2751
2752 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2758 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2759
2760 let message = match connection {
2761 ConnectionHandle::Tls(conn) => conn
2762 .read_message()
2763 .await
2764 .map_err(|e| Error::Protocol(e.to_string()))?,
2765 ConnectionHandle::TlsPrelogin(conn) => conn
2766 .read_message()
2767 .await
2768 .map_err(|e| Error::Protocol(e.to_string()))?,
2769 ConnectionHandle::Plain(conn) => conn
2770 .read_message()
2771 .await
2772 .map_err(|e| Error::Protocol(e.to_string()))?,
2773 }
2774 .ok_or(Error::ConnectionClosed)?;
2775
2776 let mut parser = TokenParser::new(message.payload);
2777 let mut transaction_descriptor: u64 = 0;
2778
2779 loop {
2780 let token = parser
2781 .next_token()
2782 .map_err(|e| Error::Protocol(e.to_string()))?;
2783
2784 let Some(token) = token else {
2785 break;
2786 };
2787
2788 match token {
2789 Token::EnvChange(env) => {
2790 if env.env_type == EnvChangeType::BeginTransaction {
2791 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2794 {
2795 if data.len() >= 8 {
2796 transaction_descriptor = u64::from_le_bytes([
2797 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2798 data[7],
2799 ]);
2800 tracing::debug!(
2801 transaction_descriptor =
2802 format!("0x{:016X}", transaction_descriptor),
2803 "transaction begun"
2804 );
2805 }
2806 }
2807 }
2808 }
2809 Token::Done(done) => {
2810 if done.status.error {
2811 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2812 }
2813 break;
2814 }
2815 Token::Error(err) => {
2816 return Err(Error::Server {
2817 number: err.number,
2818 state: err.state,
2819 class: err.class,
2820 message: err.message.clone(),
2821 server: if err.server.is_empty() {
2822 None
2823 } else {
2824 Some(err.server.clone())
2825 },
2826 procedure: if err.procedure.is_empty() {
2827 None
2828 } else {
2829 Some(err.procedure.clone())
2830 },
2831 line: err.line as u32,
2832 });
2833 }
2834 Token::Info(info) => {
2835 tracing::info!(
2836 number = info.number,
2837 message = %info.message,
2838 "server info message"
2839 );
2840 }
2841 _ => {}
2842 }
2843 }
2844
2845 Ok(transaction_descriptor)
2846 }
2847}
2848
2849impl Client<Ready> {
2850 pub async fn query<'a>(
2875 &'a mut self,
2876 sql: &str,
2877 params: &[&(dyn crate::ToSql + Sync)],
2878 ) -> Result<QueryStream<'a>> {
2879 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
2880
2881 #[cfg(feature = "otel")]
2882 let instrumentation = self.instrumentation.clone();
2883 #[cfg(feature = "otel")]
2884 let mut span = instrumentation.query_span(sql);
2885
2886 let result = async {
2887 if params.is_empty() {
2888 self.send_sql_batch(sql).await?;
2890 } else {
2891 let rpc_params = Self::convert_params(params)?;
2893 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2894 self.send_rpc(&rpc).await?;
2895 }
2896
2897 self.read_query_response().await
2899 }
2900 .await;
2901
2902 #[cfg(feature = "otel")]
2903 match &result {
2904 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2905 Err(e) => InstrumentationContext::record_error(&mut span, e),
2906 }
2907
2908 #[cfg(feature = "otel")]
2910 drop(span);
2911
2912 let (columns, rows) = result?;
2913 Ok(QueryStream::new(columns, rows))
2914 }
2915
2916 pub async fn query_with_timeout<'a>(
2943 &'a mut self,
2944 sql: &str,
2945 params: &[&(dyn crate::ToSql + Sync)],
2946 timeout_duration: std::time::Duration,
2947 ) -> Result<QueryStream<'a>> {
2948 timeout(timeout_duration, self.query(sql, params))
2949 .await
2950 .map_err(|_| Error::CommandTimeout)?
2951 }
2952
2953 pub async fn query_multiple<'a>(
2980 &'a mut self,
2981 sql: &str,
2982 params: &[&(dyn crate::ToSql + Sync)],
2983 ) -> Result<MultiResultStream<'a>> {
2984 tracing::debug!(
2985 sql = sql,
2986 params_count = params.len(),
2987 "executing multi-result query"
2988 );
2989
2990 if params.is_empty() {
2991 self.send_sql_batch(sql).await?;
2993 } else {
2994 let rpc_params = Self::convert_params(params)?;
2996 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2997 self.send_rpc(&rpc).await?;
2998 }
2999
3000 let result_sets = self.read_multi_result_response().await?;
3002 Ok(MultiResultStream::new(result_sets))
3003 }
3004
3005 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
3007 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3008
3009 let message = match connection {
3010 ConnectionHandle::Tls(conn) => conn
3011 .read_message()
3012 .await
3013 .map_err(|e| Error::Protocol(e.to_string()))?,
3014 ConnectionHandle::TlsPrelogin(conn) => conn
3015 .read_message()
3016 .await
3017 .map_err(|e| Error::Protocol(e.to_string()))?,
3018 ConnectionHandle::Plain(conn) => conn
3019 .read_message()
3020 .await
3021 .map_err(|e| Error::Protocol(e.to_string()))?,
3022 }
3023 .ok_or(Error::ConnectionClosed)?;
3024
3025 let mut parser = TokenParser::new(message.payload);
3026 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
3027 let mut current_columns: Vec<crate::row::Column> = Vec::new();
3028 let mut current_rows: Vec<crate::row::Row> = Vec::new();
3029 let mut protocol_metadata: Option<ColMetaData> = None;
3030
3031 loop {
3032 let token = parser
3033 .next_token_with_metadata(protocol_metadata.as_ref())
3034 .map_err(|e| Error::Protocol(e.to_string()))?;
3035
3036 let Some(token) = token else {
3037 break;
3038 };
3039
3040 match token {
3041 Token::ColMetaData(meta) => {
3042 if !current_columns.is_empty() {
3044 result_sets.push(crate::stream::ResultSet::new(
3045 std::mem::take(&mut current_columns),
3046 std::mem::take(&mut current_rows),
3047 ));
3048 }
3049
3050 current_columns = meta
3052 .columns
3053 .iter()
3054 .enumerate()
3055 .map(|(i, col)| {
3056 let type_name = format!("{:?}", col.type_id);
3057 let mut column = crate::row::Column::new(&col.name, i, type_name)
3058 .with_nullable(col.flags & 0x01 != 0);
3059
3060 if let Some(max_len) = col.type_info.max_length {
3061 column = column.with_max_length(max_len);
3062 }
3063 if let (Some(prec), Some(scale)) =
3064 (col.type_info.precision, col.type_info.scale)
3065 {
3066 column = column.with_precision_scale(prec, scale);
3067 }
3068 column
3069 })
3070 .collect();
3071
3072 tracing::debug!(
3073 columns = current_columns.len(),
3074 result_set = result_sets.len(),
3075 "received column metadata for result set"
3076 );
3077 protocol_metadata = Some(meta);
3078 }
3079 Token::Row(raw_row) => {
3080 if let Some(ref meta) = protocol_metadata {
3081 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
3082 current_rows.push(row);
3083 }
3084 }
3085 Token::NbcRow(nbc_row) => {
3086 if let Some(ref meta) = protocol_metadata {
3087 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
3088 current_rows.push(row);
3089 }
3090 }
3091 Token::Error(err) => {
3092 return Err(Error::Server {
3093 number: err.number,
3094 state: err.state,
3095 class: err.class,
3096 message: err.message.clone(),
3097 server: if err.server.is_empty() {
3098 None
3099 } else {
3100 Some(err.server.clone())
3101 },
3102 procedure: if err.procedure.is_empty() {
3103 None
3104 } else {
3105 Some(err.procedure.clone())
3106 },
3107 line: err.line as u32,
3108 });
3109 }
3110 Token::Done(done) => {
3111 if done.status.error {
3112 return Err(Error::Query("query failed".to_string()));
3113 }
3114
3115 if !current_columns.is_empty() {
3117 result_sets.push(crate::stream::ResultSet::new(
3118 std::mem::take(&mut current_columns),
3119 std::mem::take(&mut current_rows),
3120 ));
3121 protocol_metadata = None;
3122 }
3123
3124 if !done.status.more {
3126 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
3127 break;
3128 }
3129 }
3130 Token::DoneInProc(done) => {
3131 if done.status.error {
3132 return Err(Error::Query("query failed".to_string()));
3133 }
3134
3135 if !current_columns.is_empty() {
3137 result_sets.push(crate::stream::ResultSet::new(
3138 std::mem::take(&mut current_columns),
3139 std::mem::take(&mut current_rows),
3140 ));
3141 protocol_metadata = None;
3142 }
3143
3144 if !done.status.more {
3146 }
3148 }
3149 Token::DoneProc(done) => {
3150 if done.status.error {
3151 return Err(Error::Query("query failed".to_string()));
3152 }
3153 }
3155 Token::Info(info) => {
3156 tracing::debug!(
3157 number = info.number,
3158 message = %info.message,
3159 "server info message"
3160 );
3161 }
3162 _ => {}
3163 }
3164 }
3165
3166 if !current_columns.is_empty() {
3168 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
3169 }
3170
3171 Ok(result_sets)
3172 }
3173
3174 pub async fn execute(
3178 &mut self,
3179 sql: &str,
3180 params: &[&(dyn crate::ToSql + Sync)],
3181 ) -> Result<u64> {
3182 tracing::debug!(
3183 sql = sql,
3184 params_count = params.len(),
3185 "executing statement"
3186 );
3187
3188 #[cfg(feature = "otel")]
3189 let instrumentation = self.instrumentation.clone();
3190 #[cfg(feature = "otel")]
3191 let mut span = instrumentation.query_span(sql);
3192
3193 let result = async {
3194 if params.is_empty() {
3195 self.send_sql_batch(sql).await?;
3197 } else {
3198 let rpc_params = Self::convert_params(params)?;
3200 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3201 self.send_rpc(&rpc).await?;
3202 }
3203
3204 self.read_execute_result().await
3206 }
3207 .await;
3208
3209 #[cfg(feature = "otel")]
3210 match &result {
3211 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3212 Err(e) => InstrumentationContext::record_error(&mut span, e),
3213 }
3214
3215 #[cfg(feature = "otel")]
3217 drop(span);
3218
3219 result
3220 }
3221
3222 pub async fn execute_with_timeout(
3249 &mut self,
3250 sql: &str,
3251 params: &[&(dyn crate::ToSql + Sync)],
3252 timeout_duration: std::time::Duration,
3253 ) -> Result<u64> {
3254 timeout(timeout_duration, self.execute(sql, params))
3255 .await
3256 .map_err(|_| Error::CommandTimeout)?
3257 }
3258
3259 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
3266 tracing::debug!("beginning transaction");
3267
3268 #[cfg(feature = "otel")]
3269 let instrumentation = self.instrumentation.clone();
3270 #[cfg(feature = "otel")]
3271 let mut span = instrumentation.transaction_span("BEGIN");
3272
3273 let result = async {
3275 self.send_sql_batch("BEGIN TRANSACTION").await?;
3276 self.read_transaction_begin_result().await
3277 }
3278 .await;
3279
3280 #[cfg(feature = "otel")]
3281 match &result {
3282 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3283 Err(e) => InstrumentationContext::record_error(&mut span, e),
3284 }
3285
3286 #[cfg(feature = "otel")]
3288 drop(span);
3289
3290 let transaction_descriptor = result?;
3291
3292 Ok(Client {
3293 config: self.config,
3294 _state: PhantomData,
3295 connection: self.connection,
3296 server_version: self.server_version,
3297 current_database: self.current_database,
3298 statement_cache: self.statement_cache,
3299 transaction_descriptor, #[cfg(feature = "otel")]
3301 instrumentation: self.instrumentation,
3302 })
3303 }
3304
3305 pub async fn begin_transaction_with_isolation(
3320 mut self,
3321 isolation_level: crate::transaction::IsolationLevel,
3322 ) -> Result<Client<InTransaction>> {
3323 tracing::debug!(
3324 isolation_level = %isolation_level.name(),
3325 "beginning transaction with isolation level"
3326 );
3327
3328 #[cfg(feature = "otel")]
3329 let instrumentation = self.instrumentation.clone();
3330 #[cfg(feature = "otel")]
3331 let mut span = instrumentation.transaction_span("BEGIN");
3332
3333 let result = async {
3335 self.send_sql_batch(isolation_level.as_sql()).await?;
3336 self.read_execute_result().await?;
3337
3338 self.send_sql_batch("BEGIN TRANSACTION").await?;
3340 self.read_transaction_begin_result().await
3341 }
3342 .await;
3343
3344 #[cfg(feature = "otel")]
3345 match &result {
3346 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3347 Err(e) => InstrumentationContext::record_error(&mut span, e),
3348 }
3349
3350 #[cfg(feature = "otel")]
3351 drop(span);
3352
3353 let transaction_descriptor = result?;
3354
3355 Ok(Client {
3356 config: self.config,
3357 _state: PhantomData,
3358 connection: self.connection,
3359 server_version: self.server_version,
3360 current_database: self.current_database,
3361 statement_cache: self.statement_cache,
3362 transaction_descriptor,
3363 #[cfg(feature = "otel")]
3364 instrumentation: self.instrumentation,
3365 })
3366 }
3367
3368 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
3373 tracing::debug!(sql = sql, "executing simple query");
3374
3375 self.send_sql_batch(sql).await?;
3377
3378 let _ = self.read_execute_result().await?;
3380
3381 Ok(())
3382 }
3383
3384 pub async fn close(self) -> Result<()> {
3386 tracing::debug!("closing connection");
3387 Ok(())
3388 }
3389
3390 #[must_use]
3392 pub fn database(&self) -> Option<&str> {
3393 self.config.database.as_deref()
3394 }
3395
3396 #[must_use]
3398 pub fn host(&self) -> &str {
3399 &self.config.host
3400 }
3401
3402 #[must_use]
3404 pub fn port(&self) -> u16 {
3405 self.config.port
3406 }
3407
3408 #[must_use]
3427 pub fn is_in_transaction(&self) -> bool {
3428 self.transaction_descriptor != 0
3429 }
3430
3431 #[must_use]
3453 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3454 let connection = self
3455 .connection
3456 .as_ref()
3457 .expect("connection should be present");
3458 match connection {
3459 ConnectionHandle::Tls(conn) => {
3460 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3461 }
3462 ConnectionHandle::TlsPrelogin(conn) => {
3463 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3464 }
3465 ConnectionHandle::Plain(conn) => {
3466 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3467 }
3468 }
3469 }
3470}
3471
3472impl Client<InTransaction> {
3473 pub async fn query<'a>(
3477 &'a mut self,
3478 sql: &str,
3479 params: &[&(dyn crate::ToSql + Sync)],
3480 ) -> Result<QueryStream<'a>> {
3481 tracing::debug!(
3482 sql = sql,
3483 params_count = params.len(),
3484 "executing query in transaction"
3485 );
3486
3487 #[cfg(feature = "otel")]
3488 let instrumentation = self.instrumentation.clone();
3489 #[cfg(feature = "otel")]
3490 let mut span = instrumentation.query_span(sql);
3491
3492 let result = async {
3493 if params.is_empty() {
3494 self.send_sql_batch(sql).await?;
3496 } else {
3497 let rpc_params = Self::convert_params(params)?;
3499 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3500 self.send_rpc(&rpc).await?;
3501 }
3502
3503 self.read_query_response().await
3505 }
3506 .await;
3507
3508 #[cfg(feature = "otel")]
3509 match &result {
3510 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3511 Err(e) => InstrumentationContext::record_error(&mut span, e),
3512 }
3513
3514 #[cfg(feature = "otel")]
3516 drop(span);
3517
3518 let (columns, rows) = result?;
3519 Ok(QueryStream::new(columns, rows))
3520 }
3521
3522 pub async fn execute(
3526 &mut self,
3527 sql: &str,
3528 params: &[&(dyn crate::ToSql + Sync)],
3529 ) -> Result<u64> {
3530 tracing::debug!(
3531 sql = sql,
3532 params_count = params.len(),
3533 "executing statement in transaction"
3534 );
3535
3536 #[cfg(feature = "otel")]
3537 let instrumentation = self.instrumentation.clone();
3538 #[cfg(feature = "otel")]
3539 let mut span = instrumentation.query_span(sql);
3540
3541 let result = async {
3542 if params.is_empty() {
3543 self.send_sql_batch(sql).await?;
3545 } else {
3546 let rpc_params = Self::convert_params(params)?;
3548 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3549 self.send_rpc(&rpc).await?;
3550 }
3551
3552 self.read_execute_result().await
3554 }
3555 .await;
3556
3557 #[cfg(feature = "otel")]
3558 match &result {
3559 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3560 Err(e) => InstrumentationContext::record_error(&mut span, e),
3561 }
3562
3563 #[cfg(feature = "otel")]
3565 drop(span);
3566
3567 result
3568 }
3569
3570 pub async fn query_with_timeout<'a>(
3574 &'a mut self,
3575 sql: &str,
3576 params: &[&(dyn crate::ToSql + Sync)],
3577 timeout_duration: std::time::Duration,
3578 ) -> Result<QueryStream<'a>> {
3579 timeout(timeout_duration, self.query(sql, params))
3580 .await
3581 .map_err(|_| Error::CommandTimeout)?
3582 }
3583
3584 pub async fn execute_with_timeout(
3588 &mut self,
3589 sql: &str,
3590 params: &[&(dyn crate::ToSql + Sync)],
3591 timeout_duration: std::time::Duration,
3592 ) -> Result<u64> {
3593 timeout(timeout_duration, self.execute(sql, params))
3594 .await
3595 .map_err(|_| Error::CommandTimeout)?
3596 }
3597
3598 pub async fn commit(mut self) -> Result<Client<Ready>> {
3602 tracing::debug!("committing transaction");
3603
3604 #[cfg(feature = "otel")]
3605 let instrumentation = self.instrumentation.clone();
3606 #[cfg(feature = "otel")]
3607 let mut span = instrumentation.transaction_span("COMMIT");
3608
3609 let result = async {
3611 self.send_sql_batch("COMMIT TRANSACTION").await?;
3612 self.read_execute_result().await
3613 }
3614 .await;
3615
3616 #[cfg(feature = "otel")]
3617 match &result {
3618 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3619 Err(e) => InstrumentationContext::record_error(&mut span, e),
3620 }
3621
3622 #[cfg(feature = "otel")]
3624 drop(span);
3625
3626 result?;
3627
3628 Ok(Client {
3629 config: self.config,
3630 _state: PhantomData,
3631 connection: self.connection,
3632 server_version: self.server_version,
3633 current_database: self.current_database,
3634 statement_cache: self.statement_cache,
3635 transaction_descriptor: 0, #[cfg(feature = "otel")]
3637 instrumentation: self.instrumentation,
3638 })
3639 }
3640
3641 pub async fn rollback(mut self) -> Result<Client<Ready>> {
3645 tracing::debug!("rolling back transaction");
3646
3647 #[cfg(feature = "otel")]
3648 let instrumentation = self.instrumentation.clone();
3649 #[cfg(feature = "otel")]
3650 let mut span = instrumentation.transaction_span("ROLLBACK");
3651
3652 let result = async {
3654 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
3655 self.read_execute_result().await
3656 }
3657 .await;
3658
3659 #[cfg(feature = "otel")]
3660 match &result {
3661 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3662 Err(e) => InstrumentationContext::record_error(&mut span, e),
3663 }
3664
3665 #[cfg(feature = "otel")]
3667 drop(span);
3668
3669 result?;
3670
3671 Ok(Client {
3672 config: self.config,
3673 _state: PhantomData,
3674 connection: self.connection,
3675 server_version: self.server_version,
3676 current_database: self.current_database,
3677 statement_cache: self.statement_cache,
3678 transaction_descriptor: 0, #[cfg(feature = "otel")]
3680 instrumentation: self.instrumentation,
3681 })
3682 }
3683
3684 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3701 validate_identifier(name)?;
3702 tracing::debug!(name = name, "creating savepoint");
3703
3704 let sql = format!("SAVE TRANSACTION {}", name);
3707 self.send_sql_batch(&sql).await?;
3708 self.read_execute_result().await?;
3709
3710 Ok(SavePoint::new(name.to_string()))
3711 }
3712
3713 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3728 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3729
3730 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3733 self.send_sql_batch(&sql).await?;
3734 self.read_execute_result().await?;
3735
3736 Ok(())
3737 }
3738
3739 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3745 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3746
3747 drop(savepoint);
3751 Ok(())
3752 }
3753
3754 #[must_use]
3758 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3759 let connection = self
3760 .connection
3761 .as_ref()
3762 .expect("connection should be present");
3763 match connection {
3764 ConnectionHandle::Tls(conn) => {
3765 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3766 }
3767 ConnectionHandle::TlsPrelogin(conn) => {
3768 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3769 }
3770 ConnectionHandle::Plain(conn) => {
3771 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3772 }
3773 }
3774 }
3775}
3776
3777fn validate_identifier(name: &str) -> Result<()> {
3779 use once_cell::sync::Lazy;
3780 use regex::Regex;
3781
3782 static IDENTIFIER_RE: Lazy<Regex> =
3783 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3784
3785 if name.is_empty() {
3786 return Err(Error::InvalidIdentifier(
3787 "identifier cannot be empty".into(),
3788 ));
3789 }
3790
3791 if !IDENTIFIER_RE.is_match(name) {
3792 return Err(Error::InvalidIdentifier(format!(
3793 "invalid identifier '{}': must start with letter/underscore, \
3794 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3795 name
3796 )));
3797 }
3798
3799 Ok(())
3800}
3801
3802impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3803 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3804 f.debug_struct("Client")
3805 .field("host", &self.config.host)
3806 .field("port", &self.config.port)
3807 .field("database", &self.config.database)
3808 .finish()
3809 }
3810}
3811
3812#[cfg(test)]
3813#[allow(clippy::unwrap_used, clippy::panic)]
3814mod tests {
3815 use super::*;
3816
3817 #[test]
3818 fn test_validate_identifier_valid() {
3819 assert!(validate_identifier("my_table").is_ok());
3820 assert!(validate_identifier("Table123").is_ok());
3821 assert!(validate_identifier("_private").is_ok());
3822 assert!(validate_identifier("sp_test").is_ok());
3823 }
3824
3825 #[test]
3826 fn test_validate_identifier_invalid() {
3827 assert!(validate_identifier("").is_err());
3828 assert!(validate_identifier("123abc").is_err());
3829 assert!(validate_identifier("table-name").is_err());
3830 assert!(validate_identifier("table name").is_err());
3831 assert!(validate_identifier("table;DROP TABLE users").is_err());
3832 }
3833
3834 fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
3843 let mut data = Vec::new();
3844 data.extend_from_slice(&total_len.to_le_bytes());
3846 for chunk in chunks {
3848 let len = chunk.len() as u32;
3849 data.extend_from_slice(&len.to_le_bytes());
3850 data.extend_from_slice(chunk);
3851 }
3852 data.extend_from_slice(&0u32.to_le_bytes());
3854 data
3855 }
3856
3857 #[test]
3858 fn test_parse_plp_nvarchar_simple() {
3859 let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
3861 let plp = make_plp_data(10, &[&utf16_data]);
3862 let mut buf: &[u8] = &plp;
3863
3864 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3865 match result {
3866 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3867 _ => panic!("expected String, got {:?}", result),
3868 }
3869 }
3870
3871 #[test]
3872 fn test_parse_plp_nvarchar_null() {
3873 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3875 let mut buf: &[u8] = &plp;
3876
3877 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3878 assert!(matches!(result, mssql_types::SqlValue::Null));
3879 }
3880
3881 #[test]
3882 fn test_parse_plp_nvarchar_empty() {
3883 let plp = make_plp_data(0, &[]);
3885 let mut buf: &[u8] = &plp;
3886
3887 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3888 match result {
3889 mssql_types::SqlValue::String(s) => assert_eq!(s, ""),
3890 _ => panic!("expected empty String"),
3891 }
3892 }
3893
3894 #[test]
3895 fn test_parse_plp_nvarchar_multi_chunk() {
3896 let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
3900 let mut buf: &[u8] = &plp;
3901
3902 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3903 match result {
3904 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3905 _ => panic!("expected String"),
3906 }
3907 }
3908
3909 #[test]
3910 fn test_parse_plp_varchar_simple() {
3911 let data = b"Hello World";
3912 let plp = make_plp_data(11, &[data]);
3913 let mut buf: &[u8] = &plp;
3914
3915 let result = Client::<Ready>::parse_plp_varchar(&mut buf).unwrap();
3916 match result {
3917 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello World"),
3918 _ => panic!("expected String"),
3919 }
3920 }
3921
3922 #[test]
3923 fn test_parse_plp_varchar_null() {
3924 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3925 let mut buf: &[u8] = &plp;
3926
3927 let result = Client::<Ready>::parse_plp_varchar(&mut buf).unwrap();
3928 assert!(matches!(result, mssql_types::SqlValue::Null));
3929 }
3930
3931 #[test]
3932 fn test_parse_plp_varbinary_simple() {
3933 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
3934 let plp = make_plp_data(5, &[&data]);
3935 let mut buf: &[u8] = &plp;
3936
3937 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
3938 match result {
3939 mssql_types::SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
3940 _ => panic!("expected Binary"),
3941 }
3942 }
3943
3944 #[test]
3945 fn test_parse_plp_varbinary_null() {
3946 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3947 let mut buf: &[u8] = &plp;
3948
3949 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
3950 assert!(matches!(result, mssql_types::SqlValue::Null));
3951 }
3952
3953 #[test]
3954 fn test_parse_plp_varbinary_large() {
3955 let chunk1: Vec<u8> = (0..100u8).collect();
3957 let chunk2: Vec<u8> = (100..200u8).collect();
3958 let chunk3: Vec<u8> = (200..255u8).collect();
3959 let total_len = chunk1.len() + chunk2.len() + chunk3.len();
3960 let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
3961 let mut buf: &[u8] = &plp;
3962
3963 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
3964 match result {
3965 mssql_types::SqlValue::Binary(b) => {
3966 assert_eq!(b.len(), 255);
3967 for (i, &byte) in b.iter().enumerate() {
3969 assert_eq!(byte, i as u8);
3970 }
3971 }
3972 _ => panic!("expected Binary"),
3973 }
3974 }
3975
3976 use tds_protocol::token::{ColumnData, TypeInfo};
3984 use tds_protocol::types::TypeId;
3985
3986 fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
3989 let mut data = Vec::new();
3990
3991 let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
3993 let byte_len = (utf16.len() * 2) as u16;
3994 data.extend_from_slice(&byte_len.to_le_bytes());
3995 for code_unit in utf16 {
3996 data.extend_from_slice(&code_unit.to_le_bytes());
3997 }
3998
3999 data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
4002
4003 data
4004 }
4005
4006 #[test]
4007 fn test_parse_row_nvarchar_then_int() {
4008 let raw_data = make_nvarchar_int_row("World", 42);
4010
4011 let col0 = ColumnData {
4013 name: "greeting".to_string(),
4014 type_id: TypeId::NVarChar,
4015 col_type: 0xE7,
4016 flags: 0x01,
4017 user_type: 0,
4018 type_info: TypeInfo {
4019 max_length: Some(10), precision: None,
4021 scale: None,
4022 collation: None,
4023 },
4024 };
4025
4026 let col1 = ColumnData {
4027 name: "number".to_string(),
4028 type_id: TypeId::IntN,
4029 col_type: 0x26,
4030 flags: 0x01,
4031 user_type: 0,
4032 type_info: TypeInfo {
4033 max_length: Some(4),
4034 precision: None,
4035 scale: None,
4036 collation: None,
4037 },
4038 };
4039
4040 let mut buf: &[u8] = &raw_data;
4041
4042 let value0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4044 match value0 {
4045 mssql_types::SqlValue::String(s) => assert_eq!(s, "World"),
4046 _ => panic!("expected String, got {:?}", value0),
4047 }
4048
4049 let value1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4051 match value1 {
4052 mssql_types::SqlValue::Int(i) => assert_eq!(i, 42),
4053 _ => panic!("expected Int, got {:?}", value1),
4054 }
4055
4056 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4058 }
4059
4060 #[test]
4061 fn test_parse_row_multiple_types() {
4062 let mut data = Vec::new();
4064
4065 data.extend_from_slice(&0xFFFFu16.to_le_bytes());
4067
4068 data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
4071
4072 let utf16: Vec<u16> = "Test".encode_utf16().collect();
4074 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4075 for code_unit in utf16 {
4076 data.extend_from_slice(&code_unit.to_le_bytes());
4077 }
4078
4079 data.push(0);
4081
4082 let col0 = ColumnData {
4084 name: "col0".to_string(),
4085 type_id: TypeId::NVarChar,
4086 col_type: 0xE7,
4087 flags: 0x01,
4088 user_type: 0,
4089 type_info: TypeInfo {
4090 max_length: Some(100),
4091 precision: None,
4092 scale: None,
4093 collation: None,
4094 },
4095 };
4096 let col1 = ColumnData {
4097 name: "col1".to_string(),
4098 type_id: TypeId::IntN,
4099 col_type: 0x26,
4100 flags: 0x01,
4101 user_type: 0,
4102 type_info: TypeInfo {
4103 max_length: Some(4),
4104 precision: None,
4105 scale: None,
4106 collation: None,
4107 },
4108 };
4109 let col2 = col0.clone();
4110 let col3 = col1.clone();
4111
4112 let mut buf: &[u8] = &data;
4113
4114 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4116 assert!(
4117 matches!(v0, mssql_types::SqlValue::Null),
4118 "col0 should be Null"
4119 );
4120
4121 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4122 assert!(
4123 matches!(v1, mssql_types::SqlValue::Int(123)),
4124 "col1 should be 123"
4125 );
4126
4127 let v2 = Client::<Ready>::parse_column_value(&mut buf, &col2).unwrap();
4128 match v2 {
4129 mssql_types::SqlValue::String(s) => assert_eq!(s, "Test"),
4130 _ => panic!("col2 should be 'Test'"),
4131 }
4132
4133 let v3 = Client::<Ready>::parse_column_value(&mut buf, &col3).unwrap();
4134 assert!(
4135 matches!(v3, mssql_types::SqlValue::Null),
4136 "col3 should be Null"
4137 );
4138
4139 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4141 }
4142
4143 #[test]
4144 fn test_parse_row_with_unicode() {
4145 let test_str = "Héllo Wörld 日本語";
4147 let mut data = Vec::new();
4148
4149 let utf16: Vec<u16> = test_str.encode_utf16().collect();
4151 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4152 for code_unit in utf16 {
4153 data.extend_from_slice(&code_unit.to_le_bytes());
4154 }
4155
4156 data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
4159
4160 let col0 = ColumnData {
4161 name: "text".to_string(),
4162 type_id: TypeId::NVarChar,
4163 col_type: 0xE7,
4164 flags: 0x01,
4165 user_type: 0,
4166 type_info: TypeInfo {
4167 max_length: Some(100),
4168 precision: None,
4169 scale: None,
4170 collation: None,
4171 },
4172 };
4173 let col1 = ColumnData {
4174 name: "num".to_string(),
4175 type_id: TypeId::IntN,
4176 col_type: 0x26,
4177 flags: 0x01,
4178 user_type: 0,
4179 type_info: TypeInfo {
4180 max_length: Some(8),
4181 precision: None,
4182 scale: None,
4183 collation: None,
4184 },
4185 };
4186
4187 let mut buf: &[u8] = &data;
4188
4189 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4190 match v0 {
4191 mssql_types::SqlValue::String(s) => assert_eq!(s, test_str),
4192 _ => panic!("expected String"),
4193 }
4194
4195 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4196 match v1 {
4197 mssql_types::SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
4198 _ => panic!("expected BigInt"),
4199 }
4200
4201 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4202 }
4203}