1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, Collation, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token,
19 TokenParser,
20};
21use tds_protocol::tvp::{
22 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
23 encode_tvp_decimal, encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar,
24 encode_tvp_varbinary,
25};
26use tokio::net::TcpStream;
27use tokio::time::timeout;
28
29use crate::config::Config;
30use crate::error::{Error, Result};
31#[cfg(feature = "otel")]
32use crate::instrumentation::InstrumentationContext;
33use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
34use crate::statement_cache::StatementCache;
35use crate::stream::{MultiResultStream, QueryStream};
36use crate::transaction::SavePoint;
37
38pub struct Client<S: ConnectionState> {
44 config: Config,
45 _state: PhantomData<S>,
46 connection: Option<ConnectionHandle>,
48 server_version: Option<u32>,
50 current_database: Option<String>,
52 statement_cache: StatementCache,
54 transaction_descriptor: u64,
58 #[cfg(feature = "otel")]
60 instrumentation: InstrumentationContext,
61}
62
63#[allow(dead_code)] enum ConnectionHandle {
71 Tls(Connection<TlsStream<TcpStream>>),
73 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
75 Plain(Connection<TcpStream>),
77}
78
79impl Client<Disconnected> {
80 pub async fn connect(config: Config) -> Result<Client<Ready>> {
91 let max_redirects = config.redirect.max_redirects;
92 let follow_redirects = config.redirect.follow_redirects;
93 let mut attempts = 0;
94 let mut current_config = config;
95
96 loop {
97 attempts += 1;
98 if attempts > max_redirects + 1 {
99 return Err(Error::TooManyRedirects { max: max_redirects });
100 }
101
102 match Self::try_connect(¤t_config).await {
103 Ok(client) => return Ok(client),
104 Err(Error::Routing { host, port }) => {
105 if !follow_redirects {
106 return Err(Error::Routing { host, port });
107 }
108 tracing::info!(
109 host = %host,
110 port = port,
111 attempt = attempts,
112 max_redirects = max_redirects,
113 "following Azure SQL routing redirect"
114 );
115 current_config = current_config.with_host(&host).with_port(port);
116 continue;
117 }
118 Err(e) => return Err(e),
119 }
120 }
121 }
122
123 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
124 tracing::info!(
125 host = %config.host,
126 port = config.port,
127 database = ?config.database,
128 "connecting to SQL Server"
129 );
130
131 let addr = format!("{}:{}", config.host, config.port);
132
133 tracing::debug!("establishing TCP connection to {}", addr);
135 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
136 .await
137 .map_err(|_| Error::ConnectTimeout)?
138 .map_err(|e| Error::Io(Arc::new(e)))?;
139
140 tcp_stream
142 .set_nodelay(true)
143 .map_err(|e| Error::Io(Arc::new(e)))?;
144
145 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
147
148 if tls_mode.is_tls_first() {
150 return Self::connect_tds_8(config, tcp_stream).await;
151 }
152
153 Self::connect_tds_7x(config, tcp_stream).await
155 }
156
157 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
161 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
162
163 let tls_config = TlsConfig::new()
165 .strict_mode(true)
166 .trust_server_certificate(config.trust_server_certificate);
167
168 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
169
170 let tls_stream = timeout(
172 config.timeouts.tls_timeout,
173 tls_connector.connect(tcp_stream, &config.host),
174 )
175 .await
176 .map_err(|_| Error::TlsTimeout)?
177 .map_err(|e| Error::Tls(e.to_string()))?;
178
179 tracing::debug!("TLS handshake completed (strict mode)");
180
181 let mut connection = Connection::new(tls_stream);
183
184 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
186 Self::send_prelogin(&mut connection, &prelogin).await?;
187 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
188
189 let login = Self::build_login7(config);
191 Self::send_login7(&mut connection, &login).await?;
192
193 let (server_version, current_database, routing) =
195 Self::process_login_response(&mut connection).await?;
196
197 if let Some((host, port)) = routing {
199 return Err(Error::Routing { host, port });
200 }
201
202 Ok(Client {
203 config: config.clone(),
204 _state: PhantomData,
205 connection: Some(ConnectionHandle::Tls(connection)),
206 server_version,
207 current_database: current_database.clone(),
208 statement_cache: StatementCache::with_default_size(),
209 transaction_descriptor: 0, #[cfg(feature = "otel")]
211 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
212 .with_database(current_database.unwrap_or_default()),
213 })
214 }
215
216 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
224 use bytes::BufMut;
225 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
226 use tokio::io::{AsyncReadExt, AsyncWriteExt};
227
228 tracing::debug!("using TDS 7.x flow (PreLogin first)");
229
230 let client_encryption = if config.encrypt {
233 EncryptionLevel::On
234 } else {
235 EncryptionLevel::Off
236 };
237 let prelogin = Self::build_prelogin(config, client_encryption);
238 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
239 let prelogin_bytes = prelogin.encode();
240
241 let header = PacketHeader::new(
243 PacketType::PreLogin,
244 PacketStatus::END_OF_MESSAGE,
245 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
246 );
247
248 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
249 header.encode(&mut packet_buf);
250 packet_buf.put_slice(&prelogin_bytes);
251
252 tcp_stream
253 .write_all(&packet_buf)
254 .await
255 .map_err(|e| Error::Io(Arc::new(e)))?;
256
257 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
259 tcp_stream
260 .read_exact(&mut header_buf)
261 .await
262 .map_err(|e| Error::Io(Arc::new(e)))?;
263
264 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
265 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
266
267 let mut response_buf = vec![0u8; payload_length];
268 tcp_stream
269 .read_exact(&mut response_buf)
270 .await
271 .map_err(|e| Error::Io(Arc::new(e)))?;
272
273 let prelogin_response =
274 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
275
276 let server_version = prelogin_response.version;
278 let client_version = config.tds_version;
279 tracing::debug!(
280 client_version = %client_version,
281 server_version = %server_version,
282 server_sql_version = server_version.sql_server_version_name(),
283 "TDS version negotiation"
284 );
285
286 if server_version < client_version && !client_version.is_tds_8() {
288 tracing::warn!(
289 client_version = %client_version,
290 server_version = %server_version,
291 "Server supports lower TDS version than requested. \
292 Connection will use server's version: {}",
293 server_version.sql_server_version_name()
294 );
295 }
296
297 if server_version.is_legacy() {
299 tracing::warn!(
300 server_version = %server_version,
301 "Server uses legacy TDS version ({}). \
302 Some features may not be available.",
303 server_version.sql_server_version_name()
304 );
305 }
306
307 let server_encryption = prelogin_response.encryption;
309 tracing::debug!(encryption = ?server_encryption, "server encryption level");
310
311 let negotiated_encryption = match (client_encryption, server_encryption) {
317 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
318 EncryptionLevel::NotSupported
319 }
320 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
321 (EncryptionLevel::On, EncryptionLevel::Off)
322 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
323 return Err(Error::Protocol(
324 "Server does not support requested encryption level".to_string(),
325 ));
326 }
327 _ => EncryptionLevel::On,
328 };
329
330 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
333
334 if use_tls {
335 let tls_config =
338 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
339
340 let tls_connector =
341 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
342
343 let mut tls_stream = timeout(
345 config.timeouts.tls_timeout,
346 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
347 )
348 .await
349 .map_err(|_| Error::TlsTimeout)?
350 .map_err(|e| Error::Tls(e.to_string()))?;
351
352 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
353
354 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
356
357 if login_only_encryption {
358 use tokio::io::AsyncWriteExt;
366
367 let login = Self::build_login7(config);
369 let login_payload = login.encode();
370
371 let max_packet = MAX_PACKET_SIZE;
373 let max_payload = max_packet - PACKET_HEADER_SIZE;
374 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
375 let total_chunks = chunks.len();
376
377 for (i, chunk) in chunks.into_iter().enumerate() {
378 let is_last = i == total_chunks - 1;
379 let status = if is_last {
380 PacketStatus::END_OF_MESSAGE
381 } else {
382 PacketStatus::NORMAL
383 };
384
385 let header = PacketHeader::new(
386 PacketType::Tds7Login,
387 status,
388 (PACKET_HEADER_SIZE + chunk.len()) as u16,
389 );
390
391 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
392 header.encode(&mut packet_buf);
393 packet_buf.put_slice(chunk);
394
395 tls_stream
396 .write_all(&packet_buf)
397 .await
398 .map_err(|e| Error::Io(Arc::new(e)))?;
399 }
400
401 tls_stream
403 .flush()
404 .await
405 .map_err(|e| Error::Io(Arc::new(e)))?;
406
407 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
408
409 let (wrapper, _client_conn) = tls_stream.into_inner();
413 let tcp_stream = wrapper.into_inner();
414
415 let mut connection = Connection::new(tcp_stream);
417
418 let (server_version, current_database, routing) =
420 Self::process_login_response(&mut connection).await?;
421
422 if let Some((host, port)) = routing {
424 return Err(Error::Routing { host, port });
425 }
426
427 Ok(Client {
429 config: config.clone(),
430 _state: PhantomData,
431 connection: Some(ConnectionHandle::Plain(connection)),
432 server_version,
433 current_database: current_database.clone(),
434 statement_cache: StatementCache::with_default_size(),
435 transaction_descriptor: 0, #[cfg(feature = "otel")]
437 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
438 .with_database(current_database.unwrap_or_default()),
439 })
440 } else {
441 let mut connection = Connection::new(tls_stream);
444
445 let login = Self::build_login7(config);
447 Self::send_login7(&mut connection, &login).await?;
448
449 let (server_version, current_database, routing) =
451 Self::process_login_response(&mut connection).await?;
452
453 if let Some((host, port)) = routing {
455 return Err(Error::Routing { host, port });
456 }
457
458 Ok(Client {
459 config: config.clone(),
460 _state: PhantomData,
461 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
462 server_version,
463 current_database: current_database.clone(),
464 statement_cache: StatementCache::with_default_size(),
465 transaction_descriptor: 0, #[cfg(feature = "otel")]
467 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
468 .with_database(current_database.unwrap_or_default()),
469 })
470 }
471 } else {
472 tracing::warn!(
474 "Connecting without TLS encryption. This is insecure and should only be \
475 used for development/testing on trusted networks."
476 );
477
478 let login = Self::build_login7(config);
480 let login_bytes = login.encode();
481 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
482 tracing::debug!(
484 "Login7 fixed header (94 bytes): {:02X?}",
485 &login_bytes[..login_bytes.len().min(94)]
486 );
487 if login_bytes.len() > 94 {
489 tracing::debug!(
490 "Login7 variable data ({} bytes): {:02X?}",
491 login_bytes.len() - 94,
492 &login_bytes[94..]
493 );
494 }
495
496 let login_header = PacketHeader::new(
498 PacketType::Tds7Login,
499 PacketStatus::END_OF_MESSAGE,
500 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
501 )
502 .with_packet_id(1);
503 let mut login_packet_buf =
504 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
505 login_header.encode(&mut login_packet_buf);
506 login_packet_buf.put_slice(&login_bytes);
507
508 tracing::debug!(
509 "Sending Login7 packet: {} bytes total, header: {:02X?}",
510 login_packet_buf.len(),
511 &login_packet_buf[..PACKET_HEADER_SIZE]
512 );
513 tcp_stream
514 .write_all(&login_packet_buf)
515 .await
516 .map_err(|e| Error::Io(Arc::new(e)))?;
517 tcp_stream
518 .flush()
519 .await
520 .map_err(|e| Error::Io(Arc::new(e)))?;
521 tracing::debug!("Login7 sent and flushed over raw TCP");
522
523 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
525 tcp_stream
526 .read_exact(&mut response_header_buf)
527 .await
528 .map_err(|e| Error::Io(Arc::new(e)))?;
529
530 let response_type = response_header_buf[0];
531 let response_length =
532 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
533 tracing::debug!(
534 "Response header: type={:#04X}, length={}",
535 response_type,
536 response_length
537 );
538
539 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
541 let mut response_payload = vec![0u8; payload_length];
542 tcp_stream
543 .read_exact(&mut response_payload)
544 .await
545 .map_err(|e| Error::Io(Arc::new(e)))?;
546 tracing::debug!(
547 "Response payload: {} bytes, first 32: {:02X?}",
548 response_payload.len(),
549 &response_payload[..response_payload.len().min(32)]
550 );
551
552 let connection = Connection::new(tcp_stream);
554
555 let response_bytes = bytes::Bytes::from(response_payload);
557 let mut parser = TokenParser::new(response_bytes);
558 let mut server_version = None;
559 let mut current_database = None;
560 let routing = None;
561
562 while let Some(token) = parser
563 .next_token()
564 .map_err(|e| Error::Protocol(e.to_string()))?
565 {
566 match token {
567 Token::LoginAck(ack) => {
568 tracing::info!(
569 version = ack.tds_version,
570 interface = ack.interface,
571 prog_name = %ack.prog_name,
572 "login acknowledged"
573 );
574 server_version = Some(ack.tds_version);
575 }
576 Token::EnvChange(env) => {
577 Self::process_env_change(&env, &mut current_database, &mut None);
578 }
579 Token::Error(err) => {
580 return Err(Error::Server {
581 number: err.number,
582 state: err.state,
583 class: err.class,
584 message: err.message.clone(),
585 server: if err.server.is_empty() {
586 None
587 } else {
588 Some(err.server.clone())
589 },
590 procedure: if err.procedure.is_empty() {
591 None
592 } else {
593 Some(err.procedure.clone())
594 },
595 line: err.line as u32,
596 });
597 }
598 Token::Info(info) => {
599 tracing::info!(
600 number = info.number,
601 message = %info.message,
602 "server info message"
603 );
604 }
605 Token::Done(done) => {
606 if done.status.error {
607 return Err(Error::Protocol("login failed".to_string()));
608 }
609 break;
610 }
611 _ => {}
612 }
613 }
614
615 if let Some((host, port)) = routing {
617 return Err(Error::Routing { host, port });
618 }
619
620 Ok(Client {
621 config: config.clone(),
622 _state: PhantomData,
623 connection: Some(ConnectionHandle::Plain(connection)),
624 server_version,
625 current_database: current_database.clone(),
626 statement_cache: StatementCache::with_default_size(),
627 transaction_descriptor: 0, #[cfg(feature = "otel")]
629 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
630 .with_database(current_database.unwrap_or_default()),
631 })
632 }
633 }
634
635 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
637 let version = if config.strict_mode {
639 tds_protocol::version::TdsVersion::V8_0
640 } else {
641 config.tds_version
642 };
643
644 let mut prelogin = PreLogin::new()
645 .with_version(version)
646 .with_encryption(encryption);
647
648 if config.mars {
649 prelogin = prelogin.with_mars(true);
650 }
651
652 if let Some(ref instance) = config.instance {
653 prelogin = prelogin.with_instance(instance);
654 }
655
656 prelogin
657 }
658
659 fn build_login7(config: &Config) -> Login7 {
661 let version = if config.strict_mode {
663 tds_protocol::version::TdsVersion::V8_0
664 } else {
665 config.tds_version
666 };
667
668 let mut login = Login7::new()
669 .with_tds_version(version)
670 .with_packet_size(config.packet_size as u32)
671 .with_app_name(&config.application_name)
672 .with_server_name(&config.host)
673 .with_hostname(&config.host);
674
675 if let Some(ref database) = config.database {
676 login = login.with_database(database);
677 }
678
679 match &config.credentials {
681 mssql_auth::Credentials::SqlServer { username, password } => {
682 login = login.with_sql_auth(username.as_ref(), password.as_ref());
683 }
684 _ => {}
686 }
687
688 login
689 }
690
691 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
693 where
694 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
695 {
696 let payload = prelogin.encode();
697 let max_packet = MAX_PACKET_SIZE;
698
699 connection
700 .send_message(PacketType::PreLogin, payload, max_packet)
701 .await
702 .map_err(|e| Error::Protocol(e.to_string()))
703 }
704
705 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
707 where
708 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
709 {
710 let message = connection
711 .read_message()
712 .await
713 .map_err(|e| Error::Protocol(e.to_string()))?
714 .ok_or(Error::ConnectionClosed)?;
715
716 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
717 }
718
719 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
721 where
722 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
723 {
724 let payload = login.encode();
725 let max_packet = MAX_PACKET_SIZE;
726
727 connection
728 .send_message(PacketType::Tds7Login, payload, max_packet)
729 .await
730 .map_err(|e| Error::Protocol(e.to_string()))
731 }
732
733 async fn process_login_response<T>(
737 connection: &mut Connection<T>,
738 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
739 where
740 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
741 {
742 let message = connection
743 .read_message()
744 .await
745 .map_err(|e| Error::Protocol(e.to_string()))?
746 .ok_or(Error::ConnectionClosed)?;
747
748 let response_bytes = message.payload;
749
750 let mut parser = TokenParser::new(response_bytes);
751 let mut server_version = None;
752 let mut database = None;
753 let mut routing = None;
754
755 while let Some(token) = parser
756 .next_token()
757 .map_err(|e| Error::Protocol(e.to_string()))?
758 {
759 match token {
760 Token::LoginAck(ack) => {
761 tracing::info!(
762 version = ack.tds_version,
763 interface = ack.interface,
764 prog_name = %ack.prog_name,
765 "login acknowledged"
766 );
767 server_version = Some(ack.tds_version);
768 }
769 Token::EnvChange(env) => {
770 Self::process_env_change(&env, &mut database, &mut routing);
771 }
772 Token::Error(err) => {
773 return Err(Error::Server {
774 number: err.number,
775 state: err.state,
776 class: err.class,
777 message: err.message.clone(),
778 server: if err.server.is_empty() {
779 None
780 } else {
781 Some(err.server.clone())
782 },
783 procedure: if err.procedure.is_empty() {
784 None
785 } else {
786 Some(err.procedure.clone())
787 },
788 line: err.line as u32,
789 });
790 }
791 Token::Info(info) => {
792 tracing::info!(
793 number = info.number,
794 message = %info.message,
795 "server info message"
796 );
797 }
798 Token::Done(done) => {
799 if done.status.error {
800 return Err(Error::Protocol("login failed".to_string()));
801 }
802 break;
803 }
804 _ => {}
805 }
806 }
807
808 Ok((server_version, database, routing))
809 }
810
811 fn process_env_change(
813 env: &EnvChange,
814 database: &mut Option<String>,
815 routing: &mut Option<(String, u16)>,
816 ) {
817 use tds_protocol::token::EnvChangeValue;
818
819 match env.env_type {
820 EnvChangeType::Database => {
821 if let EnvChangeValue::String(ref new_value) = env.new_value {
822 tracing::debug!(database = %new_value, "database changed");
823 *database = Some(new_value.clone());
824 }
825 }
826 EnvChangeType::Routing => {
827 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
828 tracing::info!(host = %host, port = port, "routing redirect received");
829 *routing = Some((host.clone(), port));
830 }
831 }
832 _ => {
833 if let EnvChangeValue::String(ref new_value) = env.new_value {
834 tracing::debug!(
835 env_type = ?env.env_type,
836 new_value = %new_value,
837 "environment change"
838 );
839 }
840 }
841 }
842 }
843}
844
845impl<S: ConnectionState> Client<S> {
847 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
855 use tds_protocol::token::EnvChangeValue;
856
857 match env.env_type {
858 EnvChangeType::BeginTransaction => {
859 if let EnvChangeValue::Binary(ref data) = env.new_value {
860 if data.len() >= 8 {
861 let descriptor = u64::from_le_bytes([
862 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
863 ]);
864 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
865 *transaction_descriptor = descriptor;
866 }
867 }
868 }
869 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
870 tracing::debug!(
871 env_type = ?env.env_type,
872 "transaction ended via raw SQL"
873 );
874 *transaction_descriptor = 0;
875 }
876 _ => {}
877 }
878 }
879
880 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
886 let payload =
887 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
888 let max_packet = self.config.packet_size as usize;
889
890 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
891
892 match connection {
893 ConnectionHandle::Tls(conn) => {
894 conn.send_message(PacketType::SqlBatch, payload, max_packet)
895 .await
896 .map_err(|e| Error::Protocol(e.to_string()))?;
897 }
898 ConnectionHandle::TlsPrelogin(conn) => {
899 conn.send_message(PacketType::SqlBatch, payload, max_packet)
900 .await
901 .map_err(|e| Error::Protocol(e.to_string()))?;
902 }
903 ConnectionHandle::Plain(conn) => {
904 conn.send_message(PacketType::SqlBatch, payload, max_packet)
905 .await
906 .map_err(|e| Error::Protocol(e.to_string()))?;
907 }
908 }
909
910 Ok(())
911 }
912
913 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
917 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
918 let max_packet = self.config.packet_size as usize;
919
920 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
921
922 match connection {
923 ConnectionHandle::Tls(conn) => {
924 conn.send_message(PacketType::Rpc, payload, max_packet)
925 .await
926 .map_err(|e| Error::Protocol(e.to_string()))?;
927 }
928 ConnectionHandle::TlsPrelogin(conn) => {
929 conn.send_message(PacketType::Rpc, payload, max_packet)
930 .await
931 .map_err(|e| Error::Protocol(e.to_string()))?;
932 }
933 ConnectionHandle::Plain(conn) => {
934 conn.send_message(PacketType::Rpc, payload, max_packet)
935 .await
936 .map_err(|e| Error::Protocol(e.to_string()))?;
937 }
938 }
939
940 Ok(())
941 }
942
943 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
945 use bytes::{BufMut, BytesMut};
946 use mssql_types::SqlValue;
947
948 params
949 .iter()
950 .enumerate()
951 .map(|(i, p)| {
952 let sql_value = p.to_sql()?;
953 let name = format!("@p{}", i + 1);
954
955 Ok(match sql_value {
956 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
957 SqlValue::Bool(v) => {
958 let mut buf = BytesMut::with_capacity(1);
959 buf.put_u8(if v { 1 } else { 0 });
960 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
961 }
962 SqlValue::TinyInt(v) => {
963 let mut buf = BytesMut::with_capacity(1);
964 buf.put_u8(v);
965 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
966 }
967 SqlValue::SmallInt(v) => {
968 let mut buf = BytesMut::with_capacity(2);
969 buf.put_i16_le(v);
970 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
971 }
972 SqlValue::Int(v) => RpcParam::int(&name, v),
973 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
974 SqlValue::Float(v) => {
975 let mut buf = BytesMut::with_capacity(4);
976 buf.put_f32_le(v);
977 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
978 }
979 SqlValue::Double(v) => {
980 let mut buf = BytesMut::with_capacity(8);
981 buf.put_f64_le(v);
982 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
983 }
984 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
985 SqlValue::Binary(ref b) => {
986 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
987 }
988 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
989 #[cfg(feature = "uuid")]
990 SqlValue::Uuid(u) => {
991 let bytes = u.as_bytes();
993 let mut buf = BytesMut::with_capacity(16);
994 buf.put_u32_le(u32::from_be_bytes([
996 bytes[0], bytes[1], bytes[2], bytes[3],
997 ]));
998 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
999 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
1000 buf.put_slice(&bytes[8..16]);
1001 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
1002 }
1003 #[cfg(feature = "decimal")]
1004 SqlValue::Decimal(d) => {
1005 RpcParam::nvarchar(&name, &d.to_string())
1007 }
1008 #[cfg(feature = "chrono")]
1009 SqlValue::Date(_)
1010 | SqlValue::Time(_)
1011 | SqlValue::DateTime(_)
1012 | SqlValue::DateTimeOffset(_) => {
1013 let s = match &sql_value {
1016 SqlValue::Date(d) => d.to_string(),
1017 SqlValue::Time(t) => t.to_string(),
1018 SqlValue::DateTime(dt) => dt.to_string(),
1019 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
1020 _ => unreachable!(),
1021 };
1022 RpcParam::nvarchar(&name, &s)
1023 }
1024 #[cfg(feature = "json")]
1025 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
1026 SqlValue::Tvp(ref tvp_data) => {
1027 Self::encode_tvp_param(&name, tvp_data)?
1029 }
1030 _ => {
1032 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
1033 from: sql_value.type_name().to_string(),
1034 to: "RPC parameter",
1035 }));
1036 }
1037 })
1038 })
1039 .collect()
1040 }
1041
1042 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
1047 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1049 .columns
1050 .iter()
1051 .map(|col| {
1052 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1053 TvpWireColumnDef {
1054 wire_type,
1055 flags: TvpColumnFlags {
1056 nullable: col.nullable,
1057 },
1058 }
1059 })
1060 .collect();
1061
1062 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1064
1065 let mut buf = BytesMut::with_capacity(256);
1067
1068 encoder.encode_metadata(&mut buf);
1070
1071 for row in &tvp_data.rows {
1073 encoder.encode_row(&mut buf, |row_buf| {
1074 for (col_idx, value) in row.iter().enumerate() {
1075 let wire_type = &wire_columns[col_idx].wire_type;
1076 Self::encode_tvp_value(value, wire_type, row_buf);
1077 }
1078 });
1079 }
1080
1081 encoder.encode_end(&mut buf);
1083
1084 let type_info = RpcTypeInfo {
1088 type_id: 0xF3, max_length: None,
1090 precision: None,
1091 scale: None,
1092 collation: None,
1093 };
1094
1095 Ok(RpcParam {
1096 name: name.to_string(),
1097 flags: tds_protocol::rpc::ParamFlags::default(),
1098 type_info,
1099 value: Some(buf.freeze()),
1100 })
1101 }
1102
1103 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1105 match col_type {
1106 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1107 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1108 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1109 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1110 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1111 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1112 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1113 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1114 precision: *precision,
1115 scale: *scale,
1116 },
1117 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1118 max_length: *max_length,
1119 },
1120 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1121 max_length: *max_length,
1122 },
1123 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1124 max_length: *max_length,
1125 },
1126 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1127 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1128 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1129 mssql_types::TvpColumnType::DateTime2 { scale } => {
1130 TvpWireType::DateTime2 { scale: *scale }
1131 }
1132 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1133 TvpWireType::DateTimeOffset { scale: *scale }
1134 }
1135 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1136 }
1137 }
1138
1139 fn encode_tvp_value(
1141 value: &mssql_types::SqlValue,
1142 wire_type: &TvpWireType,
1143 buf: &mut BytesMut,
1144 ) {
1145 use mssql_types::SqlValue;
1146
1147 match value {
1148 SqlValue::Null => {
1149 encode_tvp_null(wire_type, buf);
1150 }
1151 SqlValue::Bool(v) => {
1152 encode_tvp_bit(*v, buf);
1153 }
1154 SqlValue::TinyInt(v) => {
1155 encode_tvp_int(*v as i64, 1, buf);
1156 }
1157 SqlValue::SmallInt(v) => {
1158 encode_tvp_int(*v as i64, 2, buf);
1159 }
1160 SqlValue::Int(v) => {
1161 encode_tvp_int(*v as i64, 4, buf);
1162 }
1163 SqlValue::BigInt(v) => {
1164 encode_tvp_int(*v, 8, buf);
1165 }
1166 SqlValue::Float(v) => {
1167 encode_tvp_float(*v as f64, 4, buf);
1168 }
1169 SqlValue::Double(v) => {
1170 encode_tvp_float(*v, 8, buf);
1171 }
1172 SqlValue::String(s) => {
1173 let max_len = match wire_type {
1174 TvpWireType::NVarChar { max_length } => *max_length,
1175 _ => 4000,
1176 };
1177 encode_tvp_nvarchar(s, max_len, buf);
1178 }
1179 SqlValue::Binary(b) => {
1180 let max_len = match wire_type {
1181 TvpWireType::VarBinary { max_length } => *max_length,
1182 _ => 8000,
1183 };
1184 encode_tvp_varbinary(b, max_len, buf);
1185 }
1186 #[cfg(feature = "decimal")]
1187 SqlValue::Decimal(d) => {
1188 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1189 let mantissa = d.mantissa().unsigned_abs();
1190 encode_tvp_decimal(sign, mantissa, buf);
1191 }
1192 #[cfg(feature = "uuid")]
1193 SqlValue::Uuid(u) => {
1194 let bytes = u.as_bytes();
1195 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1196 }
1197 #[cfg(feature = "chrono")]
1198 SqlValue::Date(d) => {
1199 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1201 let days = d.signed_duration_since(base).num_days() as u32;
1202 tds_protocol::tvp::encode_tvp_date(days, buf);
1203 }
1204 #[cfg(feature = "chrono")]
1205 SqlValue::Time(t) => {
1206 use chrono::Timelike;
1207 let nanos =
1208 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1209 let intervals = nanos / 100;
1210 let scale = match wire_type {
1211 TvpWireType::Time { scale } => *scale,
1212 _ => 7,
1213 };
1214 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1215 }
1216 #[cfg(feature = "chrono")]
1217 SqlValue::DateTime(dt) => {
1218 use chrono::Timelike;
1219 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1221 + dt.time().nanosecond() as u64;
1222 let intervals = nanos / 100;
1223 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1225 let days = dt.date().signed_duration_since(base).num_days() as u32;
1226 let scale = match wire_type {
1227 TvpWireType::DateTime2 { scale } => *scale,
1228 _ => 7,
1229 };
1230 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1231 }
1232 #[cfg(feature = "chrono")]
1233 SqlValue::DateTimeOffset(dto) => {
1234 use chrono::{Offset, Timelike};
1235 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1237 + dto.time().nanosecond() as u64;
1238 let intervals = nanos / 100;
1239 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1241 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1242 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1244 let scale = match wire_type {
1245 TvpWireType::DateTimeOffset { scale } => *scale,
1246 _ => 7,
1247 };
1248 tds_protocol::tvp::encode_tvp_datetimeoffset(
1249 intervals,
1250 days,
1251 offset_minutes,
1252 scale,
1253 buf,
1254 );
1255 }
1256 #[cfg(feature = "json")]
1257 SqlValue::Json(j) => {
1258 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1260 }
1261 SqlValue::Xml(s) => {
1262 encode_tvp_nvarchar(s, 0xFFFF, buf);
1264 }
1265 SqlValue::Tvp(_) => {
1266 encode_tvp_null(wire_type, buf);
1268 }
1269 _ => {
1271 encode_tvp_null(wire_type, buf);
1272 }
1273 }
1274 }
1275
1276 async fn read_query_response(
1278 &mut self,
1279 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1280 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1281
1282 let message = match connection {
1283 ConnectionHandle::Tls(conn) => conn
1284 .read_message()
1285 .await
1286 .map_err(|e| Error::Protocol(e.to_string()))?,
1287 ConnectionHandle::TlsPrelogin(conn) => conn
1288 .read_message()
1289 .await
1290 .map_err(|e| Error::Protocol(e.to_string()))?,
1291 ConnectionHandle::Plain(conn) => conn
1292 .read_message()
1293 .await
1294 .map_err(|e| Error::Protocol(e.to_string()))?,
1295 }
1296 .ok_or(Error::ConnectionClosed)?;
1297
1298 let mut parser = TokenParser::new(message.payload);
1299 let mut columns: Vec<crate::row::Column> = Vec::new();
1300 let mut rows: Vec<crate::row::Row> = Vec::new();
1301 let mut protocol_metadata: Option<ColMetaData> = None;
1302
1303 loop {
1304 let token = parser
1306 .next_token_with_metadata(protocol_metadata.as_ref())
1307 .map_err(|e| Error::Protocol(e.to_string()))?;
1308
1309 let Some(token) = token else {
1310 break;
1311 };
1312
1313 match token {
1314 Token::ColMetaData(meta) => {
1315 rows.clear();
1318
1319 columns = meta
1320 .columns
1321 .iter()
1322 .enumerate()
1323 .map(|(i, col)| {
1324 let type_name = format!("{:?}", col.type_id);
1325 let mut column = crate::row::Column::new(&col.name, i, type_name)
1326 .with_nullable(col.flags & 0x01 != 0);
1327
1328 if let Some(max_len) = col.type_info.max_length {
1329 column = column.with_max_length(max_len);
1330 }
1331 if let (Some(prec), Some(scale)) =
1332 (col.type_info.precision, col.type_info.scale)
1333 {
1334 column = column.with_precision_scale(prec, scale);
1335 }
1336 if let Some(collation) = col.type_info.collation {
1339 column = column.with_collation(collation);
1340 }
1341 column
1342 })
1343 .collect();
1344
1345 tracing::debug!(columns = columns.len(), "received column metadata");
1346 protocol_metadata = Some(meta);
1347 }
1348 Token::Row(raw_row) => {
1349 if let Some(ref meta) = protocol_metadata {
1350 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1351 rows.push(row);
1352 }
1353 }
1354 Token::NbcRow(nbc_row) => {
1355 if let Some(ref meta) = protocol_metadata {
1356 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1357 rows.push(row);
1358 }
1359 }
1360 Token::Error(err) => {
1361 return Err(Error::Server {
1362 number: err.number,
1363 state: err.state,
1364 class: err.class,
1365 message: err.message.clone(),
1366 server: if err.server.is_empty() {
1367 None
1368 } else {
1369 Some(err.server.clone())
1370 },
1371 procedure: if err.procedure.is_empty() {
1372 None
1373 } else {
1374 Some(err.procedure.clone())
1375 },
1376 line: err.line as u32,
1377 });
1378 }
1379 Token::Done(done) => {
1380 if done.status.error {
1381 return Err(Error::Query("query failed".to_string()));
1382 }
1383 tracing::debug!(
1384 row_count = done.row_count,
1385 has_more = done.status.more,
1386 "query complete"
1387 );
1388 if !done.status.more {
1391 break;
1392 }
1393 }
1394 Token::DoneProc(done) => {
1395 if done.status.error {
1396 return Err(Error::Query("query failed".to_string()));
1397 }
1398 }
1399 Token::DoneInProc(done) => {
1400 if done.status.error {
1401 return Err(Error::Query("query failed".to_string()));
1402 }
1403 }
1404 Token::Info(info) => {
1405 tracing::debug!(
1406 number = info.number,
1407 message = %info.message,
1408 "server info message"
1409 );
1410 }
1411 Token::EnvChange(env) => {
1412 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1416 }
1417 _ => {}
1418 }
1419 }
1420
1421 tracing::debug!(
1422 columns = columns.len(),
1423 rows = rows.len(),
1424 "query response parsed"
1425 );
1426 Ok((columns, rows))
1427 }
1428
1429 fn convert_raw_row(
1433 raw: &RawRow,
1434 meta: &ColMetaData,
1435 columns: &[crate::row::Column],
1436 ) -> Result<crate::row::Row> {
1437 let mut values = Vec::with_capacity(meta.columns.len());
1438 let mut buf = raw.data.as_ref();
1439
1440 for col in &meta.columns {
1441 let value = Self::parse_column_value(&mut buf, col)?;
1442 values.push(value);
1443 }
1444
1445 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1446 }
1447
1448 fn convert_nbc_row(
1452 nbc: &NbcRow,
1453 meta: &ColMetaData,
1454 columns: &[crate::row::Column],
1455 ) -> Result<crate::row::Row> {
1456 let mut values = Vec::with_capacity(meta.columns.len());
1457 let mut buf = nbc.data.as_ref();
1458
1459 for (i, col) in meta.columns.iter().enumerate() {
1460 if nbc.is_null(i) {
1461 values.push(mssql_types::SqlValue::Null);
1462 } else {
1463 let value = Self::parse_column_value(&mut buf, col)?;
1464 values.push(value);
1465 }
1466 }
1467
1468 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1469 }
1470
1471 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1473 use bytes::Buf;
1474 use mssql_types::SqlValue;
1475 use tds_protocol::types::TypeId;
1476
1477 let value = match col.type_id {
1478 TypeId::Null => SqlValue::Null,
1480
1481 TypeId::Int1 => {
1483 if buf.remaining() < 1 {
1484 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1485 }
1486 SqlValue::TinyInt(buf.get_u8())
1487 }
1488 TypeId::Bit => {
1489 if buf.remaining() < 1 {
1490 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1491 }
1492 SqlValue::Bool(buf.get_u8() != 0)
1493 }
1494
1495 TypeId::Int2 => {
1497 if buf.remaining() < 2 {
1498 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1499 }
1500 SqlValue::SmallInt(buf.get_i16_le())
1501 }
1502
1503 TypeId::Int4 => {
1505 if buf.remaining() < 4 {
1506 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1507 }
1508 SqlValue::Int(buf.get_i32_le())
1509 }
1510 TypeId::Float4 => {
1511 if buf.remaining() < 4 {
1512 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1513 }
1514 SqlValue::Float(buf.get_f32_le())
1515 }
1516
1517 TypeId::Int8 => {
1519 if buf.remaining() < 8 {
1520 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1521 }
1522 SqlValue::BigInt(buf.get_i64_le())
1523 }
1524 TypeId::Float8 => {
1525 if buf.remaining() < 8 {
1526 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1527 }
1528 SqlValue::Double(buf.get_f64_le())
1529 }
1530 TypeId::Money => {
1531 if buf.remaining() < 8 {
1532 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1533 }
1534 let high = buf.get_i32_le();
1536 let low = buf.get_u32_le();
1537 let cents = ((high as i64) << 32) | (low as i64);
1538 let value = (cents as f64) / 10000.0;
1539 SqlValue::Double(value)
1540 }
1541 TypeId::Money4 => {
1542 if buf.remaining() < 4 {
1543 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1544 }
1545 let cents = buf.get_i32_le();
1546 let value = (cents as f64) / 10000.0;
1547 SqlValue::Double(value)
1548 }
1549
1550 TypeId::IntN => {
1552 if buf.remaining() < 1 {
1553 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1554 }
1555 let len = buf.get_u8();
1556 match len {
1557 0 => SqlValue::Null,
1558 1 => SqlValue::TinyInt(buf.get_u8()),
1559 2 => SqlValue::SmallInt(buf.get_i16_le()),
1560 4 => SqlValue::Int(buf.get_i32_le()),
1561 8 => SqlValue::BigInt(buf.get_i64_le()),
1562 _ => {
1563 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1564 }
1565 }
1566 }
1567 TypeId::FloatN => {
1568 if buf.remaining() < 1 {
1569 return Err(Error::Protocol(
1570 "unexpected EOF reading FloatN length".into(),
1571 ));
1572 }
1573 let len = buf.get_u8();
1574 match len {
1575 0 => SqlValue::Null,
1576 4 => SqlValue::Float(buf.get_f32_le()),
1577 8 => SqlValue::Double(buf.get_f64_le()),
1578 _ => {
1579 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1580 }
1581 }
1582 }
1583 TypeId::BitN => {
1584 if buf.remaining() < 1 {
1585 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1586 }
1587 let len = buf.get_u8();
1588 match len {
1589 0 => SqlValue::Null,
1590 1 => SqlValue::Bool(buf.get_u8() != 0),
1591 _ => {
1592 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1593 }
1594 }
1595 }
1596 TypeId::MoneyN => {
1597 if buf.remaining() < 1 {
1598 return Err(Error::Protocol(
1599 "unexpected EOF reading MoneyN length".into(),
1600 ));
1601 }
1602 let len = buf.get_u8();
1603 match len {
1604 0 => SqlValue::Null,
1605 4 => {
1606 let cents = buf.get_i32_le();
1607 SqlValue::Double((cents as f64) / 10000.0)
1608 }
1609 8 => {
1610 let high = buf.get_i32_le();
1611 let low = buf.get_u32_le();
1612 let cents = ((high as i64) << 32) | (low as i64);
1613 SqlValue::Double((cents as f64) / 10000.0)
1614 }
1615 _ => {
1616 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1617 }
1618 }
1619 }
1620 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1622 if buf.remaining() < 1 {
1623 return Err(Error::Protocol(
1624 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1625 ));
1626 }
1627 let len = buf.get_u8() as usize;
1628 if len == 0 {
1629 SqlValue::Null
1630 } else {
1631 if buf.remaining() < len {
1632 return Err(Error::Protocol(
1633 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1634 ));
1635 }
1636
1637 let sign = buf.get_u8();
1639 let mantissa_len = len - 1;
1640
1641 let mut mantissa_bytes = [0u8; 16];
1643 for i in 0..mantissa_len.min(16) {
1644 mantissa_bytes[i] = buf.get_u8();
1645 }
1646 for _ in 16..mantissa_len {
1648 buf.get_u8();
1649 }
1650
1651 let mantissa = u128::from_le_bytes(mantissa_bytes);
1652 let scale = col.type_info.scale.unwrap_or(0) as u32;
1653
1654 #[cfg(feature = "decimal")]
1655 {
1656 use rust_decimal::Decimal;
1657 if scale > 28 {
1660 let divisor = 10f64.powi(scale as i32);
1662 let value = (mantissa as f64) / divisor;
1663 let value = if sign == 0 { -value } else { value };
1664 SqlValue::Double(value)
1665 } else {
1666 let mut decimal =
1667 Decimal::from_i128_with_scale(mantissa as i128, scale);
1668 if sign == 0 {
1669 decimal.set_sign_negative(true);
1670 }
1671 SqlValue::Decimal(decimal)
1672 }
1673 }
1674
1675 #[cfg(not(feature = "decimal"))]
1676 {
1677 let divisor = 10f64.powi(scale as i32);
1679 let value = (mantissa as f64) / divisor;
1680 let value = if sign == 0 { -value } else { value };
1681 SqlValue::Double(value)
1682 }
1683 }
1684 }
1685
1686 TypeId::DateTimeN => {
1688 if buf.remaining() < 1 {
1689 return Err(Error::Protocol(
1690 "unexpected EOF reading DateTimeN length".into(),
1691 ));
1692 }
1693 let len = buf.get_u8() as usize;
1694 if len == 0 {
1695 SqlValue::Null
1696 } else if buf.remaining() < len {
1697 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1698 } else {
1699 match len {
1700 4 => {
1701 let days = buf.get_u16_le() as i64;
1703 let minutes = buf.get_u16_le() as u32;
1704 #[cfg(feature = "chrono")]
1705 {
1706 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1707 let date = base + chrono::Duration::days(days);
1708 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1709 minutes * 60,
1710 0,
1711 )
1712 .unwrap();
1713 SqlValue::DateTime(date.and_time(time))
1714 }
1715 #[cfg(not(feature = "chrono"))]
1716 {
1717 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1718 }
1719 }
1720 8 => {
1721 let days = buf.get_i32_le() as i64;
1723 let time_300ths = buf.get_u32_le() as u64;
1724 #[cfg(feature = "chrono")]
1725 {
1726 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1727 let date = base + chrono::Duration::days(days);
1728 let total_ms = (time_300ths * 1000) / 300;
1730 let secs = (total_ms / 1000) as u32;
1731 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1732 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1733 secs, nanos,
1734 )
1735 .unwrap();
1736 SqlValue::DateTime(date.and_time(time))
1737 }
1738 #[cfg(not(feature = "chrono"))]
1739 {
1740 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1741 }
1742 }
1743 _ => {
1744 return Err(Error::Protocol(format!(
1745 "invalid DateTimeN length: {len}"
1746 )));
1747 }
1748 }
1749 }
1750 }
1751
1752 TypeId::DateTime => {
1754 if buf.remaining() < 8 {
1755 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1756 }
1757 let days = buf.get_i32_le() as i64;
1758 let time_300ths = buf.get_u32_le() as u64;
1759 #[cfg(feature = "chrono")]
1760 {
1761 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1762 let date = base + chrono::Duration::days(days);
1763 let total_ms = (time_300ths * 1000) / 300;
1764 let secs = (total_ms / 1000) as u32;
1765 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1766 let time =
1767 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1768 SqlValue::DateTime(date.and_time(time))
1769 }
1770 #[cfg(not(feature = "chrono"))]
1771 {
1772 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1773 }
1774 }
1775
1776 TypeId::DateTime4 => {
1778 if buf.remaining() < 4 {
1779 return Err(Error::Protocol(
1780 "unexpected EOF reading SMALLDATETIME".into(),
1781 ));
1782 }
1783 let days = buf.get_u16_le() as i64;
1784 let minutes = buf.get_u16_le() as u32;
1785 #[cfg(feature = "chrono")]
1786 {
1787 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1788 let date = base + chrono::Duration::days(days);
1789 let time =
1790 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1791 .unwrap();
1792 SqlValue::DateTime(date.and_time(time))
1793 }
1794 #[cfg(not(feature = "chrono"))]
1795 {
1796 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1797 }
1798 }
1799
1800 TypeId::Date => {
1802 if buf.remaining() < 1 {
1803 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1804 }
1805 let len = buf.get_u8() as usize;
1806 if len == 0 {
1807 SqlValue::Null
1808 } else if len != 3 {
1809 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1810 } else if buf.remaining() < 3 {
1811 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1812 } else {
1813 let days = buf.get_u8() as u32
1815 | ((buf.get_u8() as u32) << 8)
1816 | ((buf.get_u8() as u32) << 16);
1817 #[cfg(feature = "chrono")]
1818 {
1819 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1820 let date = base + chrono::Duration::days(days as i64);
1821 SqlValue::Date(date)
1822 }
1823 #[cfg(not(feature = "chrono"))]
1824 {
1825 SqlValue::String(format!("DATE({days})"))
1826 }
1827 }
1828 }
1829
1830 TypeId::Time => {
1832 if buf.remaining() < 1 {
1833 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1834 }
1835 let len = buf.get_u8() as usize;
1836 if len == 0 {
1837 SqlValue::Null
1838 } else if buf.remaining() < len {
1839 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1840 } else {
1841 let scale = col.type_info.scale.unwrap_or(7);
1842 let mut time_bytes = [0u8; 8];
1843 for byte in time_bytes.iter_mut().take(len) {
1844 *byte = buf.get_u8();
1845 }
1846 let intervals = u64::from_le_bytes(time_bytes);
1847 #[cfg(feature = "chrono")]
1848 {
1849 let time = Self::intervals_to_time(intervals, scale);
1850 SqlValue::Time(time)
1851 }
1852 #[cfg(not(feature = "chrono"))]
1853 {
1854 SqlValue::String(format!("TIME({intervals})"))
1855 }
1856 }
1857 }
1858
1859 TypeId::DateTime2 => {
1861 if buf.remaining() < 1 {
1862 return Err(Error::Protocol(
1863 "unexpected EOF reading DATETIME2 length".into(),
1864 ));
1865 }
1866 let len = buf.get_u8() as usize;
1867 if len == 0 {
1868 SqlValue::Null
1869 } else if buf.remaining() < len {
1870 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1871 } else {
1872 let scale = col.type_info.scale.unwrap_or(7);
1873 let time_len = Self::time_bytes_for_scale(scale);
1874
1875 let mut time_bytes = [0u8; 8];
1877 for byte in time_bytes.iter_mut().take(time_len) {
1878 *byte = buf.get_u8();
1879 }
1880 let intervals = u64::from_le_bytes(time_bytes);
1881
1882 let days = buf.get_u8() as u32
1884 | ((buf.get_u8() as u32) << 8)
1885 | ((buf.get_u8() as u32) << 16);
1886
1887 #[cfg(feature = "chrono")]
1888 {
1889 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1890 let date = base + chrono::Duration::days(days as i64);
1891 let time = Self::intervals_to_time(intervals, scale);
1892 SqlValue::DateTime(date.and_time(time))
1893 }
1894 #[cfg(not(feature = "chrono"))]
1895 {
1896 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1897 }
1898 }
1899 }
1900
1901 TypeId::DateTimeOffset => {
1903 if buf.remaining() < 1 {
1904 return Err(Error::Protocol(
1905 "unexpected EOF reading DATETIMEOFFSET length".into(),
1906 ));
1907 }
1908 let len = buf.get_u8() as usize;
1909 if len == 0 {
1910 SqlValue::Null
1911 } else if buf.remaining() < len {
1912 return Err(Error::Protocol(
1913 "unexpected EOF reading DATETIMEOFFSET".into(),
1914 ));
1915 } else {
1916 let scale = col.type_info.scale.unwrap_or(7);
1917 let time_len = Self::time_bytes_for_scale(scale);
1918
1919 let mut time_bytes = [0u8; 8];
1921 for byte in time_bytes.iter_mut().take(time_len) {
1922 *byte = buf.get_u8();
1923 }
1924 let intervals = u64::from_le_bytes(time_bytes);
1925
1926 let days = buf.get_u8() as u32
1928 | ((buf.get_u8() as u32) << 8)
1929 | ((buf.get_u8() as u32) << 16);
1930
1931 let offset_minutes = buf.get_i16_le();
1933
1934 #[cfg(feature = "chrono")]
1935 {
1936 use chrono::TimeZone;
1937 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1938 let date = base + chrono::Duration::days(days as i64);
1939 let time = Self::intervals_to_time(intervals, scale);
1940 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1941 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1942 let datetime = offset
1943 .from_local_datetime(&date.and_time(time))
1944 .single()
1945 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1946 SqlValue::DateTimeOffset(datetime)
1947 }
1948 #[cfg(not(feature = "chrono"))]
1949 {
1950 SqlValue::String(format!(
1951 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1952 ))
1953 }
1954 }
1955 }
1956
1957 TypeId::Text => Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
1959
1960 TypeId::Char | TypeId::VarChar => {
1962 if buf.remaining() < 1 {
1963 return Err(Error::Protocol(
1964 "unexpected EOF reading legacy varchar length".into(),
1965 ));
1966 }
1967 let len = buf.get_u8();
1968 if len == 0xFF {
1969 SqlValue::Null
1970 } else if len == 0 {
1971 SqlValue::String(String::new())
1972 } else if buf.remaining() < len as usize {
1973 return Err(Error::Protocol(
1974 "unexpected EOF reading legacy varchar data".into(),
1975 ));
1976 } else {
1977 let data = &buf[..len as usize];
1978 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
1980 buf.advance(len as usize);
1981 SqlValue::String(s)
1982 }
1983 }
1984
1985 TypeId::BigVarChar | TypeId::BigChar => {
1987 if col.type_info.max_length == Some(0xFFFF) {
1989 Self::parse_plp_varchar(buf, col.type_info.collation.as_ref())?
1991 } else {
1992 if buf.remaining() < 2 {
1994 return Err(Error::Protocol(
1995 "unexpected EOF reading varchar length".into(),
1996 ));
1997 }
1998 let len = buf.get_u16_le();
1999 if len == 0xFFFF {
2000 SqlValue::Null
2001 } else if buf.remaining() < len as usize {
2002 return Err(Error::Protocol(
2003 "unexpected EOF reading varchar data".into(),
2004 ));
2005 } else {
2006 let data = &buf[..len as usize];
2007 let s = Self::decode_varchar_string(data, col.type_info.collation.as_ref());
2009 buf.advance(len as usize);
2010 SqlValue::String(s)
2011 }
2012 }
2013 }
2014
2015 TypeId::NText => Self::parse_plp_nvarchar(buf)?,
2017
2018 TypeId::NVarChar | TypeId::NChar => {
2020 if col.type_info.max_length == Some(0xFFFF) {
2022 Self::parse_plp_nvarchar(buf)?
2024 } else {
2025 if buf.remaining() < 2 {
2027 return Err(Error::Protocol(
2028 "unexpected EOF reading nvarchar length".into(),
2029 ));
2030 }
2031 let len = buf.get_u16_le();
2032 if len == 0xFFFF {
2033 SqlValue::Null
2034 } else if buf.remaining() < len as usize {
2035 return Err(Error::Protocol(
2036 "unexpected EOF reading nvarchar data".into(),
2037 ));
2038 } else {
2039 let data = &buf[..len as usize];
2040 let utf16: Vec<u16> = data
2042 .chunks_exact(2)
2043 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2044 .collect();
2045 let s = String::from_utf16(&utf16)
2046 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
2047 buf.advance(len as usize);
2048 SqlValue::String(s)
2049 }
2050 }
2051 }
2052
2053 TypeId::Image => Self::parse_plp_varbinary(buf)?,
2055
2056 TypeId::Binary | TypeId::VarBinary => {
2058 if buf.remaining() < 1 {
2059 return Err(Error::Protocol(
2060 "unexpected EOF reading legacy varbinary length".into(),
2061 ));
2062 }
2063 let len = buf.get_u8();
2064 if len == 0xFF {
2065 SqlValue::Null
2066 } else if len == 0 {
2067 SqlValue::Binary(bytes::Bytes::new())
2068 } else if buf.remaining() < len as usize {
2069 return Err(Error::Protocol(
2070 "unexpected EOF reading legacy varbinary data".into(),
2071 ));
2072 } else {
2073 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2074 buf.advance(len as usize);
2075 SqlValue::Binary(data)
2076 }
2077 }
2078
2079 TypeId::BigVarBinary | TypeId::BigBinary => {
2081 if col.type_info.max_length == Some(0xFFFF) {
2083 Self::parse_plp_varbinary(buf)?
2085 } else {
2086 if buf.remaining() < 2 {
2087 return Err(Error::Protocol(
2088 "unexpected EOF reading varbinary length".into(),
2089 ));
2090 }
2091 let len = buf.get_u16_le();
2092 if len == 0xFFFF {
2093 SqlValue::Null
2094 } else if buf.remaining() < len as usize {
2095 return Err(Error::Protocol(
2096 "unexpected EOF reading varbinary data".into(),
2097 ));
2098 } else {
2099 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2100 buf.advance(len as usize);
2101 SqlValue::Binary(data)
2102 }
2103 }
2104 }
2105
2106 TypeId::Xml => {
2108 match Self::parse_plp_nvarchar(buf)? {
2110 SqlValue::Null => SqlValue::Null,
2111 SqlValue::String(s) => SqlValue::Xml(s),
2112 _ => {
2113 return Err(Error::Protocol(
2114 "unexpected value type when parsing XML".into(),
2115 ));
2116 }
2117 }
2118 }
2119
2120 TypeId::Guid => {
2122 if buf.remaining() < 1 {
2123 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
2124 }
2125 let len = buf.get_u8();
2126 if len == 0 {
2127 SqlValue::Null
2128 } else if len != 16 {
2129 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
2130 } else if buf.remaining() < 16 {
2131 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
2132 } else {
2133 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
2135 buf.advance(16);
2136 SqlValue::Binary(data)
2137 }
2138 }
2139
2140 TypeId::Variant => Self::parse_sql_variant(buf)?,
2142
2143 TypeId::Udt => Self::parse_plp_varbinary(buf)?,
2145
2146 _ => {
2148 if buf.remaining() < 2 {
2150 return Err(Error::Protocol(format!(
2151 "unexpected EOF reading {:?}",
2152 col.type_id
2153 )));
2154 }
2155 let len = buf.get_u16_le();
2156 if len == 0xFFFF {
2157 SqlValue::Null
2158 } else if buf.remaining() < len as usize {
2159 return Err(Error::Protocol(format!(
2160 "unexpected EOF reading {:?} data",
2161 col.type_id
2162 )));
2163 } else {
2164 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2165 buf.advance(len as usize);
2166 SqlValue::Binary(data)
2167 }
2168 }
2169 };
2170
2171 Ok(value)
2172 }
2173
2174 fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2180 use bytes::Buf;
2181 use mssql_types::SqlValue;
2182
2183 if buf.remaining() < 8 {
2184 return Err(Error::Protocol(
2185 "unexpected EOF reading PLP total length".into(),
2186 ));
2187 }
2188
2189 let total_len = buf.get_u64_le();
2190 if total_len == 0xFFFFFFFFFFFFFFFF {
2191 return Ok(SqlValue::Null);
2192 }
2193
2194 let mut all_data = Vec::new();
2196 loop {
2197 if buf.remaining() < 4 {
2198 return Err(Error::Protocol(
2199 "unexpected EOF reading PLP chunk length".into(),
2200 ));
2201 }
2202 let chunk_len = buf.get_u32_le() as usize;
2203 if chunk_len == 0 {
2204 break; }
2206 if buf.remaining() < chunk_len {
2207 return Err(Error::Protocol(
2208 "unexpected EOF reading PLP chunk data".into(),
2209 ));
2210 }
2211 all_data.extend_from_slice(&buf[..chunk_len]);
2212 buf.advance(chunk_len);
2213 }
2214
2215 let utf16: Vec<u16> = all_data
2217 .chunks_exact(2)
2218 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2219 .collect();
2220 let s = String::from_utf16(&utf16)
2221 .map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
2222 Ok(SqlValue::String(s))
2223 }
2224
2225 #[allow(unused_variables)]
2231 fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
2232 if let Ok(s) = std::str::from_utf8(data) {
2234 return s.to_owned();
2235 }
2236
2237 #[cfg(feature = "encoding")]
2239 if let Some(coll) = collation {
2240 if let Some(encoding) = coll.encoding() {
2241 let (decoded, _, had_errors) = encoding.decode(data);
2242 if !had_errors {
2243 return decoded.into_owned();
2244 }
2245 }
2246 }
2247
2248 String::from_utf8_lossy(data).into_owned()
2250 }
2251
2252 fn parse_plp_varchar(
2254 buf: &mut &[u8],
2255 collation: Option<&Collation>,
2256 ) -> Result<mssql_types::SqlValue> {
2257 use bytes::Buf;
2258 use mssql_types::SqlValue;
2259
2260 if buf.remaining() < 8 {
2261 return Err(Error::Protocol(
2262 "unexpected EOF reading PLP total length".into(),
2263 ));
2264 }
2265
2266 let total_len = buf.get_u64_le();
2267 if total_len == 0xFFFFFFFFFFFFFFFF {
2268 return Ok(SqlValue::Null);
2269 }
2270
2271 let mut all_data = Vec::new();
2273 loop {
2274 if buf.remaining() < 4 {
2275 return Err(Error::Protocol(
2276 "unexpected EOF reading PLP chunk length".into(),
2277 ));
2278 }
2279 let chunk_len = buf.get_u32_le() as usize;
2280 if chunk_len == 0 {
2281 break; }
2283 if buf.remaining() < chunk_len {
2284 return Err(Error::Protocol(
2285 "unexpected EOF reading PLP chunk data".into(),
2286 ));
2287 }
2288 all_data.extend_from_slice(&buf[..chunk_len]);
2289 buf.advance(chunk_len);
2290 }
2291
2292 let s = Self::decode_varchar_string(&all_data, collation);
2294 Ok(SqlValue::String(s))
2295 }
2296
2297 fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2299 use bytes::Buf;
2300 use mssql_types::SqlValue;
2301
2302 if buf.remaining() < 8 {
2303 return Err(Error::Protocol(
2304 "unexpected EOF reading PLP total length".into(),
2305 ));
2306 }
2307
2308 let total_len = buf.get_u64_le();
2309 if total_len == 0xFFFFFFFFFFFFFFFF {
2310 return Ok(SqlValue::Null);
2311 }
2312
2313 let mut all_data = Vec::new();
2315 loop {
2316 if buf.remaining() < 4 {
2317 return Err(Error::Protocol(
2318 "unexpected EOF reading PLP chunk length".into(),
2319 ));
2320 }
2321 let chunk_len = buf.get_u32_le() as usize;
2322 if chunk_len == 0 {
2323 break; }
2325 if buf.remaining() < chunk_len {
2326 return Err(Error::Protocol(
2327 "unexpected EOF reading PLP chunk data".into(),
2328 ));
2329 }
2330 all_data.extend_from_slice(&buf[..chunk_len]);
2331 buf.advance(chunk_len);
2332 }
2333
2334 Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
2335 }
2336
2337 fn parse_sql_variant(buf: &mut &[u8]) -> Result<mssql_types::SqlValue> {
2346 use bytes::Buf;
2347 use mssql_types::SqlValue;
2348
2349 if buf.remaining() < 4 {
2351 return Err(Error::Protocol(
2352 "unexpected EOF reading SQL_VARIANT length".into(),
2353 ));
2354 }
2355 let total_len = buf.get_u32_le() as usize;
2356
2357 if total_len == 0 {
2358 return Ok(SqlValue::Null);
2359 }
2360
2361 if buf.remaining() < total_len {
2362 return Err(Error::Protocol(
2363 "unexpected EOF reading SQL_VARIANT data".into(),
2364 ));
2365 }
2366
2367 if total_len < 2 {
2369 return Err(Error::Protocol(
2370 "SQL_VARIANT too short for type info".into(),
2371 ));
2372 }
2373
2374 let base_type = buf.get_u8();
2375 let prop_count = buf.get_u8() as usize;
2376
2377 if buf.remaining() < prop_count {
2378 return Err(Error::Protocol(
2379 "unexpected EOF reading SQL_VARIANT properties".into(),
2380 ));
2381 }
2382
2383 let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
2385
2386 match base_type {
2389 0x30 => {
2391 buf.advance(prop_count);
2393 if data_len < 1 {
2394 return Ok(SqlValue::Null);
2395 }
2396 let v = buf.get_u8();
2397 Ok(SqlValue::TinyInt(v))
2398 }
2399 0x32 => {
2400 buf.advance(prop_count);
2402 if data_len < 1 {
2403 return Ok(SqlValue::Null);
2404 }
2405 let v = buf.get_u8();
2406 Ok(SqlValue::Bool(v != 0))
2407 }
2408 0x34 => {
2409 buf.advance(prop_count);
2411 if data_len < 2 {
2412 return Ok(SqlValue::Null);
2413 }
2414 let v = buf.get_i16_le();
2415 Ok(SqlValue::SmallInt(v))
2416 }
2417 0x38 => {
2418 buf.advance(prop_count);
2420 if data_len < 4 {
2421 return Ok(SqlValue::Null);
2422 }
2423 let v = buf.get_i32_le();
2424 Ok(SqlValue::Int(v))
2425 }
2426 0x7F => {
2427 buf.advance(prop_count);
2429 if data_len < 8 {
2430 return Ok(SqlValue::Null);
2431 }
2432 let v = buf.get_i64_le();
2433 Ok(SqlValue::BigInt(v))
2434 }
2435 0x6D => {
2436 let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2438 buf.advance(prop_count.saturating_sub(1));
2439
2440 if float_len == 4 && data_len >= 4 {
2441 let v = buf.get_f32_le();
2442 Ok(SqlValue::Float(v))
2443 } else if data_len >= 8 {
2444 let v = buf.get_f64_le();
2445 Ok(SqlValue::Double(v))
2446 } else {
2447 Ok(SqlValue::Null)
2448 }
2449 }
2450 0x6E => {
2451 let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2453 buf.advance(prop_count.saturating_sub(1));
2454
2455 if money_len == 4 && data_len >= 4 {
2456 let raw = buf.get_i32_le();
2457 let value = raw as f64 / 10000.0;
2458 Ok(SqlValue::Double(value))
2459 } else if data_len >= 8 {
2460 let high = buf.get_i32_le() as i64;
2461 let low = buf.get_u32_le() as i64;
2462 let raw = (high << 32) | low;
2463 let value = raw as f64 / 10000.0;
2464 Ok(SqlValue::Double(value))
2465 } else {
2466 Ok(SqlValue::Null)
2467 }
2468 }
2469 0x6F => {
2470 let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
2472 buf.advance(prop_count.saturating_sub(1));
2473
2474 #[cfg(feature = "chrono")]
2475 {
2476 use chrono::NaiveDate;
2477 if dt_len == 4 && data_len >= 4 {
2478 let days = buf.get_u16_le() as i64;
2480 let mins = buf.get_u16_le() as u32;
2481 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2482 .unwrap()
2483 .and_hms_opt(0, 0, 0)
2484 .unwrap();
2485 let dt = base
2486 + chrono::Duration::days(days)
2487 + chrono::Duration::minutes(mins as i64);
2488 Ok(SqlValue::DateTime(dt))
2489 } else if data_len >= 8 {
2490 let days = buf.get_i32_le() as i64;
2492 let ticks = buf.get_u32_le() as i64;
2493 let base = NaiveDate::from_ymd_opt(1900, 1, 1)
2494 .unwrap()
2495 .and_hms_opt(0, 0, 0)
2496 .unwrap();
2497 let millis = (ticks * 10) / 3;
2498 let dt = base
2499 + chrono::Duration::days(days)
2500 + chrono::Duration::milliseconds(millis);
2501 Ok(SqlValue::DateTime(dt))
2502 } else {
2503 Ok(SqlValue::Null)
2504 }
2505 }
2506 #[cfg(not(feature = "chrono"))]
2507 {
2508 buf.advance(data_len);
2509 Ok(SqlValue::Null)
2510 }
2511 }
2512 0x6A | 0x6C => {
2513 let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
2515 let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
2516 buf.advance(prop_count.saturating_sub(2));
2517
2518 if data_len < 1 {
2519 return Ok(SqlValue::Null);
2520 }
2521
2522 let sign = buf.get_u8();
2523 let mantissa_len = data_len - 1;
2524
2525 if mantissa_len > 16 {
2526 buf.advance(mantissa_len);
2528 return Ok(SqlValue::Null);
2529 }
2530
2531 let mut mantissa_bytes = [0u8; 16];
2532 for i in 0..mantissa_len.min(16) {
2533 mantissa_bytes[i] = buf.get_u8();
2534 }
2535 let mantissa = u128::from_le_bytes(mantissa_bytes);
2536
2537 #[cfg(feature = "decimal")]
2538 {
2539 use rust_decimal::Decimal;
2540 if scale > 28 {
2541 let divisor = 10f64.powi(scale as i32);
2543 let value = (mantissa as f64) / divisor;
2544 let value = if sign == 0 { -value } else { value };
2545 Ok(SqlValue::Double(value))
2546 } else {
2547 let mut decimal =
2548 Decimal::from_i128_with_scale(mantissa as i128, scale as u32);
2549 if sign == 0 {
2550 decimal.set_sign_negative(true);
2551 }
2552 Ok(SqlValue::Decimal(decimal))
2553 }
2554 }
2555 #[cfg(not(feature = "decimal"))]
2556 {
2557 let divisor = 10f64.powi(scale as i32);
2558 let value = (mantissa as f64) / divisor;
2559 let value = if sign == 0 { -value } else { value };
2560 Ok(SqlValue::Double(value))
2561 }
2562 }
2563 0x24 => {
2564 buf.advance(prop_count);
2566 if data_len < 16 {
2567 return Ok(SqlValue::Null);
2568 }
2569 let mut guid_bytes = [0u8; 16];
2570 for byte in &mut guid_bytes {
2571 *byte = buf.get_u8();
2572 }
2573 Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(&guid_bytes)))
2574 }
2575 0x28 => {
2576 buf.advance(prop_count);
2578 #[cfg(feature = "chrono")]
2579 {
2580 if data_len < 3 {
2581 return Ok(SqlValue::Null);
2582 }
2583 let mut date_bytes = [0u8; 4];
2584 date_bytes[0] = buf.get_u8();
2585 date_bytes[1] = buf.get_u8();
2586 date_bytes[2] = buf.get_u8();
2587 let days = u32::from_le_bytes(date_bytes);
2588 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
2589 let date = base + chrono::Duration::days(days as i64);
2590 Ok(SqlValue::Date(date))
2591 }
2592 #[cfg(not(feature = "chrono"))]
2593 {
2594 buf.advance(data_len);
2595 Ok(SqlValue::Null)
2596 }
2597 }
2598 0xA7 | 0x2F | 0x27 => {
2599 let collation = if prop_count >= 5 && buf.remaining() >= 5 {
2602 let lcid = buf.get_u32_le();
2603 let sort_id = buf.get_u8();
2604 buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
2606 } else {
2607 buf.advance(prop_count);
2608 None
2609 };
2610 if data_len == 0 {
2611 return Ok(SqlValue::String(String::new()));
2612 }
2613 let data = &buf[..data_len];
2614 let s = Self::decode_varchar_string(data, collation.as_ref());
2616 buf.advance(data_len);
2617 Ok(SqlValue::String(s))
2618 }
2619 0xE7 | 0xEF => {
2620 buf.advance(prop_count);
2622 if data_len == 0 {
2623 return Ok(SqlValue::String(String::new()));
2624 }
2625 let utf16: Vec<u16> = buf[..data_len]
2627 .chunks_exact(2)
2628 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
2629 .collect();
2630 buf.advance(data_len);
2631 let s = String::from_utf16(&utf16).map_err(|_| {
2632 Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into())
2633 })?;
2634 Ok(SqlValue::String(s))
2635 }
2636 0xA5 | 0x2D | 0x25 => {
2637 buf.advance(prop_count);
2639 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2640 buf.advance(data_len);
2641 Ok(SqlValue::Binary(data))
2642 }
2643 _ => {
2644 buf.advance(prop_count);
2646 let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
2647 buf.advance(data_len);
2648 Ok(SqlValue::Binary(data))
2649 }
2650 }
2651 }
2652
2653 fn time_bytes_for_scale(scale: u8) -> usize {
2655 match scale {
2656 0..=2 => 3,
2657 3..=4 => 4,
2658 5..=7 => 5,
2659 _ => 5, }
2661 }
2662
2663 #[cfg(feature = "chrono")]
2665 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2666 let nanos = match scale {
2676 0 => intervals * 1_000_000_000,
2677 1 => intervals * 100_000_000,
2678 2 => intervals * 10_000_000,
2679 3 => intervals * 1_000_000,
2680 4 => intervals * 100_000,
2681 5 => intervals * 10_000,
2682 6 => intervals * 1_000,
2683 7 => intervals * 100,
2684 _ => intervals * 100,
2685 };
2686
2687 let secs = (nanos / 1_000_000_000) as u32;
2688 let nano_part = (nanos % 1_000_000_000) as u32;
2689
2690 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2691 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2692 }
2693
2694 async fn read_execute_result(&mut self) -> Result<u64> {
2696 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2697
2698 let message = match connection {
2699 ConnectionHandle::Tls(conn) => conn
2700 .read_message()
2701 .await
2702 .map_err(|e| Error::Protocol(e.to_string()))?,
2703 ConnectionHandle::TlsPrelogin(conn) => conn
2704 .read_message()
2705 .await
2706 .map_err(|e| Error::Protocol(e.to_string()))?,
2707 ConnectionHandle::Plain(conn) => conn
2708 .read_message()
2709 .await
2710 .map_err(|e| Error::Protocol(e.to_string()))?,
2711 }
2712 .ok_or(Error::ConnectionClosed)?;
2713
2714 let mut parser = TokenParser::new(message.payload);
2715 let mut rows_affected = 0u64;
2716 let mut current_metadata: Option<ColMetaData> = None;
2717
2718 loop {
2719 let token = parser
2721 .next_token_with_metadata(current_metadata.as_ref())
2722 .map_err(|e| Error::Protocol(e.to_string()))?;
2723
2724 let Some(token) = token else {
2725 break;
2726 };
2727
2728 match token {
2729 Token::ColMetaData(meta) => {
2730 current_metadata = Some(meta);
2732 }
2733 Token::Row(_) | Token::NbcRow(_) => {
2734 }
2737 Token::Done(done) => {
2738 if done.status.error {
2739 return Err(Error::Query("execution failed".to_string()));
2740 }
2741 if done.status.count {
2742 rows_affected += done.row_count;
2744 }
2745 if !done.status.more {
2748 break;
2749 }
2750 }
2751 Token::DoneProc(done) => {
2752 if done.status.count {
2753 rows_affected += done.row_count;
2754 }
2755 }
2756 Token::DoneInProc(done) => {
2757 if done.status.count {
2758 rows_affected += done.row_count;
2759 }
2760 }
2761 Token::Error(err) => {
2762 return Err(Error::Server {
2763 number: err.number,
2764 state: err.state,
2765 class: err.class,
2766 message: err.message.clone(),
2767 server: if err.server.is_empty() {
2768 None
2769 } else {
2770 Some(err.server.clone())
2771 },
2772 procedure: if err.procedure.is_empty() {
2773 None
2774 } else {
2775 Some(err.procedure.clone())
2776 },
2777 line: err.line as u32,
2778 });
2779 }
2780 Token::Info(info) => {
2781 tracing::info!(
2782 number = info.number,
2783 message = %info.message,
2784 "server info message"
2785 );
2786 }
2787 Token::EnvChange(env) => {
2788 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
2792 }
2793 _ => {}
2794 }
2795 }
2796
2797 Ok(rows_affected)
2798 }
2799
2800 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2806 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2807
2808 let message = match connection {
2809 ConnectionHandle::Tls(conn) => conn
2810 .read_message()
2811 .await
2812 .map_err(|e| Error::Protocol(e.to_string()))?,
2813 ConnectionHandle::TlsPrelogin(conn) => conn
2814 .read_message()
2815 .await
2816 .map_err(|e| Error::Protocol(e.to_string()))?,
2817 ConnectionHandle::Plain(conn) => conn
2818 .read_message()
2819 .await
2820 .map_err(|e| Error::Protocol(e.to_string()))?,
2821 }
2822 .ok_or(Error::ConnectionClosed)?;
2823
2824 let mut parser = TokenParser::new(message.payload);
2825 let mut transaction_descriptor: u64 = 0;
2826
2827 loop {
2828 let token = parser
2829 .next_token()
2830 .map_err(|e| Error::Protocol(e.to_string()))?;
2831
2832 let Some(token) = token else {
2833 break;
2834 };
2835
2836 match token {
2837 Token::EnvChange(env) => {
2838 if env.env_type == EnvChangeType::BeginTransaction {
2839 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2842 {
2843 if data.len() >= 8 {
2844 transaction_descriptor = u64::from_le_bytes([
2845 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2846 data[7],
2847 ]);
2848 tracing::debug!(
2849 transaction_descriptor =
2850 format!("0x{:016X}", transaction_descriptor),
2851 "transaction begun"
2852 );
2853 }
2854 }
2855 }
2856 }
2857 Token::Done(done) => {
2858 if done.status.error {
2859 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2860 }
2861 break;
2862 }
2863 Token::Error(err) => {
2864 return Err(Error::Server {
2865 number: err.number,
2866 state: err.state,
2867 class: err.class,
2868 message: err.message.clone(),
2869 server: if err.server.is_empty() {
2870 None
2871 } else {
2872 Some(err.server.clone())
2873 },
2874 procedure: if err.procedure.is_empty() {
2875 None
2876 } else {
2877 Some(err.procedure.clone())
2878 },
2879 line: err.line as u32,
2880 });
2881 }
2882 Token::Info(info) => {
2883 tracing::info!(
2884 number = info.number,
2885 message = %info.message,
2886 "server info message"
2887 );
2888 }
2889 _ => {}
2890 }
2891 }
2892
2893 Ok(transaction_descriptor)
2894 }
2895}
2896
2897impl Client<Ready> {
2898 pub async fn query<'a>(
2923 &'a mut self,
2924 sql: &str,
2925 params: &[&(dyn crate::ToSql + Sync)],
2926 ) -> Result<QueryStream<'a>> {
2927 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
2928
2929 #[cfg(feature = "otel")]
2930 let instrumentation = self.instrumentation.clone();
2931 #[cfg(feature = "otel")]
2932 let mut span = instrumentation.query_span(sql);
2933
2934 let result = async {
2935 if params.is_empty() {
2936 self.send_sql_batch(sql).await?;
2938 } else {
2939 let rpc_params = Self::convert_params(params)?;
2941 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2942 self.send_rpc(&rpc).await?;
2943 }
2944
2945 self.read_query_response().await
2947 }
2948 .await;
2949
2950 #[cfg(feature = "otel")]
2951 match &result {
2952 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2953 Err(e) => InstrumentationContext::record_error(&mut span, e),
2954 }
2955
2956 #[cfg(feature = "otel")]
2958 drop(span);
2959
2960 let (columns, rows) = result?;
2961 Ok(QueryStream::new(columns, rows))
2962 }
2963
2964 pub async fn query_with_timeout<'a>(
2991 &'a mut self,
2992 sql: &str,
2993 params: &[&(dyn crate::ToSql + Sync)],
2994 timeout_duration: std::time::Duration,
2995 ) -> Result<QueryStream<'a>> {
2996 timeout(timeout_duration, self.query(sql, params))
2997 .await
2998 .map_err(|_| Error::CommandTimeout)?
2999 }
3000
3001 pub async fn query_multiple<'a>(
3028 &'a mut self,
3029 sql: &str,
3030 params: &[&(dyn crate::ToSql + Sync)],
3031 ) -> Result<MultiResultStream<'a>> {
3032 tracing::debug!(
3033 sql = sql,
3034 params_count = params.len(),
3035 "executing multi-result query"
3036 );
3037
3038 if params.is_empty() {
3039 self.send_sql_batch(sql).await?;
3041 } else {
3042 let rpc_params = Self::convert_params(params)?;
3044 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3045 self.send_rpc(&rpc).await?;
3046 }
3047
3048 let result_sets = self.read_multi_result_response().await?;
3050 Ok(MultiResultStream::new(result_sets))
3051 }
3052
3053 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
3055 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
3056
3057 let message = match connection {
3058 ConnectionHandle::Tls(conn) => conn
3059 .read_message()
3060 .await
3061 .map_err(|e| Error::Protocol(e.to_string()))?,
3062 ConnectionHandle::TlsPrelogin(conn) => conn
3063 .read_message()
3064 .await
3065 .map_err(|e| Error::Protocol(e.to_string()))?,
3066 ConnectionHandle::Plain(conn) => conn
3067 .read_message()
3068 .await
3069 .map_err(|e| Error::Protocol(e.to_string()))?,
3070 }
3071 .ok_or(Error::ConnectionClosed)?;
3072
3073 let mut parser = TokenParser::new(message.payload);
3074 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
3075 let mut current_columns: Vec<crate::row::Column> = Vec::new();
3076 let mut current_rows: Vec<crate::row::Row> = Vec::new();
3077 let mut protocol_metadata: Option<ColMetaData> = None;
3078
3079 loop {
3080 let token = parser
3081 .next_token_with_metadata(protocol_metadata.as_ref())
3082 .map_err(|e| Error::Protocol(e.to_string()))?;
3083
3084 let Some(token) = token else {
3085 break;
3086 };
3087
3088 match token {
3089 Token::ColMetaData(meta) => {
3090 if !current_columns.is_empty() {
3092 result_sets.push(crate::stream::ResultSet::new(
3093 std::mem::take(&mut current_columns),
3094 std::mem::take(&mut current_rows),
3095 ));
3096 }
3097
3098 current_columns = meta
3100 .columns
3101 .iter()
3102 .enumerate()
3103 .map(|(i, col)| {
3104 let type_name = format!("{:?}", col.type_id);
3105 let mut column = crate::row::Column::new(&col.name, i, type_name)
3106 .with_nullable(col.flags & 0x01 != 0);
3107
3108 if let Some(max_len) = col.type_info.max_length {
3109 column = column.with_max_length(max_len);
3110 }
3111 if let (Some(prec), Some(scale)) =
3112 (col.type_info.precision, col.type_info.scale)
3113 {
3114 column = column.with_precision_scale(prec, scale);
3115 }
3116 if let Some(collation) = col.type_info.collation {
3119 column = column.with_collation(collation);
3120 }
3121 column
3122 })
3123 .collect();
3124
3125 tracing::debug!(
3126 columns = current_columns.len(),
3127 result_set = result_sets.len(),
3128 "received column metadata for result set"
3129 );
3130 protocol_metadata = Some(meta);
3131 }
3132 Token::Row(raw_row) => {
3133 if let Some(ref meta) = protocol_metadata {
3134 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
3135 current_rows.push(row);
3136 }
3137 }
3138 Token::NbcRow(nbc_row) => {
3139 if let Some(ref meta) = protocol_metadata {
3140 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
3141 current_rows.push(row);
3142 }
3143 }
3144 Token::Error(err) => {
3145 return Err(Error::Server {
3146 number: err.number,
3147 state: err.state,
3148 class: err.class,
3149 message: err.message.clone(),
3150 server: if err.server.is_empty() {
3151 None
3152 } else {
3153 Some(err.server.clone())
3154 },
3155 procedure: if err.procedure.is_empty() {
3156 None
3157 } else {
3158 Some(err.procedure.clone())
3159 },
3160 line: err.line as u32,
3161 });
3162 }
3163 Token::Done(done) => {
3164 if done.status.error {
3165 return Err(Error::Query("query failed".to_string()));
3166 }
3167
3168 if !current_columns.is_empty() {
3170 result_sets.push(crate::stream::ResultSet::new(
3171 std::mem::take(&mut current_columns),
3172 std::mem::take(&mut current_rows),
3173 ));
3174 protocol_metadata = None;
3175 }
3176
3177 if !done.status.more {
3179 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
3180 break;
3181 }
3182 }
3183 Token::DoneInProc(done) => {
3184 if done.status.error {
3185 return Err(Error::Query("query failed".to_string()));
3186 }
3187
3188 if !current_columns.is_empty() {
3190 result_sets.push(crate::stream::ResultSet::new(
3191 std::mem::take(&mut current_columns),
3192 std::mem::take(&mut current_rows),
3193 ));
3194 protocol_metadata = None;
3195 }
3196
3197 if !done.status.more {
3199 }
3201 }
3202 Token::DoneProc(done) => {
3203 if done.status.error {
3204 return Err(Error::Query("query failed".to_string()));
3205 }
3206 }
3208 Token::Info(info) => {
3209 tracing::debug!(
3210 number = info.number,
3211 message = %info.message,
3212 "server info message"
3213 );
3214 }
3215 _ => {}
3216 }
3217 }
3218
3219 if !current_columns.is_empty() {
3221 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
3222 }
3223
3224 Ok(result_sets)
3225 }
3226
3227 pub async fn execute(
3231 &mut self,
3232 sql: &str,
3233 params: &[&(dyn crate::ToSql + Sync)],
3234 ) -> Result<u64> {
3235 tracing::debug!(
3236 sql = sql,
3237 params_count = params.len(),
3238 "executing statement"
3239 );
3240
3241 #[cfg(feature = "otel")]
3242 let instrumentation = self.instrumentation.clone();
3243 #[cfg(feature = "otel")]
3244 let mut span = instrumentation.query_span(sql);
3245
3246 let result = async {
3247 if params.is_empty() {
3248 self.send_sql_batch(sql).await?;
3250 } else {
3251 let rpc_params = Self::convert_params(params)?;
3253 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3254 self.send_rpc(&rpc).await?;
3255 }
3256
3257 self.read_execute_result().await
3259 }
3260 .await;
3261
3262 #[cfg(feature = "otel")]
3263 match &result {
3264 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3265 Err(e) => InstrumentationContext::record_error(&mut span, e),
3266 }
3267
3268 #[cfg(feature = "otel")]
3270 drop(span);
3271
3272 result
3273 }
3274
3275 pub async fn execute_with_timeout(
3302 &mut self,
3303 sql: &str,
3304 params: &[&(dyn crate::ToSql + Sync)],
3305 timeout_duration: std::time::Duration,
3306 ) -> Result<u64> {
3307 timeout(timeout_duration, self.execute(sql, params))
3308 .await
3309 .map_err(|_| Error::CommandTimeout)?
3310 }
3311
3312 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
3319 tracing::debug!("beginning transaction");
3320
3321 #[cfg(feature = "otel")]
3322 let instrumentation = self.instrumentation.clone();
3323 #[cfg(feature = "otel")]
3324 let mut span = instrumentation.transaction_span("BEGIN");
3325
3326 let result = async {
3328 self.send_sql_batch("BEGIN TRANSACTION").await?;
3329 self.read_transaction_begin_result().await
3330 }
3331 .await;
3332
3333 #[cfg(feature = "otel")]
3334 match &result {
3335 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3336 Err(e) => InstrumentationContext::record_error(&mut span, e),
3337 }
3338
3339 #[cfg(feature = "otel")]
3341 drop(span);
3342
3343 let transaction_descriptor = result?;
3344
3345 Ok(Client {
3346 config: self.config,
3347 _state: PhantomData,
3348 connection: self.connection,
3349 server_version: self.server_version,
3350 current_database: self.current_database,
3351 statement_cache: self.statement_cache,
3352 transaction_descriptor, #[cfg(feature = "otel")]
3354 instrumentation: self.instrumentation,
3355 })
3356 }
3357
3358 pub async fn begin_transaction_with_isolation(
3373 mut self,
3374 isolation_level: crate::transaction::IsolationLevel,
3375 ) -> Result<Client<InTransaction>> {
3376 tracing::debug!(
3377 isolation_level = %isolation_level.name(),
3378 "beginning transaction with isolation level"
3379 );
3380
3381 #[cfg(feature = "otel")]
3382 let instrumentation = self.instrumentation.clone();
3383 #[cfg(feature = "otel")]
3384 let mut span = instrumentation.transaction_span("BEGIN");
3385
3386 let result = async {
3388 self.send_sql_batch(isolation_level.as_sql()).await?;
3389 self.read_execute_result().await?;
3390
3391 self.send_sql_batch("BEGIN TRANSACTION").await?;
3393 self.read_transaction_begin_result().await
3394 }
3395 .await;
3396
3397 #[cfg(feature = "otel")]
3398 match &result {
3399 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3400 Err(e) => InstrumentationContext::record_error(&mut span, e),
3401 }
3402
3403 #[cfg(feature = "otel")]
3404 drop(span);
3405
3406 let transaction_descriptor = result?;
3407
3408 Ok(Client {
3409 config: self.config,
3410 _state: PhantomData,
3411 connection: self.connection,
3412 server_version: self.server_version,
3413 current_database: self.current_database,
3414 statement_cache: self.statement_cache,
3415 transaction_descriptor,
3416 #[cfg(feature = "otel")]
3417 instrumentation: self.instrumentation,
3418 })
3419 }
3420
3421 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
3426 tracing::debug!(sql = sql, "executing simple query");
3427
3428 self.send_sql_batch(sql).await?;
3430
3431 let _ = self.read_execute_result().await?;
3433
3434 Ok(())
3435 }
3436
3437 pub async fn close(self) -> Result<()> {
3439 tracing::debug!("closing connection");
3440 Ok(())
3441 }
3442
3443 #[must_use]
3445 pub fn database(&self) -> Option<&str> {
3446 self.config.database.as_deref()
3447 }
3448
3449 #[must_use]
3451 pub fn host(&self) -> &str {
3452 &self.config.host
3453 }
3454
3455 #[must_use]
3457 pub fn port(&self) -> u16 {
3458 self.config.port
3459 }
3460
3461 #[must_use]
3480 pub fn is_in_transaction(&self) -> bool {
3481 self.transaction_descriptor != 0
3482 }
3483
3484 #[must_use]
3506 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3507 let connection = self
3508 .connection
3509 .as_ref()
3510 .expect("connection should be present");
3511 match connection {
3512 ConnectionHandle::Tls(conn) => {
3513 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3514 }
3515 ConnectionHandle::TlsPrelogin(conn) => {
3516 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3517 }
3518 ConnectionHandle::Plain(conn) => {
3519 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3520 }
3521 }
3522 }
3523}
3524
3525impl Client<InTransaction> {
3526 pub async fn query<'a>(
3530 &'a mut self,
3531 sql: &str,
3532 params: &[&(dyn crate::ToSql + Sync)],
3533 ) -> Result<QueryStream<'a>> {
3534 tracing::debug!(
3535 sql = sql,
3536 params_count = params.len(),
3537 "executing query in transaction"
3538 );
3539
3540 #[cfg(feature = "otel")]
3541 let instrumentation = self.instrumentation.clone();
3542 #[cfg(feature = "otel")]
3543 let mut span = instrumentation.query_span(sql);
3544
3545 let result = async {
3546 if params.is_empty() {
3547 self.send_sql_batch(sql).await?;
3549 } else {
3550 let rpc_params = Self::convert_params(params)?;
3552 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3553 self.send_rpc(&rpc).await?;
3554 }
3555
3556 self.read_query_response().await
3558 }
3559 .await;
3560
3561 #[cfg(feature = "otel")]
3562 match &result {
3563 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3564 Err(e) => InstrumentationContext::record_error(&mut span, e),
3565 }
3566
3567 #[cfg(feature = "otel")]
3569 drop(span);
3570
3571 let (columns, rows) = result?;
3572 Ok(QueryStream::new(columns, rows))
3573 }
3574
3575 pub async fn execute(
3579 &mut self,
3580 sql: &str,
3581 params: &[&(dyn crate::ToSql + Sync)],
3582 ) -> Result<u64> {
3583 tracing::debug!(
3584 sql = sql,
3585 params_count = params.len(),
3586 "executing statement in transaction"
3587 );
3588
3589 #[cfg(feature = "otel")]
3590 let instrumentation = self.instrumentation.clone();
3591 #[cfg(feature = "otel")]
3592 let mut span = instrumentation.query_span(sql);
3593
3594 let result = async {
3595 if params.is_empty() {
3596 self.send_sql_batch(sql).await?;
3598 } else {
3599 let rpc_params = Self::convert_params(params)?;
3601 let rpc = RpcRequest::execute_sql(sql, rpc_params);
3602 self.send_rpc(&rpc).await?;
3603 }
3604
3605 self.read_execute_result().await
3607 }
3608 .await;
3609
3610 #[cfg(feature = "otel")]
3611 match &result {
3612 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
3613 Err(e) => InstrumentationContext::record_error(&mut span, e),
3614 }
3615
3616 #[cfg(feature = "otel")]
3618 drop(span);
3619
3620 result
3621 }
3622
3623 pub async fn query_with_timeout<'a>(
3627 &'a mut self,
3628 sql: &str,
3629 params: &[&(dyn crate::ToSql + Sync)],
3630 timeout_duration: std::time::Duration,
3631 ) -> Result<QueryStream<'a>> {
3632 timeout(timeout_duration, self.query(sql, params))
3633 .await
3634 .map_err(|_| Error::CommandTimeout)?
3635 }
3636
3637 pub async fn execute_with_timeout(
3641 &mut self,
3642 sql: &str,
3643 params: &[&(dyn crate::ToSql + Sync)],
3644 timeout_duration: std::time::Duration,
3645 ) -> Result<u64> {
3646 timeout(timeout_duration, self.execute(sql, params))
3647 .await
3648 .map_err(|_| Error::CommandTimeout)?
3649 }
3650
3651 pub async fn commit(mut self) -> Result<Client<Ready>> {
3655 tracing::debug!("committing transaction");
3656
3657 #[cfg(feature = "otel")]
3658 let instrumentation = self.instrumentation.clone();
3659 #[cfg(feature = "otel")]
3660 let mut span = instrumentation.transaction_span("COMMIT");
3661
3662 let result = async {
3664 self.send_sql_batch("COMMIT TRANSACTION").await?;
3665 self.read_execute_result().await
3666 }
3667 .await;
3668
3669 #[cfg(feature = "otel")]
3670 match &result {
3671 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3672 Err(e) => InstrumentationContext::record_error(&mut span, e),
3673 }
3674
3675 #[cfg(feature = "otel")]
3677 drop(span);
3678
3679 result?;
3680
3681 Ok(Client {
3682 config: self.config,
3683 _state: PhantomData,
3684 connection: self.connection,
3685 server_version: self.server_version,
3686 current_database: self.current_database,
3687 statement_cache: self.statement_cache,
3688 transaction_descriptor: 0, #[cfg(feature = "otel")]
3690 instrumentation: self.instrumentation,
3691 })
3692 }
3693
3694 pub async fn rollback(mut self) -> Result<Client<Ready>> {
3698 tracing::debug!("rolling back transaction");
3699
3700 #[cfg(feature = "otel")]
3701 let instrumentation = self.instrumentation.clone();
3702 #[cfg(feature = "otel")]
3703 let mut span = instrumentation.transaction_span("ROLLBACK");
3704
3705 let result = async {
3707 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
3708 self.read_execute_result().await
3709 }
3710 .await;
3711
3712 #[cfg(feature = "otel")]
3713 match &result {
3714 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3715 Err(e) => InstrumentationContext::record_error(&mut span, e),
3716 }
3717
3718 #[cfg(feature = "otel")]
3720 drop(span);
3721
3722 result?;
3723
3724 Ok(Client {
3725 config: self.config,
3726 _state: PhantomData,
3727 connection: self.connection,
3728 server_version: self.server_version,
3729 current_database: self.current_database,
3730 statement_cache: self.statement_cache,
3731 transaction_descriptor: 0, #[cfg(feature = "otel")]
3733 instrumentation: self.instrumentation,
3734 })
3735 }
3736
3737 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3754 validate_identifier(name)?;
3755 tracing::debug!(name = name, "creating savepoint");
3756
3757 let sql = format!("SAVE TRANSACTION {}", name);
3760 self.send_sql_batch(&sql).await?;
3761 self.read_execute_result().await?;
3762
3763 Ok(SavePoint::new(name.to_string()))
3764 }
3765
3766 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3781 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3782
3783 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3786 self.send_sql_batch(&sql).await?;
3787 self.read_execute_result().await?;
3788
3789 Ok(())
3790 }
3791
3792 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3798 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3799
3800 drop(savepoint);
3804 Ok(())
3805 }
3806
3807 #[must_use]
3811 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3812 let connection = self
3813 .connection
3814 .as_ref()
3815 .expect("connection should be present");
3816 match connection {
3817 ConnectionHandle::Tls(conn) => {
3818 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3819 }
3820 ConnectionHandle::TlsPrelogin(conn) => {
3821 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3822 }
3823 ConnectionHandle::Plain(conn) => {
3824 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3825 }
3826 }
3827 }
3828}
3829
3830fn validate_identifier(name: &str) -> Result<()> {
3832 use once_cell::sync::Lazy;
3833 use regex::Regex;
3834
3835 static IDENTIFIER_RE: Lazy<Regex> =
3836 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3837
3838 if name.is_empty() {
3839 return Err(Error::InvalidIdentifier(
3840 "identifier cannot be empty".into(),
3841 ));
3842 }
3843
3844 if !IDENTIFIER_RE.is_match(name) {
3845 return Err(Error::InvalidIdentifier(format!(
3846 "invalid identifier '{}': must start with letter/underscore, \
3847 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3848 name
3849 )));
3850 }
3851
3852 Ok(())
3853}
3854
3855impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3856 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3857 f.debug_struct("Client")
3858 .field("host", &self.config.host)
3859 .field("port", &self.config.port)
3860 .field("database", &self.config.database)
3861 .finish()
3862 }
3863}
3864
3865#[cfg(test)]
3866#[allow(clippy::unwrap_used, clippy::panic)]
3867mod tests {
3868 use super::*;
3869
3870 #[test]
3871 fn test_validate_identifier_valid() {
3872 assert!(validate_identifier("my_table").is_ok());
3873 assert!(validate_identifier("Table123").is_ok());
3874 assert!(validate_identifier("_private").is_ok());
3875 assert!(validate_identifier("sp_test").is_ok());
3876 }
3877
3878 #[test]
3879 fn test_validate_identifier_invalid() {
3880 assert!(validate_identifier("").is_err());
3881 assert!(validate_identifier("123abc").is_err());
3882 assert!(validate_identifier("table-name").is_err());
3883 assert!(validate_identifier("table name").is_err());
3884 assert!(validate_identifier("table;DROP TABLE users").is_err());
3885 }
3886
3887 fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
3896 let mut data = Vec::new();
3897 data.extend_from_slice(&total_len.to_le_bytes());
3899 for chunk in chunks {
3901 let len = chunk.len() as u32;
3902 data.extend_from_slice(&len.to_le_bytes());
3903 data.extend_from_slice(chunk);
3904 }
3905 data.extend_from_slice(&0u32.to_le_bytes());
3907 data
3908 }
3909
3910 #[test]
3911 fn test_parse_plp_nvarchar_simple() {
3912 let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
3914 let plp = make_plp_data(10, &[&utf16_data]);
3915 let mut buf: &[u8] = &plp;
3916
3917 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3918 match result {
3919 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3920 _ => panic!("expected String, got {:?}", result),
3921 }
3922 }
3923
3924 #[test]
3925 fn test_parse_plp_nvarchar_null() {
3926 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3928 let mut buf: &[u8] = &plp;
3929
3930 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3931 assert!(matches!(result, mssql_types::SqlValue::Null));
3932 }
3933
3934 #[test]
3935 fn test_parse_plp_nvarchar_empty() {
3936 let plp = make_plp_data(0, &[]);
3938 let mut buf: &[u8] = &plp;
3939
3940 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3941 match result {
3942 mssql_types::SqlValue::String(s) => assert_eq!(s, ""),
3943 _ => panic!("expected empty String"),
3944 }
3945 }
3946
3947 #[test]
3948 fn test_parse_plp_nvarchar_multi_chunk() {
3949 let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
3953 let mut buf: &[u8] = &plp;
3954
3955 let result = Client::<Ready>::parse_plp_nvarchar(&mut buf).unwrap();
3956 match result {
3957 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello"),
3958 _ => panic!("expected String"),
3959 }
3960 }
3961
3962 #[test]
3963 fn test_parse_plp_varchar_simple() {
3964 let data = b"Hello World";
3965 let plp = make_plp_data(11, &[data]);
3966 let mut buf: &[u8] = &plp;
3967
3968 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
3969 match result {
3970 mssql_types::SqlValue::String(s) => assert_eq!(s, "Hello World"),
3971 _ => panic!("expected String"),
3972 }
3973 }
3974
3975 #[test]
3976 fn test_parse_plp_varchar_null() {
3977 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
3978 let mut buf: &[u8] = &plp;
3979
3980 let result = Client::<Ready>::parse_plp_varchar(&mut buf, None).unwrap();
3981 assert!(matches!(result, mssql_types::SqlValue::Null));
3982 }
3983
3984 #[test]
3985 fn test_parse_plp_varbinary_simple() {
3986 let data = [0x01, 0x02, 0x03, 0x04, 0x05];
3987 let plp = make_plp_data(5, &[&data]);
3988 let mut buf: &[u8] = &plp;
3989
3990 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
3991 match result {
3992 mssql_types::SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
3993 _ => panic!("expected Binary"),
3994 }
3995 }
3996
3997 #[test]
3998 fn test_parse_plp_varbinary_null() {
3999 let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
4000 let mut buf: &[u8] = &plp;
4001
4002 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4003 assert!(matches!(result, mssql_types::SqlValue::Null));
4004 }
4005
4006 #[test]
4007 fn test_parse_plp_varbinary_large() {
4008 let chunk1: Vec<u8> = (0..100u8).collect();
4010 let chunk2: Vec<u8> = (100..200u8).collect();
4011 let chunk3: Vec<u8> = (200..255u8).collect();
4012 let total_len = chunk1.len() + chunk2.len() + chunk3.len();
4013 let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
4014 let mut buf: &[u8] = &plp;
4015
4016 let result = Client::<Ready>::parse_plp_varbinary(&mut buf).unwrap();
4017 match result {
4018 mssql_types::SqlValue::Binary(b) => {
4019 assert_eq!(b.len(), 255);
4020 for (i, &byte) in b.iter().enumerate() {
4022 assert_eq!(byte, i as u8);
4023 }
4024 }
4025 _ => panic!("expected Binary"),
4026 }
4027 }
4028
4029 use tds_protocol::token::{ColumnData, TypeInfo};
4037 use tds_protocol::types::TypeId;
4038
4039 fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
4042 let mut data = Vec::new();
4043
4044 let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
4046 let byte_len = (utf16.len() * 2) as u16;
4047 data.extend_from_slice(&byte_len.to_le_bytes());
4048 for code_unit in utf16 {
4049 data.extend_from_slice(&code_unit.to_le_bytes());
4050 }
4051
4052 data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
4055
4056 data
4057 }
4058
4059 #[test]
4060 fn test_parse_row_nvarchar_then_int() {
4061 let raw_data = make_nvarchar_int_row("World", 42);
4063
4064 let col0 = ColumnData {
4066 name: "greeting".to_string(),
4067 type_id: TypeId::NVarChar,
4068 col_type: 0xE7,
4069 flags: 0x01,
4070 user_type: 0,
4071 type_info: TypeInfo {
4072 max_length: Some(10), precision: None,
4074 scale: None,
4075 collation: None,
4076 },
4077 };
4078
4079 let col1 = ColumnData {
4080 name: "number".to_string(),
4081 type_id: TypeId::IntN,
4082 col_type: 0x26,
4083 flags: 0x01,
4084 user_type: 0,
4085 type_info: TypeInfo {
4086 max_length: Some(4),
4087 precision: None,
4088 scale: None,
4089 collation: None,
4090 },
4091 };
4092
4093 let mut buf: &[u8] = &raw_data;
4094
4095 let value0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4097 match value0 {
4098 mssql_types::SqlValue::String(s) => assert_eq!(s, "World"),
4099 _ => panic!("expected String, got {:?}", value0),
4100 }
4101
4102 let value1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4104 match value1 {
4105 mssql_types::SqlValue::Int(i) => assert_eq!(i, 42),
4106 _ => panic!("expected Int, got {:?}", value1),
4107 }
4108
4109 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4111 }
4112
4113 #[test]
4114 fn test_parse_row_multiple_types() {
4115 let mut data = Vec::new();
4117
4118 data.extend_from_slice(&0xFFFFu16.to_le_bytes());
4120
4121 data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
4124
4125 let utf16: Vec<u16> = "Test".encode_utf16().collect();
4127 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4128 for code_unit in utf16 {
4129 data.extend_from_slice(&code_unit.to_le_bytes());
4130 }
4131
4132 data.push(0);
4134
4135 let col0 = ColumnData {
4137 name: "col0".to_string(),
4138 type_id: TypeId::NVarChar,
4139 col_type: 0xE7,
4140 flags: 0x01,
4141 user_type: 0,
4142 type_info: TypeInfo {
4143 max_length: Some(100),
4144 precision: None,
4145 scale: None,
4146 collation: None,
4147 },
4148 };
4149 let col1 = ColumnData {
4150 name: "col1".to_string(),
4151 type_id: TypeId::IntN,
4152 col_type: 0x26,
4153 flags: 0x01,
4154 user_type: 0,
4155 type_info: TypeInfo {
4156 max_length: Some(4),
4157 precision: None,
4158 scale: None,
4159 collation: None,
4160 },
4161 };
4162 let col2 = col0.clone();
4163 let col3 = col1.clone();
4164
4165 let mut buf: &[u8] = &data;
4166
4167 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4169 assert!(
4170 matches!(v0, mssql_types::SqlValue::Null),
4171 "col0 should be Null"
4172 );
4173
4174 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4175 assert!(
4176 matches!(v1, mssql_types::SqlValue::Int(123)),
4177 "col1 should be 123"
4178 );
4179
4180 let v2 = Client::<Ready>::parse_column_value(&mut buf, &col2).unwrap();
4181 match v2 {
4182 mssql_types::SqlValue::String(s) => assert_eq!(s, "Test"),
4183 _ => panic!("col2 should be 'Test'"),
4184 }
4185
4186 let v3 = Client::<Ready>::parse_column_value(&mut buf, &col3).unwrap();
4187 assert!(
4188 matches!(v3, mssql_types::SqlValue::Null),
4189 "col3 should be Null"
4190 );
4191
4192 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4194 }
4195
4196 #[test]
4197 fn test_parse_row_with_unicode() {
4198 let test_str = "Héllo Wörld 日本語";
4200 let mut data = Vec::new();
4201
4202 let utf16: Vec<u16> = test_str.encode_utf16().collect();
4204 data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
4205 for code_unit in utf16 {
4206 data.extend_from_slice(&code_unit.to_le_bytes());
4207 }
4208
4209 data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
4212
4213 let col0 = ColumnData {
4214 name: "text".to_string(),
4215 type_id: TypeId::NVarChar,
4216 col_type: 0xE7,
4217 flags: 0x01,
4218 user_type: 0,
4219 type_info: TypeInfo {
4220 max_length: Some(100),
4221 precision: None,
4222 scale: None,
4223 collation: None,
4224 },
4225 };
4226 let col1 = ColumnData {
4227 name: "num".to_string(),
4228 type_id: TypeId::IntN,
4229 col_type: 0x26,
4230 flags: 0x01,
4231 user_type: 0,
4232 type_info: TypeInfo {
4233 max_length: Some(8),
4234 precision: None,
4235 scale: None,
4236 collation: None,
4237 },
4238 };
4239
4240 let mut buf: &[u8] = &data;
4241
4242 let v0 = Client::<Ready>::parse_column_value(&mut buf, &col0).unwrap();
4243 match v0 {
4244 mssql_types::SqlValue::String(s) => assert_eq!(s, test_str),
4245 _ => panic!("expected String"),
4246 }
4247
4248 let v1 = Client::<Ready>::parse_column_value(&mut buf, &col1).unwrap();
4249 match v1 {
4250 mssql_types::SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
4251 _ => panic!("expected BigInt"),
4252 }
4253
4254 assert_eq!(buf.len(), 0, "buffer should be fully consumed");
4255 }
4256}