1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, Collation, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token,
19 TokenParser,
20};
21#[cfg(feature = "decimal")]
22use tds_protocol::tvp::encode_tvp_decimal;
23use tds_protocol::tvp::{
24 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
25 encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar, encode_tvp_varbinary,
26};
27use tokio::net::TcpStream;
28use tokio::time::timeout;
29
30use crate::config::Config;
31use crate::error::{Error, Result};
32#[cfg(feature = "otel")]
33use crate::instrumentation::InstrumentationContext;
34use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
35use crate::statement_cache::StatementCache;
36use crate::stream::{MultiResultStream, QueryStream};
37use crate::transaction::SavePoint;
38
39pub struct Client<S: ConnectionState> {
45 config: Config,
46 _state: PhantomData<S>,
47 connection: Option<ConnectionHandle>,
49 server_version: Option<u32>,
51 current_database: Option<String>,
53 statement_cache: StatementCache,
55 transaction_descriptor: u64,
59 #[cfg(feature = "otel")]
61 instrumentation: InstrumentationContext,
62}
63
64#[allow(dead_code)] enum ConnectionHandle {
72 Tls(Connection<TlsStream<TcpStream>>),
74 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
76 Plain(Connection<TcpStream>),
78}
79
80impl Client<Disconnected> {
81 pub async fn connect(config: Config) -> Result<Client<Ready>> {
92 let max_redirects = config.redirect.max_redirects;
93 let follow_redirects = config.redirect.follow_redirects;
94 let mut attempts = 0;
95 let mut current_config = config;
96
97 loop {
98 attempts += 1;
99 if attempts > max_redirects + 1 {
100 return Err(Error::TooManyRedirects { max: max_redirects });
101 }
102
103 match Self::try_connect(¤t_config).await {
104 Ok(client) => return Ok(client),
105 Err(Error::Routing { host, port }) => {
106 if !follow_redirects {
107 return Err(Error::Routing { host, port });
108 }
109 tracing::info!(
110 host = %host,
111 port = port,
112 attempt = attempts,
113 max_redirects = max_redirects,
114 "following Azure SQL routing redirect"
115 );
116 current_config = current_config.with_host(&host).with_port(port);
117 continue;
118 }
119 Err(e) => return Err(e),
120 }
121 }
122 }
123
124 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
125 tracing::info!(
126 host = %config.host,
127 port = config.port,
128 database = ?config.database,
129 "connecting to SQL Server"
130 );
131
132 let addr = format!("{}:{}", config.host, config.port);
133
134 tracing::debug!("establishing TCP connection to {}", addr);
136 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
137 .await
138 .map_err(|_| Error::ConnectTimeout)?
139 .map_err(|e| Error::Io(Arc::new(e)))?;
140
141 tcp_stream
143 .set_nodelay(true)
144 .map_err(|e| Error::Io(Arc::new(e)))?;
145
146 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
148
149 if tls_mode.is_tls_first() {
151 return Self::connect_tds_8(config, tcp_stream).await;
152 }
153
154 Self::connect_tds_7x(config, tcp_stream).await
156 }
157
158 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
162 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
163
164 let tls_config = TlsConfig::new()
166 .strict_mode(true)
167 .trust_server_certificate(config.trust_server_certificate);
168
169 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
170
171 let tls_stream = timeout(
173 config.timeouts.tls_timeout,
174 tls_connector.connect(tcp_stream, &config.host),
175 )
176 .await
177 .map_err(|_| Error::TlsTimeout)?
178 .map_err(|e| Error::Tls(e.to_string()))?;
179
180 tracing::debug!("TLS handshake completed (strict mode)");
181
182 let mut connection = Connection::new(tls_stream);
184
185 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
187 Self::send_prelogin(&mut connection, &prelogin).await?;
188 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
189
190 let login = Self::build_login7(config);
192 Self::send_login7(&mut connection, &login).await?;
193
194 let (server_version, current_database, routing) =
196 Self::process_login_response(&mut connection).await?;
197
198 if let Some((host, port)) = routing {
200 return Err(Error::Routing { host, port });
201 }
202
203 Ok(Client {
204 config: config.clone(),
205 _state: PhantomData,
206 connection: Some(ConnectionHandle::Tls(connection)),
207 server_version,
208 current_database: current_database.clone(),
209 statement_cache: StatementCache::with_default_size(),
210 transaction_descriptor: 0, #[cfg(feature = "otel")]
212 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
213 .with_database(current_database.unwrap_or_default()),
214 })
215 }
216
217 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
225 use bytes::BufMut;
226 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
227 use tokio::io::{AsyncReadExt, AsyncWriteExt};
228
229 tracing::debug!("using TDS 7.x flow (PreLogin first)");
230
231 let client_encryption = if config.encrypt {
234 EncryptionLevel::On
235 } else {
236 EncryptionLevel::Off
237 };
238 let prelogin = Self::build_prelogin(config, client_encryption);
239 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
240 let prelogin_bytes = prelogin.encode();
241
242 let header = PacketHeader::new(
244 PacketType::PreLogin,
245 PacketStatus::END_OF_MESSAGE,
246 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
247 );
248
249 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
250 header.encode(&mut packet_buf);
251 packet_buf.put_slice(&prelogin_bytes);
252
253 tcp_stream
254 .write_all(&packet_buf)
255 .await
256 .map_err(|e| Error::Io(Arc::new(e)))?;
257
258 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
260 tcp_stream
261 .read_exact(&mut header_buf)
262 .await
263 .map_err(|e| Error::Io(Arc::new(e)))?;
264
265 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
266 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
267
268 let mut response_buf = vec![0u8; payload_length];
269 tcp_stream
270 .read_exact(&mut response_buf)
271 .await
272 .map_err(|e| Error::Io(Arc::new(e)))?;
273
274 let prelogin_response =
275 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
276
277 let server_version = prelogin_response.version;
279 let client_version = config.tds_version;
280 tracing::debug!(
281 client_version = %client_version,
282 server_version = %server_version,
283 server_sql_version = server_version.sql_server_version_name(),
284 "TDS version negotiation"
285 );
286
287 if server_version < client_version && !client_version.is_tds_8() {
289 tracing::warn!(
290 client_version = %client_version,
291 server_version = %server_version,
292 "Server supports lower TDS version than requested. \
293 Connection will use server's version: {}",
294 server_version.sql_server_version_name()
295 );
296 }
297
298 if server_version.is_legacy() {
300 tracing::warn!(
301 server_version = %server_version,
302 "Server uses legacy TDS version ({}). \
303 Some features may not be available.",
304 server_version.sql_server_version_name()
305 );
306 }
307
308 let server_encryption = prelogin_response.encryption;
310 tracing::debug!(encryption = ?server_encryption, "server encryption level");
311
312 let negotiated_encryption = match (client_encryption, server_encryption) {
318 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
319 EncryptionLevel::NotSupported
320 }
321 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
322 (EncryptionLevel::On, EncryptionLevel::Off)
323 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
324 return Err(Error::Protocol(
325 "Server does not support requested encryption level".to_string(),
326 ));
327 }
328 _ => EncryptionLevel::On,
329 };
330
331 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
334
335 if use_tls {
336 let tls_config =
339 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
340
341 let tls_connector =
342 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
343
344 let mut tls_stream = timeout(
346 config.timeouts.tls_timeout,
347 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
348 )
349 .await
350 .map_err(|_| Error::TlsTimeout)?
351 .map_err(|e| Error::Tls(e.to_string()))?;
352
353 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
354
355 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
357
358 if login_only_encryption {
359 use tokio::io::AsyncWriteExt;
367
368 let login = Self::build_login7(config);
370 let login_payload = login.encode();
371
372 let max_packet = MAX_PACKET_SIZE;
374 let max_payload = max_packet - PACKET_HEADER_SIZE;
375 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
376 let total_chunks = chunks.len();
377
378 for (i, chunk) in chunks.into_iter().enumerate() {
379 let is_last = i == total_chunks - 1;
380 let status = if is_last {
381 PacketStatus::END_OF_MESSAGE
382 } else {
383 PacketStatus::NORMAL
384 };
385
386 let header = PacketHeader::new(
387 PacketType::Tds7Login,
388 status,
389 (PACKET_HEADER_SIZE + chunk.len()) as u16,
390 );
391
392 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
393 header.encode(&mut packet_buf);
394 packet_buf.put_slice(chunk);
395
396 tls_stream
397 .write_all(&packet_buf)
398 .await
399 .map_err(|e| Error::Io(Arc::new(e)))?;
400 }
401
402 tls_stream
404 .flush()
405 .await
406 .map_err(|e| Error::Io(Arc::new(e)))?;
407
408 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
409
410 let (wrapper, _client_conn) = tls_stream.into_inner();
414 let tcp_stream = wrapper.into_inner();
415
416 let mut connection = Connection::new(tcp_stream);
418
419 let (server_version, current_database, routing) =
421 Self::process_login_response(&mut connection).await?;
422
423 if let Some((host, port)) = routing {
425 return Err(Error::Routing { host, port });
426 }
427
428 Ok(Client {
430 config: config.clone(),
431 _state: PhantomData,
432 connection: Some(ConnectionHandle::Plain(connection)),
433 server_version,
434 current_database: current_database.clone(),
435 statement_cache: StatementCache::with_default_size(),
436 transaction_descriptor: 0, #[cfg(feature = "otel")]
438 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
439 .with_database(current_database.unwrap_or_default()),
440 })
441 } else {
442 let mut connection = Connection::new(tls_stream);
445
446 let login = Self::build_login7(config);
448 Self::send_login7(&mut connection, &login).await?;
449
450 let (server_version, current_database, routing) =
452 Self::process_login_response(&mut connection).await?;
453
454 if let Some((host, port)) = routing {
456 return Err(Error::Routing { host, port });
457 }
458
459 Ok(Client {
460 config: config.clone(),
461 _state: PhantomData,
462 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
463 server_version,
464 current_database: current_database.clone(),
465 statement_cache: StatementCache::with_default_size(),
466 transaction_descriptor: 0, #[cfg(feature = "otel")]
468 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
469 .with_database(current_database.unwrap_or_default()),
470 })
471 }
472 } else {
473 tracing::warn!(
475 "Connecting without TLS encryption. This is insecure and should only be \
476 used for development/testing on trusted networks."
477 );
478
479 let login = Self::build_login7(config);
481 let login_bytes = login.encode();
482 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
483 tracing::debug!(
485 "Login7 fixed header (94 bytes): {:02X?}",
486 &login_bytes[..login_bytes.len().min(94)]
487 );
488 if login_bytes.len() > 94 {
490 tracing::debug!(
491 "Login7 variable data ({} bytes): {:02X?}",
492 login_bytes.len() - 94,
493 &login_bytes[94..]
494 );
495 }
496
497 let login_header = PacketHeader::new(
499 PacketType::Tds7Login,
500 PacketStatus::END_OF_MESSAGE,
501 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
502 )
503 .with_packet_id(1);
504 let mut login_packet_buf =
505 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
506 login_header.encode(&mut login_packet_buf);
507 login_packet_buf.put_slice(&login_bytes);
508
509 tracing::debug!(
510 "Sending Login7 packet: {} bytes total, header: {:02X?}",
511 login_packet_buf.len(),
512 &login_packet_buf[..PACKET_HEADER_SIZE]
513 );
514 tcp_stream
515 .write_all(&login_packet_buf)
516 .await
517 .map_err(|e| Error::Io(Arc::new(e)))?;
518 tcp_stream
519 .flush()
520 .await
521 .map_err(|e| Error::Io(Arc::new(e)))?;
522 tracing::debug!("Login7 sent and flushed over raw TCP");
523
524 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
526 tcp_stream
527 .read_exact(&mut response_header_buf)
528 .await
529 .map_err(|e| Error::Io(Arc::new(e)))?;
530
531 let response_type = response_header_buf[0];
532 let response_length =
533 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
534 tracing::debug!(
535 "Response header: type={:#04X}, length={}",
536 response_type,
537 response_length
538 );
539
540 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
542 let mut response_payload = vec![0u8; payload_length];
543 tcp_stream
544 .read_exact(&mut response_payload)
545 .await
546 .map_err(|e| Error::Io(Arc::new(e)))?;
547 tracing::debug!(
548 "Response payload: {} bytes, first 32: {:02X?}",
549 response_payload.len(),
550 &response_payload[..response_payload.len().min(32)]
551 );
552
553 let connection = Connection::new(tcp_stream);
555
556 let response_bytes = bytes::Bytes::from(response_payload);
558 let mut parser = TokenParser::new(response_bytes);
559 let mut server_version = None;
560 let mut current_database = None;
561 let routing = None;
562
563 while let Some(token) = parser
564 .next_token()
565 .map_err(|e| Error::Protocol(e.to_string()))?
566 {
567 match token {
568 Token::LoginAck(ack) => {
569 tracing::info!(
570 version = ack.tds_version,
571 interface = ack.interface,
572 prog_name = %ack.prog_name,
573 "login acknowledged"
574 );
575 server_version = Some(ack.tds_version);
576 }
577 Token::EnvChange(env) => {
578 Self::process_env_change(&env, &mut current_database, &mut None);
579 }
580 Token::Error(err) => {
581 return Err(Error::Server {
582 number: err.number,
583 state: err.state,
584 class: err.class,
585 message: err.message.clone(),
586 server: if err.server.is_empty() {
587 None
588 } else {
589 Some(err.server.clone())
590 },
591 procedure: if err.procedure.is_empty() {
592 None
593 } else {
594 Some(err.procedure.clone())
595 },
596 line: err.line as u32,
597 });
598 }
599 Token::Info(info) => {
600 tracing::info!(
601 number = info.number,
602 message = %info.message,
603 "server info message"
604 );
605 }
606 Token::Done(done) => {
607 if done.status.error {
608 return Err(Error::Protocol("login failed".to_string()));
609 }
610 break;
611 }
612 _ => {}
613 }
614 }
615
616 if let Some((host, port)) = routing {
618 return Err(Error::Routing { host, port });
619 }
620
621 Ok(Client {
622 config: config.clone(),
623 _state: PhantomData,
624 connection: Some(ConnectionHandle::Plain(connection)),
625 server_version,
626 current_database: current_database.clone(),
627 statement_cache: StatementCache::with_default_size(),
628 transaction_descriptor: 0, #[cfg(feature = "otel")]
630 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
631 .with_database(current_database.unwrap_or_default()),
632 })
633 }
634 }
635
636 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
638 let version = if config.strict_mode {
640 tds_protocol::version::TdsVersion::V8_0
641 } else {
642 config.tds_version
643 };
644
645 let mut prelogin = PreLogin::new()
646 .with_version(version)
647 .with_encryption(encryption);
648
649 if config.mars {
650 prelogin = prelogin.with_mars(true);
651 }
652
653 if let Some(ref instance) = config.instance {
654 prelogin = prelogin.with_instance(instance);
655 }
656
657 prelogin
658 }
659
660 fn build_login7(config: &Config) -> Login7 {
662 let version = if config.strict_mode {
664 tds_protocol::version::TdsVersion::V8_0
665 } else {
666 config.tds_version
667 };
668
669 let mut login = Login7::new()
670 .with_tds_version(version)
671 .with_packet_size(config.packet_size as u32)
672 .with_app_name(&config.application_name)
673 .with_server_name(&config.host)
674 .with_hostname(&config.host);
675
676 if let Some(ref database) = config.database {
677 login = login.with_database(database);
678 }
679
680 match &config.credentials {
682 mssql_auth::Credentials::SqlServer { username, password } => {
683 login = login.with_sql_auth(username.as_ref(), password.as_ref());
684 }
685 _ => {}
687 }
688
689 login
690 }
691
692 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
694 where
695 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
696 {
697 let payload = prelogin.encode();
698 let max_packet = MAX_PACKET_SIZE;
699
700 connection
701 .send_message(PacketType::PreLogin, payload, max_packet)
702 .await
703 .map_err(|e| Error::Protocol(e.to_string()))
704 }
705
706 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
708 where
709 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
710 {
711 let message = connection
712 .read_message()
713 .await
714 .map_err(|e| Error::Protocol(e.to_string()))?
715 .ok_or(Error::ConnectionClosed)?;
716
717 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
718 }
719
720 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
722 where
723 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
724 {
725 let payload = login.encode();
726 let max_packet = MAX_PACKET_SIZE;
727
728 connection
729 .send_message(PacketType::Tds7Login, payload, max_packet)
730 .await
731 .map_err(|e| Error::Protocol(e.to_string()))
732 }
733
734 async fn process_login_response<T>(
738 connection: &mut Connection<T>,
739 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
740 where
741 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
742 {
743 let message = connection
744 .read_message()
745 .await
746 .map_err(|e| Error::Protocol(e.to_string()))?
747 .ok_or(Error::ConnectionClosed)?;
748
749 let response_bytes = message.payload;
750
751 let mut parser = TokenParser::new(response_bytes);
752 let mut server_version = None;
753 let mut database = None;
754 let mut routing = None;
755
756 while let Some(token) = parser
757 .next_token()
758 .map_err(|e| Error::Protocol(e.to_string()))?
759 {
760 match token {
761 Token::LoginAck(ack) => {
762 tracing::info!(
763 version = ack.tds_version,
764 interface = ack.interface,
765 prog_name = %ack.prog_name,
766 "login acknowledged"
767 );
768 server_version = Some(ack.tds_version);
769 }
770 Token::EnvChange(env) => {
771 Self::process_env_change(&env, &mut database, &mut routing);
772 }
773 Token::Error(err) => {
774 return Err(Error::Server {
775 number: err.number,
776 state: err.state,
777 class: err.class,
778 message: err.message.clone(),
779 server: if err.server.is_empty() {
780 None
781 } else {
782 Some(err.server.clone())
783 },
784 procedure: if err.procedure.is_empty() {
785 None
786 } else {
787 Some(err.procedure.clone())
788 },
789 line: err.line as u32,
790 });
791 }
792 Token::Info(info) => {
793 tracing::info!(
794 number = info.number,
795 message = %info.message,
796 "server info message"
797 );
798 }
799 Token::Done(done) => {
800 if done.status.error {
801 return Err(Error::Protocol("login failed".to_string()));
802 }
803 break;
804 }
805 _ => {}
806 }
807 }
808
809 Ok((server_version, database, routing))
810 }
811
812 fn process_env_change(
814 env: &EnvChange,
815 database: &mut Option<String>,
816 routing: &mut Option<(String, u16)>,
817 ) {
818 use tds_protocol::token::EnvChangeValue;
819
820 match env.env_type {
821 EnvChangeType::Database => {
822 if let EnvChangeValue::String(ref new_value) = env.new_value {
823 tracing::debug!(database = %new_value, "database changed");
824 *database = Some(new_value.clone());
825 }
826 }
827 EnvChangeType::Routing => {
828 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
829 tracing::info!(host = %host, port = port, "routing redirect received");
830 *routing = Some((host.clone(), port));
831 }
832 }
833 _ => {
834 if let EnvChangeValue::String(ref new_value) = env.new_value {
835 tracing::debug!(
836 env_type = ?env.env_type,
837 new_value = %new_value,
838 "environment change"
839 );
840 }
841 }
842 }
843 }
844}
845
846impl<S: ConnectionState> Client<S> {
848 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
856 use tds_protocol::token::EnvChangeValue;
857
858 match env.env_type {
859 EnvChangeType::BeginTransaction => {
860 if let EnvChangeValue::Binary(ref data) = env.new_value {
861 if data.len() >= 8 {
862 let descriptor = u64::from_le_bytes([
863 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
864 ]);
865 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
866 *transaction_descriptor = descriptor;
867 }
868 }
869 }
870 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
871 tracing::debug!(
872 env_type = ?env.env_type,
873 "transaction ended via raw SQL"
874 );
875 *transaction_descriptor = 0;
876 }
877 _ => {}
878 }
879 }
880
881 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
887 let payload =
888 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
889 let max_packet = self.config.packet_size as usize;
890
891 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
892
893 match connection {
894 ConnectionHandle::Tls(conn) => {
895 conn.send_message(PacketType::SqlBatch, payload, max_packet)
896 .await
897 .map_err(|e| Error::Protocol(e.to_string()))?;
898 }
899 ConnectionHandle::TlsPrelogin(conn) => {
900 conn.send_message(PacketType::SqlBatch, payload, max_packet)
901 .await
902 .map_err(|e| Error::Protocol(e.to_string()))?;
903 }
904 ConnectionHandle::Plain(conn) => {
905 conn.send_message(PacketType::SqlBatch, payload, max_packet)
906 .await
907 .map_err(|e| Error::Protocol(e.to_string()))?;
908 }
909 }
910
911 Ok(())
912 }
913
914 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
918 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
919 let max_packet = self.config.packet_size as usize;
920
921 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
922
923 match connection {
924 ConnectionHandle::Tls(conn) => {
925 conn.send_message(PacketType::Rpc, payload, max_packet)
926 .await
927 .map_err(|e| Error::Protocol(e.to_string()))?;
928 }
929 ConnectionHandle::TlsPrelogin(conn) => {
930 conn.send_message(PacketType::Rpc, payload, max_packet)
931 .await
932 .map_err(|e| Error::Protocol(e.to_string()))?;
933 }
934 ConnectionHandle::Plain(conn) => {
935 conn.send_message(PacketType::Rpc, payload, max_packet)
936 .await
937 .map_err(|e| Error::Protocol(e.to_string()))?;
938 }
939 }
940
941 Ok(())
942 }
943
944 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
946 use bytes::{BufMut, BytesMut};
947 use mssql_types::SqlValue;
948
949 params
950 .iter()
951 .enumerate()
952 .map(|(i, p)| {
953 let sql_value = p.to_sql()?;
954 let name = format!("@p{}", i + 1);
955
956 Ok(match sql_value {
957 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
958 SqlValue::Bool(v) => {
959 let mut buf = BytesMut::with_capacity(1);
960 buf.put_u8(if v { 1 } else { 0 });
961 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
962 }
963 SqlValue::TinyInt(v) => {
964 let mut buf = BytesMut::with_capacity(1);
965 buf.put_u8(v);
966 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
967 }
968 SqlValue::SmallInt(v) => {
969 let mut buf = BytesMut::with_capacity(2);
970 buf.put_i16_le(v);
971 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
972 }
973 SqlValue::Int(v) => RpcParam::int(&name, v),
974 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
975 SqlValue::Float(v) => {
976 let mut buf = BytesMut::with_capacity(4);
977 buf.put_f32_le(v);
978 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
979 }
980 SqlValue::Double(v) => {
981 let mut buf = BytesMut::with_capacity(8);
982 buf.put_f64_le(v);
983 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
984 }
985 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
986 SqlValue::Binary(ref b) => {
987 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
988 }
989 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
990 #[cfg(feature = "uuid")]
991 SqlValue::Uuid(u) => {
992 let bytes = u.as_bytes();
994 let mut buf = BytesMut::with_capacity(16);
995 buf.put_u32_le(u32::from_be_bytes([
997 bytes[0], bytes[1], bytes[2], bytes[3],
998 ]));
999 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
1000 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
1001 buf.put_slice(&bytes[8..16]);
1002 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
1003 }
1004 #[cfg(feature = "decimal")]
1005 SqlValue::Decimal(d) => {
1006 RpcParam::nvarchar(&name, &d.to_string())
1008 }
1009 #[cfg(feature = "chrono")]
1010 SqlValue::Date(_)
1011 | SqlValue::Time(_)
1012 | SqlValue::DateTime(_)
1013 | SqlValue::DateTimeOffset(_) => {
1014 let s = match &sql_value {
1017 SqlValue::Date(d) => d.to_string(),
1018 SqlValue::Time(t) => t.to_string(),
1019 SqlValue::DateTime(dt) => dt.to_string(),
1020 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
1021 _ => unreachable!(),
1022 };
1023 RpcParam::nvarchar(&name, &s)
1024 }
1025 #[cfg(feature = "json")]
1026 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
1027 SqlValue::Tvp(ref tvp_data) => {
1028 Self::encode_tvp_param(&name, tvp_data)?
1030 }
1031 _ => {
1033 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
1034 from: sql_value.type_name().to_string(),
1035 to: "RPC parameter",
1036 }));
1037 }
1038 })
1039 })
1040 .collect()
1041 }
1042
1043 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
1048 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1050 .columns
1051 .iter()
1052 .map(|col| {
1053 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1054 TvpWireColumnDef {
1055 wire_type,
1056 flags: TvpColumnFlags {
1057 nullable: col.nullable,
1058 },
1059 }
1060 })
1061 .collect();
1062
1063 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1065
1066 let mut buf = BytesMut::with_capacity(256);
1068
1069 encoder.encode_metadata(&mut buf);
1071
1072 for row in &tvp_data.rows {
1074 encoder.encode_row(&mut buf, |row_buf| {
1075 for (col_idx, value) in row.iter().enumerate() {
1076 let wire_type = &wire_columns[col_idx].wire_type;
1077 Self::encode_tvp_value(value, wire_type, row_buf);
1078 }
1079 });
1080 }
1081
1082 encoder.encode_end(&mut buf);
1084
1085 let type_info = RpcTypeInfo {
1089 type_id: 0xF3, max_length: None,
1091 precision: None,
1092 scale: None,
1093 collation: None,
1094 };
1095
1096 Ok(RpcParam {
1097 name: name.to_string(),
1098 flags: tds_protocol::rpc::ParamFlags::default(),
1099 type_info,
1100 value: Some(buf.freeze()),
1101 })
1102 }
1103
1104 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1106 match col_type {
1107 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1108 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1109 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1110 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1111 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1112 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1113 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1114 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1115 precision: *precision,
1116 scale: *scale,
1117 },
1118 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1119 max_length: *max_length,
1120 },
1121 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1122 max_length: *max_length,
1123 },
1124 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1125 max_length: *max_length,
1126 },
1127 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1128 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1129 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1130 mssql_types::TvpColumnType::DateTime2 { scale } => {
1131 TvpWireType::DateTime2 { scale: *scale }
1132 }
1133 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1134 TvpWireType::DateTimeOffset { scale: *scale }
1135 }
1136 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1137 }
1138 }
1139
1140 fn encode_tvp_value(
1142 value: &mssql_types::SqlValue,
1143 wire_type: &TvpWireType,
1144 buf: &mut BytesMut,
1145 ) {
1146 use mssql_types::SqlValue;
1147
1148 match value {
1149 SqlValue::Null => {
1150 encode_tvp_null(wire_type, buf);
1151 }
1152 SqlValue::Bool(v) => {
1153 encode_tvp_bit(*v, buf);
1154 }
1155 SqlValue::TinyInt(v) => {
1156 encode_tvp_int(*v as i64, 1, buf);
1157 }
1158 SqlValue::SmallInt(v) => {
1159 encode_tvp_int(*v as i64, 2, buf);
1160 }
1161 SqlValue::Int(v) => {
1162 encode_tvp_int(*v as i64, 4, buf);
1163 }
1164 SqlValue::BigInt(v) => {
1165 encode_tvp_int(*v, 8, buf);
1166 }
1167 SqlValue::Float(v) => {
1168 encode_tvp_float(*v as f64, 4, buf);
1169 }
1170 SqlValue::Double(v) => {
1171 encode_tvp_float(*v, 8, buf);
1172 }
1173 SqlValue::String(s) => {
1174 let max_len = match wire_type {
1175 TvpWireType::NVarChar { max_length } => *max_length,
1176 _ => 4000,
1177 };
1178 encode_tvp_nvarchar(s, max_len, buf);
1179 }
1180 SqlValue::Binary(b) => {
1181 let max_len = match wire_type {
1182 TvpWireType::VarBinary { max_length } => *max_length,
1183 _ => 8000,
1184 };
1185 encode_tvp_varbinary(b, max_len, buf);
1186 }
1187 #[cfg(feature = "decimal")]
1188 SqlValue::Decimal(d) => {
1189 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1190 let mantissa = d.mantissa().unsigned_abs();
1191 encode_tvp_decimal(sign, mantissa, buf);
1192 }
1193 #[cfg(feature = "uuid")]
1194 SqlValue::Uuid(u) => {
1195 let bytes = u.as_bytes();
1196 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1197 }
1198 #[cfg(feature = "chrono")]
1199 SqlValue::Date(d) => {
1200 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1202 let days = d.signed_duration_since(base).num_days() as u32;
1203 tds_protocol::tvp::encode_tvp_date(days, buf);
1204 }
1205 #[cfg(feature = "chrono")]
1206 SqlValue::Time(t) => {
1207 use chrono::Timelike;
1208 let nanos =
1209 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1210 let intervals = nanos / 100;
1211 let scale = match wire_type {
1212 TvpWireType::Time { scale } => *scale,
1213 _ => 7,
1214 };
1215 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1216 }
1217 #[cfg(feature = "chrono")]
1218 SqlValue::DateTime(dt) => {
1219 use chrono::Timelike;
1220 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1222 + dt.time().nanosecond() as u64;
1223 let intervals = nanos / 100;
1224 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1226 let days = dt.date().signed_duration_since(base).num_days() as u32;
1227 let scale = match wire_type {
1228 TvpWireType::DateTime2 { scale } => *scale,
1229 _ => 7,
1230 };
1231 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1232 }
1233 #[cfg(feature = "chrono")]
1234 SqlValue::DateTimeOffset(dto) => {
1235 use chrono::{Offset, Timelike};
1236 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1238 + dto.time().nanosecond() as u64;
1239 let intervals = nanos / 100;
1240 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1242 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1243 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1245 let scale = match wire_type {
1246 TvpWireType::DateTimeOffset { scale } => *scale,
1247 _ => 7,
1248 };
1249 tds_protocol::tvp::encode_tvp_datetimeoffset(
1250 intervals,
1251 days,
1252 offset_minutes,
1253 scale,
1254 buf,
1255 );
1256 }
1257 #[cfg(feature = "json")]
1258 SqlValue::Json(j) => {
1259 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1261 }
1262 SqlValue::Xml(s) => {
1263 encode_tvp_nvarchar(s, 0xFFFF, buf);
1265 }
1266 SqlValue::Tvp(_) => {
1267 encode_tvp_null(wire_type, buf);
1269 }
1270 _ => {
1272 encode_tvp_null(wire_type, buf);
1273 }
1274 }
1275 }
1276
1277 async fn read_query_response(
1279 &mut self,
1280 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1281 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1282
1283 let message = match connection {
1284 ConnectionHandle::Tls(conn) => conn
1285 .read_message()
1286 .await
1287 .map_err(|e| Error::Protocol(e.to_string()))?,
1288 ConnectionHandle::TlsPrelogin(conn) => conn
1289 .read_message()
1290 .await
1291 .map_err(|e| Error::Protocol(e.to_string()))?,
1292 ConnectionHandle::Plain(conn) => conn
1293 .read_message()
1294 .await
1295 .map_err(|e| Error::Protocol(e.to_string()))?,
1296 }
1297 .ok_or(Error::ConnectionClosed)?;
1298
1299 let mut parser = TokenParser::new(message.payload);
1300 let mut columns: Vec<crate::row::Column> = Vec::new();
1301 let mut rows: Vec<crate::row::Row> = Vec::new();
1302 let mut protocol_metadata: Option<ColMetaData> = None;
1303
1304 loop {
1305 let token = parser
1307 .next_token_with_metadata(protocol_metadata.as_ref())
1308 .map_err(|e| Error::Protocol(e.to_string()))?;
1309
1310 let Some(token) = token else {
1311 break;
1312 };
1313
1314 match token {
1315 Token::ColMetaData(meta) => {
1316 rows.clear();
1319
1320 columns = meta
1321 .columns
1322 .iter()
1323 .enumerate()
1324 .map(|(i, col)| {
1325 let type_name = format!("{:?}", col.type_id);
1326 let mut column = crate::row::Column::new(&col.name, i, type_name)
1327 .with_nullable(col.flags & 0x01 != 0);
1328
1329 if let Some(max_len) = col.type_info.max_length {
1330 column = column.with_max_length(max_len);
1331 }
1332 if let (Some(prec), Some(scale)) =
1333 (col.type_info.precision, col.type_info.scale)
1334 {
1335 column = column.with_precision_scale(prec, scale);
1336 }
1337 if let Some(collation) = col.type_info.collation {
1340 column = column.with_collation(collation);
1341 }
1342 column
1343 })
1344 .collect();
1345
1346 tracing::debug!(columns = columns.len(), "received column metadata");
1347 protocol_metadata = Some(meta);
1348 }
1349 Token::Row(raw_row) => {
1350 if let Some(ref meta) = protocol_metadata {
1351 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1352 rows.push(row);
1353 }
1354 }
1355 Token::NbcRow(nbc_row) => {
1356 if let Some(ref meta) = protocol_metadata {
1357 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1358 rows.push(row);
1359 }
1360 }
1361 Token::Error(err) => {
1362 return Err(Error::Server {
1363 number: err.number,
1364 state: err.state,
1365 class: err.class,
1366 message: err.message.clone(),
1367 server: if err.server.is_empty() {
1368 None
1369 } else {
1370 Some(err.server.clone())
1371 },
1372 procedure: if err.procedure.is_empty() {
1373 None
1374 } else {
1375 Some(err.procedure.clone())
1376 },
1377 line: err.line as u32,
1378 });
1379 }
1380 Token::Done(done) => {
1381 if done.status.error {
1382 return Err(Error::Query("query failed".to_string()));
1383 }
1384 tracing::debug!(
1385 row_count = done.row_count,
1386 has_more = done.status.more,
1387 "query complete"
1388 );
1389 if !done.status.more {
1392 break;
1393 }
1394 }
1395 Token::DoneProc(done) => {
1396 if done.status.error {
1397 return Err(Error::Query("query failed".to_string()));
1398 }
1399 }
1400 Token::DoneInProc(done) => {
1401 if done.status.error {
1402 return Err(Error::Query("query failed".to_string()));
1403 }
1404 }
1405 Token::Info(info) => {
1406 tracing::debug!(
1407 number = info.number,
1408 message = %info.message,
1409 "server info message"
1410 );
1411 }
1412 Token::EnvChange(env) => {
1413 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1417 }
1418 _ => {}
1419 }
1420 }
1421
1422 tracing::debug!(
1423 columns = columns.len(),
1424 rows = rows.len(),
1425 "query response parsed"
1426 );
1427 Ok((columns, rows))
1428 }
1429
1430 fn convert_raw_row(
1434 raw: &RawRow,
1435 meta: &ColMetaData,
1436 columns: &[crate::row::Column],
1437 ) -> Result<crate::row::Row> {
1438 let mut values = Vec::with_capacity(meta.columns.len());
1439 let mut buf = raw.data.as_ref();
1440
1441 for col in &meta.columns {
1442 let value = Self::parse_column_value(&mut buf, col)?;
1443 values.push(value);
1444 }
1445
1446 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1447 }
1448
1449 fn convert_nbc_row(
1453 nbc: &NbcRow,
1454 meta: &ColMetaData,
1455 columns: &[crate::row::Column],
1456 ) -> Result<crate::row::Row> {
1457 let mut values = Vec::with_capacity(meta.columns.len());
1458 let mut buf = nbc.data.as_ref();
1459
1460 for (i, col) in meta.columns.iter().enumerate() {
1461 if nbc.is_null(i) {
1462 values.push(mssql_types::SqlValue::Null);
1463 } else {
1464 let value = Self::parse_column_value(&mut buf, col)?;
1465 values.push(value);
1466 }
1467 }
1468
1469 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1470 }
1471
1472 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1474 use bytes::Buf;
1475 use mssql_types::SqlValue;
1476 use tds_protocol::types::TypeId;
1477
1478 let value = match col.type_id {
1479 TypeId::Null => SqlValue::Null,
1481
1482 TypeId::Int1 => {
1484 if buf.remaining() < 1 {
1485 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1486 }
1487 SqlValue::TinyInt(buf.get_u8())
1488 }
1489 TypeId::Bit => {
1490 if buf.remaining() < 1 {
1491 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1492 }
1493 SqlValue::Bool(buf.get_u8() != 0)
1494 }
1495
1496 TypeId::Int2 => {
1498 if buf.remaining() < 2 {
1499 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1500 }
1501 SqlValue::SmallInt(buf.get_i16_le())
1502 }
1503
1504 TypeId::Int4 => {
1506 if buf.remaining() < 4 {
1507 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1508 }
1509 SqlValue::Int(buf.get_i32_le())
1510 }
1511 TypeId::Float4 => {
1512 if buf.remaining() < 4 {
1513 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1514 }
1515 SqlValue::Float(buf.get_f32_le())
1516 }
1517
1518 TypeId::Int8 => {
1520 if buf.remaining() < 8 {
1521 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1522 }
1523 SqlValue::BigInt(buf.get_i64_le())
1524 }
1525 TypeId::Float8 => {
1526 if buf.remaining() < 8 {
1527 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1528 }
1529 SqlValue::Double(buf.get_f64_le())
1530 }
1531 TypeId::Money => {
1532 if buf.remaining() < 8 {
1533 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1534 }
1535 let high = buf.get_i32_le();
1537 let low = buf.get_u32_le();
1538 let cents = ((high as i64) << 32) | (low as i64);
1539 let value = (cents as f64) / 10000.0;
1540 SqlValue::Double(value)
1541 }
1542 TypeId::Money4 => {
1543 if buf.remaining() < 4 {
1544 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1545 }
1546 let cents = buf.get_i32_le();
1547 let value = (cents as f64) / 10000.0;
1548 SqlValue::Double(value)
1549 }
1550
1551 TypeId::IntN => {
1553 if buf.remaining() < 1 {
1554 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1555 }
1556 let len = buf.get_u8();
1557 match len {
1558 0 => SqlValue::Null,
1559 1 => SqlValue::TinyInt(buf.get_u8()),
1560 2 => SqlValue::SmallInt(buf.get_i16_le()),
1561 4 => SqlValue::Int(buf.get_i32_le()),
1562 8 => SqlValue::BigInt(buf.get_i64_le()),
1563 _ => {
1564 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1565 }
1566 }
1567 }
1568 TypeId::FloatN => {
1569 if buf.remaining() < 1 {
1570 return Err(Error::Protocol(
1571 "unexpected EOF reading FloatN length".into(),
1572 ));
1573 }
1574 let len = buf.get_u8();
1575 match len {
1576 0 => SqlValue::Null,
1577 4 => SqlValue::Float(buf.get_f32_le()),
1578 8 => SqlValue::Double(buf.get_f64_le()),
1579 _ => {
1580 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1581 }
1582 }
1583 }
1584 TypeId::BitN => {
1585 if buf.remaining() < 1 {
1586 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1587 }
1588 let len = buf.get_u8();
1589 match len {
1590 0 => SqlValue::Null,
1591 1 => SqlValue::Bool(buf.get_u8() != 0),
1592 _ => {
1593 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1594 }
1595 }
1596 }
1597 TypeId::MoneyN => {
1598 if buf.remaining() < 1 {
1599 return Err(Error::Protocol(
1600 "unexpected EOF reading MoneyN length".into(),
1601 ));
1602 }
1603 let len = buf.get_u8();
1604 match len {
1605 0 => SqlValue::Null,
1606 4 => {
1607 let cents = buf.get_i32_le();
1608 SqlValue::Double((cents as f64) / 10000.0)
1609 }
1610 8 => {
1611 let high = buf.get_i32_le();
1612 let low = buf.get_u32_le();
1613 let cents = ((high as i64) << 32) | (low as i64);
1614 SqlValue::Double((cents as f64) / 10000.0)
1615 }
1616 _ => {
1617 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1618 }
1619 }
1620 }
1621 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1623 if buf.remaining() < 1 {
1624 return Err(Error::Protocol(
1625 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1626 ));
1627 }
1628 let len = buf.get_u8() as usize;
1629 if len == 0 {
1630 SqlValue::Null
1631 } else {
1632 if buf.remaining() < len {
1633 return Err(Error::Protocol(
1634 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1635 ));
1636 }
1637
1638 let sign = buf.get_u8();
1640 let mantissa_len = len - 1;
1641
1642 let mut mantissa_bytes = [0u8; 16];
1644 for i in 0..mantissa_len.min(16) {
1645 mantissa_bytes[i] = buf.get_u8();
1646 }
1647 for _ in 16..mantissa_len {
1649 buf.get_u8();
1650 }
1651
1652 let mantissa = u128::from_le_bytes(mantissa_bytes);
1653 let scale = col.type_info.scale.unwrap_or(0) as u32;
1654
1655 #[cfg(feature = "decimal")]
1656 {
1657 use rust_decimal::Decimal;
1658 if scale > 28 {
1661 let divisor = 10f64.powi(scale as i32);
1663 let value = (mantissa as f64) / divisor;
1664 let value = if sign == 0 { -value } else { value };
1665 SqlValue::Double(value)
1666 } else {
1667 let mut decimal =
1668 Decimal::from_i128_with_scale(mantissa as i128, scale);
1669 if sign == 0 {
1670 decimal.set_sign_negative(true);
1671 }
1672 SqlValue::Decimal(decimal)
1673 }
1674 }
1675
1676 #[cfg(not(feature = "decimal"))]
1677 {
1678 let divisor = 10f64.powi(scale as i32);
1680 let value = (mantissa as f64) / divisor;
1681 let value = if sign == 0 { -value } else { value };
1682 SqlValue::Double(value)
1683 }
1684 }
1685 }
1686
1687 TypeId::DateTimeN => {
1689 if buf.remaining() < 1 {
1690 return Err(Error::Protocol(
1691 "unexpected EOF reading DateTimeN length".into(),
1692 ));
1693 }
1694 let len = buf.get_u8() as usize;
1695 if len == 0 {
1696 SqlValue::Null
1697 } else if buf.remaining() < len {
1698 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1699 } else {
1700 match len {
1701 4 => {
1702 let days = buf.get_u16_le() as i64;
1704 let minutes = buf.get_u16_le() as u32;
1705 #[cfg(feature = "chrono")]
1706 {
1707 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1708 let date = base + chrono::Duration::days(days);
1709 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1710 minutes * 60,
1711 0,
1712 )
1713 .unwrap();
1714 SqlValue::DateTime(date.and_time(time))
1715 }
1716 #[cfg(not(feature = "chrono"))]
1717 {
1718 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1719 }
1720 }
1721 8 => {
1722 let days = buf.get_i32_le() as i64;
1724 let time_300ths = buf.get_u32_le() as u64;
1725 #[cfg(feature = "chrono")]
1726 {
1727 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1728 let date = base + chrono::Duration::days(days);
1729 let total_ms = (time_300ths * 1000) / 300;
1731 let secs = (total_ms / 1000) as u32;
1732 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1733 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1734 secs, nanos,
1735 )
1736 .unwrap();
1737 SqlValue::DateTime(date.and_time(time))
1738 }
1739 #[cfg(not(feature = "chrono"))]
1740 {
1741 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1742 }
1743 }
1744 _ => {
1745 return Err(Error::Protocol(format!(
1746 "invalid DateTimeN length: {len}"
1747 )));
1748 }
1749 }
1750 }
1751 }
1752
1753 TypeId::DateTime => {
1755 if buf.remaining() < 8 {
1756 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1757 }
1758 let days = buf.get_i32_le() as i64;
1759 let time_300ths = buf.get_u32_le() as u64;
1760 #[cfg(feature = "chrono")]
1761 {
1762 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1763 let date = base + chrono::Duration::days(days);
1764 let total_ms = (time_300ths * 1000) / 300;
1765 let secs = (total_ms / 1000) as u32;
1766 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1767 let time =
1768 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1769 SqlValue::DateTime(date.and_time(time))
1770 }
1771 #[cfg(not(feature = "chrono"))]
1772 {
1773 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1774 }
1775 }
1776
1777 TypeId::DateTime4 => {
1779 if buf.remaining() < 4 {
1780 return Err(Error::Protocol(
1781 "unexpected EOF reading SMALLDATETIME".into(),
1782 ));
1783 }
1784 let days = buf.get_u16_le() as i64;
1785 let minutes = buf.get_u16_le() as u32;
1786 #[cfg(feature = "chrono")]
1787 {
1788 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1789 let date = base + chrono::Duration::days(days);
1790 let time =
1791 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1792 .unwrap();
1793 SqlValue::DateTime(date.and_time(time))
1794 }
1795 #[cfg(not(feature = "chrono"))]
1796 {
1797 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1798 }
1799 }
1800
1801 TypeId::Date => {
1803 if buf.remaining() < 1 {
1804 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1805 }
1806 let len = buf.get_u8() as usize;
1807 if len == 0 {
1808 SqlValue::Null
1809 } else if len != 3 {
1810 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1811 } else if buf.remaining() < 3 {
1812 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1813 } else {
1814 let days = buf.get_u8() as u32
1816 | ((buf.get_u8() as u32) << 8)
1817 | ((buf.get_u8() as u32) << 16);
1818 #[cfg(feature = "chrono")]
1819 {
1820 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1821 let date = base + chrono::Duration::days(days as i64);
1822 SqlValue::Date(date)
1823 }
1824 #[cfg(not(feature = "chrono"))]
1825 {
1826 SqlValue::String(format!("DATE({days})"))
1827 }
1828 }
1829 }
1830
1831 TypeId::Time => {
1833 if buf.remaining() < 1 {
1834 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1835 }
1836 let len = buf.get_u8() as usize;
1837 if len == 0 {
1838 SqlValue::Null
1839 } else if buf.remaining() < len {
1840 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1841 } else {
1842 let mut time_bytes = [0u8; 8];
1843 for byte in time_bytes.iter_mut().take(len) {
1844 *byte = buf.get_u8();
1845 }
1846 let intervals = u64::from_le_bytes(time_bytes);
1847 #[cfg(feature = "chrono")]
1848 {
1849 let scale = col.type_info.scale.unwrap_or(7);
1850 let time = Self::intervals_to_time(intervals, scale);
1851 SqlValue::Time(time)
1852 }
1853 #[cfg(not(feature = "chrono"))]
1854 {
1855 SqlValue::String(format!("TIME({intervals})"))
1856 }
1857 }
1858 }
1859
1860 TypeId::DateTime2 => {
1862 if buf.remaining() < 1 {
1863 return Err(Error::Protocol(
1864 "unexpected EOF reading DATETIME2 length".into(),
1865 ));
1866 }
1867 let len = buf.get_u8() as usize;
1868 if len == 0 {
1869 SqlValue::Null
1870 } else if buf.remaining() < len {
1871 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1872 } else {
1873 let scale = col.type_info.scale.unwrap_or(7);
1874 let time_len = Self::time_bytes_for_scale(scale);
1875
1876 let mut time_bytes = [0u8; 8];
1878 for byte in time_bytes.iter_mut().take(time_len) {
1879 *byte = buf.get_u8();
1880 }
1881 let intervals = u64::from_le_bytes(time_bytes);
1882
1883 let days = buf.get_u8() as u32
1885 | ((buf.get_u8() as u32) << 8)
1886 | ((buf.get_u8() as u32) << 16);
1887
1888 #[cfg(feature = "chrono")]
1889 {
1890 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1891 let date = base + chrono::Duration::days(days as i64);
1892 let time = Self::intervals_to_time(intervals, scale);
1893 SqlValue::DateTime(date.and_time(time))
1894 }
1895 #[cfg(not(feature = "chrono"))]
1896 {
1897 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1898 }
1899 }
1900 }
1901
1902 TypeId::DateTimeOffset => {
1904 if buf.remaining() < 1 {
1905 return Err(Error::Protocol(
1906 "unexpected EOF reading DATETIMEOFFSET length".into(),
1907 ));
1908 }
1909 let len = buf.get_u8() as usize;
1910 if len == 0 {
1911 SqlValue::Null
1912 } else if buf.remaining() < len {
1913 return Err(Error::Protocol(
1914 "unexpected EOF reading DATETIMEOFFSET".into(),
1915 ));
1916 } else {
1917 let scale = col.type_info.scale.unwrap_or(7);
1918 let time_len = Self::time_bytes_for_scale(scale);
1919
1920 let mut time_bytes = [0u8; 8];
1922 for byte in time_bytes.iter_mut().take(time_len) {
1923 *byte = buf.get_u8();
1924 }
1925 let intervals = u64::from_le_bytes(time_bytes);
1926
1927 let days = buf.get_u8() as u32
1929 | ((buf.get_u8() as u32) << 8)
1930 | ((buf.get_u8() as u32) << 16);
1931
1932 let offset_minutes = buf.get_i16_le();
1934
1935 #[cfg(feature = "chrono")]
1936 {
1937 use chrono::TimeZone;
1938 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1939 let date = base + chrono::Duration::days(days as i64);
1940 let time = Self::intervals_to_time(intervals, scale);
1941 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1942 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1943 let datetime = offset
1944 .from_local_datetime(&date.and_time(time))
1945 .single()
1946 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1947 SqlValue::DateTimeOffset(datetime)
1948 }
1949 #[cfg(not(feature = "chrono"))]
1950 {
1951 SqlValue::String(format!(
1952 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1953 ))
1954 }
1955 }
1956 }
1957
1958 TypeId::Text => Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
1960
1961 TypeId::Char | TypeId::VarChar => {
1963 if buf.remaining() < 1 {
1964 return Err(Error::Protocol(
1965 "unexpected EOF reading legacy varchar length".into(),
1966 ));
1967 }
1968 let len = buf.get_u8();
1969 if len == 0xFF {
1970 SqlValue::Null
1971 } else if len == 0 {
1972 SqlValue::String(String::new())
1973 } else if buf.remaining() < len as usize {
1974 return Err(Error::Protocol(
1975 "unexpected EOF reading legacy varchar data".into(),
1976 ));
1977 } else {
1978 let data = &buf[..len as usize];
1979 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
1981 buf.advance(len as usize);
1982 SqlValue::String(s)
1983 }
1984 }
1985
1986 TypeId::BigVarChar | TypeId::BigChar => {
1988 if col.type_info.max_length == Some(0xFFFF) {
1990 Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?
1992 } else {
1993 if buf.remaining() < 2 {
1995 return Err(Error::Protocol(
1996 "unexpected EOF reading varchar length".into(),
1997 ));
1998 }
1999 let len = buf.get_u16_le();
2000 if len == 0xFFFF {
2001 SqlValue::Null
2002 } else if buf.remaining() < len as usize {
2003 return Err(Error::Protocol(
2004 "unexpected EOF reading varchar data".into(),
2005 ));
2006 } else {
2007 let data = &buf[..len as usize];
2008 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2010 buf.advance(len as usize);
2011 SqlValue::String(s)
2012 }
2013 }
2014 }
2015
2016 TypeId::NText => Self::parse_plp_nvarchar(buf)?,
2018
2019 TypeId::NVarChar | TypeId::NChar => {
2021 if col.type_info.max_length == Some(0xFFFF) {
2023 Self::parse_plp_nvarchar(buf)?
2025 } else {
2026 if buf.remaining() < 2 {
2028 return Err(Error::Protocol(
2029 "unexpected EOF reading nvarchar length".into(),
2030 ));
2031 }
2032 let len = buf.get_u16_le();
2033 if len == 0xFFFF {
2034 SqlValue::Null
2035 } else if buf.remaining() < len as usize {
2036 return Err(Error::Protocol(
2037 "unexpected EOF reading nvarchar data".into(),
2038 ));
2039 } else {
2040 let data = &buf[..len as usize];
2041 let utf16: Vec<u16> = data
2043 .chunks_exact(2)
2044 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2045 .collect();
2046 let s = String::from_utf16(&utf16)
2047 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
2048 buf.advance(len as usize);
2049 SqlValue::String(s)
2050 }
2051 }
2052 }
2053
2054 TypeId::Image => Self::parse_plp_varbinary(buf)?,
2056
2057 TypeId::Binary | TypeId::VarBinary => {
2059 if buf.remaining() < 1 {
2060 return Err(Error::Protocol(
2061 "unexpected EOF reading legacy varbinary length".into(),
2062 ));
2063 }
2064 let len = buf.get_u8();
2065 if len == 0xFF {
2066 SqlValue::Null
2067 } else if len == 0 {
2068 SqlValue::Binary(bytes::Bytes::new())
2069 } else if buf.remaining() < len as usize {
2070 return Err(Error::Protocol(
2071 "unexpected EOF reading legacy varbinary data".into(),
2072 ));
2073 } else {
2074 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2075 buf.advance(len as usize);
2076 SqlValue::Binary(data)
2077 }
2078 }
2079
2080 TypeId::BigVarBinary | TypeId::BigBinary => {
2082 if col.type_info.max_length == Some(0xFFFF) {
2084 Self::parse_plp_varbinary(buf)?
2086 } else {
2087 if buf.remaining() < 2 {
2088 return Err(Error::Protocol(
2089 "unexpected EOF reading varbinary length".into(),
2090 ));
2091 }
2092 let len = buf.get_u16_le();
2093 if len == 0xFFFF {
2094 SqlValue::Null
2095 } else if buf.remaining() < len as usize {
2096 return Err(Error::Protocol(
2097 "unexpected EOF reading varbinary data".into(),
2098 ));
2099 } else {
2100 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2101 buf.advance(len as usize);
2102 SqlValue::Binary(data)
2103 }
2104 }
2105 }
2106
2107 TypeId::Xml => {
2109 match Self::parse_plp_nvarchar(buf)? {
2111 SqlValue::Null => SqlValue::Null,
2112 SqlValue::String(s) => SqlValue::Xml(s),
2113 _ => {
2114 return Err(Error::Protocol(
2115 "unexpected value type when parsing XML".into(),
2116 ));
2117 }
2118 }
2119 }
2120
2121 TypeId::Guid => {
2123 if buf.remaining() < 1 {
2124 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
2125 }
2126 let len = buf.get_u8();
2127 if len == 0 {
2128 SqlValue::Null
2129 } else if len != 16 {
2130 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
2131 } else if buf.remaining() < 16 {
2132 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
2133 } else {
2134 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
2136 buf.advance(16);
2137 SqlValue::Binary(data)
2138 }
2139 }
2140
2141 TypeId::Variant => Self::parse_sql_variant(buf)?,
2143
2144 TypeId::Udt => Self::parse_plp_varbinary(buf)?,
2146
2147 _ => {
2149 if buf.remaining() < 2 {
2151 return Err(Error::Protocol(format!(
2152 "unexpected EOF reading {:?}",
2153 col.type_id
2154 )));
2155 }
2156 let len = buf.get_u16_le();
2157 if len == 0xFFFF {
2158 SqlValue::Null
2159 } else if buf.remaining() < len as usize {
2160 return Err(Error::Protocol(format!(
2161 "unexpected EOF reading {:?} data",
2162 col.type_id
2163 )));
2164 } else {
2165 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2166 buf.advance(len as usize);
2167 SqlValue::Binary(data)
2168 }
2169 }
2170 };
2171
2172 Ok(value)
2173 }
2174
2175 fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2181 use bytes::Buf;
2182 use mssql_types::SqlValue;
2183
2184 if buf.remaining() < 8 {
2185 return Err(Error::Protocol(
2186 "unexpected EOF reading PLP total length".into(),
2187 ));
2188 }
2189
2190 let total_len = buf.get_u64_le();
2191 if total_len == 0xFFFFFFFFFFFFFFFF {
2192 return Ok(SqlValue::Null);
2193 }
2194
2195 let mut all_data = Vec::new();
2197 loop {
2198 if buf.remaining() < 4 {
2199 return Err(Error::Protocol(
2200 "unexpected EOF reading PLP chunk length".into(),
2201 ));
2202 }
2203 let chunk_len = buf.get_u32_le() as usize;
2204 if chunk_len == 0 {
2205 break; }
2207 if buf.remaining() < chunk_len {
2208 return Err(Error::Protocol(
2209 "unexpected EOF reading PLP chunk data".into(),
2210 ));
2211 }
2212 all_data.extend_from_slice(&buf[..chunk_len]);
2213 buf.advance(chunk_len);
2214 }
2215
2216 let utf16: Vec<u16> = all_data
2218 .chunks_exact(2)
2219 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2220 .collect();
2221 let s = String::from_utf16(&utf16)
2222 .map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
2223 Ok(SqlValue::String(s))
2224 }
2225
2226 #[allow(unused_variables)]
2232 fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
2233 if let Ok(s) = std::str::from_utf8(data) {
2235 return s.to_owned();
2236 }
2237
2238 #[cfg(feature = "encoding")]
2240 if let Some(coll) = collation {
2241 if let Some(encoding) = coll.encoding() {
2242 let (decoded, _, had_errors) = encoding.decode(data);
2243 if !had_errors {
2244 return decoded.into_owned();
2245 }
2246 }
2247 }
2248
2249 String::from_utf8_lossy(data).into_owned()
2251 }
2252
2253 fn parse_plp_varchar(
2255 buf: &mut &[u8],
2256 collation: Option<&Collation>,
2257 ) -> Result<mssql_types::SqlValue> {
2258 use bytes::Buf;
2259 use mssql_types::SqlValue;
2260
2261 if buf.remaining() < 8 {
2262 return Err(Error::Protocol(
2263 "unexpected EOF reading PLP total length".into(),
2264 ));
2265 }
2266
2267 let total_len = buf.get_u64_le();
2268 if total_len == 0xFFFFFFFFFFFFFFFF {
2269 return Ok(SqlValue::Null);
2270 }
2271
2272 let mut all_data = Vec::new();
2274 loop {
2275 if buf.remaining() < 4 {
2276 return Err(Error::Protocol(
2277 "unexpected EOF reading PLP chunk length".into(),
2278 ));
2279 }
2280 let chunk_len = buf.get_u32_le() as usize;
2281 if chunk_len == 0 {
2282 break; }
2284 if buf.remaining() < chunk_len {
2285 return Err(Error::Protocol(
2286 "unexpected EOF reading PLP chunk data".into(),
2287 ));
2288 }
2289 all_data.extend_from_slice(&buf[..chunk_len]);
2290 buf.advance(chunk_len);
2291 }
2292
2293 let s = Self::decode_varchar_string(&all_data, collation);
2295 Ok(SqlValue::String(s))
2296 }
2297
2298 fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2300 use bytes::Buf;
2301 use mssql_types::SqlValue;
2302
2303 if buf.remaining() < 8 {
2304 return Err(Error::Protocol(
2305 "unexpected EOF reading PLP total length".into(),
2306 ));
2307 }
2308
2309 let total_len = buf.get_u64_le();
2310 if total_len == 0xFFFFFFFFFFFFFFFF {
2311 return Ok(SqlValue::Null);
2312 }
2313
2314 let mut all_data = Vec::new();
2316 loop {
2317 if buf.remaining() < 4 {
2318 return Err(Error::Protocol(
2319 "unexpected EOF reading PLP chunk length".into(),
2320 ));
2321 }
2322 let chunk_len = buf.get_u32_le() as usize;
2323 if chunk_len == 0 {
2324 break; }
2326 if buf.remaining() < chunk_len {
2327 return Err(Error::Protocol(
2328 "unexpected EOF reading PLP chunk data".into(),
2329 ));
2330 }
2331 all_data.extend_from_slice(&buf[..chunk_len]);
2332 buf.advance(chunk_len);
2333 }
2334
2335 Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
2336 }
2337
2338 fn parse_sql_variant(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2347 use bytes::Buf;
2348 use mssql_types::SqlValue;
2349
2350 if buf.remaining() < 4 {
2352 return Err(Error::Protocol(
2353 "unexpected EOF reading SQL_VARIANT length".into(),
2354 ));
2355 }
2356 let total_len = buf.get_u32_le() as usize;
2357
2358 if total_len == 0 {
2359 return Ok(SqlValue::Null);
2360 }
2361
2362 if buf.remaining() < total_len {
2363 return Err(Error::Protocol(
2364 "unexpected EOF reading SQL_VARIANT data".into(),
2365 ));
2366 }
2367
2368 if total_len < 2 {
2370 return Err(Error::Protocol(
2371 "SQL_VARIANT too short for type info".into(),
2372 ));
2373 }
2374
2375 let base_type = buf.get_u8();
2376 let prop_count = buf.get_u8() as usize;
2377
2378 if buf.remaining() < prop_count {
2379 return Err(Error::Protocol(
2380 "unexpected EOF reading SQL_VARIANT properties".into(),
2381 ));
2382 }
2383
2384 let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
2386
2387 match base_type {
2390 0x30 => {
2392 buf.advance(prop_count);
2394 if data_len < 1 {
2395 return Ok(SqlValue::Null);
2396 }
2397 let v = buf.get_u8();
2398 Ok(SqlValue::TinyInt(v))
2399 }
2400 0x32 => {
2401 buf.advance(prop_count);
2403 if data_len < 1 {
2404 return Ok(SqlValue::Null);
2405 }
2406 let v = buf.get_u8();
2407 Ok(SqlValue::Bool(v != 0))
2408 }
2409 0x34 => {
2410 buf.advance(prop_count);
2412 if data_len < 2 {
2413 return Ok(SqlValue::Null);
2414 }
2415 let v = buf.get_i16_le();
2416 Ok(SqlValue::SmallInt(v))
2417 }
2418 0x38 => {
2419 buf.advance(prop_count);
2421 if data_len < 4 {
2422 return Ok(SqlValue::Null);
2423 }
2424 let v = buf.get_i32_le();
2425 Ok(SqlValue::Int(v))
2426 }
2427 0x7F => {
2428 buf.advance(prop_count);
2430 if data_len < 8 {
2431 return Ok(SqlValue::Null);
2432 }
2433 let v = buf.get_i64_le();
2434 Ok(SqlValue::BigInt(v))
2435 }
2436 0x6D => {
2437 let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2439 buf.advance(prop_count.saturating_sub(1));
2440
2441 if float_len == 4 && data_len >= 4 {
2442 let v = buf.get_f32_le();
2443 Ok(SqlValue::Float(v))
2444 } else if data_len >= 8 {
2445 let v = buf.get_f64_le();
2446 Ok(SqlValue::Double(v))
2447 } else {
2448 Ok(SqlValue::Null)
2449 }
2450 }
2451 0x6E => {
2452 let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2454 buf.advance(prop_count.saturating_sub(1));
2455
2456 if money_len == 4 && data_len >= 4 {
2457 let raw = buf.get_i32_le();
2458 let value = raw as f64 / 10000.0;
2459 Ok(SqlValue::Double(value))
2460 } else if data_len >= 8 {
2461 let high = buf.get_i32_le() as i64;
2462 let low = buf.get_u32_le() as i64;
2463 let raw = (high << 32) | low;
2464 let value = raw as f64 / 10000.0;
2465 Ok(SqlValue::Double(value))
2466 } else {
2467 Ok(SqlValue::Null)
2468 }
2469 }
2470 0x6F => {
2471 #[cfg(feature = "chrono")]
2473 let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2474 #[cfg(not(feature = "chrono"))]
2475 if prop_count >= 1 {
2476 buf.get_u8();
2477 }
2478 buf.advance(prop_count.saturating_sub(1));
2479
2480 #[cfg(feature = "chrono")]
2481 {
2482 use chrono::NaiveDate;
2483 if dt_len == 4 && data_len >= 4 {
2484 let days = buf.get_u16_le() as i64;
2486 let mins = buf.get_u16_le() as u32;
2487 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2488 .unwrap()
2489 .and_hms_opt(0, 0, 0)
2490 .unwrap();
2491 let dt = base
2492 + chrono::Duration::days(days)
2493 + chrono::Duration::minutes(mins as i64);
2494 Ok(SqlValue::DateTime(dt))
2495 } else if data_len >= 8 {
2496 let days = buf.get_i32_le() as i64;
2498 let ticks = buf.get_u32_le() as i64;
2499 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2500 .unwrap()
2501 .and_hms_opt(0, 0, 0)
2502 .unwrap();
2503 let millis = (ticks * 10) / 3;
2504 let dt = base
2505 + chrono::Duration::days(days)
2506 + chrono::Duration::milliseconds(millis);
2507 Ok(SqlValue::DateTime(dt))
2508 } else {
2509 Ok(SqlValue::Null)
2510 }
2511 }
2512 #[cfg(not(feature = "chrono"))]
2513 {
2514 buf.advance(data_len);
2515 Ok(SqlValue::Null)
2516 }
2517 }
2518 0x6A | 0x6C => {
2519 let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
2521 let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
2522 buf.advance(prop_count.saturating_sub(2));
2523
2524 if data_len < 1 {
2525 return Ok(SqlValue::Null);
2526 }
2527
2528 let sign = buf.get_u8();
2529 let mantissa_len = data_len - 1;
2530
2531 if mantissa_len > 16 {
2532 buf.advance(mantissa_len);
2534 return Ok(SqlValue::Null);
2535 }
2536
2537 let mut mantissa_bytes = [0u8; 16];
2538 for i in 0..mantissa_len.min(16) {
2539 mantissa_bytes[i] = buf.get_u8();
2540 }
2541 let mantissa = u128::from_le_bytes(mantissa_bytes);
2542
2543 #[cfg(feature = "decimal")]
2544 {
2545 use rust_decimal::Decimal;
2546 if scale > 28 {
2547 let divisor = 10f64.powi(scale as i32);
2549 let value = (mantissa as f64) / divisor;
2550 let value = if sign == 0 { -value } else { value };
2551 Ok(SqlValue::Double(value))
2552 } else {
2553 let mut decimal =
2554 Decimal::from_i128_with_scale(mantissa as i128, scale as u32);
2555 if sign == 0 {
2556 decimal.set_sign_negative(true);
2557 }
2558 Ok(SqlValue::Decimal(decimal))
2559 }
2560 }
2561 #[cfg(not(feature = "decimal"))]
2562 {
2563 let divisor = 10f64.powi(scale as i32);
2564 let value = (mantissa as f64) / divisor;
2565 let value = if sign == 0 { -value } else { value };
2566 Ok(SqlValue::Double(value))
2567 }
2568 }
2569 0x24 => {
2570 buf.advance(prop_count);
2572 if data_len < 16 {
2573 return Ok(SqlValue::Null);
2574 }
2575 let mut guid_bytes = [0u8; 16];
2576 for byte in &mut guid_bytes {
2577 *byte = buf.get_u8();
2578 }
2579 Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(&guid_bytes)))
2580 }
2581 0x28 => {
2582 buf.advance(prop_count);
2584 #[cfg(feature = "chrono")]
2585 {
2586 if data_len < 3 {
2587 return Ok(SqlValue::Null);
2588 }
2589 let mut date_bytes = [0u8; 4];
2590 date_bytes[0] = buf.get_u8();
2591 date_bytes[1] = buf.get_u8();
2592 date_bytes[2] = buf.get_u8();
2593 let days = u32::from_le_bytes(date_bytes);
2594 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2595 let date = base + chrono::Duration::days(days as i64);
2596 Ok(SqlValue::Date(date))
2597 }
2598 #[cfg(not(feature = "chrono"))]
2599 {
2600 buf.advance(data_len);
2601 Ok(SqlValue::Null)
2602 }
2603 }
2604 0xA7 | 0x2F | 0x27 => {
2605 let collation = if prop_count >= 5 && buf.remaining() >= 5 {
2608 let lcid = buf.get_u32_le();
2609 let sort_id = buf.get_u8();
2610 buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
2612 } else {
2613 buf.advance(prop_count);
2614 None
2615 };
2616 if data_len == 0 {
2617 return Ok(SqlValue::String(String::new()));
2618 }
2619 let data = &buf[..data_len];
2620 let s = Self::decode_varchar_string(data, collation.as_ref());
2622 buf.advance(data_len);
2623 Ok(SqlValue::String(s))
2624 }
2625 0xE7 | 0xEF => {
2626 buf.advance(prop_count);
2628 if data_len == 0 {
2629 return Ok(SqlValue::String(String::new()));
2630 }
2631 let utf16: Vec<u16> = buf[..data_len]
2633 .chunks_exact(2)
2634 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2635 .collect();
2636 buf.advance(data_len);
2637 let s = String::from_utf16(&utf16).map_err(|_| {
2638 Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into())
2639 })?;
2640 Ok(SqlValue::String(s))
2641 }
2642 0xA5 | 0x2D | 0x25 => {
2643 buf.advance(prop_count);
2645 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2646 buf.advance(data_len);
2647 Ok(SqlValue::Binary(data))
2648 }
2649 _ => {
2650 buf.advance(prop_count);
2652 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2653 buf.advance(data_len);
2654 Ok(SqlValue::Binary(data))
2655 }
2656 }
2657 }
2658
2659 fn time_bytes_for_scale(scale: u8) -> usize {
2661 match scale {
2662 0..=2 => 3,
2663 3..=4 => 4,
2664 5..=7 => 5,
2665 _ => 5, }
2667 }
2668
2669 #[cfg(feature = "chrono")]
2671 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2672 let nanos = match scale {
2682 0 => intervals * 1_000_000_000,
2683 1 => intervals * 100_000_000,
2684 2 => intervals * 10_000_000,
2685 3 => intervals * 1_000_000,
2686 4 => intervals * 100_000,
2687 5 => intervals * 10_000,
2688 6 => intervals * 1_000,
2689 7 => intervals * 100,
2690 _ => intervals * 100,
2691 };
2692
2693 let secs = (nanos / 1_000_000_000) as u32;
2694 let nano_part = (nanos % 1_000_000_000) as u32;
2695
2696 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2697 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2698 }
2699
2700 async fn read_execute_result(&mut self) -> Result<u64> {
2702 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2703
2704 let message = match connection {
2705 ConnectionHandle::Tls(conn) => conn
2706 .read_message()
2707 .await
2708 .map_err(|e| Error::Protocol(e.to_string()))?,
2709 ConnectionHandle::TlsPrelogin(conn) => conn
2710 .read_message()
2711 .await
2712 .map_err(|e| Error::Protocol(e.to_string()))?,
2713 ConnectionHandle::Plain(conn) => conn
2714 .read_message()
2715 .await
2716 .map_err(|e| Error::Protocol(e.to_string()))?,
2717 }
2718 .ok_or(Error::ConnectionClosed)?;
2719
2720 let mut parser = TokenParser::new(message.payload);
2721 let mut rows_affected = 0u64;
2722 let mut current_metadata: Option<ColMetaData> = None;
2723
2724 loop {
2725 let token = parser
2727 .next_token_with_metadata(current_metadata.as_ref())
2728 .map_err(|e| Error::Protocol(e.to_string()))?;
2729
2730 let Some(token) = token else {
2731 break;
2732 };
2733
2734 match token {
2735 Token::ColMetaData(meta) => {
2736 current_metadata = Some(meta);
2738 }
2739 Token::Row(_) | Token::NbcRow(_) => {
2740 }
2743 Token::Done(done) => {
2744 if done.status.error {
2745 return Err(Error::Query("execution failed".to_string()));
2746 }
2747 if done.status.count {
2748 rows_affected += done.row_count;
2750 }
2751 if !done.status.more {
2754 break;
2755 }
2756 }
2757 Token::DoneProc(done) => {
2758 if done.status.count {
2759 rows_affected += done.row_count;
2760 }
2761 }
2762 Token::DoneInProc(done) => {
2763 if done.status.count {
2764 rows_affected += done.row_count;
2765 }
2766 }
2767 Token::Error(err) => {
2768 return Err(Error::Server {
2769 number: err.number,
2770 state: err.state,
2771 class: err.class,
2772 message: err.message.clone(),
2773 server: if err.server.is_empty() {
2774 None
2775 } else {
2776 Some(err.server.clone())
2777 },
2778 procedure: if err.procedure.is_empty() {
2779 None
2780 } else {
2781 Some(err.procedure.clone())
2782 },
2783 line: err.line as u32,
2784 });
2785 }
2786 Token::Info(info) => {
2787 tracing::info!(
2788 number = info.number,
2789 message = %info.message,
2790 "server info message"
2791 );
2792 }
2793 Token::EnvChange(env) => {
2794 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
2798 }
2799 _ => {}
2800 }
2801 }
2802
2803 Ok(rows_affected)
2804 }
2805
2806 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2812 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2813
2814 let message = match connection {
2815 ConnectionHandle::Tls(conn) => conn
2816 .read_message()
2817 .await
2818 .map_err(|e| Error::Protocol(e.to_string()))?,
2819 ConnectionHandle::TlsPrelogin(conn) => conn
2820 .read_message()
2821 .await
2822 .map_err(|e| Error::Protocol(e.to_string()))?,
2823 ConnectionHandle::Plain(conn) => conn
2824 .read_message()
2825 .await
2826 .map_err(|e| Error::Protocol(e.to_string()))?,
2827 }
2828 .ok_or(Error::ConnectionClosed)?;
2829
2830 let mut parser = TokenParser::new(message.payload);
2831 let mut transaction_descriptor: u64 = 0;
2832
2833 loop {
2834 let token = parser
2835 .next_token()
2836 .map_err(|e| Error::Protocol(e.to_string()))?;
2837
2838 let Some(token) = token else {
2839 break;
2840 };
2841
2842 match token {
2843 Token::EnvChange(env) => {
2844 if env.env_type == EnvChangeType::BeginTransaction {
2845 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2848 {
2849 if data.len() >= 8 {
2850 transaction_descriptor = u64::from_le_bytes([
2851 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2852 data[7],
2853 ]);
2854 tracing::debug!(
2855 transaction_descriptor =
2856 format!("0x{:016X}", transaction_descriptor),
2857 "transaction begun"
2858 );
2859 }
2860 }
2861 }
2862 }
2863 Token::Done(done) => {
2864 if done.status.error {
2865 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2866 }
2867 break;
2868 }
2869 Token::Error(err) => {
2870 return Err(Error::Server {
2871 number: err.number,
2872 state: err.state,
2873 class: err.class,
2874 message: err.message.clone(),
2875 server: if err.server.is_empty() {
2876 None
2877 } else {
2878 Some(err.server.clone())
2879 },
2880 procedure: if err.procedure.is_empty() {
2881 None
2882 } else {
2883 Some(err.procedure.clone())
2884 },
2885 line: err.line as u32,
2886 });
2887 }
2888 Token::Info(info) => {
2889 tracing::info!(
2890 number = info.number,
2891 message = %info.message,
2892 "server info message"
2893 );
2894 }
2895 _ => {}
2896 }
2897 }
2898
2899 Ok(transaction_descriptor)
2900 }
2901}
2902
2903impl Client<Ready> {
2904 pub async fn query<'a>(
2929 &'a mut self,
2930 sql: &str,
2931 params: &[&(dyn crate::ToSql + Sync)],
2932 ) -> Result<QueryStream<'a>> {
2933 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
2934
2935 #[cfg(feature = "otel")]
2936 let instrumentation = self.instrumentation.clone();
2937 #[cfg(feature = "otel")]
2938 let mut span = instrumentation.query_span(sql);
2939
2940 let result = async {
2941 if params.is_empty() {
2942 self.send_sql_batch(sql).await?;
2944 } else {
2945 let rpc_params = Self::convert_params(params)?;
2947 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2948 self.send_rpc(&rpc).await?;
2949 }
2950
2951 self.read_query_response().await
2953 }
2954 .await;
2955
2956 #[cfg(feature = "otel")]
2957 match &result {
2958 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2959 Err(e) => InstrumentationContext::record_error(&mut span, e),
2960 }
2961
2962 #[cfg(feature = "otel")]
2964 drop(span);
2965
2966 let (columns, rows) = result?;
2967 Ok(QueryStream::new(columns, rows))
2968 }
2969
2970 pub async fn query_with_timeout<'a>(
2997 &'a mut self,
2998 sql: &str,
2999 params: &[&(dyn crate::ToSql + Sync)],
3000 timeout_duration: std::time::Duration,
3001 ) -> Result<QueryStream<'a>> {
3002 timeout(timeout_duration, self.query(sql, params))
3003 .await
3004 .map_err(|_| Error::CommandTimeout)?
3005 }
3006
3007 pub async fn query_multiple<'a>(
3034 &'a mut self,
3035 sql: &str,
3036 params: &[&(dyn crate::ToSql + Sync)],
3037 ) -> Result<MultiResultStream<'a>> {
3038 tracing::debug!(
3039 sql = sql,
3040 params_count = params.len(),
3041 "executing multi-result query"
3042 );
3043
3044 if params.is_empty() {
3045 self.send_sql_batch(sql).await?;
3047 } else {
3048 let rpc_params = Self::convert_params(params)?;
3050 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3051 self.send_rpc(&rpc).await?;
3052 }
3053
3054 let result_sets = self.read_multi_result_response().await?;
3056 Ok(MultiResultStream::new(result_sets))
3057 }
3058
3059 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
3061 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3062
3063 let message = match connection {
3064 ConnectionHandle::Tls(conn) => conn
3065 .read_message()
3066 .await
3067 .map_err(|e| Error::Protocol(e.to_string()))?,
3068 ConnectionHandle::TlsPrelogin(conn) => conn
3069 .read_message()
3070 .await
3071 .map_err(|e| Error::Protocol(e.to_string()))?,
3072 ConnectionHandle::Plain(conn) => conn
3073 .read_message()
3074 .await
3075 .map_err(|e| Error::Protocol(e.to_string()))?,
3076 }
3077 .ok_or(Error::ConnectionClosed)?;
3078
3079 let mut parser = TokenParser::new(message.payload);
3080 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
3081 let mut current_columns: Vec<crate::row::Column> = Vec::new();
3082 let mut current_rows: Vec<crate::row::Row> = Vec::new();
3083 let mut protocol_metadata: Option<ColMetaData> = None;
3084
3085 loop {
3086 let token = parser
3087 .next_token_with_metadata(protocol_metadata.as_ref())
3088 .map_err(|e| Error::Protocol(e.to_string()))?;
3089
3090 let Some(token) = token else {
3091 break;
3092 };
3093
3094 match token {
3095 Token::ColMetaData(meta) => {
3096 if !current_columns.is_empty() {
3098 result_sets.push(crate::stream::ResultSet::new(
3099 std::mem::take(&mut current_columns),
3100 std::mem::take(&mut current_rows),
3101 ));
3102 }
3103
3104 current_columns = meta
3106 .columns
3107 .iter()
3108 .enumerate()
3109 .map(|(i, col)| {
3110 let type_name = format!("{:?}", col.type_id);
3111 let mut column = crate::row::Column::new(&col.name, i, type_name)
3112 .with_nullable(col.flags & 0x01 != 0);
3113
3114 if let Some(max_len) = col.type_info.max_length {
3115 column = column.with_max_length(max_len);
3116 }
3117 if let (Some(prec), Some(scale)) =
3118 (col.type_info.precision, col.type_info.scale)
3119 {
3120 column = column.with_precision_scale(prec, scale);
3121 }
3122 if let Some(collation) = col.type_info.collation {
3125 column = column.with_collation(collation);
3126 }
3127 column
3128 })
3129 .collect();
3130
3131 tracing::debug!(
3132 columns = current_columns.len(),
3133 result_set = result_sets.len(),
3134 "received column metadata for result set"
3135 );
3136 protocol_metadata = Some(meta);
3137 }
3138 Token::Row(raw_row) => {
3139 if let Some(ref meta) = protocol_metadata {
3140 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
3141 current_rows.push(row);
3142 }
3143 }
3144 Token::NbcRow(nbc_row) => {
3145 if let Some(ref meta) = protocol_metadata {
3146 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
3147 current_rows.push(row);
3148 }
3149 }
3150 Token::Error(err) => {
3151 return Err(Error::Server {
3152 number: err.number,
3153 state: err.state,
3154 class: err.class,
3155 message: err.message.clone(),
3156 server: if err.server.is_empty() {
3157 None
3158 } else {
3159 Some(err.server.clone())
3160 },
3161 procedure: if err.procedure.is_empty() {
3162 None
3163 } else {
3164 Some(err.procedure.clone())
3165 },
3166 line: err.line as u32,
3167 });
3168 }
3169 Token::Done(done) => {
3170 if done.status.error {
3171 return Err(Error::Query("query failed".to_string()));
3172 }
3173
3174 if !current_columns.is_empty() {
3176 result_sets.push(crate::stream::ResultSet::new(
3177 std::mem::take(&mut current_columns),
3178 std::mem::take(&mut current_rows),
3179 ));
3180 protocol_metadata = None;
3181 }
3182
3183 if !done.status.more {
3185 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
3186 break;
3187 }
3188 }
3189 Token::DoneInProc(done) => {
3190 if done.status.error {
3191 return Err(Error::Query("query failed".to_string()));
3192 }
3193
3194 if !current_columns.is_empty() {
3196 result_sets.push(crate::stream::ResultSet::new(
3197 std::mem::take(&mut current_columns),
3198 std::mem::take(&mut current_rows),
3199 ));
3200 protocol_metadata = None;
3201 }
3202
3203 if !done.status.more {
3205 }
3207 }
3208 Token::DoneProc(done) => {
3209 if done.status.error {
3210 return Err(Error::Query("query failed".to_string()));
3211 }
3212 }
3214 Token::Info(info) => {
3215 tracing::debug!(
3216 number = info.number,
3217 message = %info.message,
3218 "server info message"
3219 );
3220 }
3221 _ => {}
3222 }
3223 }
3224
3225 if !current_columns.is_empty() {
3227 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
3228 }
3229
3230 Ok(result_sets)
3231 }
3232
3233 pub async fn execute(
3237 &mut self,
3238 sql: &str,
3239 params: &[&(dyn crate::ToSql + Sync)],
3240 ) -> Result<u64> {
3241 tracing::debug!(
3242 sql = sql,
3243 params_count = params.len(),
3244 "executing statement"
3245 );
3246
3247 #[cfg(feature = "otel")]
3248 let instrumentation = self.instrumentation.clone();
3249 #[cfg(feature = "otel")]
3250 let mut span = instrumentation.query_span(sql);
3251
3252 let result = async {
3253 if params.is_empty() {
3254 self.send_sql_batch(sql).await?;
3256 } else {
3257 let rpc_params = Self::convert_params(params)?;
3259 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3260 self.send_rpc(&rpc).await?;
3261 }
3262
3263 self.read_execute_result().await
3265 }
3266 .await;
3267
3268 #[cfg(feature = "otel")]
3269 match &result {
3270 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3271 Err(e) => InstrumentationContext::record_error(&mut span, e),
3272 }
3273
3274 #[cfg(feature = "otel")]
3276 drop(span);
3277
3278 result
3279 }
3280
3281 pub async fn execute_with_timeout(
3308 &mut self,
3309 sql: &str,
3310 params: &[&(dyn crate::ToSql + Sync)],
3311 timeout_duration: std::time::Duration,
3312 ) -> Result<u64> {
3313 timeout(timeout_duration, self.execute(sql, params))
3314 .await
3315 .map_err(|_| Error::CommandTimeout)?
3316 }
3317
3318 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
3325 tracing::debug!("beginning transaction");
3326
3327 #[cfg(feature = "otel")]
3328 let instrumentation = self.instrumentation.clone();
3329 #[cfg(feature = "otel")]
3330 let mut span = instrumentation.transaction_span("BEGIN");
3331
3332 let result = async {
3334 self.send_sql_batch("BEGIN TRANSACTION").await?;
3335 self.read_transaction_begin_result().await
3336 }
3337 .await;
3338
3339 #[cfg(feature = "otel")]
3340 match &result {
3341 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3342 Err(e) => InstrumentationContext::record_error(&mut span, e),
3343 }
3344
3345 #[cfg(feature = "otel")]
3347 drop(span);
3348
3349 let transaction_descriptor = result?;
3350
3351 Ok(Client {
3352 config: self.config,
3353 _state: PhantomData,
3354 connection: self.connection,
3355 server_version: self.server_version,
3356 current_database: self.current_database,
3357 statement_cache: self.statement_cache,
3358 transaction_descriptor, #[cfg(feature = "otel")]
3360 instrumentation: self.instrumentation,
3361 })
3362 }
3363
3364 pub async fn begin_transaction_with_isolation(
3379 mut self,
3380 isolation_level: crate::transaction::IsolationLevel,
3381 ) -> Result<Client<InTransaction>> {
3382 tracing::debug!(
3383 isolation_level = %isolation_level.name(),
3384 "beginning transaction with isolation level"
3385 );
3386
3387 #[cfg(feature = "otel")]
3388 let instrumentation = self.instrumentation.clone();
3389 #[cfg(feature = "otel")]
3390 let mut span = instrumentation.transaction_span("BEGIN");
3391
3392 let result = async {
3394 self.send_sql_batch(isolation_level.as_sql()).await?;
3395 self.read_execute_result().await?;
3396
3397 self.send_sql_batch("BEGIN TRANSACTION").await?;
3399 self.read_transaction_begin_result().await
3400 }
3401 .await;
3402
3403 #[cfg(feature = "otel")]
3404 match &result {
3405 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3406 Err(e) => InstrumentationContext::record_error(&mut span, e),
3407 }
3408
3409 #[cfg(feature = "otel")]
3410 drop(span);
3411
3412 let transaction_descriptor = result?;
3413
3414 Ok(Client {
3415 config: self.config,
3416 _state: PhantomData,
3417 connection: self.connection,
3418 server_version: self.server_version,
3419 current_database: self.current_database,
3420 statement_cache: self.statement_cache,
3421 transaction_descriptor,
3422 #[cfg(feature = "otel")]
3423 instrumentation: self.instrumentation,
3424 })
3425 }
3426
3427 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
3432 tracing::debug!(sql = sql, "executing simple query");
3433
3434 self.send_sql_batch(sql).await?;
3436
3437 let _ = self.read_execute_result().await?;
3439
3440 Ok(())
3441 }
3442
3443 pub async fn close(self) -> Result<()> {
3445 tracing::debug!("closing connection");
3446 Ok(())
3447 }
3448
3449 #[must_use]
3451 pub fn database(&self) -> Option<&str> {
3452 self.config.database.as_deref()
3453 }
3454
3455 #[must_use]
3457 pub fn host(&self) -> &str {
3458 &self.config.host
3459 }
3460
3461 #[must_use]
3463 pub fn port(&self) -> u16 {
3464 self.config.port
3465 }
3466
3467 #[must_use]
3486 pub fn is_in_transaction(&self) -> bool {
3487 self.transaction_descriptor != 0
3488 }
3489
3490 #[must_use]
3512 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3513 let connection = self
3514 .connection
3515 .as_ref()
3516 .expect("connection should be present");
3517 match connection {
3518 ConnectionHandle::Tls(conn) => {
3519 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3520 }
3521 ConnectionHandle::TlsPrelogin(conn) => {
3522 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3523 }
3524 ConnectionHandle::Plain(conn) => {
3525 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3526 }
3527 }
3528 }
3529}
3530
3531impl Client<InTransaction> {
3532 pub async fn query<'a>(
3536 &'a mut self,
3537 sql: &str,
3538 params: &[&(dyn crate::ToSql + Sync)],
3539 ) -> Result<QueryStream<'a>> {
3540 tracing::debug!(
3541 sql = sql,
3542 params_count = params.len(),
3543 "executing query in transaction"
3544 );
3545
3546 #[cfg(feature = "otel")]
3547 let instrumentation = self.instrumentation.clone();
3548 #[cfg(feature = "otel")]
3549 let mut span = instrumentation.query_span(sql);
3550
3551 let result = async {
3552 if params.is_empty() {
3553 self.send_sql_batch(sql).await?;
3555 } else {
3556 let rpc_params = Self::convert_params(params)?;
3558 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3559 self.send_rpc(&rpc).await?;
3560 }
3561
3562 self.read_query_response().await
3564 }
3565 .await;
3566
3567 #[cfg(feature = "otel")]
3568 match &result {
3569 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3570 Err(e) => InstrumentationContext::record_error(&mut span, e),
3571 }
3572
3573 #[cfg(feature = "otel")]
3575 drop(span);
3576
3577 let (columns, rows) = result?;
3578 Ok(QueryStream::new(columns, rows))
3579 }
3580
3581 pub async fn execute(
3585 &mut self,
3586 sql: &str,
3587 params: &[&(dyn crate::ToSql + Sync)],
3588 ) -> Result<u64> {
3589 tracing::debug!(
3590 sql = sql,
3591 params_count = params.len(),
3592 "executing statement in transaction"
3593 );
3594
3595 #[cfg(feature = "otel")]
3596 let instrumentation = self.instrumentation.clone();
3597 #[cfg(feature = "otel")]
3598 let mut span = instrumentation.query_span(sql);
3599
3600 let result = async {
3601 if params.is_empty() {
3602 self.send_sql_batch(sql).await?;
3604 } else {
3605 let rpc_params = Self::convert_params(params)?;
3607 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3608 self.send_rpc(&rpc).await?;
3609 }
3610
3611 self.read_execute_result().await
3613 }
3614 .await;
3615
3616 #[cfg(feature = "otel")]
3617 match &result {
3618 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3619 Err(e) => InstrumentationContext::record_error(&mut span, e),
3620 }
3621
3622 #[cfg(feature = "otel")]
3624 drop(span);
3625
3626 result
3627 }
3628
3629 pub async fn query_with_timeout<'a>(
3633 &'a mut self,
3634 sql: &str,
3635 params: &[&(dyn crate::ToSql + Sync)],
3636 timeout_duration: std::time::Duration,
3637 ) -> Result<QueryStream<'a>> {
3638 timeout(timeout_duration, self.query(sql, params))
3639 .await
3640 .map_err(|_| Error::CommandTimeout)?
3641 }
3642
3643 pub async fn execute_with_timeout(
3647 &mut self,
3648 sql: &str,
3649 params: &[&(dyn crate::ToSql + Sync)],
3650 timeout_duration: std::time::Duration,
3651 ) -> Result<u64> {
3652 timeout(timeout_duration, self.execute(sql, params))
3653 .await
3654 .map_err(|_| Error::CommandTimeout)?
3655 }
3656
3657 pub async fn commit(mut self) -> Result<Client<Ready>> {
3661 tracing::debug!("committing transaction");
3662
3663 #[cfg(feature = "otel")]
3664 let instrumentation = self.instrumentation.clone();
3665 #[cfg(feature = "otel")]
3666 let mut span = instrumentation.transaction_span("COMMIT");
3667
3668 let result = async {
3670 self.send_sql_batch("COMMIT TRANSACTION").await?;
3671 self.read_execute_result().await
3672 }
3673 .await;
3674
3675 #[cfg(feature = "otel")]
3676 match &result {
3677 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3678 Err(e) => InstrumentationContext::record_error(&mut span, e),
3679 }
3680
3681 #[cfg(feature = "otel")]
3683 drop(span);
3684
3685 result?;
3686
3687 Ok(Client {
3688 config: self.config,
3689 _state: PhantomData,
3690 connection: self.connection,
3691 server_version: self.server_version,
3692 current_database: self.current_database,
3693 statement_cache: self.statement_cache,
3694 transaction_descriptor: 0, #[cfg(feature = "otel")]
3696 instrumentation: self.instrumentation,
3697 })
3698 }
3699
3700 pub async fn rollback(mut self) -> Result<Client<Ready>> {
3704 tracing::debug!("rolling back transaction");
3705
3706 #[cfg(feature = "otel")]
3707 let instrumentation = self.instrumentation.clone();
3708 #[cfg(feature = "otel")]
3709 let mut span = instrumentation.transaction_span("ROLLBACK");
3710
3711 let result = async {
3713 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
3714 self.read_execute_result().await
3715 }
3716 .await;
3717
3718 #[cfg(feature = "otel")]
3719 match &result {
3720 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3721 Err(e) => InstrumentationContext::record_error(&mut span, e),
3722 }
3723
3724 #[cfg(feature = "otel")]
3726 drop(span);
3727
3728 result?;
3729
3730 Ok(Client {
3731 config: self.config,
3732 _state: PhantomData,
3733 connection: self.connection,
3734 server_version: self.server_version,
3735 current_database: self.current_database,
3736 statement_cache: self.statement_cache,
3737 transaction_descriptor: 0, #[cfg(feature = "otel")]
3739 instrumentation: self.instrumentation,
3740 })
3741 }
3742
3743 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3760 validate_identifier(name)?;
3761 tracing::debug!(name = name, "creating savepoint");
3762
3763 let sql = format!("SAVE TRANSACTION {}", name);
3766 self.send_sql_batch(&sql).await?;
3767 self.read_execute_result().await?;
3768
3769 Ok(SavePoint::new(name.to_string()))
3770 }
3771
3772 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3787 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3788
3789 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3792 self.send_sql_batch(&sql).await?;
3793 self.read_execute_result().await?;
3794
3795 Ok(())
3796 }
3797
3798 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3804 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3805
3806 drop(savepoint);
3810 Ok(())
3811 }
3812
3813 #[must_use]
3817 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3818 let connection = self
3819 .connection
3820 .as_ref()
3821 .expect("connection should be present");
3822 match connection {
3823 ConnectionHandle::Tls(conn) => {
3824 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3825 }
3826 ConnectionHandle::TlsPrelogin(conn) => {
3827 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3828 }
3829 ConnectionHandle::Plain(conn) => {
3830 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3831 }
3832 }
3833 }
3834}
3835
3836fn validate_identifier(name: &str) -> Result<()> {
3838 use once_cell::sync::Lazy;
3839 use regex::Regex;
3840
3841 static IDENTIFIER_RE: Lazy<Regex> =
3842 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3843
3844 if name.is_empty() {
3845 return Err(Error::InvalidIdentifier(
3846 "identifier cannot be empty".into(),
3847 ));
3848 }
3849
3850 if !IDENTIFIER_RE.is_match(name) {
3851 return Err(Error::InvalidIdentifier(format!(
3852 "invalid identifier '{}': must start with letter/underscore, \
3853 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3854 name
3855 )));
3856 }
3857
3858 Ok(())
3859}
3860
3861impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3862 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3863 f.debug_struct("Client")
3864 .field("host", &self.config.host)
3865 .field("port", &self.config.port)
3866 .field("database", &self.config.database)
3867 .finish()
3868 }
3869}
3870
3871#[cfg(test)]
3872#[allow(clippy::unwrap_used, clippy::panic)]
3873mod tests {
3874 use super::*;
3875
3876 #[test]
3877 fn test_validate_identifier_valid() {
3878 assert!(validate_identifier("my_table").is_ok());
3879 assert!(validate_identifier("Table123").is_ok());
3880 assert!(validate_identifier("_private").is_ok());
3881 assert!(validate_identifier("sp_test").is_ok());
3882 }
3883
3884 #[test]
3885 fn test_validate_identifier_invalid() {
3886 assert!(validate_identifier("").is_err());
3887 assert!(validate_identifier("123abc").is_err());
3888 assert!(validate_identifier("table-name").is_err());
3889 assert!(validate_identifier("table name").is_err());
3890 assert!(validate_identifier("table;DROP TABLE users").is_err());
3891 }
3892
3893 fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
3902 let mut data = Vec::new();
3903 data.extend_from_slice(&total_len.to_le_bytes());
3905 for chunk in chunks {
3907 let len = chunk.len() as u32;
3908 data.extend_from_slice(&len.to_le_bytes());
3909 data.extend_from_slice(chunk);
3910 }
3911 data.extend_from_slice(&0u32.to_le_bytes());
3913 data
3914 }
3915
3916 #[test]
3917 fn test_parse_plp_nvarchar_simple() {
3918 let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
3920 let plp = make_plp_data(10, &[&utf16_data]);
3921 let mut buf: &[u8] = &plp;
3922
3923 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3924 match result {
3925 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3926 _ => panic!("expected String, got {:?}", result),
3927 }
3928 }
3929
3930 #[test]
3931 fn test_parse_plp_nvarchar_null() {
3932 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3934 let mut buf: &[u8] = &plp;
3935
3936 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3937 assert!(matches!(result, mssql_types::SqlValue::Null));
3938 }
3939
3940 #[test]
3941 fn test_parse_plp_nvarchar_empty() {
3942 let plp = make_plp_data(0, &[]);
3944 let mut buf: &[u8] = &plp;
3945
3946 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3947 match result {
3948 mssql_types::SqlValue::String(s) => assert_eq!(s, ""),
3949 _ => panic!("expected empty String"),
3950 }
3951 }
3952
3953 #[test]
3954 fn test_parse_plp_nvarchar_multi_chunk() {
3955 let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
3959 let mut buf: &[u8] = &plp;
3960
3961 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3962 match result {
3963 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3964 _ => panic!("expected String"),
3965 }
3966 }
3967
3968 #[test]
3969 fn test_parse_plp_varchar_simple() {
3970 let data = b"Hello World";
3971 let plp = make_plp_data(11, &[data]);
3972 let mut buf: &[u8] = &plp;
3973
3974 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
3975 match result {
3976 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello World"),
3977 _ => panic!("expected String"),
3978 }
3979 }
3980
3981 #[test]
3982 fn test_parse_plp_varchar_null() {
3983 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3984 let mut buf: &[u8] = &plp;
3985
3986 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
3987 assert!(matches!(result, mssql_types::SqlValue::Null));
3988 }
3989
3990 #[test]
3991 fn test_parse_plp_varbinary_simple() {
3992 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
3993 let plp = make_plp_data(5, &[&data]);
3994 let mut buf: &[u8] = &plp;
3995
3996 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
3997 match result {
3998 mssql_types::SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
3999 _ => panic!("expected Binary"),
4000 }
4001 }
4002
4003 #[test]
4004 fn test_parse_plp_varbinary_null() {
4005 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4006 let mut buf: &[u8] = &plp;
4007
4008 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4009 assert!(matches!(result, mssql_types::SqlValue::Null));
4010 }
4011
4012 #[test]
4013 fn test_parse_plp_varbinary_large() {
4014 let chunk1: Vec<u8> = (0..100u8).collect();
4016 let chunk2: Vec<u8> = (100..200u8).collect();
4017 let chunk3: Vec<u8> = (200..255u8).collect();
4018 let total_len = chunk1.len() + chunk2.len() + chunk3.len();
4019 let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
4020 let mut buf: &[u8] = &plp;
4021
4022 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4023 match result {
4024 mssql_types::SqlValue::Binary(b) => {
4025 assert_eq!(b.len(), 255);
4026 for (i, &byte) in b.iter().enumerate() {
4028 assert_eq!(byte, i as u8);
4029 }
4030 }
4031 _ => panic!("expected Binary"),
4032 }
4033 }
4034
4035 use tds_protocol::token::{ColumnData, TypeInfo};
4043 use tds_protocol::types::TypeId;
4044
4045 fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
4048 let mut data = Vec::new();
4049
4050 let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
4052 let byte_len = (utf16.len() * 2) as u16;
4053 data.extend_from_slice(&byte_len.to_le_bytes());
4054 for code_unit in utf16 {
4055 data.extend_from_slice(&code_unit.to_le_bytes());
4056 }
4057
4058 data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
4061
4062 data
4063 }
4064
4065 #[test]
4066 fn test_parse_row_nvarchar_then_int() {
4067 let raw_data = make_nvarchar_int_row("World", 42);
4069
4070 let col0 = ColumnData {
4072 name: "greeting".to_string(),
4073 type_id: TypeId::NVarChar,
4074 col_type: 0xE7,
4075 flags: 0x01,
4076 user_type: 0,
4077 type_info: TypeInfo {
4078 max_length: Some(10), precision: None,
4080 scale: None,
4081 collation: None,
4082 },
4083 };
4084
4085 let col1 = ColumnData {
4086 name: "number".to_string(),
4087 type_id: TypeId::IntN,
4088 col_type: 0x26,
4089 flags: 0x01,
4090 user_type: 0,
4091 type_info: TypeInfo {
4092 max_length: Some(4),
4093 precision: None,
4094 scale: None,
4095 collation: None,
4096 },
4097 };
4098
4099 let mut buf: &[u8] = &raw_data;
4100
4101 let value0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4103 match value0 {
4104 mssql_types::SqlValue::String(s) => assert_eq!(s, "World"),
4105 _ => panic!("expected String, got {:?}", value0),
4106 }
4107
4108 let value1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4110 match value1 {
4111 mssql_types::SqlValue::Int(i) => assert_eq!(i, 42),
4112 _ => panic!("expected Int, got {:?}", value1),
4113 }
4114
4115 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4117 }
4118
4119 #[test]
4120 fn test_parse_row_multiple_types() {
4121 let mut data = Vec::new();
4123
4124 data.extend_from_slice(&0xFFFFu16.to_le_bytes());
4126
4127 data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
4130
4131 let utf16: Vec<u16> = "Test".encode_utf16().collect();
4133 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4134 for code_unit in utf16 {
4135 data.extend_from_slice(&code_unit.to_le_bytes());
4136 }
4137
4138 data.push(0);
4140
4141 let col0 = ColumnData {
4143 name: "col0".to_string(),
4144 type_id: TypeId::NVarChar,
4145 col_type: 0xE7,
4146 flags: 0x01,
4147 user_type: 0,
4148 type_info: TypeInfo {
4149 max_length: Some(100),
4150 precision: None,
4151 scale: None,
4152 collation: None,
4153 },
4154 };
4155 let col1 = ColumnData {
4156 name: "col1".to_string(),
4157 type_id: TypeId::IntN,
4158 col_type: 0x26,
4159 flags: 0x01,
4160 user_type: 0,
4161 type_info: TypeInfo {
4162 max_length: Some(4),
4163 precision: None,
4164 scale: None,
4165 collation: None,
4166 },
4167 };
4168 let col2 = col0.clone();
4169 let col3 = col1.clone();
4170
4171 let mut buf: &[u8] = &data;
4172
4173 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4175 assert!(
4176 matches!(v0, mssql_types::SqlValue::Null),
4177 "col0 should be Null"
4178 );
4179
4180 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4181 assert!(
4182 matches!(v1, mssql_types::SqlValue::Int(123)),
4183 "col1 should be 123"
4184 );
4185
4186 let v2 = Client::<Ready>::parse_column_value(&mut buf, &col2).unwrap();
4187 match v2 {
4188 mssql_types::SqlValue::String(s) => assert_eq!(s, "Test"),
4189 _ => panic!("col2 should be 'Test'"),
4190 }
4191
4192 let v3 = Client::<Ready>::parse_column_value(&mut buf, &col3).unwrap();
4193 assert!(
4194 matches!(v3, mssql_types::SqlValue::Null),
4195 "col3 should be Null"
4196 );
4197
4198 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4200 }
4201
4202 #[test]
4203 fn test_parse_row_with_unicode() {
4204 let test_str = "Héllo Wörld 日本語";
4206 let mut data = Vec::new();
4207
4208 let utf16: Vec<u16> = test_str.encode_utf16().collect();
4210 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4211 for code_unit in utf16 {
4212 data.extend_from_slice(&code_unit.to_le_bytes());
4213 }
4214
4215 data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
4218
4219 let col0 = ColumnData {
4220 name: "text".to_string(),
4221 type_id: TypeId::NVarChar,
4222 col_type: 0xE7,
4223 flags: 0x01,
4224 user_type: 0,
4225 type_info: TypeInfo {
4226 max_length: Some(100),
4227 precision: None,
4228 scale: None,
4229 collation: None,
4230 },
4231 };
4232 let col1 = ColumnData {
4233 name: "num".to_string(),
4234 type_id: TypeId::IntN,
4235 col_type: 0x26,
4236 flags: 0x01,
4237 user_type: 0,
4238 type_info: TypeInfo {
4239 max_length: Some(8),
4240 precision: None,
4241 scale: None,
4242 collation: None,
4243 },
4244 };
4245
4246 let mut buf: &[u8] = &data;
4247
4248 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4249 match v0 {
4250 mssql_types::SqlValue::String(s) => assert_eq!(s, test_str),
4251 _ => panic!("expected String"),
4252 }
4253
4254 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4255 match v1 {
4256 mssql_types::SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
4257 _ => panic!("expected BigInt"),
4258 }
4259
4260 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4261 }
4262}